SegRS / app.py
JunchuanYu's picture
Update app.py
0ebdb47
raw
history blame
2.66 kB
import os
import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import glob
import gradio as gr
from PIL import Image
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
gr.markdown(
"""
# Segment Anything Model (SAM)
### A test on remote sensing data
- Paper:[Link](https://scontent-fml2-1.xx.fbcdn.net/v/t39.2365-6/10000000_900554171201033_1602411987825904100_n.pdf?_nc_cat=100&ccb=1-7&_nc_sid=3c67a6&_nc_ohc=Ald4OYhL6hgAX-ZcGmS&_nc_ht=scontent-fml2-1.xx&oh=00_AfDk4FvyiDYeXgflANA2CbdV6HSS8CcJmrvjSfTqsgUmog&oe=643500E7)
- Github:[https://github.com/facebookresearch/segment-anything](https://github.com/facebookresearch/segment-anything)
- Dataset:https://ai.facebook.com/datasets/segment-anything-downloads/(https://ai.facebook.com/datasets/segment-anything-downloads/)
- Official Demo:[https://segment-anything.com/demo](https://segment-anything.com/demo)
"""
)
#setup model
sam_checkpoint = "meta-model.pth"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
model_type = "default"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
predictor = SamPredictor(sam)
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
for ann in sorted_anns:
m = ann['segmentation']
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack((img, m*0.35)))
def segment_image(image):
masks = mask_generator.generate(image)
plt.clf()
ppi = 100
height, width, _ = image.shape
plt.figure(figsize=(width / ppi, height / ppi), dpi=ppi)
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
output = cv2.imread('output.png')
return Image.fromarray(output)
with gr.Blocks() as demo:
# gr.Markdown("## Segment-anything Demo")
with gr.Row():
image = gr.Image()
image_output = gr.Image()
segment_image_button = gr.Button("Segment")
segment_image_button.click(segment_image, inputs=[image], outputs=image_output)
gr.Examples(
[
glob.glob('./images/*')
],
image,
image_output,
segment_image)
demo.launch()