import sys import os import cv2 import matplotlib import matplotlib.pyplot as plt import numpy as np import torch import torchvision import glob import gradio as gr from PIL import Image from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry import logging from huggingface_hub import login from huggingface_hub import Repository from huggingface_hub import hf_hub_download token = os.environ['HUB_TOKEN'] loc =hf_hub_download(repo_id="JunchuanYu/files_for_segmentRS", filename="utils.py",repo_type="dataset",local_dir='.',token=token) sys.path.append(loc) from utils import * with gr.Blocks(theme='gradio/soft') as demo: gr.Markdown(title) with gr.Accordion("Instructions For User 👉", open=False): gr.Markdown(description) x=gr.State(value=[]) y=gr.State(value=[]) label=gr.State(value=[]) with gr.Row(): with gr.Column(): mode=gr.inputs.Radio(['Positive','Negative'], type="value",default='Positive',label='Types of sampling methods') with gr.Column(): clear_bn=gr.Button("Clear Selection") interseg_button = gr.Button("Interactive Segment",variant='primary') with gr.Row(): input_img = gr.Image(label="Input") gallery = gr.Image(label="Selected Sample Points") input_img.select(get_select_coords, [input_img, mode,x,y,label], [gallery,x,y,label]) with gr.Row(): output_img = gr.Image(label="Result") mask_img = gr.Image(label="Mask") with gr.Row(): with gr.Column(): pred_iou_thresh = gr.Slider(minimum=0.8, maximum=1, value=0.90, step=0.01, interactive=True, label="Prediction Thresh") with gr.Column(): points_per_side = gr.Slider(minimum=16, maximum=96, value=32, step=16, interactive=True, label="Points Per Side") autoseg_button = gr.Button("Auto Segment",variant="primary") emptyBtn = gr.Button("Restart",variant="secondary") interseg_button.click(interactive_seg, inputs=[input_img,x,y,label], outputs=[output_img,mask_img]) autoseg_button.click(auto_seg, inputs=[input_img,pred_iou_thresh,points_per_side], outputs=[mask_img]) clear_bn.click(clear_point,outputs=[gallery,x,y,label],show_progress=True) emptyBtn.click(reset_state,outputs=[input_img,gallery,output_img,mask_img,x,y,label],show_progress=True,) example = gr.Examples( examples=[[s,0.88,32] for s in glob.glob('./images/*')], fn=auto_seg, inputs=[input_img,pred_iou_thresh,points_per_side], outputs=[output_img], cache_examples=False,examples_per_page=5) gr.Markdown(descriptionend) if __name__ == "__main__": demo.launch(debug=False,show_api=False) # matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio # #setup model # sam_checkpoint = "sam_vit_h_4b8939.pth" # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available # model_type = "default" # sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) # sam.to(device=device) # mask_generator = SamAutomaticMaskGenerator(sam) # predictor = SamPredictor(sam) # def show_anns(anns): # if len(anns) == 0: # return # sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) # ax = plt.gca() # ax.set_autoscale_on(False) # polygons = [] # color = [] # for ann in sorted_anns: # m = ann['segmentation'] # img = np.ones((m.shape[0], m.shape[1], 3)) # color_mask = np.random.random((1, 3)).tolist()[0] # for i in range(3): # img[:,:,i] = color_mask[i] # ax.imshow(np.dstack((img, m*0.35))) # def segment_image(image): # masks = mask_generator.generate(image) # plt.clf() # ppi = 100 # height, width, _ = image.shape # plt.figure(figsize=(width / ppi, height / ppi), dpi=ppi) # plt.imshow(image) # show_anns(masks) # plt.axis('off') # plt.savefig('output.png', bbox_inches='tight', pad_inches=0) # output = cv2.imread('output.png') # return Image.fromarray(output) # with gr.Blocks() as demo: # gr.Markdown( # """ # # Segment Anything Model (SAM) # ### A test on remote sensing data # - Paper:[(https://arxiv.org/abs/2304.02643](https://arxiv.org/abs/2304.02643) # - Github:[https://github.com/facebookresearch/segment-anything](https://github.com/facebookresearch/segment-anything) # - Dataset:https://ai.facebook.com/datasets/segment-anything-downloads/(https://ai.facebook.com/datasets/segment-anything-downloads/) # - Official Demo:[https://segment-anything.com/demo](https://segment-anything.com/demo) # """ # ) # with gr.Row(): # image = gr.Image() # image_output = gr.Image() # # print(image.shape) # segment_image_button = gr.Button("Segment") # segment_image_button.click(segment_image, inputs=[image], outputs=image_output) # gr.Examples(glob.glob('./images/*'),image,image_output,segment_image) # gr.Markdown(""" # ###