File size: 6,501 Bytes
a9fdb4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import glob
import gradio as gr
from PIL import Image
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
import logging
matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio

sam_checkpoint = "sam_vit_h_4b8939.pth"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
logging.basicConfig(filename="app.log", level=logging.INFO)

title=(
        """
        # <p align="center"> Segment-RS  🛰️ <b>
        ## <p align="center"> A remote sensing interactive interpretation tools based on segment-anything (SAM 👍) <b>
        ### <p align="center"> YJC ([email protected])  📧<b>

        """
        )
description =( 
        """
        Segment-RS is an interactive remote sensing interpretation tool that has been developed based on [SAM](https://github.com/facebookresearch/segment-anything). It allows for the real-time extraction of various remote sensing targets through interaction. Segment-RS is equipped with two interpretation models, namely, interactive extraction and automatic extraction.  
        * Interactive extraction involves manually selecting samples (positive and negative) from the image to extract obvious targets. It should be emphasized that this manual interaction method is suitable for extracting an independent target in the scene and not suitable for extracting multiple targets of the same type at once as it is still under development. 
        * Automatic extraction does not require any interaction, one can simply click the "Auto Segment" button to get the segmentation result. Additionally, the accuracy and granularity of segmentation can be adjusted through "Prediction Thresh" and "Points Per Side".
        """
        )
descriptionend=(
        """
        <div align=center><img src="https://em-content.zobj.net/source/microsoft-teams/337/robot_1f916.png" style="width:5%;"></div>
        <br />
        <div align=center>you can follow the WeChat public account [45度科研人] and leave me a message!  </div>
        <br />
        <div style="display:flex; justify-content:center;">
        <img src="https://dunazo.oss-cn-beijing.aliyuncs.com/blog/wechat-simple.png" style="margin-right:25px;width:200px;height:200px;">
        <div style="width:25px;"></div>
        <img src="https://dunazo.oss-cn-beijing.aliyuncs.com/blog/shoukuanma222.png" style="margin-left:25px;width:170px;height:190px;">
        </div>
        """        
    )

def show_image_with_scatter(img, x, y, label):
    # convert to numpy array
    x = np.array(x)
    y = np.array(y)
    label = np.array(label)
    # scatter plot
    mask = label == 0
    color = (0, 0, 255)  # blue
    pts = np.stack((x[mask], y[mask]), axis=-1).astype(int)
    for pt in pts:
        img = cv2.circle(img, tuple(pt), radius=10, color=color, thickness=-1)
    mask = label == 1
    color = (255, 0, 0)  # red
    pts = np.stack((x[mask], y[mask]), axis=-1).astype(int)
    for pt in pts:
        img = cv2.circle(img, tuple(pt), radius=10, color=color, thickness=-1)
    return img, x, y, label

def get_select_coords(img,mode,x,y,label,evt:gr.SelectData):
    x=list(x)
    y=list(y)
    label=list(label)
    x.append(evt.index[0])
    y.append(evt.index[1])
    if mode=='Positive':
        label.append((1))
    if mode=='Negative':
        label.append((0))
    out,x,y,label=show_image_with_scatter(img,x,y,label)
    # print(x,y,label)
    return out,x,y,label

def save_color_mask(masks):
    bin_mask=masks.reshape(masks.shape[1], masks.shape[2])*255
    color = np.array([30, 144, 200,255])
    mask_image = masks.reshape(masks.shape[1], masks.shape[2], 1) * color.reshape(1, 1, -1)
    mask_image = mask_image.astype(np.uint8)
    # pil=Image.fromarray(mask_image)
    # pil.save('result.png', format='PNG', mode='RGBA')
    return mask_image,bin_mask

def img_compose(mask_image,image):
    mask_alpha = np.array(mask_image[:, :, -1]*0.65, dtype=np.uint8) # 提取出 alpha 通道
    mask_rgba = np.dstack((mask_image[:, :, :-1], mask_alpha)) # 将 RGB 和 alpha 合并成 RGBA
    new_a_pil = Image.fromarray(mask_rgba, mode='RGBA')
    b_pil=Image.fromarray(image).convert('RGBA')
    result_pil = Image.alpha_composite(b_pil,new_a_pil)
    # result_pil.save('result.png', format='PNG', mode='RGBA')
    return np.array(result_pil)

def interactive_seg(image,input_pointx,input_pointy,input_label):
    # print(input_pointx,input_pointy,input_label)
    tmp=list(zip(input_pointx,input_pointy))
    input_point = np.array(tmp)
    input_label = np.array(input_label)
    if np.all([input_point.size == 0, input_label.size == 0]):
        logging.error('Please select the target you want to extract by click in the image above!')
        return None,None
    predictor.set_image(image) # embedding操作
    masks, scores, logits = predictor.predict(point_coords=input_point,
        point_labels=input_label,multimask_output=False,)
    mask_image,bin_mask=save_color_mask(masks)
    result=img_compose(mask_image,image)
    return result,bin_mask
def draw_masks(image, masks, alpha=0.35):
    for mask in masks:
        color = [np.random.randint(0,255)for _ in range(3)]
        # draw mask overlay
        colored_mask = np.expand_dims(mask["segmentation"], 0).repeat(3, axis=0)
        colored_mask = np.moveaxis(colored_mask, 0, -1)
        masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
        image_overlay = masked.filled()
        image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
        # draw contour
        contours, _ = cv2.findContours(np.uint8(mask["segmentation"]), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(image, contours, -1, (255, 255, 255), 2)
    return image

def auto_seg(image,pred_iou_thresh,points_per_side):
    mask_generator = SamAutomaticMaskGenerator(model=sam,points_per_side=points_per_side,pred_iou_thresh=pred_iou_thresh,min_mask_region_area=30)
    masks = mask_generator.generate(image)
    result=draw_masks(image,masks)
    return result

def clear_point():
    return None,[],[],[]
    
def reset_state():
    logging.info("Reset")
    # delete_temp()
    return None,None,None,None,[],[],[]