JunchuanYu commited on
Commit
57bcd95
·
1 Parent(s): 2f550e0

Update run.py

Browse files
Files changed (1) hide show
  1. run.py +88 -51
run.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import cv2
3
  import matplotlib
@@ -9,63 +10,99 @@ import glob
9
  import gradio as gr
10
  from PIL import Image
11
  from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
 
 
12
 
13
- matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
14
- #setup model
15
- sam_checkpoint = "sam_vit_h_4b8939.pth"
 
 
 
 
16
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
17
- model_type = "default"
18
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
19
  sam.to(device=device)
20
- mask_generator = SamAutomaticMaskGenerator(sam)
21
  predictor = SamPredictor(sam)
 
22
 
23
- def show_anns(anns):
24
- if len(anns) == 0:
25
- return
26
- sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
27
- ax = plt.gca()
28
- ax.set_autoscale_on(False)
29
- polygons = []
30
- color = []
31
- for ann in sorted_anns:
32
- m = ann['segmentation']
33
- img = np.ones((m.shape[0], m.shape[1], 3))
34
- color_mask = np.random.random((1, 3)).tolist()[0]
35
- for i in range(3):
36
- img[:,:,i] = color_mask[i]
37
- ax.imshow(np.dstack((img, m*0.35)))
38
-
39
- def segment_image(image):
40
- masks = mask_generator.generate(image)
41
- plt.clf()
42
- ppi = 100
43
- height, width, _ = image.shape
44
- plt.figure(figsize=(width / ppi, height / ppi), dpi=ppi)
45
- plt.imshow(image)
46
- show_anns(masks)
47
- plt.axis('off')
48
- plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
49
- output = cv2.imread('output.png')
50
- return Image.fromarray(output)
51
-
52
- with gr.Blocks() as demo:
53
- gr.Markdown(
54
  """
55
- # Segment Anything Model (SAM)
56
- ### A test on remote sensing data (软件将更新2.0版本加入交互功能请关注公众号获得最新消息)
57
- - Paper:[(https://arxiv.org/abs/2304.02643](https://arxiv.org/abs/2304.02643)
58
- - Github:[https://github.com/facebookresearch/segment-anything](https://github.com/facebookresearch/segment-anything)
59
- - Dataset:https://ai.facebook.com/datasets/segment-anything-downloads/(https://ai.facebook.com/datasets/segment-anything-downloads/)
60
- - Official Demo:[https://segment-anything.com/demo](https://segment-anything.com/demo)
61
- """
 
 
 
 
 
 
62
  )
 
 
 
 
 
 
 
 
63
  with gr.Row():
64
- image = gr.Image()
65
- image_output = gr.Image()
66
- print(image.shape)
67
- segment_image_button = gr.Button("Segment")
68
- segment_image_button.click(segment_image, inputs=[image], outputs=image_output)
69
- gr.Examples(glob.glob('./images/*'),image,image_output,segment_image)
70
-
71
- demo.launch(debug=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
  import os
3
  import cv2
4
  import matplotlib
 
10
  import gradio as gr
11
  from PIL import Image
12
  from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
13
+ import logging
14
+ from huggingface_hub import hf_hub_download
15
 
16
+ token = os.environ['HUB_TOKEN']
17
+ loc =hf_hub_download(repo_id="JunchuanYu/files_for_segmentRS", filename="utils.py",repo_type="dataset",local_dir='.',token=token)
18
+ sys.path.append(loc)
19
+ from utils import *
20
+
21
+ sam_checkpoint ="sam_vit_b_01ec64.pth"
22
+ # sam_checkpoint = "sam_vit_h_4b8939.pth"
23
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
24
+ model_type = "vit_b"
25
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
26
  sam.to(device=device)
 
27
  predictor = SamPredictor(sam)
28
+ logging.basicConfig(filename="app.log", level=logging.INFO)
29
 
30
+ title=(
31
+ """
32
+ # <p align="center"> Segment-RS 🛰️ <b>
33
+ ## <p align="center"> A remote sensing interactive interpretation tools based on segment-anything (SAM 👍) <b>
34
+ ### <p align="center"> YJC (yujunchuan@mail.cgs.gov.cn) 📧<b>
35
+ """
36
+ )
37
+ description =(
38
+ """
39
+ Segment-RS is an interactive remote sensing interpretation tool that has been developed based on [SAM](https://github.com/facebookresearch/segment-anything). It allows for the real-time extraction of various remote sensing targets through interaction. Segment-RS is equipped with two interpretation models, namely, interactive extraction and automatic extraction.
40
+ * Interactive extraction involves manually selecting samples (positive and negative) from the image to extract obvious targets. It should be emphasized that this manual interaction method is suitable for extracting an independent target in the scene and not suitable for extracting multiple targets of the same type at once as it is still under development.
41
+ * Automatic extraction does not require any interaction, one can simply click the "Auto Segment" button to get the segmentation result. Additionally, the accuracy and granularity of segmentation can be adjusted through "Prediction Thresh" and "Points Per Side".
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  """
43
+ )
44
+ descriptionend=(
45
+ """
46
+ <div align=center><img src="https://em-content.zobj.net/source/microsoft-teams/337/robot_1f916.png" style="width:5%;"></div>
47
+ <br />
48
+ <div align=center>you can follow the WeChat public account [45度科研人] and leave me a message! </div>
49
+ <br />
50
+ <div style="display:flex; justify-content:center;">
51
+ <img src="https://dunazo.oss-cn-beijing.aliyuncs.com/blog/wechat-simple.png" style="margin-right:25px;width:200px;height:200px;">
52
+ <div style="width:25px;"></div>
53
+ <img src="https://dunazo.oss-cn-beijing.aliyuncs.com/blog/shoukuanma222.png" style="margin-left:25px;width:170px;height:190px;">
54
+ </div>
55
+ """
56
  )
57
+
58
+ with gr.Blocks(theme='gradio/soft') as demo:
59
+ gr.Markdown(title)
60
+ with gr.Accordion("Instructions For User 👉", open=False):
61
+ gr.Markdown(description)
62
+ x=gr.State(value=[])
63
+ y=gr.State(value=[])
64
+ label=gr.State(value=[])
65
  with gr.Row():
66
+ with gr.Column(scale=13):
67
+ with gr.Row():
68
+ with gr.Column():
69
+ mode=gr.inputs.Radio(['Positive','Negative'], type="value",default='Positive',label='Types of sampling methods')
70
+ with gr.Column():
71
+ clear_bn=gr.Button("Clear Selection")
72
+ interseg_button = gr.Button("Interactive Segment",variant='primary')
73
+ with gr.Row():
74
+ input_img = gr.Image(label="Input")
75
+ gallery = gr.Image(label="Points")
76
+
77
+ input_img.select(get_select_coords, [input_img, mode,x,y,label], [gallery,x,y,label])
78
+
79
+ with gr.Row():
80
+ output_img = gr.Image(label="Result")
81
+ mask_img = gr.Image(label="Mask")
82
+ with gr.Row():
83
+ with gr.Column():
84
+ thresh = gr.Slider(minimum=0.8, maximum=1, value=0.90, step=0.01, interactive=True, label="Threshhold")
85
+ with gr.Column():
86
+ points = gr.Slider(minimum=16, maximum=96, value=32, step=16, interactive=True, label="Points/Side")
87
+
88
+ with gr.Column(scale=2,min_width=8):
89
+ example = gr.Examples(
90
+ examples=[[s,0.9,32] for s in glob.glob('./images/*')],
91
+ fn=auto_seg,
92
+ inputs=[input_img,thresh,points],
93
+ outputs=[output_img],
94
+ cache_examples=False,examples_per_page=5)
95
+
96
+ autoseg_button = gr.Button("Auto Segment",variant="primary")
97
+ emptyBtn = gr.Button("Restart",variant="secondary")
98
+
99
+ interseg_button.click(interactive_seg, inputs=[input_img,x,y,label], outputs=[output_img,mask_img])
100
+ autoseg_button.click(auto_seg, inputs=[input_img,thresh,points], outputs=[mask_img])
101
+
102
+ clear_bn.click(clear_point,outputs=[gallery,mode,x,y,label],show_progress=True)
103
+ emptyBtn.click(reset_state,outputs=[input_img,gallery,output_img,mask_img,thresh,points,mode,x,y,label],show_progress=True,)
104
+
105
+ gr.Markdown(descriptionend)
106
+ if __name__ == "__main__":
107
+ demo.launch(debug=False,show_api=False)
108
+