JunchuanYu commited on
Commit
35ac044
·
1 Parent(s): 9c4d40a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -22
app.py CHANGED
@@ -4,29 +4,15 @@ import matplotlib
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
  import torch
 
7
  import glob
8
  import gradio as gr
9
  from PIL import Image
10
-
11
  from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
12
- from torchvision.utils import draw_segmentation_masks
13
 
14
  matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
15
-
16
- gr.Markdown(
17
- """
18
- # Segment Anything Model (SAM)
19
- ### A test on remote sensing data
20
-
21
- - Paper:[Link](https://scontent-fml2-1.xx.fbcdn.net/v/t39.2365-6/10000000_900554171201033_1602411987825904100_n.pdf?_nc_cat=100&ccb=1-7&_nc_sid=3c67a6&_nc_ohc=Ald4OYhL6hgAX-ZcGmS&_nc_ht=scontent-fml2-1.xx&oh=00_AfDk4FvyiDYeXgflANA2CbdV6HSS8CcJmrvjSfTqsgUmog&oe=643500E7)
22
- - Github:[https://github.com/facebookresearch/segment-anything](https://github.com/facebookresearch/segment-anything)
23
- - Dataset:https://ai.facebook.com/datasets/segment-anything-downloads/(https://ai.facebook.com/datasets/segment-anything-downloads/)
24
- - Official Demo:[https://segment-anything.com/demo](https://segment-anything.com/demo)
25
- """
26
- )
27
-
28
  #setup model
29
- sam_checkpoint = "meta-model.pth"
30
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
31
  model_type = "default"
32
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
@@ -62,17 +48,24 @@ def segment_image(image):
62
  plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
63
  output = cv2.imread('output.png')
64
  return Image.fromarray(output)
65
-
66
  with gr.Blocks() as demo:
67
- # gr.Markdown("## Segment-anything Demo")
68
-
 
 
 
 
 
 
 
 
69
  with gr.Row():
70
  image = gr.Image()
71
  image_output = gr.Image()
72
-
73
  segment_image_button = gr.Button("Segment")
74
  segment_image_button.click(segment_image, inputs=[image], outputs=image_output)
75
  gr.Examples(glob.glob('./images/*'),image,image_output,segment_image)
76
-
77
 
78
- demo.launch()
 
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
  import torch
7
+ import torchvision
8
  import glob
9
  import gradio as gr
10
  from PIL import Image
 
11
  from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
 
12
 
13
  matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  #setup model
15
+ sam_checkpoint = "sam_vit_h_4b8939.pth"
16
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
17
  model_type = "default"
18
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
 
48
  plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
49
  output = cv2.imread('output.png')
50
  return Image.fromarray(output)
51
+
52
  with gr.Blocks() as demo:
53
+ gr.Markdown(
54
+ """
55
+ # Segment Anything Model (SAM)
56
+ ### A test on remote sensing data
57
+ - Paper:[(https://arxiv.org/abs/2304.02643](https://arxiv.org/abs/2304.02643)
58
+ - Github:[https://github.com/facebookresearch/segment-anything](https://github.com/facebookresearch/segment-anything)
59
+ - Dataset:https://ai.facebook.com/datasets/segment-anything-downloads/(https://ai.facebook.com/datasets/segment-anything-downloads/)
60
+ - Official Demo:[https://segment-anything.com/demo](https://segment-anything.com/demo)
61
+ """
62
+ )
63
  with gr.Row():
64
  image = gr.Image()
65
  image_output = gr.Image()
66
+ print(image.shape)
67
  segment_image_button = gr.Button("Segment")
68
  segment_image_button.click(segment_image, inputs=[image], outputs=image_output)
69
  gr.Examples(glob.glob('./images/*'),image,image_output,segment_image)
 
70
 
71
+ demo.launch(debug=True)