aikenml commited on
Commit
31be3d8
·
1 Parent(s): e021a24

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. SegTracker.py +264 -0
  2. aot_tracker.py +186 -0
  3. app.py +782 -0
  4. img2vid.py +26 -0
  5. model_args.py +28 -0
  6. seg_track_anything.py +300 -0
SegTracker.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("..")
3
+ sys.path.append("./sam")
4
+ from sam.segment_anything import sam_model_registry, SamAutomaticMaskGenerator
5
+ from aot_tracker import get_aot
6
+ import numpy as np
7
+ from tool.segmentor import Segmentor
8
+ from tool.detector import Detector
9
+ from tool.transfer_tools import draw_outline, draw_points
10
+ import cv2
11
+ from seg_track_anything import draw_mask
12
+
13
+
14
+ class SegTracker():
15
+ def __init__(self,segtracker_args, sam_args, aot_args) -> None:
16
+ """
17
+ Initialize SAM and AOT.
18
+ """
19
+ self.sam = Segmentor(sam_args)
20
+ self.tracker = get_aot(aot_args)
21
+ self.detector = Detector(self.sam.device)
22
+ self.sam_gap = segtracker_args['sam_gap']
23
+ self.min_area = segtracker_args['min_area']
24
+ self.max_obj_num = segtracker_args['max_obj_num']
25
+ self.min_new_obj_iou = segtracker_args['min_new_obj_iou']
26
+ self.reference_objs_list = []
27
+ self.object_idx = 1
28
+ self.curr_idx = 1
29
+ self.origin_merged_mask = None # init by segment-everything or update
30
+ self.first_frame_mask = None
31
+
32
+ # debug
33
+ self.everything_points = []
34
+ self.everything_labels = []
35
+ print("SegTracker has been initialized")
36
+
37
+ def seg(self,frame):
38
+ '''
39
+ Arguments:
40
+ frame: numpy array (h,w,3)
41
+ Return:
42
+ origin_merged_mask: numpy array (h,w)
43
+ '''
44
+ frame = frame[:, :, ::-1]
45
+ anns = self.sam.everything_generator.generate(frame)
46
+
47
+ # anns is a list recording all predictions in an image
48
+ if len(anns) == 0:
49
+ return
50
+ # merge all predictions into one mask (h,w)
51
+ # note that the merged mask may lost some objects due to the overlapping
52
+ self.origin_merged_mask = np.zeros(anns[0]['segmentation'].shape,dtype=np.uint8)
53
+ idx = 1
54
+ for ann in anns:
55
+ if ann['area'] > self.min_area:
56
+ m = ann['segmentation']
57
+ self.origin_merged_mask[m==1] = idx
58
+ idx += 1
59
+ self.everything_points.append(ann["point_coords"][0])
60
+ self.everything_labels.append(1)
61
+
62
+ obj_ids = np.unique(self.origin_merged_mask)
63
+ obj_ids = obj_ids[obj_ids!=0]
64
+
65
+ self.object_idx = 1
66
+ for id in obj_ids:
67
+ if np.sum(self.origin_merged_mask==id) < self.min_area or self.object_idx > self.max_obj_num:
68
+ self.origin_merged_mask[self.origin_merged_mask==id] = 0
69
+ else:
70
+ self.origin_merged_mask[self.origin_merged_mask==id] = self.object_idx
71
+ self.object_idx += 1
72
+
73
+ self.first_frame_mask = self.origin_merged_mask
74
+ return self.origin_merged_mask
75
+
76
+ def update_origin_merged_mask(self, updated_merged_mask):
77
+ self.origin_merged_mask = updated_merged_mask
78
+ # obj_ids = np.unique(updated_merged_mask)
79
+ # obj_ids = obj_ids[obj_ids!=0]
80
+ # self.object_idx = int(max(obj_ids)) + 1
81
+
82
+ def reset_origin_merged_mask(self, mask, id):
83
+ self.origin_merged_mask = mask
84
+ self.curr_idx = id
85
+
86
+ def add_reference(self,frame,mask,frame_step=0):
87
+ '''
88
+ Add objects in a mask for tracking.
89
+ Arguments:
90
+ frame: numpy array (h,w,3)
91
+ mask: numpy array (h,w)
92
+ '''
93
+ self.reference_objs_list.append(np.unique(mask))
94
+ self.curr_idx = self.get_obj_num() + 1
95
+ self.tracker.add_reference_frame(frame,mask, self.curr_idx - 1, frame_step)
96
+
97
+ def track(self,frame,update_memory=False):
98
+ '''
99
+ Track all known objects.
100
+ Arguments:
101
+ frame: numpy array (h,w,3)
102
+ Return:
103
+ origin_merged_mask: numpy array (h,w)
104
+ '''
105
+ pred_mask = self.tracker.track(frame)
106
+ if update_memory:
107
+ self.tracker.update_memory(pred_mask)
108
+ return pred_mask.squeeze(0).squeeze(0).detach().cpu().numpy().astype(np.uint8)
109
+
110
+ def get_tracking_objs(self):
111
+ objs = set()
112
+ for ref in self.reference_objs_list:
113
+ objs.update(set(ref))
114
+ objs = list(sorted(list(objs)))
115
+ objs = [i for i in objs if i!=0]
116
+ return objs
117
+
118
+ def get_obj_num(self):
119
+ objs = self.get_tracking_objs()
120
+ if len(objs) == 0: return 0
121
+ return int(max(objs))
122
+
123
+ def find_new_objs(self, track_mask, seg_mask):
124
+ '''
125
+ Compare tracked results from AOT with segmented results from SAM. Select objects from background if they are not tracked.
126
+ Arguments:
127
+ track_mask: numpy array (h,w)
128
+ seg_mask: numpy array (h,w)
129
+ Return:
130
+ new_obj_mask: numpy array (h,w)
131
+ '''
132
+ new_obj_mask = (track_mask==0) * seg_mask
133
+ new_obj_ids = np.unique(new_obj_mask)
134
+ new_obj_ids = new_obj_ids[new_obj_ids!=0]
135
+ # obj_num = self.get_obj_num() + 1
136
+ obj_num = self.curr_idx
137
+ for idx in new_obj_ids:
138
+ new_obj_area = np.sum(new_obj_mask==idx)
139
+ obj_area = np.sum(seg_mask==idx)
140
+ if new_obj_area/obj_area < self.min_new_obj_iou or new_obj_area < self.min_area\
141
+ or obj_num > self.max_obj_num:
142
+ new_obj_mask[new_obj_mask==idx] = 0
143
+ else:
144
+ new_obj_mask[new_obj_mask==idx] = obj_num
145
+ obj_num += 1
146
+ return new_obj_mask
147
+
148
+ def restart_tracker(self):
149
+ self.tracker.restart()
150
+
151
+ def seg_acc_bbox(self, origin_frame: np.ndarray, bbox: np.ndarray,):
152
+ ''''
153
+ Use bbox-prompt to get mask
154
+ Parameters:
155
+ origin_frame: H, W, C
156
+ bbox: [[x0, y0], [x1, y1]]
157
+ Return:
158
+ refined_merged_mask: numpy array (h, w)
159
+ masked_frame: numpy array (h, w, c)
160
+ '''
161
+ # get interactive_mask
162
+ interactive_mask = self.sam.segment_with_box(origin_frame, bbox)[0]
163
+ refined_merged_mask = self.add_mask(interactive_mask)
164
+
165
+ # draw mask
166
+ masked_frame = draw_mask(origin_frame.copy(), refined_merged_mask)
167
+
168
+ # draw bbox
169
+ masked_frame = cv2.rectangle(masked_frame, bbox[0], bbox[1], (0, 0, 255))
170
+
171
+ return refined_merged_mask, masked_frame
172
+
173
+ def seg_acc_click(self, origin_frame: np.ndarray, coords: np.ndarray, modes: np.ndarray, multimask=True):
174
+ '''
175
+ Use point-prompt to get mask
176
+ Parameters:
177
+ origin_frame: H, W, C
178
+ coords: nd.array [[x, y]]
179
+ modes: nd.array [[1]]
180
+ Return:
181
+ refined_merged_mask: numpy array (h, w)
182
+ masked_frame: numpy array (h, w, c)
183
+ '''
184
+ # get interactive_mask
185
+ interactive_mask = self.sam.segment_with_click(origin_frame, coords, modes, multimask)
186
+
187
+ refined_merged_mask = self.add_mask(interactive_mask)
188
+
189
+ # draw mask
190
+ masked_frame = draw_mask(origin_frame.copy(), refined_merged_mask)
191
+
192
+ # draw points
193
+ # self.everything_labels = np.array(self.everything_labels).astype(np.int64)
194
+ # self.everything_points = np.array(self.everything_points).astype(np.int64)
195
+
196
+ masked_frame = draw_points(coords, modes, masked_frame)
197
+
198
+ # draw outline
199
+ masked_frame = draw_outline(interactive_mask, masked_frame)
200
+
201
+ return refined_merged_mask, masked_frame
202
+
203
+ def add_mask(self, interactive_mask: np.ndarray):
204
+ '''
205
+ Merge interactive mask with self.origin_merged_mask
206
+ Parameters:
207
+ interactive_mask: numpy array (h, w)
208
+ Return:
209
+ refined_merged_mask: numpy array (h, w)
210
+ '''
211
+ if self.origin_merged_mask is None:
212
+ self.origin_merged_mask = np.zeros(interactive_mask.shape,dtype=np.uint8)
213
+
214
+ refined_merged_mask = self.origin_merged_mask.copy()
215
+ refined_merged_mask[interactive_mask > 0] = self.curr_idx
216
+
217
+ return refined_merged_mask
218
+
219
+ def detect_and_seg(self, origin_frame: np.ndarray, grounding_caption, box_threshold, text_threshold, box_size_threshold=1, reset_image=False):
220
+ '''
221
+ Using Grounding-DINO to detect object acc Text-prompts
222
+ Retrun:
223
+ refined_merged_mask: numpy array (h, w)
224
+ annotated_frame: numpy array (h, w, 3)
225
+ '''
226
+ # backup id and origin-merged-mask
227
+ bc_id = self.curr_idx
228
+ bc_mask = self.origin_merged_mask
229
+
230
+ # get annotated_frame and boxes
231
+ annotated_frame, boxes = self.detector.run_grounding(origin_frame, grounding_caption, box_threshold, text_threshold)
232
+ for i in range(len(boxes)):
233
+ bbox = boxes[i]
234
+ if (bbox[1][0] - bbox[0][0]) * (bbox[1][1] - bbox[0][1]) > annotated_frame.shape[0] * annotated_frame.shape[1] * box_size_threshold:
235
+ continue
236
+ interactive_mask = self.sam.segment_with_box(origin_frame, bbox, reset_image)[0]
237
+ refined_merged_mask = self.add_mask(interactive_mask)
238
+ self.update_origin_merged_mask(refined_merged_mask)
239
+ self.curr_idx += 1
240
+
241
+ # reset origin_mask
242
+ self.reset_origin_merged_mask(bc_mask, bc_id)
243
+
244
+ return refined_merged_mask, annotated_frame
245
+
246
+ if __name__ == '__main__':
247
+ from model_args import segtracker_args,sam_args,aot_args
248
+
249
+ Seg_Tracker = SegTracker(segtracker_args, sam_args, aot_args)
250
+
251
+ # ------------------ detect test ----------------------
252
+
253
+ origin_frame = cv2.imread('/data2/cym/Seg_Tra_any/Segment-and-Track-Anything/debug/point.png')
254
+ origin_frame = cv2.cvtColor(origin_frame, cv2.COLOR_BGR2RGB)
255
+ grounding_caption = "swan.water"
256
+ box_threshold = 0.25
257
+ text_threshold = 0.25
258
+
259
+ predicted_mask, annotated_frame = Seg_Tracker.detect_and_seg(origin_frame, grounding_caption, box_threshold, text_threshold)
260
+ masked_frame = draw_mask(annotated_frame, predicted_mask)
261
+ origin_frame = cv2.cvtColor(origin_frame, cv2.COLOR_RGB2BGR)
262
+
263
+ cv2.imwrite('./debug/masked_frame.png', masked_frame)
264
+ cv2.imwrite('./debug/x.png', annotated_frame)
aot_tracker.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from statistics import mode
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import os
5
+ import sys
6
+ sys.path.append("./aot")
7
+ from aot.networks.engines.aot_engine import AOTEngine,AOTInferEngine
8
+ from aot.networks.engines.deaot_engine import DeAOTEngine,DeAOTInferEngine
9
+ import importlib
10
+ import numpy as np
11
+ from PIL import Image
12
+ from skimage.morphology.binary import binary_dilation
13
+
14
+
15
+ np.random.seed(200)
16
+ _palette = ((np.random.random((3*255))*0.7+0.3)*255).astype(np.uint8).tolist()
17
+ _palette = [0,0,0]+_palette
18
+
19
+ import aot.dataloaders.video_transforms as tr
20
+ from aot.utils.checkpoint import load_network
21
+ from aot.networks.models import build_vos_model
22
+ from aot.networks.engines import build_engine
23
+ from torchvision import transforms
24
+
25
+ class AOTTracker(object):
26
+ def __init__(self, cfg, gpu_id=0):
27
+ self.gpu_id = gpu_id
28
+ self.model = build_vos_model(cfg.MODEL_VOS, cfg).cuda(gpu_id)
29
+ self.model, _ = load_network(self.model, cfg.TEST_CKPT_PATH, gpu_id)
30
+ # self.engine = self.build_tracker_engine(cfg.MODEL_ENGINE,
31
+ # aot_model=self.model,
32
+ # gpu_id=gpu_id,
33
+ # short_term_mem_skip=4,
34
+ # long_term_mem_gap=cfg.TEST_LONG_TERM_MEM_GAP)
35
+ self.engine = build_engine(cfg.MODEL_ENGINE,
36
+ phase='eval',
37
+ aot_model=self.model,
38
+ gpu_id=gpu_id,
39
+ short_term_mem_skip=1,
40
+ long_term_mem_gap=cfg.TEST_LONG_TERM_MEM_GAP,
41
+ max_len_long_term=cfg.MAX_LEN_LONG_TERM)
42
+
43
+ self.transform = transforms.Compose([
44
+ tr.MultiRestrictSize(cfg.TEST_MAX_SHORT_EDGE,
45
+ cfg.TEST_MAX_LONG_EDGE, cfg.TEST_FLIP,
46
+ cfg.TEST_MULTISCALE, cfg.MODEL_ALIGN_CORNERS),
47
+ tr.MultiToTensor()
48
+ ])
49
+
50
+ self.model.eval()
51
+
52
+ @torch.no_grad()
53
+ def add_reference_frame(self, frame, mask, obj_nums, frame_step, incremental=False):
54
+ # mask = cv2.resize(mask, frame.shape[:2][::-1], interpolation = cv2.INTER_NEAREST)
55
+
56
+ sample = {
57
+ 'current_img': frame,
58
+ 'current_label': mask,
59
+ }
60
+
61
+ sample = self.transform(sample)
62
+ frame = sample[0]['current_img'].unsqueeze(0).float().cuda(self.gpu_id)
63
+ mask = sample[0]['current_label'].unsqueeze(0).float().cuda(self.gpu_id)
64
+ _mask = F.interpolate(mask,size=frame.shape[-2:],mode='nearest')
65
+
66
+ if incremental:
67
+ self.engine.add_reference_frame_incremental(frame, _mask, obj_nums=obj_nums, frame_step=frame_step)
68
+ else:
69
+ self.engine.add_reference_frame(frame, _mask, obj_nums=obj_nums, frame_step=frame_step)
70
+
71
+
72
+
73
+ @torch.no_grad()
74
+ def track(self, image):
75
+ output_height, output_width = image.shape[0], image.shape[1]
76
+ sample = {'current_img': image}
77
+ sample = self.transform(sample)
78
+ image = sample[0]['current_img'].unsqueeze(0).float().cuda(self.gpu_id)
79
+ self.engine.match_propogate_one_frame(image)
80
+ pred_logit = self.engine.decode_current_logits((output_height, output_width))
81
+
82
+ # pred_prob = torch.softmax(pred_logit, dim=1)
83
+ pred_label = torch.argmax(pred_logit, dim=1,
84
+ keepdim=True).float()
85
+
86
+ return pred_label
87
+
88
+ @torch.no_grad()
89
+ def update_memory(self, pred_label):
90
+ self.engine.update_memory(pred_label)
91
+
92
+ @torch.no_grad()
93
+ def restart(self):
94
+ self.engine.restart_engine()
95
+
96
+ @torch.no_grad()
97
+ def build_tracker_engine(self, name, **kwargs):
98
+ if name == 'aotengine':
99
+ return AOTTrackerInferEngine(**kwargs)
100
+ elif name == 'deaotengine':
101
+ return DeAOTTrackerInferEngine(**kwargs)
102
+ else:
103
+ raise NotImplementedError
104
+
105
+
106
+ class AOTTrackerInferEngine(AOTInferEngine):
107
+ def __init__(self, aot_model, gpu_id=0, long_term_mem_gap=9999, short_term_mem_skip=1, max_aot_obj_num=None):
108
+ super().__init__(aot_model, gpu_id, long_term_mem_gap, short_term_mem_skip, max_aot_obj_num)
109
+ def add_reference_frame_incremental(self, img, mask, obj_nums, frame_step=-1):
110
+ if isinstance(obj_nums, list):
111
+ obj_nums = obj_nums[0]
112
+ self.obj_nums = obj_nums
113
+ aot_num = max(np.ceil(obj_nums / self.max_aot_obj_num), 1)
114
+ while (aot_num > len(self.aot_engines)):
115
+ new_engine = AOTEngine(self.AOT, self.gpu_id,
116
+ self.long_term_mem_gap,
117
+ self.short_term_mem_skip)
118
+ new_engine.eval()
119
+ self.aot_engines.append(new_engine)
120
+
121
+ separated_masks, separated_obj_nums = self.separate_mask(
122
+ mask, obj_nums)
123
+ img_embs = None
124
+ for aot_engine, separated_mask, separated_obj_num in zip(
125
+ self.aot_engines, separated_masks, separated_obj_nums):
126
+ if aot_engine.obj_nums is None or aot_engine.obj_nums[0] < separated_obj_num:
127
+ aot_engine.add_reference_frame(img,
128
+ separated_mask,
129
+ obj_nums=[separated_obj_num],
130
+ frame_step=frame_step,
131
+ img_embs=img_embs)
132
+ else:
133
+ aot_engine.update_short_term_memory(separated_mask)
134
+
135
+ if img_embs is None: # reuse image embeddings
136
+ img_embs = aot_engine.curr_enc_embs
137
+
138
+ self.update_size()
139
+
140
+
141
+
142
+ class DeAOTTrackerInferEngine(DeAOTInferEngine):
143
+ def __init__(self, aot_model, gpu_id=0, long_term_mem_gap=9999, short_term_mem_skip=1, max_aot_obj_num=None):
144
+ super().__init__(aot_model, gpu_id, long_term_mem_gap, short_term_mem_skip, max_aot_obj_num)
145
+ def add_reference_frame_incremental(self, img, mask, obj_nums, frame_step=-1):
146
+ if isinstance(obj_nums, list):
147
+ obj_nums = obj_nums[0]
148
+ self.obj_nums = obj_nums
149
+ aot_num = max(np.ceil(obj_nums / self.max_aot_obj_num), 1)
150
+ while (aot_num > len(self.aot_engines)):
151
+ new_engine = DeAOTEngine(self.AOT, self.gpu_id,
152
+ self.long_term_mem_gap,
153
+ self.short_term_mem_skip)
154
+ new_engine.eval()
155
+ self.aot_engines.append(new_engine)
156
+
157
+ separated_masks, separated_obj_nums = self.separate_mask(
158
+ mask, obj_nums)
159
+ img_embs = None
160
+ for aot_engine, separated_mask, separated_obj_num in zip(
161
+ self.aot_engines, separated_masks, separated_obj_nums):
162
+ if aot_engine.obj_nums is None or aot_engine.obj_nums[0] < separated_obj_num:
163
+ aot_engine.add_reference_frame(img,
164
+ separated_mask,
165
+ obj_nums=[separated_obj_num],
166
+ frame_step=frame_step,
167
+ img_embs=img_embs)
168
+ else:
169
+ aot_engine.update_short_term_memory(separated_mask)
170
+
171
+ if img_embs is None: # reuse image embeddings
172
+ img_embs = aot_engine.curr_enc_embs
173
+
174
+ self.update_size()
175
+
176
+
177
+ def get_aot(args):
178
+ # build vos engine
179
+ engine_config = importlib.import_module('configs.' + 'pre_ytb_dav')
180
+ cfg = engine_config.EngineConfig(args['phase'], args['model'])
181
+ cfg.TEST_CKPT_PATH = args['model_path']
182
+ cfg.TEST_LONG_TERM_MEM_GAP = args['long_term_mem_gap']
183
+ cfg.MAX_LEN_LONG_TERM = args['max_len_long_term']
184
+ # init AOTTracker
185
+ tracker = AOTTracker(cfg, args['gpu_id'])
186
+ return tracker
app.py ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL.ImageOps import colorize, scale
2
+ import gradio as gr
3
+ import importlib
4
+ import sys
5
+ import os
6
+
7
+ from matplotlib.pyplot import step
8
+
9
+ from model_args import segtracker_args,sam_args,aot_args
10
+ from SegTracker import SegTracker
11
+
12
+ # sys.path.append('.')
13
+ # sys.path.append('..')
14
+
15
+ import cv2
16
+ from PIL import Image
17
+ from skimage.morphology.binary import binary_dilation
18
+ import argparse
19
+ import torch
20
+ import time
21
+ from seg_track_anything import aot_model2ckpt, tracking_objects_in_video, draw_mask
22
+ import gc
23
+ import numpy as np
24
+ import json
25
+ from tool.transfer_tools import mask2bbox
26
+
27
+ def clean():
28
+ return None, None, None, None, None, None, [[], []]
29
+
30
+ def get_click_prompt(click_stack, point):
31
+
32
+ click_stack[0].append(point["coord"])
33
+ click_stack[1].append(point["mode"]
34
+ )
35
+
36
+ prompt = {
37
+ "points_coord":click_stack[0],
38
+ "points_mode":click_stack[1],
39
+ "multimask":"True",
40
+ }
41
+
42
+ return prompt
43
+
44
+ def get_meta_from_video(input_video):
45
+ if input_video is None:
46
+ return None, None, None, ""
47
+
48
+ print("get meta information of input video")
49
+ cap = cv2.VideoCapture(input_video)
50
+
51
+ _, first_frame = cap.read()
52
+ cap.release()
53
+
54
+ first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
55
+
56
+ return first_frame, first_frame, first_frame, ""
57
+
58
+ def get_meta_from_img_seq(input_img_seq):
59
+ if input_img_seq is None:
60
+ return None, None, None, ""
61
+
62
+ print("get meta information of img seq")
63
+ # Create dir
64
+ file_name = input_img_seq.name.split('/')[-1].split('.')[0]
65
+ file_path = f'./assets/{file_name}'
66
+ if os.path.isdir(file_path):
67
+ os.system(f'rm -r {file_path}')
68
+ os.makedirs(file_path)
69
+ # Unzip file
70
+ os.system(f'unzip {input_img_seq.name} -d ./assets ')
71
+
72
+ imgs_path = sorted([os.path.join(file_path, img_name) for img_name in os.listdir(file_path)])
73
+ first_frame = imgs_path[0]
74
+ first_frame = cv2.imread(first_frame)
75
+ first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
76
+
77
+ return first_frame, first_frame, first_frame
78
+
79
+ def SegTracker_add_first_frame(Seg_Tracker, origin_frame, predicted_mask):
80
+ with torch.cuda.amp.autocast():
81
+ # Reset the first frame's mask
82
+ frame_idx = 0
83
+ Seg_Tracker.restart_tracker()
84
+ Seg_Tracker.add_reference(origin_frame, predicted_mask, frame_idx)
85
+ Seg_Tracker.first_frame_mask = predicted_mask
86
+
87
+ return Seg_Tracker
88
+
89
+ def init_SegTracker(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, origin_frame):
90
+
91
+ if origin_frame is None:
92
+ return None, origin_frame, [[], []], ""
93
+
94
+ # reset aot args
95
+ aot_args["model"] = aot_model
96
+ aot_args["model_path"] = aot_model2ckpt[aot_model]
97
+ aot_args["long_term_mem_gap"] = long_term_mem
98
+ aot_args["max_len_long_term"] = max_len_long_term
99
+ # reset sam args
100
+ segtracker_args["sam_gap"] = sam_gap
101
+ segtracker_args["max_obj_num"] = max_obj_num
102
+ sam_args["generator_args"]["points_per_side"] = points_per_side
103
+
104
+ Seg_Tracker = SegTracker(segtracker_args, sam_args, aot_args)
105
+ Seg_Tracker.restart_tracker()
106
+
107
+ return Seg_Tracker, origin_frame, [[], []], ""
108
+
109
+ def init_SegTracker_Stroke(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, origin_frame):
110
+
111
+ if origin_frame is None:
112
+ return None, origin_frame, [[], []], origin_frame
113
+
114
+ # reset aot args
115
+ aot_args["model"] = aot_model
116
+ aot_args["model_path"] = aot_model2ckpt[aot_model]
117
+ aot_args["long_term_mem_gap"] = long_term_mem
118
+ aot_args["max_len_long_term"] = max_len_long_term
119
+
120
+ # reset sam args
121
+ segtracker_args["sam_gap"] = sam_gap
122
+ segtracker_args["max_obj_num"] = max_obj_num
123
+ sam_args["generator_args"]["points_per_side"] = points_per_side
124
+
125
+ Seg_Tracker = SegTracker(segtracker_args, sam_args, aot_args)
126
+ Seg_Tracker.restart_tracker()
127
+ return Seg_Tracker, origin_frame, [[], []], origin_frame
128
+
129
+ def undo_click_stack_and_refine_seg(Seg_Tracker, origin_frame, click_stack, aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side):
130
+
131
+ if Seg_Tracker is None:
132
+ return Seg_Tracker, origin_frame, [[], []]
133
+
134
+ print("Undo!")
135
+ if len(click_stack[0]) > 0:
136
+ click_stack[0] = click_stack[0][: -1]
137
+ click_stack[1] = click_stack[1][: -1]
138
+
139
+ if len(click_stack[0]) > 0:
140
+ prompt = {
141
+ "points_coord":click_stack[0],
142
+ "points_mode":click_stack[1],
143
+ "multimask":"True",
144
+ }
145
+
146
+ masked_frame = seg_acc_click(Seg_Tracker, prompt, origin_frame)
147
+ return Seg_Tracker, masked_frame, click_stack
148
+ else:
149
+ return Seg_Tracker, origin_frame, [[], []]
150
+
151
+
152
+ def seg_acc_click(Seg_Tracker, prompt, origin_frame):
153
+ # seg acc to click
154
+ predicted_mask, masked_frame = Seg_Tracker.seg_acc_click(
155
+ origin_frame=origin_frame,
156
+ coords=np.array(prompt["points_coord"]),
157
+ modes=np.array(prompt["points_mode"]),
158
+ multimask=prompt["multimask"],
159
+ )
160
+
161
+ Seg_Tracker = SegTracker_add_first_frame(Seg_Tracker, origin_frame, predicted_mask)
162
+
163
+ return masked_frame
164
+
165
+ def sam_click(Seg_Tracker, origin_frame, point_mode, click_stack, aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, evt:gr.SelectData):
166
+ """
167
+ Args:
168
+ origin_frame: nd.array
169
+ click_stack: [[coordinate], [point_mode]]
170
+ """
171
+
172
+ print("Click")
173
+
174
+ if point_mode == "Positive":
175
+ point = {"coord": [evt.index[0], evt.index[1]], "mode": 1}
176
+ else:
177
+ # TODO:add everything positive points
178
+ point = {"coord": [evt.index[0], evt.index[1]], "mode": 0}
179
+
180
+ if Seg_Tracker is None:
181
+ Seg_Tracker, _, _, _ = init_SegTracker(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, origin_frame)
182
+
183
+ # get click prompts for sam to predict mask
184
+ click_prompt = get_click_prompt(click_stack, point)
185
+
186
+ # Refine acc to prompt
187
+ masked_frame = seg_acc_click(Seg_Tracker, click_prompt, origin_frame)
188
+
189
+ return Seg_Tracker, masked_frame, click_stack
190
+
191
+ def sam_stroke(Seg_Tracker, origin_frame, drawing_board, aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side):
192
+
193
+ if Seg_Tracker is None:
194
+ Seg_Tracker, _ , _, _ = init_SegTracker(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, origin_frame)
195
+
196
+ print("Stroke")
197
+ mask = drawing_board["mask"]
198
+ bbox = mask2bbox(mask[:, :, 0]) # bbox: [[x0, y0], [x1, y1]]
199
+ predicted_mask, masked_frame = Seg_Tracker.seg_acc_bbox(origin_frame, bbox)
200
+
201
+ Seg_Tracker = SegTracker_add_first_frame(Seg_Tracker, origin_frame, predicted_mask)
202
+
203
+ return Seg_Tracker, masked_frame, origin_frame
204
+
205
+ def gd_detect(Seg_Tracker, origin_frame, grounding_caption, box_threshold, text_threshold, aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side):
206
+ if Seg_Tracker is None:
207
+ Seg_Tracker, _ , _, _ = init_SegTracker(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, origin_frame)
208
+
209
+ print("Detect")
210
+ predicted_mask, annotated_frame= Seg_Tracker.detect_and_seg(origin_frame, grounding_caption, box_threshold, text_threshold)
211
+
212
+ Seg_Tracker = SegTracker_add_first_frame(Seg_Tracker, origin_frame, predicted_mask)
213
+
214
+
215
+ masked_frame = draw_mask(annotated_frame, predicted_mask)
216
+
217
+ return Seg_Tracker, masked_frame, origin_frame
218
+
219
+ def segment_everything(Seg_Tracker, aot_model, long_term_mem, max_len_long_term, origin_frame, sam_gap, max_obj_num, points_per_side):
220
+
221
+ if Seg_Tracker is None:
222
+ Seg_Tracker, _ , _, _ = init_SegTracker(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, origin_frame)
223
+
224
+ print("Everything")
225
+
226
+ frame_idx = 0
227
+
228
+ with torch.cuda.amp.autocast():
229
+ pred_mask = Seg_Tracker.seg(origin_frame)
230
+ torch.cuda.empty_cache()
231
+ gc.collect()
232
+ Seg_Tracker.add_reference(origin_frame, pred_mask, frame_idx)
233
+ Seg_Tracker.first_frame_mask = pred_mask
234
+
235
+ masked_frame = draw_mask(origin_frame.copy(), pred_mask)
236
+
237
+ return Seg_Tracker, masked_frame
238
+
239
+ def add_new_object(Seg_Tracker):
240
+
241
+ prev_mask = Seg_Tracker.first_frame_mask
242
+ Seg_Tracker.update_origin_merged_mask(prev_mask)
243
+ Seg_Tracker.curr_idx += 1
244
+
245
+ print("Ready to add new object!")
246
+
247
+ return Seg_Tracker, [[], []]
248
+
249
+ def tracking_objects(Seg_Tracker, input_video, input_img_seq, fps):
250
+ print("Start tracking !")
251
+ return tracking_objects_in_video(Seg_Tracker, input_video, input_img_seq, fps)
252
+
253
+ def seg_track_app():
254
+
255
+ ##########################################################
256
+ ###################### Front-end ########################
257
+ ##########################################################
258
+ app = gr.Blocks()
259
+
260
+ with app:
261
+ gr.Markdown(
262
+ '''
263
+ <div style="text-align:center;">
264
+ <span style="font-size:3em; font-weight:bold;">Segment and Track Anything(SAM-Track)</span>
265
+ </div>
266
+ '''
267
+ )
268
+
269
+ click_stack = gr.State([[],[]]) # Storage clicks status
270
+ origin_frame = gr.State(None)
271
+ Seg_Tracker = gr.State(None)
272
+
273
+ aot_model = gr.State(None)
274
+ sam_gap = gr.State(None)
275
+ points_per_side = gr.State(None)
276
+ max_obj_num = gr.State(None)
277
+
278
+ with gr.Row():
279
+ # video input
280
+ with gr.Column(scale=0.5):
281
+
282
+ tab_video_input = gr.Tab(label="Video type input")
283
+ with tab_video_input:
284
+ input_video = gr.Video(label='Input video').style(height=550)
285
+
286
+ tab_img_seq_input = gr.Tab(label="Image-Seq type input")
287
+ with tab_img_seq_input:
288
+ with gr.Row():
289
+ input_img_seq = gr.File(label='Input Image-Seq').style(height=550)
290
+ with gr.Column(scale=0.25):
291
+ extract_button = gr.Button(value="extract")
292
+ fps = gr.Slider(label='fps', minimum=5, maximum=50, value=8, step=1)
293
+
294
+ input_first_frame = gr.Image(label='Segment result of first frame',interactive=True).style(height=550)
295
+
296
+
297
+ tab_everything = gr.Tab(label="Everything")
298
+ with tab_everything:
299
+ with gr.Row():
300
+ seg_every_first_frame = gr.Button(value="Segment everything for first frame", interactive=True)
301
+ point_mode = gr.Radio(
302
+ choices=["Positive"],
303
+ value="Positive",
304
+ label="Point Prompt",
305
+ interactive=True)
306
+
307
+ every_undo_but = gr.Button(
308
+ value="Undo",
309
+ interactive=True
310
+ )
311
+
312
+ # every_reset_but = gr.Button(
313
+ # value="Reset",
314
+ # interactive=True
315
+ # )
316
+
317
+ tab_click = gr.Tab(label="Click")
318
+ with tab_click:
319
+ with gr.Row():
320
+ point_mode = gr.Radio(
321
+ choices=["Positive", "Negative"],
322
+ value="Positive",
323
+ label="Point Prompt",
324
+ interactive=True)
325
+
326
+ # args for modify and tracking
327
+ click_undo_but = gr.Button(
328
+ value="Undo",
329
+ interactive=True
330
+ )
331
+ # click_reset_but = gr.Button(
332
+ # value="Reset",
333
+ # interactive=True
334
+ # )
335
+
336
+ tab_stroke = gr.Tab(label="Stroke")
337
+ with tab_stroke:
338
+ drawing_board = gr.Image(label='Drawing Board', tool="sketch", brush_radius=10, interactive=True)
339
+ with gr.Row():
340
+ seg_acc_stroke = gr.Button(value="Segment", interactive=True)
341
+ # stroke_reset_but = gr.Button(
342
+ # value="Reset",
343
+ # interactive=True
344
+ # )
345
+
346
+ tab_text = gr.Tab(label="Text")
347
+ with tab_text:
348
+ grounding_caption = gr.Textbox(label="Detection Prompt")
349
+ detect_button = gr.Button(value="Detect")
350
+ with gr.Accordion("Advanced options", open=False):
351
+ with gr.Row():
352
+ with gr.Column(scale=0.5):
353
+ box_threshold = gr.Slider(
354
+ label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
355
+ )
356
+ with gr.Column(scale=0.5):
357
+ text_threshold = gr.Slider(
358
+ label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
359
+ )
360
+
361
+ with gr.Row():
362
+ with gr.Column(scale=0.5):
363
+ with gr.Tab(label="SegTracker Args"):
364
+ # args for tracking in video do segment-everthing
365
+ points_per_side = gr.Slider(
366
+ label = "points_per_side",
367
+ minimum= 1,
368
+ step = 1,
369
+ maximum=100,
370
+ value=16,
371
+ interactive=True
372
+ )
373
+
374
+ sam_gap = gr.Slider(
375
+ label='sam_gap',
376
+ minimum = 1,
377
+ step=1,
378
+ maximum = 9999,
379
+ value=100,
380
+ interactive=True,
381
+ )
382
+
383
+ max_obj_num = gr.Slider(
384
+ label='max_obj_num',
385
+ minimum = 50,
386
+ step=1,
387
+ maximum = 300,
388
+ value=255,
389
+ interactive=True
390
+ )
391
+ with gr.Accordion("aot advanced options", open=False):
392
+ aot_model = gr.Dropdown(
393
+ label="aot_model",
394
+ choices = [
395
+ "deaotb",
396
+ "deaotl",
397
+ "r50_deaotl"
398
+ ],
399
+ value = "r50_deaotl",
400
+ interactive=True,
401
+ )
402
+ long_term_mem = gr.Slider(label="long term memory gap", minimum=1, maximum=9999, value=9999, step=1)
403
+ max_len_long_term = gr.Slider(label="max len of long term memory", minimum=1, maximum=9999, value=9999, step=1)
404
+
405
+ with gr.Column():
406
+ new_object_button = gr.Button(
407
+ value="Add new object",
408
+ interactive=True
409
+ )
410
+ reset_button = gr.Button(
411
+ value="Reset",
412
+ interactive=True,
413
+ )
414
+ track_for_video = gr.Button(
415
+ value="Start Tracking",
416
+ interactive=True,
417
+ )
418
+
419
+ with gr.Column(scale=0.5):
420
+ output_video = gr.Video(label='Output video').style(height=550)
421
+ output_mask = gr.File(label="Predicted masks")
422
+
423
+ ##########################################################
424
+ ###################### back-end #########################
425
+ ##########################################################
426
+
427
+ # listen to the input_video to get the first frame of video
428
+ input_video.change(
429
+ fn=get_meta_from_video,
430
+ inputs=[
431
+ input_video
432
+ ],
433
+ outputs=[
434
+ input_first_frame, origin_frame, drawing_board, grounding_caption
435
+ ]
436
+ )
437
+
438
+ # listen to the input_img_seq to get the first frame of video
439
+ input_img_seq.change(
440
+ fn=get_meta_from_img_seq,
441
+ inputs=[
442
+ input_img_seq
443
+ ],
444
+ outputs=[
445
+ input_first_frame, origin_frame, drawing_board, grounding_caption
446
+ ]
447
+ )
448
+
449
+ #-------------- Input compont -------------
450
+ tab_video_input.select(
451
+ fn = clean,
452
+ inputs=[],
453
+ outputs=[
454
+ input_video,
455
+ input_img_seq,
456
+ Seg_Tracker,
457
+ input_first_frame,
458
+ origin_frame,
459
+ drawing_board,
460
+ click_stack,
461
+ ]
462
+ )
463
+
464
+ tab_img_seq_input.select(
465
+ fn = clean,
466
+ inputs=[],
467
+ outputs=[
468
+ input_video,
469
+ input_img_seq,
470
+ Seg_Tracker,
471
+ input_first_frame,
472
+ origin_frame,
473
+ drawing_board,
474
+ click_stack,
475
+ ]
476
+ )
477
+
478
+ extract_button.click(
479
+ fn=get_meta_from_img_seq,
480
+ inputs=[
481
+ input_img_seq
482
+ ],
483
+ outputs=[
484
+ input_first_frame, origin_frame, drawing_board
485
+ ]
486
+ )
487
+
488
+
489
+ # ------------------- Interactive component -----------------
490
+
491
+ # listen to the tab to init SegTracker
492
+ tab_everything.select(
493
+ fn=init_SegTracker,
494
+ inputs=[
495
+ aot_model,
496
+ long_term_mem,
497
+ max_len_long_term,
498
+ sam_gap,
499
+ max_obj_num,
500
+ points_per_side,
501
+ origin_frame
502
+ ],
503
+ outputs=[
504
+ Seg_Tracker, input_first_frame, click_stack, grounding_caption
505
+ ],
506
+ queue=False,
507
+
508
+ )
509
+
510
+ tab_click.select(
511
+ fn=init_SegTracker,
512
+ inputs=[
513
+ aot_model,
514
+ long_term_mem,
515
+ max_len_long_term,
516
+ sam_gap,
517
+ max_obj_num,
518
+ points_per_side,
519
+ origin_frame
520
+ ],
521
+ outputs=[
522
+ Seg_Tracker, input_first_frame, click_stack, grounding_caption
523
+ ],
524
+ queue=False,
525
+ )
526
+
527
+ tab_stroke.select(
528
+ fn=init_SegTracker_Stroke,
529
+ inputs=[
530
+ aot_model,
531
+ long_term_mem,
532
+ max_len_long_term,
533
+ sam_gap,
534
+ max_obj_num,
535
+ points_per_side,
536
+ origin_frame,
537
+ ],
538
+ outputs=[
539
+ Seg_Tracker, input_first_frame, click_stack, drawing_board
540
+ ],
541
+ queue=False,
542
+ )
543
+
544
+ tab_text.select(
545
+ fn=init_SegTracker,
546
+ inputs=[
547
+ aot_model,
548
+ long_term_mem,
549
+ max_len_long_term,
550
+ sam_gap,
551
+ max_obj_num,
552
+ points_per_side,
553
+ origin_frame
554
+ ],
555
+ outputs=[
556
+ Seg_Tracker, input_first_frame, click_stack, grounding_caption
557
+ ],
558
+ queue=False,
559
+ )
560
+
561
+ # Use SAM to segment everything for the first frame of video
562
+ seg_every_first_frame.click(
563
+ fn=segment_everything,
564
+ inputs=[
565
+ Seg_Tracker,
566
+ aot_model,
567
+ long_term_mem,
568
+ max_len_long_term,
569
+ origin_frame,
570
+ sam_gap,
571
+ max_obj_num,
572
+ points_per_side,
573
+
574
+ ],
575
+ outputs=[
576
+ Seg_Tracker,
577
+ input_first_frame,
578
+ ],
579
+ )
580
+
581
+ # Interactively modify the mask acc click
582
+ input_first_frame.select(
583
+ fn=sam_click,
584
+ inputs=[
585
+ Seg_Tracker, origin_frame, point_mode, click_stack,
586
+ aot_model,
587
+ long_term_mem,
588
+ max_len_long_term,
589
+ sam_gap,
590
+ max_obj_num,
591
+ points_per_side,
592
+ ],
593
+ outputs=[
594
+ Seg_Tracker, input_first_frame, click_stack
595
+ ]
596
+ )
597
+
598
+ # Interactively segment acc stroke
599
+ seg_acc_stroke.click(
600
+ fn=sam_stroke,
601
+ inputs=[
602
+ Seg_Tracker, origin_frame, drawing_board,
603
+ aot_model,
604
+ long_term_mem,
605
+ max_len_long_term,
606
+ sam_gap,
607
+ max_obj_num,
608
+ points_per_side,
609
+ ],
610
+ outputs=[
611
+ Seg_Tracker, input_first_frame, drawing_board
612
+ ]
613
+ )
614
+
615
+ # Use grounding-dino to detect object
616
+ detect_button.click(
617
+ fn=gd_detect,
618
+ inputs=[
619
+ Seg_Tracker, origin_frame, grounding_caption, box_threshold, text_threshold,
620
+ aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side
621
+ ],
622
+ outputs=[
623
+ Seg_Tracker, input_first_frame
624
+ ]
625
+ )
626
+
627
+ # Add new object
628
+ new_object_button.click(
629
+ fn=add_new_object,
630
+ inputs=
631
+ [
632
+ Seg_Tracker
633
+ ],
634
+ outputs=
635
+ [
636
+ Seg_Tracker, click_stack
637
+ ]
638
+ )
639
+
640
+ # Track object in video
641
+ track_for_video.click(
642
+ fn=tracking_objects,
643
+ inputs=[
644
+ Seg_Tracker,
645
+ input_video,
646
+ input_img_seq,
647
+ fps,
648
+ ],
649
+ outputs=[
650
+ output_video, output_mask
651
+ ]
652
+ )
653
+
654
+ # ----------------- Reset and Undo ---------------------------
655
+
656
+ # Rest
657
+ reset_button.click(
658
+ fn=init_SegTracker,
659
+ inputs=[
660
+ aot_model,
661
+ long_term_mem,
662
+ max_len_long_term,
663
+ sam_gap,
664
+ max_obj_num,
665
+ points_per_side,
666
+ origin_frame
667
+ ],
668
+ outputs=[
669
+ Seg_Tracker, input_first_frame, click_stack, grounding_caption
670
+ ],
671
+ queue=False,
672
+ show_progress=False
673
+ )
674
+
675
+ # every_reset_but.click(
676
+ # fn=init_SegTracker,
677
+ # inputs=[
678
+ # aot_model,
679
+ # sam_gap,
680
+ # max_obj_num,
681
+ # points_per_side,
682
+ # origin_frame
683
+ # ],
684
+ # outputs=[
685
+ # Seg_Tracker, input_first_frame, click_stack, grounding_caption
686
+ # ],
687
+ # queue=False,
688
+ # show_progress=False
689
+ # )
690
+
691
+ # click_reset_but.click(
692
+ # fn=init_SegTracker,
693
+ # inputs=[
694
+ # aot_model,
695
+ # sam_gap,
696
+ # max_obj_num,
697
+ # points_per_side,
698
+ # origin_frame
699
+ # ],
700
+ # outputs=[
701
+ # Seg_Tracker, input_first_frame, click_stack, grounding_caption
702
+ # ],
703
+ # queue=False,
704
+ # show_progress=False
705
+ # )
706
+
707
+ # stroke_reset_but.click(
708
+ # fn=init_SegTracker_Stroke,
709
+ # inputs=[
710
+ # aot_model,
711
+ # sam_gap,
712
+ # max_obj_num,
713
+ # points_per_side,
714
+ # origin_frame,
715
+ # ],
716
+ # outputs=[
717
+ # Seg_Tracker, input_first_frame, click_stack, drawing_board
718
+ # ],
719
+ # queue=False,
720
+ # show_progress=False
721
+ # )
722
+
723
+ # Undo click
724
+ click_undo_but.click(
725
+ fn = undo_click_stack_and_refine_seg,
726
+ inputs=[
727
+ Seg_Tracker, origin_frame, click_stack,
728
+ aot_model,
729
+ long_term_mem,
730
+ max_len_long_term,
731
+ sam_gap,
732
+ max_obj_num,
733
+ points_per_side,
734
+ ],
735
+ outputs=[
736
+ Seg_Tracker, input_first_frame, click_stack
737
+ ]
738
+ )
739
+
740
+ every_undo_but.click(
741
+ fn = undo_click_stack_and_refine_seg,
742
+ inputs=[
743
+ Seg_Tracker, origin_frame, click_stack,
744
+ aot_model,
745
+ long_term_mem,
746
+ max_len_long_term,
747
+ sam_gap,
748
+ max_obj_num,
749
+ points_per_side,
750
+ ],
751
+ outputs=[
752
+ Seg_Tracker, input_first_frame, click_stack
753
+ ]
754
+ )
755
+
756
+ with gr.Tab(label='Video example'):
757
+ gr.Examples(
758
+ examples=[
759
+ # os.path.join(os.path.dirname(__file__), "assets", "840_iSXIa0hE8Ek.mp4"),
760
+ os.path.join(os.path.dirname(__file__), "assets", "blackswan.mp4"),
761
+ # os.path.join(os.path.dirname(__file__), "assets", "bear.mp4"),
762
+ # os.path.join(os.path.dirname(__file__), "assets", "camel.mp4"),
763
+ # os.path.join(os.path.dirname(__file__), "assets", "skate-park.mp4"),
764
+ # os.path.join(os.path.dirname(__file__), "assets", "swing.mp4"),
765
+ ],
766
+ inputs=[input_video],
767
+ )
768
+
769
+ with gr.Tab(label='Image-seq expamle'):
770
+ gr.Examples(
771
+ examples=[
772
+ os.path.join(os.path.dirname(__file__), "assets", "840_iSXIa0hE8Ek.zip"),
773
+ ],
774
+ inputs=[input_img_seq],
775
+ )
776
+
777
+ app.queue(concurrency_count=1)
778
+ app.launch(debug=True, enable_queue=True, share=True)
779
+
780
+
781
+ if __name__ == "__main__":
782
+ seg_track_app()
img2vid.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+
4
+ # set the directory containing the images
5
+ img_dir = './assets/840_iSXIa0hE8Ek'
6
+
7
+ # set the output video file name and codec
8
+ out_file = './assets/840_iSXIa0hE8Ek.mp4'
9
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
10
+
11
+ # get the dimensions of the first image
12
+ img_path = os.path.join(img_dir, os.listdir(img_dir)[0])
13
+ img = cv2.imread(img_path)
14
+ height, width, channels = img.shape
15
+
16
+ # create the VideoWriter object
17
+ out = cv2.VideoWriter(out_file, fourcc, 10, (width, height))
18
+
19
+ # loop through the images and write them to the video
20
+ for img_name in sorted(os.listdir(img_dir)):
21
+ img_path = os.path.join(img_dir, img_name)
22
+ img = cv2.imread(img_path)
23
+ out.write(img)
24
+
25
+ # release the VideoWriter object and close the video file
26
+ out.release()
model_args.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Explanation of generator_args is in sam/segment_anything/automatic_mask_generator.py: SamAutomaticMaskGenerator
2
+ sam_args = {
3
+ 'sam_checkpoint': "ckpt/sam_vit_b_01ec64.pth",
4
+ 'model_type': "vit_b",
5
+ 'generator_args':{
6
+ 'points_per_side': 16,
7
+ 'pred_iou_thresh': 0.8,
8
+ 'stability_score_thresh': 0.9,
9
+ 'crop_n_layers': 1,
10
+ 'crop_n_points_downscale_factor': 2,
11
+ 'min_mask_region_area': 200,
12
+ },
13
+ 'gpu_id': 0,
14
+ }
15
+ aot_args = {
16
+ 'phase': 'PRE_YTB_DAV',
17
+ 'model': 'r50_deaotl',
18
+ 'model_path': 'ckpt/R50_DeAOTL_PRE_YTB_DAV.pth',
19
+ 'long_term_mem_gap': 9999,
20
+ 'max_len_long_term': 9999,
21
+ 'gpu_id': 0,
22
+ }
23
+ segtracker_args = {
24
+ 'sam_gap': 10, # the interval to run sam to segment new objects
25
+ 'min_area': 200, # minimal mask area to add a new mask as a new object
26
+ 'max_obj_num': 255, # maximal object number to track in a video
27
+ 'min_new_obj_iou': 0.8, # the background area ratio of a new object should > 80%
28
+ }
seg_track_anything.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ from model_args import segtracker_args,sam_args,aot_args
4
+ from PIL import Image
5
+ from aot_tracker import _palette
6
+ import numpy as np
7
+ import torch
8
+ import gc
9
+ import imageio
10
+ from scipy.ndimage import binary_dilation
11
+
12
+ def save_prediction(pred_mask,output_dir,file_name):
13
+ save_mask = Image.fromarray(pred_mask.astype(np.uint8))
14
+ save_mask = save_mask.convert(mode='P')
15
+ save_mask.putpalette(_palette)
16
+ save_mask.save(os.path.join(output_dir,file_name))
17
+
18
+ def colorize_mask(pred_mask):
19
+ save_mask = Image.fromarray(pred_mask.astype(np.uint8))
20
+ save_mask = save_mask.convert(mode='P')
21
+ save_mask.putpalette(_palette)
22
+ save_mask = save_mask.convert(mode='RGB')
23
+ return np.array(save_mask)
24
+
25
+ def draw_mask(img, mask, alpha=0.5, id_countour=False):
26
+ img_mask = np.zeros_like(img)
27
+ img_mask = img
28
+ if id_countour:
29
+ # very slow ~ 1s per image
30
+ obj_ids = np.unique(mask)
31
+ obj_ids = obj_ids[obj_ids!=0]
32
+
33
+ for id in obj_ids:
34
+ # Overlay color on binary mask
35
+ if id <= 255:
36
+ color = _palette[id*3:id*3+3]
37
+ else:
38
+ color = [0,0,0]
39
+ foreground = img * (1-alpha) + np.ones_like(img) * alpha * np.array(color)
40
+ binary_mask = (mask == id)
41
+
42
+ # Compose image
43
+ img_mask[binary_mask] = foreground[binary_mask]
44
+
45
+ countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask
46
+ img_mask[countours, :] = 0
47
+ else:
48
+ binary_mask = (mask!=0)
49
+ countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask
50
+ foreground = img*(1-alpha)+colorize_mask(mask)*alpha
51
+ img_mask[binary_mask] = foreground[binary_mask]
52
+ img_mask[countours,:] = 0
53
+
54
+ return img_mask.astype(img.dtype)
55
+
56
+ def create_dir(dir_path):
57
+ if os.path.isdir(dir_path):
58
+ os.system(f"rm -r {dir_path}")
59
+
60
+ os.makedirs(dir_path)
61
+
62
+ aot_model2ckpt = {
63
+ "deaotb": "./ckpt/DeAOTB_PRE_YTB_DAV.pth",
64
+ "deaotl": "./ckpt/DeAOTL_PRE_YTB_DAV",
65
+ "r50_deaotl": "./ckpt/R50_DeAOTL_PRE_YTB_DAV.pth",
66
+ }
67
+
68
+
69
+ def tracking_objects_in_video(SegTracker, input_video, input_img_seq, fps):
70
+
71
+ if input_video is not None:
72
+ video_name = os.path.basename(input_video).split('.')[0]
73
+ elif input_img_seq is not None:
74
+ file_name = input_img_seq.name.split('/')[-1].split('.')[0]
75
+ file_path = f'./assets/{file_name}'
76
+ imgs_path = sorted([os.path.join(file_path, img_name) for img_name in os.listdir(file_path)])
77
+ video_name = file_name
78
+ else:
79
+ return None, None
80
+
81
+ # create dir to save result
82
+ tracking_result_dir = f'{os.path.join(os.path.dirname(__file__), "tracking_results", f"{video_name}")}'
83
+ create_dir(tracking_result_dir)
84
+
85
+ io_args = {
86
+ 'tracking_result_dir': tracking_result_dir,
87
+ 'output_mask_dir': f'{tracking_result_dir}/{video_name}_masks',
88
+ 'output_masked_frame_dir': f'{tracking_result_dir}/{video_name}_masked_frames',
89
+ 'output_video': f'{tracking_result_dir}/{video_name}_seg.mp4', # keep same format as input video
90
+ 'output_gif': f'{tracking_result_dir}/{video_name}_seg.gif',
91
+ }
92
+
93
+ if input_video is not None:
94
+ return video_type_input_tracking(SegTracker, input_video, io_args, video_name)
95
+ elif input_img_seq is not None:
96
+ return img_seq_type_input_tracking(SegTracker, io_args, video_name, imgs_path, fps)
97
+
98
+
99
+ def video_type_input_tracking(SegTracker, input_video, io_args, video_name):
100
+
101
+ # source video to segment
102
+ cap = cv2.VideoCapture(input_video)
103
+ fps = cap.get(cv2.CAP_PROP_FPS)
104
+
105
+ # create dir to save predicted mask and masked frame
106
+ output_mask_dir = io_args['output_mask_dir']
107
+ create_dir(io_args['output_mask_dir'])
108
+ create_dir(io_args['output_masked_frame_dir'])
109
+
110
+ pred_list = []
111
+ masked_pred_list = []
112
+
113
+ torch.cuda.empty_cache()
114
+ gc.collect()
115
+ sam_gap = SegTracker.sam_gap
116
+ frame_idx = 0
117
+
118
+ with torch.cuda.amp.autocast():
119
+ while cap.isOpened():
120
+ ret, frame = cap.read()
121
+ if not ret:
122
+ break
123
+ frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
124
+
125
+ if frame_idx == 0:
126
+ pred_mask = SegTracker.first_frame_mask
127
+ torch.cuda.empty_cache()
128
+ gc.collect()
129
+ elif (frame_idx % sam_gap) == 0:
130
+ seg_mask = SegTracker.seg(frame)
131
+ torch.cuda.empty_cache()
132
+ gc.collect()
133
+ track_mask = SegTracker.track(frame)
134
+ # find new objects, and update tracker with new objects
135
+ new_obj_mask = SegTracker.find_new_objs(track_mask,seg_mask)
136
+ save_prediction(new_obj_mask, output_mask_dir, str(frame_idx).zfill(5) + '_new.png')
137
+ pred_mask = track_mask + new_obj_mask
138
+ # segtracker.restart_tracker()
139
+ SegTracker.add_reference(frame, pred_mask)
140
+ else:
141
+ pred_mask = SegTracker.track(frame,update_memory=True)
142
+ torch.cuda.empty_cache()
143
+ gc.collect()
144
+
145
+ save_prediction(pred_mask, output_mask_dir, str(frame_idx).zfill(5) + '.png')
146
+ pred_list.append(pred_mask)
147
+
148
+ print("processed frame {}, obj_num {}".format(frame_idx, SegTracker.get_obj_num()),end='\r')
149
+ frame_idx += 1
150
+ cap.release()
151
+ print('\nfinished')
152
+
153
+ ##################
154
+ # Visualization
155
+ ##################
156
+
157
+ # draw pred mask on frame and save as a video
158
+ cap = cv2.VideoCapture(input_video)
159
+ fps = cap.get(cv2.CAP_PROP_FPS)
160
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
161
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
162
+ num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
163
+
164
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
165
+ # if input_video[-3:]=='mp4':
166
+ # fourcc = cv2.VideoWriter_fourcc(*"mp4v")
167
+ # elif input_video[-3:] == 'avi':
168
+ # fourcc = cv2.VideoWriter_fourcc(*"MJPG")
169
+ # # fourcc = cv2.VideoWriter_fourcc(*"XVID")
170
+ # else:
171
+ # fourcc = int(cap.get(cv2.CAP_PROP_FOURCC))
172
+ out = cv2.VideoWriter(io_args['output_video'], fourcc, fps, (width, height))
173
+
174
+ frame_idx = 0
175
+ while cap.isOpened():
176
+ ret, frame = cap.read()
177
+ if not ret:
178
+ break
179
+
180
+ frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
181
+ pred_mask = pred_list[frame_idx]
182
+ masked_frame = draw_mask(frame, pred_mask)
183
+ cv2.imwrite(f"{io_args['output_masked_frame_dir']}/{str(frame_idx).zfill(5)}.png", masked_frame[:, :, ::-1])
184
+
185
+ masked_pred_list.append(masked_frame)
186
+ masked_frame = cv2.cvtColor(masked_frame,cv2.COLOR_RGB2BGR)
187
+ out.write(masked_frame)
188
+ print('frame {} writed'.format(frame_idx),end='\r')
189
+ frame_idx += 1
190
+ out.release()
191
+ cap.release()
192
+ print("\n{} saved".format(io_args['output_video']))
193
+ print('\nfinished')
194
+
195
+ # save colorized masks as a gif
196
+ imageio.mimsave(io_args['output_gif'], masked_pred_list, fps=fps)
197
+ print("{} saved".format(io_args['output_gif']))
198
+
199
+ # zip predicted mask
200
+ os.system(f"zip -r {io_args['tracking_result_dir']}/{video_name}_pred_mask.zip {io_args['output_mask_dir']}")
201
+
202
+ # manually release memory (after cuda out of memory)
203
+ del SegTracker
204
+ torch.cuda.empty_cache()
205
+ gc.collect()
206
+
207
+ return io_args['output_video'], f"{io_args['tracking_result_dir']}/{video_name}_pred_mask.zip"
208
+
209
+
210
+ def img_seq_type_input_tracking(SegTracker, io_args, video_name, imgs_path, fps):
211
+
212
+ # create dir to save predicted mask and masked frame
213
+ output_mask_dir = io_args['output_mask_dir']
214
+ create_dir(io_args['output_mask_dir'])
215
+ create_dir(io_args['output_masked_frame_dir'])
216
+
217
+ pred_list = []
218
+ masked_pred_list = []
219
+
220
+ torch.cuda.empty_cache()
221
+ gc.collect()
222
+ sam_gap = SegTracker.sam_gap
223
+ frame_idx = 0
224
+
225
+ with torch.cuda.amp.autocast():
226
+ for img_path in imgs_path:
227
+ frame_name = os.path.basename(img_path).split('.')[0]
228
+ frame = cv2.imread(img_path)
229
+ frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
230
+
231
+ if frame_idx == 0:
232
+ pred_mask = SegTracker.first_frame_mask
233
+ torch.cuda.empty_cache()
234
+ gc.collect()
235
+ elif (frame_idx % sam_gap) == 0:
236
+ seg_mask = SegTracker.seg(frame)
237
+ torch.cuda.empty_cache()
238
+ gc.collect()
239
+ track_mask = SegTracker.track(frame)
240
+ # find new objects, and update tracker with new objects
241
+ new_obj_mask = SegTracker.find_new_objs(track_mask,seg_mask)
242
+ save_prediction(new_obj_mask, output_mask_dir, f'{frame_name}_new.png')
243
+ pred_mask = track_mask + new_obj_mask
244
+ # segtracker.restart_tracker()
245
+ SegTracker.add_reference(frame, pred_mask)
246
+ else:
247
+ pred_mask = SegTracker.track(frame,update_memory=True)
248
+ torch.cuda.empty_cache()
249
+ gc.collect()
250
+
251
+ save_prediction(pred_mask, output_mask_dir, f'{frame_name}.png')
252
+ pred_list.append(pred_mask)
253
+
254
+ print("processed frame {}, obj_num {}".format(frame_idx, SegTracker.get_obj_num()),end='\r')
255
+ frame_idx += 1
256
+ print('\nfinished')
257
+
258
+ ##################
259
+ # Visualization
260
+ ##################
261
+
262
+ # draw pred mask on frame and save as a video
263
+ height, width = pred_list[0].shape
264
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
265
+
266
+ out = cv2.VideoWriter(io_args['output_video'], fourcc, fps, (width, height))
267
+
268
+ frame_idx = 0
269
+ for img_path in imgs_path:
270
+ frame_name = os.path.basename(img_path).split('.')[0]
271
+ frame = cv2.imread(img_path)
272
+ frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
273
+
274
+ pred_mask = pred_list[frame_idx]
275
+ masked_frame = draw_mask(frame, pred_mask)
276
+ masked_pred_list.append(masked_frame)
277
+ cv2.imwrite(f"{io_args['output_masked_frame_dir']}/{frame_name}.png", masked_frame[:, :, ::-1])
278
+
279
+ masked_frame = cv2.cvtColor(masked_frame,cv2.COLOR_RGB2BGR)
280
+ out.write(masked_frame)
281
+ print('frame {} writed'.format(frame_name),end='\r')
282
+ frame_idx += 1
283
+ out.release()
284
+ print("\n{} saved".format(io_args['output_video']))
285
+ print('\nfinished')
286
+
287
+ # save colorized masks as a gif
288
+ imageio.mimsave(io_args['output_gif'], masked_pred_list, fps=fps)
289
+ print("{} saved".format(io_args['output_gif']))
290
+
291
+ # zip predicted mask
292
+ os.system(f"zip -r {io_args['tracking_result_dir']}/{video_name}_pred_mask.zip {io_args['output_mask_dir']}")
293
+
294
+ # manually release memory (after cuda out of memory)
295
+ del SegTracker
296
+ torch.cuda.empty_cache()
297
+ gc.collect()
298
+
299
+
300
+ return io_args['output_video'], f"{io_args['tracking_result_dir']}/{video_name}_pred_mask.zip"