JunchuanYu commited on
Commit
e1466f1
·
1 Parent(s): 9d05532

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -61
app.py CHANGED
@@ -9,76 +9,144 @@ import glob
9
  import gradio as gr
10
  from PIL import Image
11
  from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
 
 
 
 
 
 
12
 
13
- os.system("python -m pip install --upgrade pip")
14
- os.system("pip uninstall -y gradio")
15
- os.system("pip install gradio==3.27.0")
16
-
 
 
 
 
17
  matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
18
  #setup model
19
  sam_checkpoint = "sam_vit_h_4b8939.pth"
20
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
21
- model_type = "default"
22
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
23
  sam.to(device=device)
24
- mask_generator = SamAutomaticMaskGenerator(sam)
25
  predictor = SamPredictor(sam)
 
26
 
27
- def show_anns(anns):
28
- if len(anns) == 0:
29
- return
30
- sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
31
- ax = plt.gca()
32
- ax.set_autoscale_on(False)
33
- polygons = []
34
- color = []
35
- for ann in sorted_anns:
36
- m = ann['segmentation']
37
- img = np.ones((m.shape[0], m.shape[1], 3))
38
- color_mask = np.random.random((1, 3)).tolist()[0]
39
- for i in range(3):
40
- img[:,:,i] = color_mask[i]
41
- ax.imshow(np.dstack((img, m*0.35)))
42
 
43
- def segment_image(image):
44
- masks = mask_generator.generate(image)
45
- plt.clf()
46
- ppi = 100
47
- height, width, _ = image.shape
48
- plt.figure(figsize=(width / ppi, height / ppi), dpi=ppi)
49
- plt.imshow(image)
50
- show_anns(masks)
51
- plt.axis('off')
52
- plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
53
- output = cv2.imread('output.png')
54
- return Image.fromarray(output)
55
 
56
- with gr.Blocks() as demo:
57
- gr.Markdown(
58
- """
59
- # Segment Anything Model (SAM)
60
- ### A test on remote sensing data
61
- - Paper:[(https://arxiv.org/abs/2304.02643](https://arxiv.org/abs/2304.02643)
62
- - Github:[https://github.com/facebookresearch/segment-anything](https://github.com/facebookresearch/segment-anything)
63
- - Dataset:https://ai.facebook.com/datasets/segment-anything-downloads/(https://ai.facebook.com/datasets/segment-anything-downloads/)
64
- - Official Demo:[https://segment-anything.com/demo](https://segment-anything.com/demo)
65
- """
66
- )
67
  with gr.Row():
68
- image = gr.Image()
69
- image_output = gr.Image()
70
- # print(image.shape)
71
- segment_image_button = gr.Button("Segment")
72
- segment_image_button.click(segment_image, inputs=[image], outputs=image_output)
73
- gr.Examples(glob.glob('./images/*'),image,image_output,segment_image)
74
- gr.Markdown("""
75
- ### <div align=center>you can follow the WeChat public account [45度科研人] and leave me a message! </div>
76
- <br />
77
- <br />
78
- <div style="display:flex; justify-content:center;">
79
- <img src="https://dunazo.oss-cn-beijing.aliyuncs.com/blog/wechat-simple.png" style="margin-right:25px;width:200px;height:200px;">
80
- <div style="width:25px;"></div>
81
- <img src="https://dunazo.oss-cn-beijing.aliyuncs.com/blog/shoukuanma222.png" style="margin-left:25px;width:170px;height:190px;">
82
- </div>
83
- """)
84
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import gradio as gr
10
  from PIL import Image
11
  from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
12
+ import logging
13
+ from huggingface_hub import login
14
+ from huggingface_hub import Repository
15
+ # os.system("python -m pip install --upgrade pip")
16
+ # os.system("pip uninstall -y gradio")
17
+ # os.system("pip install gradio==3.27.0")
18
 
19
+ login(token = os.environ['HUB_TOKEN'])
20
+ repo = Repository(
21
+ local_dir="files",
22
+ repo_type="dataset",
23
+ clone_from="JunchuanYu/files_for_segmentRS",
24
+ token=True
25
+ )
26
+ repo.git_pull()
27
  matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
28
  #setup model
29
  sam_checkpoint = "sam_vit_h_4b8939.pth"
30
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
31
+ model_type = "vit_h"
32
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
33
  sam.to(device=device)
 
34
  predictor = SamPredictor(sam)
35
+ logging.basicConfig(filename="app.log", level=logging.INFO)
36
 
37
+ with gr.Blocks(theme='gradio/soft') as demo:
38
+ gr.Markdown(title)
39
+ with gr.Accordion("Instructions For User 👉", open=False):
40
+ gr.Markdown(description)
41
+ x=gr.State(value=[])
42
+ y=gr.State(value=[])
43
+ label=gr.State(value=[])
 
 
 
 
 
 
 
 
44
 
45
+ with gr.Row():
46
+ with gr.Column():
47
+ mode=gr.inputs.Radio(['Positive','Negative'], type="value",default='Positive',label='Types of sampling methods')
48
+ with gr.Column():
49
+ clear_bn=gr.Button("Clear Selection")
50
+ interseg_button = gr.Button("Interactive Segment",variant='primary')
51
+ with gr.Row():
52
+ input_img = gr.Image(label="Input")
53
+ gallery = gr.Image(label="Selected Sample Points")
54
+
55
+ input_img.select(get_select_coords, [input_img, mode,x,y,label], [gallery,x,y,label])
 
56
 
 
 
 
 
 
 
 
 
 
 
 
57
  with gr.Row():
58
+ output_img = gr.Image(label="Result")
59
+ mask_img = gr.Image(label="Mask")
60
+ with gr.Row():
61
+ with gr.Column():
62
+ pred_iou_thresh = gr.Slider(minimum=0.8, maximum=1, value=0.90, step=0.01, interactive=True, label="Prediction Thresh")
63
+ with gr.Column():
64
+ points_per_side = gr.Slider(minimum=16, maximum=96, value=32, step=16, interactive=True, label="Points Per Side")
65
+ autoseg_button = gr.Button("Auto Segment",variant="primary")
66
+ emptyBtn = gr.Button("Restart",variant="secondary")
67
+
68
+ interseg_button.click(interactive_seg, inputs=[input_img,x,y,label], outputs=[output_img,mask_img])
69
+ autoseg_button.click(auto_seg, inputs=[input_img,pred_iou_thresh,points_per_side], outputs=[mask_img])
70
+
71
+ clear_bn.click(clear_point,outputs=[gallery,x,y,label],show_progress=True)
72
+ emptyBtn.click(reset_state,outputs=[input_img,gallery,output_img,mask_img,x,y,label],show_progress=True,)
73
+
74
+ example = gr.Examples(
75
+ examples=[[s,0.88,32] for s in glob.glob('./images/*')],
76
+ fn=auto_seg,
77
+ inputs=[input_img,pred_iou_thresh,points_per_side],
78
+ outputs=[output_img],
79
+ cache_examples=True,examples_per_page=5)
80
+
81
+ gr.Markdown(descriptionend)
82
+ if __name__ == "__main__":
83
+ demo.launch(debug=False,show_api=False,Share=True)
84
+
85
+ # matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
86
+ # #setup model
87
+ # sam_checkpoint = "sam_vit_h_4b8939.pth"
88
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
89
+ # model_type = "default"
90
+ # sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
91
+ # sam.to(device=device)
92
+ # mask_generator = SamAutomaticMaskGenerator(sam)
93
+ # predictor = SamPredictor(sam)
94
+
95
+ # def show_anns(anns):
96
+ # if len(anns) == 0:
97
+ # return
98
+ # sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
99
+ # ax = plt.gca()
100
+ # ax.set_autoscale_on(False)
101
+ # polygons = []
102
+ # color = []
103
+ # for ann in sorted_anns:
104
+ # m = ann['segmentation']
105
+ # img = np.ones((m.shape[0], m.shape[1], 3))
106
+ # color_mask = np.random.random((1, 3)).tolist()[0]
107
+ # for i in range(3):
108
+ # img[:,:,i] = color_mask[i]
109
+ # ax.imshow(np.dstack((img, m*0.35)))
110
+
111
+ # def segment_image(image):
112
+ # masks = mask_generator.generate(image)
113
+ # plt.clf()
114
+ # ppi = 100
115
+ # height, width, _ = image.shape
116
+ # plt.figure(figsize=(width / ppi, height / ppi), dpi=ppi)
117
+ # plt.imshow(image)
118
+ # show_anns(masks)
119
+ # plt.axis('off')
120
+ # plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
121
+ # output = cv2.imread('output.png')
122
+ # return Image.fromarray(output)
123
+
124
+ # with gr.Blocks() as demo:
125
+ # gr.Markdown(
126
+ # """
127
+ # # Segment Anything Model (SAM)
128
+ # ### A test on remote sensing data
129
+ # - Paper:[(https://arxiv.org/abs/2304.02643](https://arxiv.org/abs/2304.02643)
130
+ # - Github:[https://github.com/facebookresearch/segment-anything](https://github.com/facebookresearch/segment-anything)
131
+ # - Dataset:https://ai.facebook.com/datasets/segment-anything-downloads/(https://ai.facebook.com/datasets/segment-anything-downloads/)
132
+ # - Official Demo:[https://segment-anything.com/demo](https://segment-anything.com/demo)
133
+ # """
134
+ # )
135
+ # with gr.Row():
136
+ # image = gr.Image()
137
+ # image_output = gr.Image()
138
+ # # print(image.shape)
139
+ # segment_image_button = gr.Button("Segment")
140
+ # segment_image_button.click(segment_image, inputs=[image], outputs=image_output)
141
+ # gr.Examples(glob.glob('./images/*'),image,image_output,segment_image)
142
+ # gr.Markdown("""
143
+ # ### <div align=center>you can follow the WeChat public account [45度科研人] and leave me a message! </div>
144
+ # <br />
145
+ # <br />
146
+ # <div style="display:flex; justify-content:center;">
147
+ # <img src="https://dunazo.oss-cn-beijing.aliyuncs.com/blog/wechat-simple.png" style="margin-right:25px;width:200px;height:200px;">
148
+ # <div style="width:25px;"></div>
149
+ # <img src="https://dunazo.oss-cn-beijing.aliyuncs.com/blog/shoukuanma222.png" style="margin-left:25px;width:170px;height:190px;">
150
+ # </div>
151
+ # """)
152
+ # demo.launch(debug=True)