JunchuanYu commited on
Commit
2f550e0
·
1 Parent(s): dd41d89

Delete utils.py

Browse files
Files changed (1) hide show
  1. utils.py +0 -142
utils.py DELETED
@@ -1,142 +0,0 @@
1
- import os
2
- import cv2
3
- import matplotlib
4
- import matplotlib.pyplot as plt
5
- import numpy as np
6
- import torch
7
- import torchvision
8
- import glob
9
- import gradio as gr
10
- from PIL import Image
11
- from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
12
- import logging
13
- matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
14
-
15
- sam_checkpoint = "sam_vit_h_4b8939.pth"
16
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
17
- model_type = "vit_h"
18
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
19
- sam.to(device=device)
20
- predictor = SamPredictor(sam)
21
- logging.basicConfig(filename="app.log", level=logging.INFO)
22
-
23
- title=(
24
- """
25
- # <p align="center"> Segment-RS 🛰️ <b>
26
- ## <p align="center"> A remote sensing interactive interpretation tools based on segment-anything (SAM 👍) <b>
27
- ### <p align="center"> YJC ([email protected]) 📧<b>
28
-
29
- """
30
- )
31
- description =(
32
- """
33
- Segment-RS is an interactive remote sensing interpretation tool that has been developed based on [SAM](https://github.com/facebookresearch/segment-anything). It allows for the real-time extraction of various remote sensing targets through interaction. Segment-RS is equipped with two interpretation models, namely, interactive extraction and automatic extraction.
34
- * Interactive extraction involves manually selecting samples (positive and negative) from the image to extract obvious targets. It should be emphasized that this manual interaction method is suitable for extracting an independent target in the scene and not suitable for extracting multiple targets of the same type at once as it is still under development.
35
- * Automatic extraction does not require any interaction, one can simply click the "Auto Segment" button to get the segmentation result. Additionally, the accuracy and granularity of segmentation can be adjusted through "Prediction Thresh" and "Points Per Side".
36
- """
37
- )
38
- descriptionend=(
39
- """
40
- <div align=center><img src="https://em-content.zobj.net/source/microsoft-teams/337/robot_1f916.png" style="width:5%;"></div>
41
- <br />
42
- <div align=center>you can follow the WeChat public account [45度科研人] and leave me a message! </div>
43
- <br />
44
- <div style="display:flex; justify-content:center;">
45
- <img src="https://dunazo.oss-cn-beijing.aliyuncs.com/blog/wechat-simple.png" style="margin-right:25px;width:200px;height:200px;">
46
- <div style="width:25px;"></div>
47
- <img src="https://dunazo.oss-cn-beijing.aliyuncs.com/blog/shoukuanma222.png" style="margin-left:25px;width:170px;height:190px;">
48
- </div>
49
- """
50
- )
51
-
52
- def show_image_with_scatter(img, x, y, label):
53
- # convert to numpy array
54
- x = np.array(x)
55
- y = np.array(y)
56
- label = np.array(label)
57
- # scatter plot
58
- mask = label == 0
59
- color = (0, 0, 255) # blue
60
- pts = np.stack((x[mask], y[mask]), axis=-1).astype(int)
61
- for pt in pts:
62
- img = cv2.circle(img, tuple(pt), radius=10, color=color, thickness=-1)
63
- mask = label == 1
64
- color = (255, 0, 0) # red
65
- pts = np.stack((x[mask], y[mask]), axis=-1).astype(int)
66
- for pt in pts:
67
- img = cv2.circle(img, tuple(pt), radius=10, color=color, thickness=-1)
68
- return img, x, y, label
69
-
70
- def get_select_coords(img,mode,x,y,label,evt:gr.SelectData):
71
- x=list(x)
72
- y=list(y)
73
- label=list(label)
74
- x.append(evt.index[0])
75
- y.append(evt.index[1])
76
- if mode=='Positive':
77
- label.append((1))
78
- if mode=='Negative':
79
- label.append((0))
80
- out,x,y,label=show_image_with_scatter(img,x,y,label)
81
- # print(x,y,label)
82
- return out,x,y,label
83
-
84
- def save_color_mask(masks):
85
- bin_mask=masks.reshape(masks.shape[1], masks.shape[2])*255
86
- color = np.array([30, 144, 200,255])
87
- mask_image = masks.reshape(masks.shape[1], masks.shape[2], 1) * color.reshape(1, 1, -1)
88
- mask_image = mask_image.astype(np.uint8)
89
- # pil=Image.fromarray(mask_image)
90
- # pil.save('result.png', format='PNG', mode='RGBA')
91
- return mask_image,bin_mask
92
-
93
- def img_compose(mask_image,image):
94
- mask_alpha = np.array(mask_image[:, :, -1]*0.65, dtype=np.uint8) # 提取出 alpha 通道
95
- mask_rgba = np.dstack((mask_image[:, :, :-1], mask_alpha)) # 将 RGB 和 alpha 合并成 RGBA
96
- new_a_pil = Image.fromarray(mask_rgba, mode='RGBA')
97
- b_pil=Image.fromarray(image).convert('RGBA')
98
- result_pil = Image.alpha_composite(b_pil,new_a_pil)
99
- # result_pil.save('result.png', format='PNG', mode='RGBA')
100
- return np.array(result_pil)
101
-
102
- def interactive_seg(image,input_pointx,input_pointy,input_label):
103
- # print(input_pointx,input_pointy,input_label)
104
- tmp=list(zip(input_pointx,input_pointy))
105
- input_point = np.array(tmp)
106
- input_label = np.array(input_label)
107
- if np.all([input_point.size == 0, input_label.size == 0]):
108
- logging.error('Please select the target you want to extract by click in the image above!')
109
- return None,None
110
- predictor.set_image(image) # embedding操作
111
- masks, scores, logits = predictor.predict(point_coords=input_point,
112
- point_labels=input_label,multimask_output=False,)
113
- mask_image,bin_mask=save_color_mask(masks)
114
- result=img_compose(mask_image,image)
115
- return result,bin_mask
116
- def draw_masks(image, masks, alpha=0.35):
117
- for mask in masks:
118
- color = [np.random.randint(0,255)for _ in range(3)]
119
- # draw mask overlay
120
- colored_mask = np.expand_dims(mask["segmentation"], 0).repeat(3, axis=0)
121
- colored_mask = np.moveaxis(colored_mask, 0, -1)
122
- masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
123
- image_overlay = masked.filled()
124
- image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
125
- # draw contour
126
- contours, _ = cv2.findContours(np.uint8(mask["segmentation"]), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
127
- cv2.drawContours(image, contours, -1, (255, 255, 255), 2)
128
- return image
129
-
130
- def auto_seg(image,pred_iou_thresh,points_per_side):
131
- mask_generator = SamAutomaticMaskGenerator(model=sam,points_per_side=points_per_side,pred_iou_thresh=pred_iou_thresh,min_mask_region_area=30)
132
- masks = mask_generator.generate(image)
133
- result=draw_masks(image,masks)
134
- return result
135
-
136
- def clear_point():
137
- return None,[],[],[]
138
-
139
- def reset_state():
140
- logging.info("Reset")
141
- # delete_temp()
142
- return None,None,None,None,[],[],[]