Spaces:
Runtime error
Runtime error
JunchuanYu
commited on
Commit
·
35ac044
1
Parent(s):
9c4d40a
Update app.py
Browse files
app.py
CHANGED
@@ -4,29 +4,15 @@ import matplotlib
|
|
4 |
import matplotlib.pyplot as plt
|
5 |
import numpy as np
|
6 |
import torch
|
|
|
7 |
import glob
|
8 |
import gradio as gr
|
9 |
from PIL import Image
|
10 |
-
|
11 |
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
|
12 |
-
from torchvision.utils import draw_segmentation_masks
|
13 |
|
14 |
matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
|
15 |
-
|
16 |
-
gr.Markdown(
|
17 |
-
"""
|
18 |
-
# Segment Anything Model (SAM)
|
19 |
-
### A test on remote sensing data
|
20 |
-
|
21 |
-
- Paper:[Link](https://scontent-fml2-1.xx.fbcdn.net/v/t39.2365-6/10000000_900554171201033_1602411987825904100_n.pdf?_nc_cat=100&ccb=1-7&_nc_sid=3c67a6&_nc_ohc=Ald4OYhL6hgAX-ZcGmS&_nc_ht=scontent-fml2-1.xx&oh=00_AfDk4FvyiDYeXgflANA2CbdV6HSS8CcJmrvjSfTqsgUmog&oe=643500E7)
|
22 |
-
- Github:[https://github.com/facebookresearch/segment-anything](https://github.com/facebookresearch/segment-anything)
|
23 |
-
- Dataset:https://ai.facebook.com/datasets/segment-anything-downloads/(https://ai.facebook.com/datasets/segment-anything-downloads/)
|
24 |
-
- Official Demo:[https://segment-anything.com/demo](https://segment-anything.com/demo)
|
25 |
-
"""
|
26 |
-
)
|
27 |
-
|
28 |
#setup model
|
29 |
-
sam_checkpoint = "
|
30 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
|
31 |
model_type = "default"
|
32 |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
@@ -62,17 +48,24 @@ def segment_image(image):
|
|
62 |
plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
|
63 |
output = cv2.imread('output.png')
|
64 |
return Image.fromarray(output)
|
65 |
-
|
66 |
with gr.Blocks() as demo:
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
with gr.Row():
|
70 |
image = gr.Image()
|
71 |
image_output = gr.Image()
|
72 |
-
|
73 |
segment_image_button = gr.Button("Segment")
|
74 |
segment_image_button.click(segment_image, inputs=[image], outputs=image_output)
|
75 |
gr.Examples(glob.glob('./images/*'),image,image_output,segment_image)
|
76 |
-
|
77 |
|
78 |
-
demo.launch()
|
|
|
4 |
import matplotlib.pyplot as plt
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
+
import torchvision
|
8 |
import glob
|
9 |
import gradio as gr
|
10 |
from PIL import Image
|
|
|
11 |
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
|
|
|
12 |
|
13 |
matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
#setup model
|
15 |
+
sam_checkpoint = "sam_vit_h_4b8939.pth"
|
16 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
|
17 |
model_type = "default"
|
18 |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
|
|
48 |
plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
|
49 |
output = cv2.imread('output.png')
|
50 |
return Image.fromarray(output)
|
51 |
+
|
52 |
with gr.Blocks() as demo:
|
53 |
+
gr.Markdown(
|
54 |
+
"""
|
55 |
+
# Segment Anything Model (SAM)
|
56 |
+
### A test on remote sensing data
|
57 |
+
- Paper:[(https://arxiv.org/abs/2304.02643](https://arxiv.org/abs/2304.02643)
|
58 |
+
- Github:[https://github.com/facebookresearch/segment-anything](https://github.com/facebookresearch/segment-anything)
|
59 |
+
- Dataset:https://ai.facebook.com/datasets/segment-anything-downloads/(https://ai.facebook.com/datasets/segment-anything-downloads/)
|
60 |
+
- Official Demo:[https://segment-anything.com/demo](https://segment-anything.com/demo)
|
61 |
+
"""
|
62 |
+
)
|
63 |
with gr.Row():
|
64 |
image = gr.Image()
|
65 |
image_output = gr.Image()
|
66 |
+
print(image.shape)
|
67 |
segment_image_button = gr.Button("Segment")
|
68 |
segment_image_button.click(segment_image, inputs=[image], outputs=image_output)
|
69 |
gr.Examples(glob.glob('./images/*'),image,image_output,segment_image)
|
|
|
70 |
|
71 |
+
demo.launch(debug=True)
|