import sys
import os
import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import glob
import gradio as gr
from PIL import Image
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
import logging
from huggingface_hub import login
from huggingface_hub import Repository
from huggingface_hub import hf_hub_download
# os.system("python -m pip install --upgrade pip")
os.system("pip uninstall -y gradio")
os.system("pip install gradio==3.27.0")
token = os.environ['HUB_TOKEN']
loc =hf_hub_download(repo_id="JunchuanYu/files_for_segmentRS", filename="utils.py",repo_type="dataset",local_dir='.',token=token)
print(loc)
sys.path.append(loc)
from utils import *
matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
#setup model
sam_checkpoint = "sam_vit_h_4b8939.pth"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
logging.basicConfig(filename="app.log", level=logging.INFO)
title=(
"""
#
Segment-RS 🛰️
##
A remote sensing interactive interpretation tools based on segment-anything (SAM 👍)
###
YJC (yujunchuan@mail.cgs.gov.cn) 📧
"""
)
description =(
"""
Segment-RS is an interactive remote sensing interpretation tool that has been developed based on [SAM](https://github.com/facebookresearch/segment-anything). It allows for the real-time extraction of various remote sensing targets through interaction. Segment-RS is equipped with two interpretation models, namely, interactive extraction and automatic extraction.
* Interactive extraction involves manually selecting samples (positive and negative) from the image to extract obvious targets. It should be emphasized that this manual interaction method is suitable for extracting an independent target in the scene and not suitable for extracting multiple targets of the same type at once as it is still under development.
* Automatic extraction does not require any interaction, one can simply click the "Auto Segment" button to get the segmentation result. Additionally, the accuracy and granularity of segmentation can be adjusted through "Prediction Thresh" and "Points Per Side".
"""
)
descriptionend=(
"""
you can follow the WeChat public account [45度科研人] and leave me a message!
"""
)
with gr.Blocks(theme='gradio/soft') as demo:
gr.Markdown(title)
with gr.Accordion("Instructions For User 👉", open=False):
gr.Markdown(description)
x=gr.State(value=[])
y=gr.State(value=[])
label=gr.State(value=[])
with gr.Row():
with gr.Column():
mode=gr.inputs.Radio(['Positive','Negative'], type="value",default='Positive',label='Types of sampling methods')
with gr.Column():
clear_bn=gr.Button("Clear Selection")
interseg_button = gr.Button("Interactive Segment",variant='primary')
with gr.Row():
input_img = gr.Image(label="Input")
gallery = gr.Image(label="Selected Sample Points")
input_img.select(get_select_coords, [input_img, mode,x,y,label], [gallery,x,y,label])
with gr.Row():
output_img = gr.Image(label="Result")
mask_img = gr.Image(label="Mask")
with gr.Row():
with gr.Column():
pred_iou_thresh = gr.Slider(minimum=0.8, maximum=1, value=0.90, step=0.01, interactive=True, label="Prediction Thresh")
with gr.Column():
points_per_side = gr.Slider(minimum=16, maximum=96, value=32, step=16, interactive=True, label="Points Per Side")
autoseg_button = gr.Button("Auto Segment",variant="primary")
emptyBtn = gr.Button("Restart",variant="secondary")
interseg_button.click(interactive_seg, inputs=[input_img,x,y,label], outputs=[output_img,mask_img])
autoseg_button.click(auto_seg, inputs=[input_img,pred_iou_thresh,points_per_side], outputs=[mask_img])
clear_bn.click(clear_point,outputs=[gallery,x,y,label],show_progress=True)
emptyBtn.click(reset_state,outputs=[input_img,gallery,output_img,mask_img,x,y,label],show_progress=True,)
example = gr.Examples(
examples=[[s,0.88,32] for s in glob.glob('./images/*')],
fn=auto_seg,
inputs=[input_img,pred_iou_thresh,points_per_side],
outputs=[output_img],
cache_examples=True,examples_per_page=5)
gr.Markdown(descriptionend)
if __name__ == "__main__":
demo.launch(debug=False,show_api=False,Share=True)
# matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
# #setup model
# sam_checkpoint = "sam_vit_h_4b8939.pth"
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
# model_type = "default"
# sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
# sam.to(device=device)
# mask_generator = SamAutomaticMaskGenerator(sam)
# predictor = SamPredictor(sam)
# def show_anns(anns):
# if len(anns) == 0:
# return
# sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
# ax = plt.gca()
# ax.set_autoscale_on(False)
# polygons = []
# color = []
# for ann in sorted_anns:
# m = ann['segmentation']
# img = np.ones((m.shape[0], m.shape[1], 3))
# color_mask = np.random.random((1, 3)).tolist()[0]
# for i in range(3):
# img[:,:,i] = color_mask[i]
# ax.imshow(np.dstack((img, m*0.35)))
# def segment_image(image):
# masks = mask_generator.generate(image)
# plt.clf()
# ppi = 100
# height, width, _ = image.shape
# plt.figure(figsize=(width / ppi, height / ppi), dpi=ppi)
# plt.imshow(image)
# show_anns(masks)
# plt.axis('off')
# plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
# output = cv2.imread('output.png')
# return Image.fromarray(output)
# with gr.Blocks() as demo:
# gr.Markdown(
# """
# # Segment Anything Model (SAM)
# ### A test on remote sensing data
# - Paper:[(https://arxiv.org/abs/2304.02643](https://arxiv.org/abs/2304.02643)
# - Github:[https://github.com/facebookresearch/segment-anything](https://github.com/facebookresearch/segment-anything)
# - Dataset:https://ai.facebook.com/datasets/segment-anything-downloads/(https://ai.facebook.com/datasets/segment-anything-downloads/)
# - Official Demo:[https://segment-anything.com/demo](https://segment-anything.com/demo)
# """
# )
# with gr.Row():
# image = gr.Image()
# image_output = gr.Image()
# # print(image.shape)
# segment_image_button = gr.Button("Segment")
# segment_image_button.click(segment_image, inputs=[image], outputs=image_output)
# gr.Examples(glob.glob('./images/*'),image,image_output,segment_image)
# gr.Markdown("""
# ### you can follow the WeChat public account [45度科研人] and leave me a message!
#
#
#
#
#
#
#
# """)
# demo.launch(debug=True)