import os
import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import glob
import gradio as gr
from PIL import Image
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
import logging
matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
sam_checkpoint = "sam_vit_h_4b8939.pth"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
logging.basicConfig(filename="app.log", level=logging.INFO)
title=(
"""
#
Segment-RS 🛰️
##
A remote sensing interactive interpretation tools based on segment-anything (SAM 👍)
###
YJC (yujunchuan@mail.cgs.gov.cn) 📧
"""
)
description =(
"""
Segment-RS is an interactive remote sensing interpretation tool that has been developed based on [SAM](https://github.com/facebookresearch/segment-anything). It allows for the real-time extraction of various remote sensing targets through interaction. Segment-RS is equipped with two interpretation models, namely, interactive extraction and automatic extraction.
* Interactive extraction involves manually selecting samples (positive and negative) from the image to extract obvious targets. It should be emphasized that this manual interaction method is suitable for extracting an independent target in the scene and not suitable for extracting multiple targets of the same type at once as it is still under development.
* Automatic extraction does not require any interaction, one can simply click the "Auto Segment" button to get the segmentation result. Additionally, the accuracy and granularity of segmentation can be adjusted through "Prediction Thresh" and "Points Per Side".
"""
)
descriptionend=(
"""
you can follow the WeChat public account [45度科研人] and leave me a message!
"""
)
def show_image_with_scatter(img, x, y, label):
# convert to numpy array
x = np.array(x)
y = np.array(y)
label = np.array(label)
# scatter plot
mask = label == 0
color = (0, 0, 255) # blue
pts = np.stack((x[mask], y[mask]), axis=-1).astype(int)
for pt in pts:
img = cv2.circle(img, tuple(pt), radius=10, color=color, thickness=-1)
mask = label == 1
color = (255, 0, 0) # red
pts = np.stack((x[mask], y[mask]), axis=-1).astype(int)
for pt in pts:
img = cv2.circle(img, tuple(pt), radius=10, color=color, thickness=-1)
return img, x, y, label
def get_select_coords(img,mode,x,y,label,evt:gr.SelectData):
x=list(x)
y=list(y)
label=list(label)
x.append(evt.index[0])
y.append(evt.index[1])
if mode=='Positive':
label.append((1))
if mode=='Negative':
label.append((0))
out,x,y,label=show_image_with_scatter(img,x,y,label)
# print(x,y,label)
return out,x,y,label
def save_color_mask(masks):
bin_mask=masks.reshape(masks.shape[1], masks.shape[2])*255
color = np.array([30, 144, 200,255])
mask_image = masks.reshape(masks.shape[1], masks.shape[2], 1) * color.reshape(1, 1, -1)
mask_image = mask_image.astype(np.uint8)
# pil=Image.fromarray(mask_image)
# pil.save('result.png', format='PNG', mode='RGBA')
return mask_image,bin_mask
def img_compose(mask_image,image):
mask_alpha = np.array(mask_image[:, :, -1]*0.65, dtype=np.uint8) # 提取出 alpha 通道
mask_rgba = np.dstack((mask_image[:, :, :-1], mask_alpha)) # 将 RGB 和 alpha 合并成 RGBA
new_a_pil = Image.fromarray(mask_rgba, mode='RGBA')
b_pil=Image.fromarray(image).convert('RGBA')
result_pil = Image.alpha_composite(b_pil,new_a_pil)
# result_pil.save('result.png', format='PNG', mode='RGBA')
return np.array(result_pil)
def interactive_seg(image,input_pointx,input_pointy,input_label):
# print(input_pointx,input_pointy,input_label)
tmp=list(zip(input_pointx,input_pointy))
input_point = np.array(tmp)
input_label = np.array(input_label)
if np.all([input_point.size == 0, input_label.size == 0]):
logging.error('Please select the target you want to extract by click in the image above!')
return None,None
predictor.set_image(image) # embedding操作
masks, scores, logits = predictor.predict(point_coords=input_point,
point_labels=input_label,multimask_output=False,)
mask_image,bin_mask=save_color_mask(masks)
result=img_compose(mask_image,image)
return result,bin_mask
def draw_masks(image, masks, alpha=0.35):
for mask in masks:
color = [np.random.randint(0,255)for _ in range(3)]
# draw mask overlay
colored_mask = np.expand_dims(mask["segmentation"], 0).repeat(3, axis=0)
colored_mask = np.moveaxis(colored_mask, 0, -1)
masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
image_overlay = masked.filled()
image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
# draw contour
contours, _ = cv2.findContours(np.uint8(mask["segmentation"]), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(image, contours, -1, (255, 255, 255), 2)
return image
def auto_seg(image,pred_iou_thresh,points_per_side):
mask_generator = SamAutomaticMaskGenerator(model=sam,points_per_side=points_per_side,pred_iou_thresh=pred_iou_thresh,min_mask_region_area=30)
masks = mask_generator.generate(image)
result=draw_masks(image,masks)
return result
def clear_point():
return None,[],[],[]
def reset_state():
logging.info("Reset")
# delete_temp()
return None,None,None,None,[],[],[]