File size: 5,623 Bytes
6c92cdf
d7c2a2b
 
 
 
 
 
35ac044
0ebdb47
8901ad6
d7c2a2b
 
e1466f1
6c92cdf
 
 
6f6b218
6c92cdf
 
 
 
6798448
e1466f1
 
 
 
 
 
 
d7c2a2b
e1466f1
 
 
 
 
 
 
 
 
 
 
35ac044
8901ad6
e1466f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9aae2bd
e1466f1
 
 
9aae2bd
e1466f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import sys
import os
import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import glob
import gradio as gr
from PIL import Image
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
import logging
from huggingface_hub import login
from huggingface_hub import Repository
from huggingface_hub import hf_hub_download

token = os.environ['HUB_TOKEN']
loc =hf_hub_download(repo_id="JunchuanYu/files_for_segmentRS", filename="utils.py",repo_type="dataset",local_dir='.',token=token)
sys.path.append(loc)
from utils import *

with gr.Blocks(theme='gradio/soft') as demo:
    gr.Markdown(title)
    with gr.Accordion("Instructions For User 👉", open=False):
        gr.Markdown(description)
    x=gr.State(value=[])
    y=gr.State(value=[])
    label=gr.State(value=[])

    with gr.Row():
        with gr.Column():  
            mode=gr.inputs.Radio(['Positive','Negative'], type="value",default='Positive',label='Types of sampling methods')
        with gr.Column():
                clear_bn=gr.Button("Clear Selection")
                interseg_button = gr.Button("Interactive Segment",variant='primary')
    with gr.Row():
        input_img = gr.Image(label="Input")
        gallery = gr.Image(label="Selected Sample Points")
        
    input_img.select(get_select_coords, [input_img, mode,x,y,label], [gallery,x,y,label])
    
    with gr.Row():
        output_img = gr.Image(label="Result")           
        mask_img = gr.Image(label="Mask")      
    with gr.Row():
        with gr.Column():
            pred_iou_thresh = gr.Slider(minimum=0.8, maximum=1, value=0.90, step=0.01, interactive=True, label="Prediction Thresh")
        with gr.Column():
            points_per_side = gr.Slider(minimum=16, maximum=96, value=32, step=16, interactive=True, label="Points Per Side")
    autoseg_button = gr.Button("Auto Segment",variant="primary")
    emptyBtn = gr.Button("Restart",variant="secondary")

    interseg_button.click(interactive_seg, inputs=[input_img,x,y,label], outputs=[output_img,mask_img])
    autoseg_button.click(auto_seg, inputs=[input_img,pred_iou_thresh,points_per_side], outputs=[mask_img])

    clear_bn.click(clear_point,outputs=[gallery,x,y,label],show_progress=True)
    emptyBtn.click(reset_state,outputs=[input_img,gallery,output_img,mask_img,x,y,label],show_progress=True,)   
    
    example = gr.Examples(
        examples=[[s,0.88,32] for s in glob.glob('./images/*')],
        fn=auto_seg,
        inputs=[input_img,pred_iou_thresh,points_per_side],
        outputs=[output_img],
        cache_examples=False,examples_per_page=5)
    
    gr.Markdown(descriptionend)
if __name__ == "__main__":
    demo.launch(debug=False,show_api=False)

# matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
# #setup model
# sam_checkpoint = "sam_vit_h_4b8939.pth"
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
# model_type = "default"
# sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
# sam.to(device=device)
# mask_generator = SamAutomaticMaskGenerator(sam)
# predictor = SamPredictor(sam)

# def show_anns(anns):
#     if len(anns) == 0:
#         return
#     sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
#     ax = plt.gca()
#     ax.set_autoscale_on(False)
#     polygons = []
#     color = []
#     for ann in sorted_anns:
#         m = ann['segmentation']
#         img = np.ones((m.shape[0], m.shape[1], 3))
#         color_mask = np.random.random((1, 3)).tolist()[0]
#         for i in range(3):
#             img[:,:,i] = color_mask[i]
#         ax.imshow(np.dstack((img, m*0.35)))

# def segment_image(image):
#     masks = mask_generator.generate(image)
#     plt.clf()
#     ppi = 100
#     height, width, _ = image.shape
#     plt.figure(figsize=(width / ppi, height / ppi), dpi=ppi)
#     plt.imshow(image)
#     show_anns(masks)
#     plt.axis('off')
#     plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
#     output = cv2.imread('output.png')
#     return Image.fromarray(output)
    
# with gr.Blocks() as demo:
#     gr.Markdown(
#         """
#     # Segment Anything Model (SAM) 
#     ### A test on remote sensing data
#     - Paper:[(https://arxiv.org/abs/2304.02643](https://arxiv.org/abs/2304.02643)
#     - Github:[https://github.com/facebookresearch/segment-anything](https://github.com/facebookresearch/segment-anything)
#     - Dataset:https://ai.facebook.com/datasets/segment-anything-downloads/(https://ai.facebook.com/datasets/segment-anything-downloads/)
#     - Official Demo:[https://segment-anything.com/demo](https://segment-anything.com/demo)
#     """
#     )
#     with gr.Row():
#         image = gr.Image()
#         image_output = gr.Image()
#     # print(image.shape)
#     segment_image_button = gr.Button("Segment")
#     segment_image_button.click(segment_image, inputs=[image], outputs=image_output)
#     gr.Examples(glob.glob('./images/*'),image,image_output,segment_image)
#     gr.Markdown("""
# ###  <div align=center>you can follow the WeChat public account [45度科研人] and leave me a message!  </div>
# <br />
# <br />
# <div style="display:flex; justify-content:center;">
#     <img src="https://dunazo.oss-cn-beijing.aliyuncs.com/blog/wechat-simple.png" style="margin-right:25px;width:200px;height:200px;">
#     <div style="width:25px;"></div>
#     <img src="https://dunazo.oss-cn-beijing.aliyuncs.com/blog/shoukuanma222.png" style="margin-left:25px;width:170px;height:190px;">
# </div>
# """)
# demo.launch(debug=True)