mrtlive commited on
Commit
381c83f
1 Parent(s): 79b5495

cuda cpu if

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -12,11 +12,10 @@ from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_
12
 
13
  matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
14
 
15
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
16
 
17
  #setup model
18
  sam_checkpoint = "sam_vit_h_4b8939.pth"
19
- device = "cuda"
20
  model_type = "default"
21
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
22
  sam.to(device=device)
@@ -56,10 +55,10 @@ with gr.blocks() as demo:
56
  gr.MArkdown("## Segment-anything Demo")
57
 
58
  with gr.Row():
59
- image_input = gr.Image()
60
  image_output = gr.Image()
61
 
62
  segment_image_button = gr.Button("Segment Image")
63
- segment_image_button.click(segment_image, image_input, image_output)
64
 
65
- demo.launch()
 
12
 
13
  matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
14
 
 
15
 
16
  #setup model
17
  sam_checkpoint = "sam_vit_h_4b8939.pth"
18
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
19
  model_type = "default"
20
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
21
  sam.to(device=device)
 
55
  gr.MArkdown("## Segment-anything Demo")
56
 
57
  with gr.Row():
58
+ image = gr.Image()
59
  image_output = gr.Image()
60
 
61
  segment_image_button = gr.Button("Segment Image")
62
+ segment_image_button.click(segment_image, inputs=[image], outputs=image_output)
63
 
64
+ demo.launch()