Spaces:
Runtime error
Runtime error
JunchuanYu
commited on
Commit
·
e1466f1
1
Parent(s):
9d05532
Update app.py
Browse files
app.py
CHANGED
@@ -9,76 +9,144 @@ import glob
|
|
9 |
import gradio as gr
|
10 |
from PIL import Image
|
11 |
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
17 |
matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
|
18 |
#setup model
|
19 |
sam_checkpoint = "sam_vit_h_4b8939.pth"
|
20 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
|
21 |
-
model_type = "
|
22 |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
23 |
sam.to(device=device)
|
24 |
-
mask_generator = SamAutomaticMaskGenerator(sam)
|
25 |
predictor = SamPredictor(sam)
|
|
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
color = []
|
35 |
-
for ann in sorted_anns:
|
36 |
-
m = ann['segmentation']
|
37 |
-
img = np.ones((m.shape[0], m.shape[1], 3))
|
38 |
-
color_mask = np.random.random((1, 3)).tolist()[0]
|
39 |
-
for i in range(3):
|
40 |
-
img[:,:,i] = color_mask[i]
|
41 |
-
ax.imshow(np.dstack((img, m*0.35)))
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
return Image.fromarray(output)
|
55 |
|
56 |
-
with gr.Blocks() as demo:
|
57 |
-
gr.Markdown(
|
58 |
-
"""
|
59 |
-
# Segment Anything Model (SAM)
|
60 |
-
### A test on remote sensing data
|
61 |
-
- Paper:[(https://arxiv.org/abs/2304.02643](https://arxiv.org/abs/2304.02643)
|
62 |
-
- Github:[https://github.com/facebookresearch/segment-anything](https://github.com/facebookresearch/segment-anything)
|
63 |
-
- Dataset:https://ai.facebook.com/datasets/segment-anything-downloads/(https://ai.facebook.com/datasets/segment-anything-downloads/)
|
64 |
-
- Official Demo:[https://segment-anything.com/demo](https://segment-anything.com/demo)
|
65 |
-
"""
|
66 |
-
)
|
67 |
with gr.Row():
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
import gradio as gr
|
10 |
from PIL import Image
|
11 |
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
|
12 |
+
import logging
|
13 |
+
from huggingface_hub import login
|
14 |
+
from huggingface_hub import Repository
|
15 |
+
# os.system("python -m pip install --upgrade pip")
|
16 |
+
# os.system("pip uninstall -y gradio")
|
17 |
+
# os.system("pip install gradio==3.27.0")
|
18 |
|
19 |
+
login(token = os.environ['HUB_TOKEN'])
|
20 |
+
repo = Repository(
|
21 |
+
local_dir="files",
|
22 |
+
repo_type="dataset",
|
23 |
+
clone_from="JunchuanYu/files_for_segmentRS",
|
24 |
+
token=True
|
25 |
+
)
|
26 |
+
repo.git_pull()
|
27 |
matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
|
28 |
#setup model
|
29 |
sam_checkpoint = "sam_vit_h_4b8939.pth"
|
30 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
|
31 |
+
model_type = "vit_h"
|
32 |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
33 |
sam.to(device=device)
|
|
|
34 |
predictor = SamPredictor(sam)
|
35 |
+
logging.basicConfig(filename="app.log", level=logging.INFO)
|
36 |
|
37 |
+
with gr.Blocks(theme='gradio/soft') as demo:
|
38 |
+
gr.Markdown(title)
|
39 |
+
with gr.Accordion("Instructions For User 👉", open=False):
|
40 |
+
gr.Markdown(description)
|
41 |
+
x=gr.State(value=[])
|
42 |
+
y=gr.State(value=[])
|
43 |
+
label=gr.State(value=[])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
+
with gr.Row():
|
46 |
+
with gr.Column():
|
47 |
+
mode=gr.inputs.Radio(['Positive','Negative'], type="value",default='Positive',label='Types of sampling methods')
|
48 |
+
with gr.Column():
|
49 |
+
clear_bn=gr.Button("Clear Selection")
|
50 |
+
interseg_button = gr.Button("Interactive Segment",variant='primary')
|
51 |
+
with gr.Row():
|
52 |
+
input_img = gr.Image(label="Input")
|
53 |
+
gallery = gr.Image(label="Selected Sample Points")
|
54 |
+
|
55 |
+
input_img.select(get_select_coords, [input_img, mode,x,y,label], [gallery,x,y,label])
|
|
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
with gr.Row():
|
58 |
+
output_img = gr.Image(label="Result")
|
59 |
+
mask_img = gr.Image(label="Mask")
|
60 |
+
with gr.Row():
|
61 |
+
with gr.Column():
|
62 |
+
pred_iou_thresh = gr.Slider(minimum=0.8, maximum=1, value=0.90, step=0.01, interactive=True, label="Prediction Thresh")
|
63 |
+
with gr.Column():
|
64 |
+
points_per_side = gr.Slider(minimum=16, maximum=96, value=32, step=16, interactive=True, label="Points Per Side")
|
65 |
+
autoseg_button = gr.Button("Auto Segment",variant="primary")
|
66 |
+
emptyBtn = gr.Button("Restart",variant="secondary")
|
67 |
+
|
68 |
+
interseg_button.click(interactive_seg, inputs=[input_img,x,y,label], outputs=[output_img,mask_img])
|
69 |
+
autoseg_button.click(auto_seg, inputs=[input_img,pred_iou_thresh,points_per_side], outputs=[mask_img])
|
70 |
+
|
71 |
+
clear_bn.click(clear_point,outputs=[gallery,x,y,label],show_progress=True)
|
72 |
+
emptyBtn.click(reset_state,outputs=[input_img,gallery,output_img,mask_img,x,y,label],show_progress=True,)
|
73 |
+
|
74 |
+
example = gr.Examples(
|
75 |
+
examples=[[s,0.88,32] for s in glob.glob('./images/*')],
|
76 |
+
fn=auto_seg,
|
77 |
+
inputs=[input_img,pred_iou_thresh,points_per_side],
|
78 |
+
outputs=[output_img],
|
79 |
+
cache_examples=True,examples_per_page=5)
|
80 |
+
|
81 |
+
gr.Markdown(descriptionend)
|
82 |
+
if __name__ == "__main__":
|
83 |
+
demo.launch(debug=False,show_api=False,Share=True)
|
84 |
+
|
85 |
+
# matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
|
86 |
+
# #setup model
|
87 |
+
# sam_checkpoint = "sam_vit_h_4b8939.pth"
|
88 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
|
89 |
+
# model_type = "default"
|
90 |
+
# sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
91 |
+
# sam.to(device=device)
|
92 |
+
# mask_generator = SamAutomaticMaskGenerator(sam)
|
93 |
+
# predictor = SamPredictor(sam)
|
94 |
+
|
95 |
+
# def show_anns(anns):
|
96 |
+
# if len(anns) == 0:
|
97 |
+
# return
|
98 |
+
# sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
99 |
+
# ax = plt.gca()
|
100 |
+
# ax.set_autoscale_on(False)
|
101 |
+
# polygons = []
|
102 |
+
# color = []
|
103 |
+
# for ann in sorted_anns:
|
104 |
+
# m = ann['segmentation']
|
105 |
+
# img = np.ones((m.shape[0], m.shape[1], 3))
|
106 |
+
# color_mask = np.random.random((1, 3)).tolist()[0]
|
107 |
+
# for i in range(3):
|
108 |
+
# img[:,:,i] = color_mask[i]
|
109 |
+
# ax.imshow(np.dstack((img, m*0.35)))
|
110 |
+
|
111 |
+
# def segment_image(image):
|
112 |
+
# masks = mask_generator.generate(image)
|
113 |
+
# plt.clf()
|
114 |
+
# ppi = 100
|
115 |
+
# height, width, _ = image.shape
|
116 |
+
# plt.figure(figsize=(width / ppi, height / ppi), dpi=ppi)
|
117 |
+
# plt.imshow(image)
|
118 |
+
# show_anns(masks)
|
119 |
+
# plt.axis('off')
|
120 |
+
# plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
|
121 |
+
# output = cv2.imread('output.png')
|
122 |
+
# return Image.fromarray(output)
|
123 |
+
|
124 |
+
# with gr.Blocks() as demo:
|
125 |
+
# gr.Markdown(
|
126 |
+
# """
|
127 |
+
# # Segment Anything Model (SAM)
|
128 |
+
# ### A test on remote sensing data
|
129 |
+
# - Paper:[(https://arxiv.org/abs/2304.02643](https://arxiv.org/abs/2304.02643)
|
130 |
+
# - Github:[https://github.com/facebookresearch/segment-anything](https://github.com/facebookresearch/segment-anything)
|
131 |
+
# - Dataset:https://ai.facebook.com/datasets/segment-anything-downloads/(https://ai.facebook.com/datasets/segment-anything-downloads/)
|
132 |
+
# - Official Demo:[https://segment-anything.com/demo](https://segment-anything.com/demo)
|
133 |
+
# """
|
134 |
+
# )
|
135 |
+
# with gr.Row():
|
136 |
+
# image = gr.Image()
|
137 |
+
# image_output = gr.Image()
|
138 |
+
# # print(image.shape)
|
139 |
+
# segment_image_button = gr.Button("Segment")
|
140 |
+
# segment_image_button.click(segment_image, inputs=[image], outputs=image_output)
|
141 |
+
# gr.Examples(glob.glob('./images/*'),image,image_output,segment_image)
|
142 |
+
# gr.Markdown("""
|
143 |
+
# ### <div align=center>you can follow the WeChat public account [45度科研人] and leave me a message! </div>
|
144 |
+
# <br />
|
145 |
+
# <br />
|
146 |
+
# <div style="display:flex; justify-content:center;">
|
147 |
+
# <img src="https://dunazo.oss-cn-beijing.aliyuncs.com/blog/wechat-simple.png" style="margin-right:25px;width:200px;height:200px;">
|
148 |
+
# <div style="width:25px;"></div>
|
149 |
+
# <img src="https://dunazo.oss-cn-beijing.aliyuncs.com/blog/shoukuanma222.png" style="margin-left:25px;width:170px;height:190px;">
|
150 |
+
# </div>
|
151 |
+
# """)
|
152 |
+
# demo.launch(debug=True)
|