Spaces:
Runtime error
Runtime error
JunchuanYu
commited on
Commit
·
a9fdb4a
1
Parent(s):
54a4759
Upload utils.py
Browse files
utils.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import matplotlib
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
import glob
|
9 |
+
import gradio as gr
|
10 |
+
from PIL import Image
|
11 |
+
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
|
12 |
+
import logging
|
13 |
+
matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
|
14 |
+
|
15 |
+
sam_checkpoint = "sam_vit_h_4b8939.pth"
|
16 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
|
17 |
+
model_type = "vit_h"
|
18 |
+
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
19 |
+
sam.to(device=device)
|
20 |
+
predictor = SamPredictor(sam)
|
21 |
+
logging.basicConfig(filename="app.log", level=logging.INFO)
|
22 |
+
|
23 |
+
title=(
|
24 |
+
"""
|
25 |
+
# <p align="center"> Segment-RS 🛰️ <b>
|
26 |
+
## <p align="center"> A remote sensing interactive interpretation tools based on segment-anything (SAM 👍) <b>
|
27 |
+
### <p align="center"> YJC ([email protected]) 📧<b>
|
28 |
+
|
29 |
+
"""
|
30 |
+
)
|
31 |
+
description =(
|
32 |
+
"""
|
33 |
+
Segment-RS is an interactive remote sensing interpretation tool that has been developed based on [SAM](https://github.com/facebookresearch/segment-anything). It allows for the real-time extraction of various remote sensing targets through interaction. Segment-RS is equipped with two interpretation models, namely, interactive extraction and automatic extraction.
|
34 |
+
* Interactive extraction involves manually selecting samples (positive and negative) from the image to extract obvious targets. It should be emphasized that this manual interaction method is suitable for extracting an independent target in the scene and not suitable for extracting multiple targets of the same type at once as it is still under development.
|
35 |
+
* Automatic extraction does not require any interaction, one can simply click the "Auto Segment" button to get the segmentation result. Additionally, the accuracy and granularity of segmentation can be adjusted through "Prediction Thresh" and "Points Per Side".
|
36 |
+
"""
|
37 |
+
)
|
38 |
+
descriptionend=(
|
39 |
+
"""
|
40 |
+
<div align=center><img src="https://em-content.zobj.net/source/microsoft-teams/337/robot_1f916.png" style="width:5%;"></div>
|
41 |
+
<br />
|
42 |
+
<div align=center>you can follow the WeChat public account [45度科研人] and leave me a message! </div>
|
43 |
+
<br />
|
44 |
+
<div style="display:flex; justify-content:center;">
|
45 |
+
<img src="https://dunazo.oss-cn-beijing.aliyuncs.com/blog/wechat-simple.png" style="margin-right:25px;width:200px;height:200px;">
|
46 |
+
<div style="width:25px;"></div>
|
47 |
+
<img src="https://dunazo.oss-cn-beijing.aliyuncs.com/blog/shoukuanma222.png" style="margin-left:25px;width:170px;height:190px;">
|
48 |
+
</div>
|
49 |
+
"""
|
50 |
+
)
|
51 |
+
|
52 |
+
def show_image_with_scatter(img, x, y, label):
|
53 |
+
# convert to numpy array
|
54 |
+
x = np.array(x)
|
55 |
+
y = np.array(y)
|
56 |
+
label = np.array(label)
|
57 |
+
# scatter plot
|
58 |
+
mask = label == 0
|
59 |
+
color = (0, 0, 255) # blue
|
60 |
+
pts = np.stack((x[mask], y[mask]), axis=-1).astype(int)
|
61 |
+
for pt in pts:
|
62 |
+
img = cv2.circle(img, tuple(pt), radius=10, color=color, thickness=-1)
|
63 |
+
mask = label == 1
|
64 |
+
color = (255, 0, 0) # red
|
65 |
+
pts = np.stack((x[mask], y[mask]), axis=-1).astype(int)
|
66 |
+
for pt in pts:
|
67 |
+
img = cv2.circle(img, tuple(pt), radius=10, color=color, thickness=-1)
|
68 |
+
return img, x, y, label
|
69 |
+
|
70 |
+
def get_select_coords(img,mode,x,y,label,evt:gr.SelectData):
|
71 |
+
x=list(x)
|
72 |
+
y=list(y)
|
73 |
+
label=list(label)
|
74 |
+
x.append(evt.index[0])
|
75 |
+
y.append(evt.index[1])
|
76 |
+
if mode=='Positive':
|
77 |
+
label.append((1))
|
78 |
+
if mode=='Negative':
|
79 |
+
label.append((0))
|
80 |
+
out,x,y,label=show_image_with_scatter(img,x,y,label)
|
81 |
+
# print(x,y,label)
|
82 |
+
return out,x,y,label
|
83 |
+
|
84 |
+
def save_color_mask(masks):
|
85 |
+
bin_mask=masks.reshape(masks.shape[1], masks.shape[2])*255
|
86 |
+
color = np.array([30, 144, 200,255])
|
87 |
+
mask_image = masks.reshape(masks.shape[1], masks.shape[2], 1) * color.reshape(1, 1, -1)
|
88 |
+
mask_image = mask_image.astype(np.uint8)
|
89 |
+
# pil=Image.fromarray(mask_image)
|
90 |
+
# pil.save('result.png', format='PNG', mode='RGBA')
|
91 |
+
return mask_image,bin_mask
|
92 |
+
|
93 |
+
def img_compose(mask_image,image):
|
94 |
+
mask_alpha = np.array(mask_image[:, :, -1]*0.65, dtype=np.uint8) # 提取出 alpha 通道
|
95 |
+
mask_rgba = np.dstack((mask_image[:, :, :-1], mask_alpha)) # 将 RGB 和 alpha 合并成 RGBA
|
96 |
+
new_a_pil = Image.fromarray(mask_rgba, mode='RGBA')
|
97 |
+
b_pil=Image.fromarray(image).convert('RGBA')
|
98 |
+
result_pil = Image.alpha_composite(b_pil,new_a_pil)
|
99 |
+
# result_pil.save('result.png', format='PNG', mode='RGBA')
|
100 |
+
return np.array(result_pil)
|
101 |
+
|
102 |
+
def interactive_seg(image,input_pointx,input_pointy,input_label):
|
103 |
+
# print(input_pointx,input_pointy,input_label)
|
104 |
+
tmp=list(zip(input_pointx,input_pointy))
|
105 |
+
input_point = np.array(tmp)
|
106 |
+
input_label = np.array(input_label)
|
107 |
+
if np.all([input_point.size == 0, input_label.size == 0]):
|
108 |
+
logging.error('Please select the target you want to extract by click in the image above!')
|
109 |
+
return None,None
|
110 |
+
predictor.set_image(image) # embedding操作
|
111 |
+
masks, scores, logits = predictor.predict(point_coords=input_point,
|
112 |
+
point_labels=input_label,multimask_output=False,)
|
113 |
+
mask_image,bin_mask=save_color_mask(masks)
|
114 |
+
result=img_compose(mask_image,image)
|
115 |
+
return result,bin_mask
|
116 |
+
def draw_masks(image, masks, alpha=0.35):
|
117 |
+
for mask in masks:
|
118 |
+
color = [np.random.randint(0,255)for _ in range(3)]
|
119 |
+
# draw mask overlay
|
120 |
+
colored_mask = np.expand_dims(mask["segmentation"], 0).repeat(3, axis=0)
|
121 |
+
colored_mask = np.moveaxis(colored_mask, 0, -1)
|
122 |
+
masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
|
123 |
+
image_overlay = masked.filled()
|
124 |
+
image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
|
125 |
+
# draw contour
|
126 |
+
contours, _ = cv2.findContours(np.uint8(mask["segmentation"]), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
127 |
+
cv2.drawContours(image, contours, -1, (255, 255, 255), 2)
|
128 |
+
return image
|
129 |
+
|
130 |
+
def auto_seg(image,pred_iou_thresh,points_per_side):
|
131 |
+
mask_generator = SamAutomaticMaskGenerator(model=sam,points_per_side=points_per_side,pred_iou_thresh=pred_iou_thresh,min_mask_region_area=30)
|
132 |
+
masks = mask_generator.generate(image)
|
133 |
+
result=draw_masks(image,masks)
|
134 |
+
return result
|
135 |
+
|
136 |
+
def clear_point():
|
137 |
+
return None,[],[],[]
|
138 |
+
|
139 |
+
def reset_state():
|
140 |
+
logging.info("Reset")
|
141 |
+
# delete_temp()
|
142 |
+
return None,None,None,None,[],[],[]
|