import os import cv2 import matplotlib import matplotlib.pyplot as plt import numpy as np import torch import torchvision import glob import gradio as gr from PIL import Image from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry import logging matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio sam_checkpoint = "sam_vit_h_4b8939.pth" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available model_type = "vit_h" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) predictor = SamPredictor(sam) logging.basicConfig(filename="app.log", level=logging.INFO) title=( """ #

Segment-RS 🛰️ ##

A remote sensing interactive interpretation tools based on segment-anything (SAM 👍) ###

YJC (yujunchuan@mail.cgs.gov.cn) 📧 """ ) description =( """ Segment-RS is an interactive remote sensing interpretation tool that has been developed based on [SAM](https://github.com/facebookresearch/segment-anything). It allows for the real-time extraction of various remote sensing targets through interaction. Segment-RS is equipped with two interpretation models, namely, interactive extraction and automatic extraction. * Interactive extraction involves manually selecting samples (positive and negative) from the image to extract obvious targets. It should be emphasized that this manual interaction method is suitable for extracting an independent target in the scene and not suitable for extracting multiple targets of the same type at once as it is still under development. * Automatic extraction does not require any interaction, one can simply click the "Auto Segment" button to get the segmentation result. Additionally, the accuracy and granularity of segmentation can be adjusted through "Prediction Thresh" and "Points Per Side". """ ) descriptionend=( """


you can follow the WeChat public account [45度科研人] and leave me a message!

""" ) def show_image_with_scatter(img, x, y, label): # convert to numpy array x = np.array(x) y = np.array(y) label = np.array(label) # scatter plot mask = label == 0 color = (0, 0, 255) # blue pts = np.stack((x[mask], y[mask]), axis=-1).astype(int) for pt in pts: img = cv2.circle(img, tuple(pt), radius=10, color=color, thickness=-1) mask = label == 1 color = (255, 0, 0) # red pts = np.stack((x[mask], y[mask]), axis=-1).astype(int) for pt in pts: img = cv2.circle(img, tuple(pt), radius=10, color=color, thickness=-1) return img, x, y, label def get_select_coords(img,mode,x,y,label,evt:gr.SelectData): x=list(x) y=list(y) label=list(label) x.append(evt.index[0]) y.append(evt.index[1]) if mode=='Positive': label.append((1)) if mode=='Negative': label.append((0)) out,x,y,label=show_image_with_scatter(img,x,y,label) # print(x,y,label) return out,x,y,label def save_color_mask(masks): bin_mask=masks.reshape(masks.shape[1], masks.shape[2])*255 color = np.array([30, 144, 200,255]) mask_image = masks.reshape(masks.shape[1], masks.shape[2], 1) * color.reshape(1, 1, -1) mask_image = mask_image.astype(np.uint8) # pil=Image.fromarray(mask_image) # pil.save('result.png', format='PNG', mode='RGBA') return mask_image,bin_mask def img_compose(mask_image,image): mask_alpha = np.array(mask_image[:, :, -1]*0.65, dtype=np.uint8) # 提取出 alpha 通道 mask_rgba = np.dstack((mask_image[:, :, :-1], mask_alpha)) # 将 RGB 和 alpha 合并成 RGBA new_a_pil = Image.fromarray(mask_rgba, mode='RGBA') b_pil=Image.fromarray(image).convert('RGBA') result_pil = Image.alpha_composite(b_pil,new_a_pil) # result_pil.save('result.png', format='PNG', mode='RGBA') return np.array(result_pil) def interactive_seg(image,input_pointx,input_pointy,input_label): # print(input_pointx,input_pointy,input_label) tmp=list(zip(input_pointx,input_pointy)) input_point = np.array(tmp) input_label = np.array(input_label) if np.all([input_point.size == 0, input_label.size == 0]): logging.error('Please select the target you want to extract by click in the image above!') return None,None predictor.set_image(image) # embedding操作 masks, scores, logits = predictor.predict(point_coords=input_point, point_labels=input_label,multimask_output=False,) mask_image,bin_mask=save_color_mask(masks) result=img_compose(mask_image,image) return result,bin_mask def draw_masks(image, masks, alpha=0.35): for mask in masks: color = [np.random.randint(0,255)for _ in range(3)] # draw mask overlay colored_mask = np.expand_dims(mask["segmentation"], 0).repeat(3, axis=0) colored_mask = np.moveaxis(colored_mask, 0, -1) masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color) image_overlay = masked.filled() image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0) # draw contour contours, _ = cv2.findContours(np.uint8(mask["segmentation"]), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(image, contours, -1, (255, 255, 255), 2) return image def auto_seg(image,pred_iou_thresh,points_per_side): mask_generator = SamAutomaticMaskGenerator(model=sam,points_per_side=points_per_side,pred_iou_thresh=pred_iou_thresh,min_mask_region_area=30) masks = mask_generator.generate(image) result=draw_masks(image,masks) return result def clear_point(): return None,[],[],[] def reset_state(): logging.info("Reset") # delete_temp() return None,None,None,None,[],[],[]