SegRS / run.py
JunchuanYu's picture
Upload run.py
72efb9f
raw
history blame
2.93 kB
import sys
import os
import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import glob
import gradio as gr
from PIL import Image
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
import logging
from huggingface_hub import hf_hub_download
token = os.environ['HUB_TOKEN']
loc =hf_hub_download(repo_id="JunchuanYu/files_for_segmentRS", filename="utils.py",repo_type="dataset",local_dir='.',token=token)
sys.path.append(loc)
from utils import *
with gr.Blocks(theme='gradio/soft') as demo:
gr.Markdown(title)
with gr.Accordion("Instructions For User 👉", open=False):
gr.Markdown(description)
x=gr.State(value=[])
y=gr.State(value=[])
label=gr.State(value=[])
with gr.Row():
with gr.Column(scale=13):
with gr.Row():
with gr.Column():
mode=gr.inputs.Radio(['Positive','Negative'], type="value",default='Positive',label='Types of sampling methods')
with gr.Column():
clear_bn=gr.Button("Clear Selection")
interseg_button = gr.Button("Interactive Segment",variant='primary')
with gr.Row():
input_img = gr.Image(label="Input")
gallery = gr.Image(label="Points")
input_img.select(get_select_coords, [input_img, mode,x,y,label], [gallery,x,y,label])
with gr.Row():
output_img = gr.Image(label="Result")
mask_img = gr.Image(label="Mask")
with gr.Row():
with gr.Column():
thresh = gr.Slider(minimum=0.8, maximum=1, value=0.90, step=0.01, interactive=True, label="Threshhold")
with gr.Column():
points = gr.Slider(minimum=16, maximum=96, value=32, step=16, interactive=True, label="Points/Side")
with gr.Column(scale=2,min_width=8):
example = gr.Examples(
examples=[[s,0.9,32] for s in glob.glob('./images/*')],
fn=auto_seg,
inputs=[input_img,thresh,points],
outputs=[output_img],
cache_examples=False,examples_per_page=5)
autoseg_button = gr.Button("Auto Segment",variant="primary")
emptyBtn = gr.Button("Restart",variant="secondary")
interseg_button.click(interactive_seg, inputs=[input_img,x,y,label], outputs=[output_img,mask_img])
autoseg_button.click(auto_seg, inputs=[input_img,thresh,points], outputs=[mask_img])
clear_bn.click(clear_point,outputs=[gallery,mode,x,y,label],show_progress=True)
emptyBtn.click(reset_state,outputs=[input_img,gallery,output_img,mask_img,thresh,points,mode,x,y,label],show_progress=True,)
gr.Markdown(descriptionend)
if __name__ == "__main__":
demo.launch(debug=False,show_api=False)