Spaces:
Runtime error
Runtime error
File size: 5,623 Bytes
6c92cdf d7c2a2b 35ac044 0ebdb47 8901ad6 d7c2a2b e1466f1 6c92cdf 6f6b218 6c92cdf 6798448 e1466f1 d7c2a2b e1466f1 35ac044 8901ad6 e1466f1 9aae2bd e1466f1 9aae2bd e1466f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import sys
import os
import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import glob
import gradio as gr
from PIL import Image
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
import logging
from huggingface_hub import login
from huggingface_hub import Repository
from huggingface_hub import hf_hub_download
token = os.environ['HUB_TOKEN']
loc =hf_hub_download(repo_id="JunchuanYu/files_for_segmentRS", filename="utils.py",repo_type="dataset",local_dir='.',token=token)
sys.path.append(loc)
from utils import *
with gr.Blocks(theme='gradio/soft') as demo:
gr.Markdown(title)
with gr.Accordion("Instructions For User 👉", open=False):
gr.Markdown(description)
x=gr.State(value=[])
y=gr.State(value=[])
label=gr.State(value=[])
with gr.Row():
with gr.Column():
mode=gr.inputs.Radio(['Positive','Negative'], type="value",default='Positive',label='Types of sampling methods')
with gr.Column():
clear_bn=gr.Button("Clear Selection")
interseg_button = gr.Button("Interactive Segment",variant='primary')
with gr.Row():
input_img = gr.Image(label="Input")
gallery = gr.Image(label="Selected Sample Points")
input_img.select(get_select_coords, [input_img, mode,x,y,label], [gallery,x,y,label])
with gr.Row():
output_img = gr.Image(label="Result")
mask_img = gr.Image(label="Mask")
with gr.Row():
with gr.Column():
pred_iou_thresh = gr.Slider(minimum=0.8, maximum=1, value=0.90, step=0.01, interactive=True, label="Prediction Thresh")
with gr.Column():
points_per_side = gr.Slider(minimum=16, maximum=96, value=32, step=16, interactive=True, label="Points Per Side")
autoseg_button = gr.Button("Auto Segment",variant="primary")
emptyBtn = gr.Button("Restart",variant="secondary")
interseg_button.click(interactive_seg, inputs=[input_img,x,y,label], outputs=[output_img,mask_img])
autoseg_button.click(auto_seg, inputs=[input_img,pred_iou_thresh,points_per_side], outputs=[mask_img])
clear_bn.click(clear_point,outputs=[gallery,x,y,label],show_progress=True)
emptyBtn.click(reset_state,outputs=[input_img,gallery,output_img,mask_img,x,y,label],show_progress=True,)
example = gr.Examples(
examples=[[s,0.88,32] for s in glob.glob('./images/*')],
fn=auto_seg,
inputs=[input_img,pred_iou_thresh,points_per_side],
outputs=[output_img],
cache_examples=False,examples_per_page=5)
gr.Markdown(descriptionend)
if __name__ == "__main__":
demo.launch(debug=False,show_api=False)
# matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
# #setup model
# sam_checkpoint = "sam_vit_h_4b8939.pth"
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
# model_type = "default"
# sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
# sam.to(device=device)
# mask_generator = SamAutomaticMaskGenerator(sam)
# predictor = SamPredictor(sam)
# def show_anns(anns):
# if len(anns) == 0:
# return
# sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
# ax = plt.gca()
# ax.set_autoscale_on(False)
# polygons = []
# color = []
# for ann in sorted_anns:
# m = ann['segmentation']
# img = np.ones((m.shape[0], m.shape[1], 3))
# color_mask = np.random.random((1, 3)).tolist()[0]
# for i in range(3):
# img[:,:,i] = color_mask[i]
# ax.imshow(np.dstack((img, m*0.35)))
# def segment_image(image):
# masks = mask_generator.generate(image)
# plt.clf()
# ppi = 100
# height, width, _ = image.shape
# plt.figure(figsize=(width / ppi, height / ppi), dpi=ppi)
# plt.imshow(image)
# show_anns(masks)
# plt.axis('off')
# plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
# output = cv2.imread('output.png')
# return Image.fromarray(output)
# with gr.Blocks() as demo:
# gr.Markdown(
# """
# # Segment Anything Model (SAM)
# ### A test on remote sensing data
# - Paper:[(https://arxiv.org/abs/2304.02643](https://arxiv.org/abs/2304.02643)
# - Github:[https://github.com/facebookresearch/segment-anything](https://github.com/facebookresearch/segment-anything)
# - Dataset:https://ai.facebook.com/datasets/segment-anything-downloads/(https://ai.facebook.com/datasets/segment-anything-downloads/)
# - Official Demo:[https://segment-anything.com/demo](https://segment-anything.com/demo)
# """
# )
# with gr.Row():
# image = gr.Image()
# image_output = gr.Image()
# # print(image.shape)
# segment_image_button = gr.Button("Segment")
# segment_image_button.click(segment_image, inputs=[image], outputs=image_output)
# gr.Examples(glob.glob('./images/*'),image,image_output,segment_image)
# gr.Markdown("""
# ### <div align=center>you can follow the WeChat public account [45度科研人] and leave me a message! </div>
# <br />
# <br />
# <div style="display:flex; justify-content:center;">
# <img src="https://dunazo.oss-cn-beijing.aliyuncs.com/blog/wechat-simple.png" style="margin-right:25px;width:200px;height:200px;">
# <div style="width:25px;"></div>
# <img src="https://dunazo.oss-cn-beijing.aliyuncs.com/blog/shoukuanma222.png" style="margin-left:25px;width:170px;height:190px;">
# </div>
# """)
# demo.launch(debug=True) |