JunchuanYu commited on
Commit
268e766
·
1 Parent(s): 5fd8804

Update run.py

Browse files
Files changed (1) hide show
  1. run.py +56 -55
run.py CHANGED
@@ -1,4 +1,3 @@
1
- import sys
2
  import os
3
  import cv2
4
  import matplotlib
@@ -10,61 +9,63 @@ import glob
10
  import gradio as gr
11
  from PIL import Image
12
  from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
13
- import logging
14
- from huggingface_hub import hf_hub_download
15
 
16
- # token = os.environ['HUB_TOKEN']
17
- # loc =hf_hub_download(repo_id="JunchuanYu/files_for_segmentRS", filename="utils.py",repo_type="dataset",local_dir='.',token=token)
18
- # sys.path.append(loc)
19
- # from utils import *
 
 
 
 
 
20
 
21
- with gr.Blocks(theme='gradio/soft') as demo:
22
- gr.Markdown(title)
23
- with gr.Accordion("Instructions For User 👉", open=False):
24
- gr.Markdown(description)
25
- x=gr.State(value=[])
26
- y=gr.State(value=[])
27
- label=gr.State(value=[])
28
- with gr.Row():
29
- with gr.Column(scale=13):
30
- with gr.Row():
31
- with gr.Column():
32
- mode=gr.inputs.Radio(['Positive','Negative'], type="value",default='Positive',label='Types of sampling methods')
33
- with gr.Column():
34
- clear_bn=gr.Button("Clear Selection")
35
- interseg_button = gr.Button("Interactive Segment",variant='primary')
36
- with gr.Row():
37
- input_img = gr.Image(label="Input")
38
- gallery = gr.Image(label="Points")
39
-
40
- input_img.select(get_select_coords, [input_img, mode,x,y,label], [gallery,x,y,label])
41
-
42
- with gr.Row():
43
- output_img = gr.Image(label="Result")
44
- mask_img = gr.Image(label="Mask")
45
- with gr.Row():
46
- with gr.Column():
47
- thresh = gr.Slider(minimum=0.8, maximum=1, value=0.90, step=0.01, interactive=True, label="Threshhold")
48
- with gr.Column():
49
- points = gr.Slider(minimum=16, maximum=96, value=32, step=16, interactive=True, label="Points/Side")
50
-
51
- with gr.Column(scale=2,min_width=8):
52
- example = gr.Examples(
53
- examples=[[s,0.9,32] for s in glob.glob('./images/*')],
54
- fn=auto_seg,
55
- inputs=[input_img,thresh,points],
56
- outputs=[output_img],
57
- cache_examples=False,examples_per_page=5)
58
-
59
- autoseg_button = gr.Button("Auto Segment",variant="primary")
60
- emptyBtn = gr.Button("Restart",variant="secondary")
61
-
62
- interseg_button.click(interactive_seg, inputs=[input_img,x,y,label], outputs=[output_img,mask_img])
63
- autoseg_button.click(auto_seg, inputs=[input_img,thresh,points], outputs=[mask_img])
64
 
65
- clear_bn.click(clear_point,outputs=[gallery,mode,x,y,label],show_progress=True)
66
- emptyBtn.click(reset_state,outputs=[input_img,gallery,output_img,mask_img,thresh,points,mode,x,y,label],show_progress=True,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- gr.Markdown(descriptionend)
69
- if __name__ == "__main__":
70
- demo.launch(debug=False,show_api=False)
 
 
1
  import os
2
  import cv2
3
  import matplotlib
 
9
  import gradio as gr
10
  from PIL import Image
11
  from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
 
 
12
 
13
+ matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
14
+ #setup model
15
+ sam_checkpoint = "sam_vit_h_4b8939.pth"
16
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
17
+ model_type = "default"
18
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
19
+ sam.to(device=device)
20
+ mask_generator = SamAutomaticMaskGenerator(sam)
21
+ predictor = SamPredictor(sam)
22
 
23
+ def show_anns(anns):
24
+ if len(anns) == 0:
25
+ return
26
+ sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
27
+ ax = plt.gca()
28
+ ax.set_autoscale_on(False)
29
+ polygons = []
30
+ color = []
31
+ for ann in sorted_anns:
32
+ m = ann['segmentation']
33
+ img = np.ones((m.shape[0], m.shape[1], 3))
34
+ color_mask = np.random.random((1, 3)).tolist()[0]
35
+ for i in range(3):
36
+ img[:,:,i] = color_mask[i]
37
+ ax.imshow(np.dstack((img, m*0.35)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ def segment_image(image):
40
+ masks = mask_generator.generate(image)
41
+ plt.clf()
42
+ ppi = 100
43
+ height, width, _ = image.shape
44
+ plt.figure(figsize=(width / ppi, height / ppi), dpi=ppi)
45
+ plt.imshow(image)
46
+ show_anns(masks)
47
+ plt.axis('off')
48
+ plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
49
+ output = cv2.imread('output.png')
50
+ return Image.fromarray(output)
51
+
52
+ with gr.Blocks() as demo:
53
+ gr.Markdown(
54
+ """
55
+ # Segment Anything Model (SAM)
56
+ ### A test on remote sensing data (软件将更新2.0版本加入交互功能请关注公众号获得最新消息)
57
+ - Paper:[(https://arxiv.org/abs/2304.02643](https://arxiv.org/abs/2304.02643)
58
+ - Github:[https://github.com/facebookresearch/segment-anything](https://github.com/facebookresearch/segment-anything)
59
+ - Dataset:https://ai.facebook.com/datasets/segment-anything-downloads/(https://ai.facebook.com/datasets/segment-anything-downloads/)
60
+ - Official Demo:[https://segment-anything.com/demo](https://segment-anything.com/demo)
61
+ """
62
+ )
63
+ with gr.Row():
64
+ image = gr.Image()
65
+ image_output = gr.Image()
66
+ print(image.shape)
67
+ segment_image_button = gr.Button("Segment")
68
+ segment_image_button.click(segment_image, inputs=[image], outputs=image_output)
69
+ gr.Examples(glob.glob('./images/*'),image,image_output,segment_image)
70
 
71
+ demo.launch(debug=False)