sindhuhegde commited on
Commit
c4c6512
·
1 Parent(s): bd49b58

Update app

Browse files
Files changed (3) hide show
  1. app.py +324 -1
  2. app_v1.py +1482 -0
  3. preprocess/inference_preprocess.py +3 -2
app.py CHANGED
@@ -9,11 +9,21 @@ import subprocess
9
  import cv2
10
  import pickle
11
  import librosa
 
12
  from decord import VideoReader
13
  from decord import cpu, gpu
14
  from utils.audio_utils import *
15
  from utils.inference_utils import *
16
  from sync_models.gestsync_models import *
 
 
 
 
 
 
 
 
 
17
  from tqdm import tqdm
18
  from glob import glob
19
  from scipy.io.wavfile import write
@@ -33,11 +43,307 @@ use_cuda = torch.cuda.is_available()
33
  batch_size = 12
34
  fps = 25
35
  n_negative_samples = 100
 
 
 
 
 
 
 
36
  print("Device: ", device)
37
 
38
  # Initialize the mediapipe holistic keypoint detection model
39
  holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5)
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  @spaces.GPU(duration=140)
42
  def preprocess_video(path, result_folder, apply_preprocess, padding=20):
43
 
@@ -881,9 +1187,26 @@ def save_video(output_tracks, input_frames, wav_file, result_folder):
881
 
882
  return video_output, "success"
883
 
884
- @spaces.GPU(duration=140)
885
  def preprocess_asd(video_path, result_folder_input):
886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
887
  print("Pre-processing the input video...")
888
  status = subprocess.call("python preprocess/inference_preprocess.py --data_dir={}/temp --sd_root={}/crops --work_root={}/metadata --data_root={}".format(result_folder_input, result_folder_input, result_folder_input, video_path), shell=True)
889
  if status != 0:
 
9
  import cv2
10
  import pickle
11
  import librosa
12
+ from ultralytics import YOLO
13
  from decord import VideoReader
14
  from decord import cpu, gpu
15
  from utils.audio_utils import *
16
  from utils.inference_utils import *
17
  from sync_models.gestsync_models import *
18
+ from shutil import rmtree, copy, copytree
19
+ import scenedetect
20
+ from scenedetect.video_manager import VideoManager
21
+ from scenedetect.scene_manager import SceneManager
22
+ from scenedetect.stats_manager import StatsManager
23
+ from scenedetect.detectors import ContentDetector
24
+ from scipy.interpolate import interp1d
25
+ from scipy import signal
26
+
27
  from tqdm import tqdm
28
  from glob import glob
29
  from scipy.io.wavfile import write
 
43
  batch_size = 12
44
  fps = 25
45
  n_negative_samples = 100
46
+
47
+ facedet_scale=0.25
48
+ crop_scale=0
49
+ min_track=50
50
+ frame_rate=25
51
+ num_failed_det=25
52
+ min_frame_size=64
53
  print("Device: ", device)
54
 
55
  # Initialize the mediapipe holistic keypoint detection model
56
  holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5)
57
 
58
+
59
+ def bb_intersection_over_union(boxA, boxB):
60
+ xA = max(boxA[0], boxB[0])
61
+ yA = max(boxA[1], boxB[1])
62
+ xB = min(boxA[2], boxB[2])
63
+ yB = min(boxB[3], boxB[3])
64
+
65
+ interArea = max(0, xB - xA) * max(0, yB - yA)
66
+
67
+ boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
68
+ boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
69
+
70
+ iou = interArea / float(boxAArea + boxBArea - interArea)
71
+
72
+ return iou
73
+
74
+ def track_shot(scenefaces):
75
+ print("Tracking video...")
76
+ iouThres = 0.5 # Minimum IOU between consecutive face detections
77
+ tracks = []
78
+
79
+ while True:
80
+ track = []
81
+ for framefaces in scenefaces:
82
+ for face in framefaces:
83
+ if track == []:
84
+ track.append(face)
85
+ framefaces.remove(face)
86
+ elif face['frame'] - track[-1]['frame'] <= num_failed_det:
87
+ iou = bb_intersection_over_union(face['bbox'], track[-1]['bbox'])
88
+ if iou > iouThres:
89
+ track.append(face)
90
+ framefaces.remove(face)
91
+ continue
92
+ else:
93
+ break
94
+
95
+ if track == []:
96
+ break
97
+ elif len(track) > min_track:
98
+ framenum = np.array([f['frame'] for f in track])
99
+ bboxes = np.array([np.array(f['bbox']) for f in track])
100
+
101
+ frame_i = np.arange(framenum[0], framenum[-1] + 1)
102
+
103
+ bboxes_i = []
104
+ for ij in range(0, 4):
105
+ interpfn = interp1d(framenum, bboxes[:, ij])
106
+ bboxes_i.append(interpfn(frame_i))
107
+ bboxes_i = np.stack(bboxes_i, axis=1)
108
+
109
+ if max(np.mean(bboxes_i[:, 2] - bboxes_i[:, 0]), np.mean(bboxes_i[:, 3] - bboxes_i[:, 1])) > min_frame_size:
110
+ tracks.append({'frame': frame_i, 'bbox': bboxes_i})
111
+
112
+ return tracks
113
+
114
+ def check_folder(folder):
115
+ if os.path.exists(folder):
116
+ return True
117
+ return False
118
+
119
+ def del_folder(folder):
120
+ if os.path.exists(folder):
121
+ rmtree(folder)
122
+
123
+ def read_video(o, start_idx):
124
+ with open(o, 'rb') as o:
125
+ video_stream = VideoReader(o)
126
+ if start_idx > 0:
127
+ video_stream.skip_frames(start_idx)
128
+ return video_stream
129
+
130
+ def crop_video(avi_dir, tmp_dir, track, cropfile, tight_scale=1):
131
+ print("Cropping video...")
132
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
133
+ vOut = cv2.VideoWriter(cropfile + '.avi', fourcc, frame_rate, (480, 270))
134
+
135
+ dets = {'x': [], 'y': [], 's': [], 'bbox': track['bbox'], 'frame': track['frame']}
136
+
137
+ for det in track['bbox']:
138
+ # Reduce the size of the bounding box by a small factor if tighter crops are needed (default -> no reduction in size)
139
+ width = (det[2] - det[0]) * tight_scale
140
+ height = (det[3] - det[1]) * tight_scale
141
+ center_x = (det[0] + det[2]) / 2
142
+ center_y = (det[1] + det[3]) / 2
143
+
144
+ dets['s'].append(max(height, width) / 2)
145
+ dets['y'].append(center_y) # crop center y
146
+ dets['x'].append(center_x) # crop center x
147
+
148
+ # Smooth detections
149
+ dets['s'] = signal.medfilt(dets['s'], kernel_size=13)
150
+ dets['x'] = signal.medfilt(dets['x'], kernel_size=13)
151
+ dets['y'] = signal.medfilt(dets['y'], kernel_size=13)
152
+
153
+ videofile = os.path.join(avi_dir, 'video.avi')
154
+ frame_no_to_start = track['frame'][0]
155
+ video_stream = cv2.VideoCapture(videofile)
156
+ video_stream.set(cv2.CAP_PROP_POS_FRAMES, frame_no_to_start)
157
+ for fidx, frame in enumerate(track['frame']):
158
+ cs = crop_scale
159
+ bs = dets['s'][fidx] # Detection box size
160
+ bsi = int(bs * (1 + 2 * cs)) # Pad videos by this amount
161
+
162
+ image = video_stream.read()[1]
163
+ frame = np.pad(image, ((bsi, bsi), (bsi, bsi), (0, 0)), 'constant', constant_values=(110, 110))
164
+
165
+ my = dets['y'][fidx] + bsi # BBox center Y
166
+ mx = dets['x'][fidx] + bsi # BBox center X
167
+
168
+ face = frame[int(my - bs):int(my + bs * (1 + 2 * cs)), int(mx - bs * (1 + cs)):int(mx + bs * (1 + cs))]
169
+ vOut.write(cv2.resize(face, (480, 270)))
170
+ video_stream.release()
171
+ audiotmp = os.path.join(tmp_dir, 'audio.wav')
172
+ audiostart = (track['frame'][0]) / frame_rate
173
+ audioend = (track['frame'][-1] + 1) / frame_rate
174
+
175
+ vOut.release()
176
+
177
+ # ========== CROP AUDIO FILE ==========
178
+
179
+ command = ("ffmpeg -hide_banner -loglevel panic -y -i %s -ss %.3f -to %.3f %s" % (os.path.join(avi_dir, 'audio.wav'), audiostart, audioend, audiotmp))
180
+ output = subprocess.call(command, shell=True, stdout=None)
181
+
182
+ copy(audiotmp, cropfile + '.wav')
183
+
184
+ # print('Written %s' % cropfile)
185
+ # print('Mean pos: x %.2f y %.2f s %.2f' % (np.mean(dets['x']), np.mean(dets['y']), np.mean(dets['s'])))
186
+
187
+ return {'track': track, 'proc_track': dets}
188
+
189
+ @spaces.GPU(duration=140)
190
+ def inference_video(avi_dir, work_dir, padding=0):
191
+ videofile = os.path.join(avi_dir, 'video.avi')
192
+ vidObj = cv2.VideoCapture(videofile)
193
+ yolo_model = YOLO("yolov9m.pt")
194
+ global dets, fidx
195
+ dets = []
196
+ fidx = 0
197
+
198
+ print("Detecting people in the video using YOLO (slowest step in the pipeline)...")
199
+ def generate_detections():
200
+ global dets, fidx
201
+ while True:
202
+ success, image = vidObj.read()
203
+ if not success:
204
+ break
205
+
206
+ image_np = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
207
+
208
+ # Perform person detection
209
+ results = yolo_model(image_np, verbose=False)
210
+ detections = results[0].boxes
211
+
212
+ dets.append([])
213
+ for i, det in enumerate(detections):
214
+ x1, y1, x2, y2 = det.xyxy[0].detach().cpu().numpy()
215
+ cls = det.cls[0].detach().cpu().numpy()
216
+ conf = det.conf[0].detach().cpu().numpy()
217
+ if int(cls) == 0 and conf>0.7: # Class 0 is 'person' in COCO dataset
218
+ x1 = max(0, int(x1) - padding)
219
+ y1 = max(0, int(y1) - padding)
220
+ x2 = min(image_np.shape[1], int(x2) + padding)
221
+ y2 = min(image_np.shape[0], int(y2) + padding)
222
+ dets[-1].append({'frame': fidx, 'bbox': [x1, y1, x2, y2], 'conf': conf})
223
+
224
+ fidx += 1
225
+ yield
226
+
227
+ return dets
228
+
229
+ for _ in tqdm(generate_detections()):
230
+ pass
231
+
232
+
233
+ print("Successfully detected people in the video")
234
+ savepath = os.path.join(work_dir, 'faces.pckl')
235
+
236
+ with open(savepath, 'wb') as fil:
237
+ pickle.dump(dets, fil)
238
+
239
+ return dets
240
+
241
+ def scene_detect(avi_dir, work_dir):
242
+ print("Detecting scenes in the video...")
243
+ video_manager = VideoManager([os.path.join(avi_dir, 'video.avi')])
244
+ stats_manager = StatsManager()
245
+ scene_manager = SceneManager(stats_manager)
246
+ scene_manager.add_detector(ContentDetector())
247
+ base_timecode = video_manager.get_base_timecode()
248
+
249
+ video_manager.set_downscale_factor()
250
+ video_manager.start()
251
+ scene_manager.detect_scenes(frame_source=video_manager)
252
+ scene_list = scene_manager.get_scene_list(base_timecode)
253
+
254
+ savepath = os.path.join(work_dir, 'scene.pckl')
255
+
256
+ if scene_list == []:
257
+ scene_list = [(video_manager.get_base_timecode(), video_manager.get_current_timecode())]
258
+
259
+ with open(savepath, 'wb') as fil:
260
+ pickle.dump(scene_list, fil)
261
+
262
+ print('%s - scenes detected %d' % (os.path.join(avi_dir, 'video.avi'), len(scene_list)))
263
+
264
+ return scene_list
265
+
266
+
267
+ def process_video_asd(file, sd_root, work_root, data_root, avi_dir, tmp_dir, work_dir, crop_dir, frames_dir):
268
+
269
+ video_file_name = os.path.basename(file.strip())
270
+ sd_dest_folder = sd_root
271
+ work_dest_folder = work_root
272
+
273
+
274
+ del_folder(sd_dest_folder)
275
+ del_folder(work_dest_folder)
276
+
277
+ videofile = file
278
+
279
+ if os.path.exists(work_dir):
280
+ rmtree(work_dir)
281
+
282
+ if os.path.exists(crop_dir):
283
+ rmtree(crop_dir)
284
+
285
+ if os.path.exists(avi_dir):
286
+ rmtree(avi_dir)
287
+
288
+ if os.path.exists(frames_dir):
289
+ rmtree(frames_dir)
290
+
291
+ if os.path.exists(tmp_dir):
292
+ rmtree(tmp_dir)
293
+
294
+ os.makedirs(work_dir)
295
+ os.makedirs(crop_dir)
296
+ os.makedirs(avi_dir)
297
+ os.makedirs(frames_dir)
298
+ os.makedirs(tmp_dir)
299
+
300
+ command = ("ffmpeg -hide_banner -loglevel panic -y -i %s -qscale:v 2 -async 1 -r 25 %s" % (videofile,
301
+ os.path.join(avi_dir,
302
+ 'video.avi')))
303
+ output = subprocess.call(command, shell=True, stdout=None)
304
+ if output != 0:
305
+ return
306
+
307
+ command = ("ffmpeg -hide_banner -loglevel panic -y -i %s -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (os.path.join(avi_dir,
308
+ 'video.avi'),
309
+ os.path.join(avi_dir,
310
+ 'audio.wav')))
311
+ output = subprocess.call(command, shell=True, stdout=None)
312
+ if output != 0:
313
+ return
314
+
315
+ faces = inference_video(avi_dir, work_dir)
316
+
317
+ try:
318
+ scene = scene_detect(avi_dir, work_dir)
319
+ except scenedetect.video_stream.VideoOpenFailure:
320
+ return
321
+
322
+
323
+ allscenes = []
324
+ for shot in scene:
325
+ if shot[1].frame_num - shot[0].frame_num >= min_track:
326
+ allscenes.append(track_shot(faces[shot[0].frame_num:shot[1].frame_num]))
327
+
328
+ alltracks = []
329
+ for sc_num in range(len(allscenes)):
330
+ vidtracks = []
331
+ for ii, track in enumerate(allscenes[sc_num]):
332
+ os.makedirs(os.path.join(crop_dir, 'scene_'+str(sc_num)), exist_ok=True)
333
+ vidtracks.append(crop_video(avi_dir, tmp_dir, track, os.path.join(crop_dir, 'scene_'+str(sc_num), '%05d' % ii)))
334
+ alltracks.append(vidtracks)
335
+
336
+ savepath = os.path.join(work_dir, 'tracks.pckl')
337
+
338
+ with open(savepath, 'wb') as fil:
339
+ pickle.dump(alltracks, fil)
340
+
341
+ rmtree(tmp_dir)
342
+ rmtree(avi_dir)
343
+ rmtree(frames_dir)
344
+ copytree(crop_dir, sd_dest_folder)
345
+ copytree(work_dir, work_dest_folder)
346
+
347
  @spaces.GPU(duration=140)
348
  def preprocess_video(path, result_folder, apply_preprocess, padding=20):
349
 
 
1187
 
1188
  return video_output, "success"
1189
 
 
1190
  def preprocess_asd(video_path, result_folder_input):
1191
 
1192
+ file = video_path
1193
+
1194
+ data_dir = os.path.join(result_folder_input, 'temp')
1195
+ sd_root = os.path.join(result_folder_input, 'crops')
1196
+ work_root = os.path.join(result_folder_input, 'metadata')
1197
+ data_root = result_folder_input
1198
+
1199
+ os.makedirs(sd_root, exist_ok=True)
1200
+ os.makedirs(work_root, exist_ok=True)
1201
+
1202
+ avi_dir = os.path.join(data_dir, 'pyavi')
1203
+ tmp_dir = os.path.join(data_dir, 'pytmp')
1204
+ work_dir = os.path.join(data_dir, 'pywork')
1205
+ crop_dir = os.path.join(data_dir, 'pycrop')
1206
+ frames_dir = os.path.join(data_dir, 'pyframes')
1207
+
1208
+ process_video_asd(file, sd_root, work_root, data_root, avi_dir, tmp_dir, work_dir, crop_dir, frames_dir)
1209
+
1210
  print("Pre-processing the input video...")
1211
  status = subprocess.call("python preprocess/inference_preprocess.py --data_dir={}/temp --sd_root={}/crops --work_root={}/metadata --data_root={}".format(result_folder_input, result_folder_input, result_folder_input, video_path), shell=True)
1212
  if status != 0:
app_v1.py ADDED
@@ -0,0 +1,1482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ from shutil import rmtree
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ import numpy as np
8
+ import subprocess
9
+ import cv2
10
+ import pickle
11
+ import librosa
12
+ from decord import VideoReader
13
+ from decord import cpu, gpu
14
+ from utils.audio_utils import *
15
+ from utils.inference_utils import *
16
+ from sync_models.gestsync_models import *
17
+ from tqdm import tqdm
18
+ from glob import glob
19
+ from scipy.io.wavfile import write
20
+ import mediapipe as mp
21
+ from protobuf_to_dict import protobuf_to_dict
22
+ import warnings
23
+ import spaces
24
+
25
+ mp_holistic = mp.solutions.holistic
26
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
27
+ warnings.filterwarnings("ignore", category=UserWarning)
28
+
29
+ # Initialize global variables
30
+ CHECKPOINT_PATH = "model_rgb.pth"
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ use_cuda = torch.cuda.is_available()
33
+ batch_size = 12
34
+ fps = 25
35
+ n_negative_samples = 100
36
+ print("Device: ", device)
37
+
38
+ # Initialize the mediapipe holistic keypoint detection model
39
+ holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5)
40
+
41
+ @spaces.GPU(duration=140)
42
+ def preprocess_video(path, result_folder, apply_preprocess, padding=20):
43
+
44
+ '''
45
+ This function preprocesses the input video to extract the audio and crop the frames using YOLO model
46
+
47
+ Args:
48
+ - path (string) : Path of the input video file
49
+ - result_folder (string) : Path of the folder to save the extracted audio and cropped video
50
+ - padding (int) : Padding to add to the bounding box
51
+ Returns:
52
+ - wav_file (string) : Path of the extracted audio file
53
+ - fps (int) : FPS of the input video
54
+ - video_output (string) : Path of the cropped video file
55
+ - msg (string) : Message to be returned
56
+ '''
57
+
58
+ # Load all video frames
59
+ try:
60
+ vr = VideoReader(path, ctx=cpu(0))
61
+ fps = vr.get_avg_fps()
62
+ frame_count = len(vr)
63
+ except:
64
+ msg = "Oops! Could not load the video. Please check the input video and try again."
65
+ return None, None, None, msg
66
+
67
+ if frame_count < 25:
68
+ msg = "Not enough frames to process! Please give a longer video as input"
69
+ return None, None, None, msg
70
+
71
+ # Extract the audio from the input video file using ffmpeg
72
+ wav_file = os.path.join(result_folder, "audio.wav")
73
+
74
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -async 1 -ac 1 -vn \
75
+ -acodec pcm_s16le -ar 16000 %s -y' % (path, wav_file), shell=True)
76
+
77
+ if status != 0:
78
+ msg = "Oops! Could not load the audio file. Please check the input video and try again."
79
+ return None, None, None, msg
80
+ print("Extracted the audio from the video")
81
+
82
+ if apply_preprocess=="True":
83
+ all_frames = []
84
+ for k in range(len(vr)):
85
+ all_frames.append(vr[k].asnumpy())
86
+ all_frames = np.asarray(all_frames)
87
+ print("Extracted the frames for pre-processing")
88
+
89
+ # Load YOLOv9 model (pre-trained on COCO dataset)
90
+ yolo_model = YOLO("yolov9s.pt")
91
+ print("Loaded the YOLO model")
92
+
93
+
94
+
95
+ person_videos = {}
96
+ person_tracks = {}
97
+
98
+ print("Processing the frames...")
99
+ for frame_idx in tqdm(range(frame_count)):
100
+
101
+ frame = all_frames[frame_idx]
102
+
103
+ # Perform person detection
104
+ results = yolo_model(frame, verbose=False)
105
+ detections = results[0].boxes
106
+
107
+ for i, det in enumerate(detections):
108
+ x1, y1, x2, y2 = det.xyxy[0]
109
+ cls = det.cls[0]
110
+ if int(cls) == 0: # Class 0 is 'person' in COCO dataset
111
+
112
+ x1 = max(0, int(x1) - padding)
113
+ y1 = max(0, int(y1) - padding)
114
+ x2 = min(frame.shape[1], int(x2) + padding)
115
+ y2 = min(frame.shape[0], int(y2) + padding)
116
+
117
+ if i not in person_videos:
118
+ person_videos[i] = []
119
+ person_tracks[i] = []
120
+
121
+ person_videos[i].append(frame)
122
+ person_tracks[i].append([x1,y1,x2,y2])
123
+
124
+
125
+ num_persons = 0
126
+ for i in person_videos.keys():
127
+ if len(person_videos[i]) >= frame_count//2:
128
+ num_persons+=1
129
+
130
+ if num_persons==0:
131
+ msg = "No person detected in the video! Please give a video with one person as input"
132
+ return None, None, None, msg
133
+ if num_persons>1:
134
+ msg = "More than one person detected in the video! Please give a video with only one person as input"
135
+ return None, None, None, msg
136
+
137
+
138
+
139
+ # For the person detected, crop the frame based on the bounding box
140
+ if len(person_videos[0]) > frame_count-10:
141
+ crop_filename = os.path.join(result_folder, "preprocessed_video.avi")
142
+ fourcc = cv2.VideoWriter_fourcc(*'DIVX')
143
+
144
+ # Get bounding box coordinates based on person_tracks[i]
145
+ max_x1 = min([track[0] for track in person_tracks[0]])
146
+ max_y1 = min([track[1] for track in person_tracks[0]])
147
+ max_x2 = max([track[2] for track in person_tracks[0]])
148
+ max_y2 = max([track[3] for track in person_tracks[0]])
149
+
150
+ max_width = max_x2 - max_x1
151
+ max_height = max_y2 - max_y1
152
+
153
+ out = cv2.VideoWriter(crop_filename, fourcc, fps, (max_width, max_height))
154
+ for frame in person_videos[0]:
155
+ crop = frame[max_y1:max_y2, max_x1:max_x2]
156
+ crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
157
+ out.write(crop)
158
+ out.release()
159
+
160
+ no_sound_video = crop_filename.split('.')[0] + '_nosound.mp4'
161
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (crop_filename, no_sound_video), shell=True)
162
+ if status != 0:
163
+ msg = "Oops! Could not preprocess the video. Please check the input video and try again."
164
+ return None, None, None, msg
165
+
166
+ video_output = crop_filename.split('.')[0] + '.mp4'
167
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -strict -2 -q:v 1 %s' %
168
+ (wav_file , no_sound_video, video_output), shell=True)
169
+ if status != 0:
170
+ msg = "Oops! Could not preprocess the video. Please check the input video and try again."
171
+ return None, None, None, msg
172
+
173
+ os.remove(crop_filename)
174
+ os.remove(no_sound_video)
175
+
176
+ print("Successfully saved the pre-processed video: ", video_output)
177
+ else:
178
+ msg = "Could not track the person in the full video! Please give a single-speaker video as input"
179
+ return None, None, None, msg
180
+
181
+ else:
182
+ video_output = path
183
+
184
+ return wav_file, fps, video_output, "success"
185
+
186
+ def resample_video(video_file, video_fname, result_folder):
187
+
188
+ '''
189
+ This function resamples the video to 25 fps
190
+
191
+ Args:
192
+ - video_file (string) : Path of the input video file
193
+ - video_fname (string) : Name of the input video file
194
+ - result_folder (string) : Path of the folder to save the resampled video
195
+ Returns:
196
+ - video_file_25fps (string) : Path of the resampled video file
197
+ '''
198
+ video_file_25fps = os.path.join(result_folder, '{}.mp4'.format(video_fname))
199
+
200
+ # Resample the video to 25 fps
201
+ status = subprocess.call("ffmpeg -hide_banner -loglevel panic -y -i {} -c:v libx264 -preset veryslow -crf 0 -filter:v fps=25 -pix_fmt yuv420p {}".format(video_file, video_file_25fps), shell=True)
202
+ if status != 0:
203
+ msg = "Oops! Could not resample the video to 25 FPS. Please check the input video and try again."
204
+ return None, msg
205
+ print('Resampled the video to 25 fps: {}'.format(video_file_25fps))
206
+
207
+ return video_file_25fps, "success"
208
+
209
+ def load_checkpoint(path, model):
210
+ '''
211
+ This function loads the trained model from the checkpoint
212
+
213
+ Args:
214
+ - path (string) : Path of the checkpoint file
215
+ - model (object) : Model object
216
+ Returns:
217
+ - model (object) : Model object with the weights loaded from the checkpoint
218
+ '''
219
+
220
+ # Load the checkpoint
221
+ if use_cuda:
222
+ checkpoint = torch.load(path)
223
+ else:
224
+ checkpoint = torch.load(path, map_location="cpu")
225
+
226
+ s = checkpoint["state_dict"]
227
+ new_s = {}
228
+
229
+ for k, v in s.items():
230
+ new_s[k.replace('module.', '')] = v
231
+ model.load_state_dict(new_s)
232
+
233
+ if use_cuda:
234
+ model.cuda()
235
+
236
+ print("Loaded checkpoint from: {}".format(path))
237
+
238
+ return model.eval()
239
+
240
+
241
+ def load_video_frames(video_file):
242
+ '''
243
+ This function extracts the frames from the video
244
+
245
+ Args:
246
+ - video_file (string) : Path of the video file
247
+ Returns:
248
+ - frames (list) : List of frames extracted from the video
249
+ - msg (string) : Message to be returned
250
+ '''
251
+
252
+ # Read the video
253
+ try:
254
+ vr = VideoReader(video_file, ctx=cpu(0))
255
+ except:
256
+ msg = "Oops! Could not load the input video file"
257
+ return None, msg
258
+
259
+
260
+ # Extract the frames
261
+ frames = []
262
+ for k in range(len(vr)):
263
+ frames.append(vr[k].asnumpy())
264
+
265
+ frames = np.asarray(frames)
266
+
267
+ return frames, "success"
268
+
269
+
270
+
271
+ def get_keypoints(frames):
272
+
273
+ '''
274
+ This function extracts the keypoints from the frames using MediaPipe Holistic pipeline
275
+
276
+ Args:
277
+ - frames (list) : List of frames extracted from the video
278
+ Returns:
279
+ - kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
280
+ - msg (string) : Message to be returned
281
+ '''
282
+
283
+ try:
284
+ holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5)
285
+
286
+ resolution = frames[0].shape
287
+ all_frame_kps = []
288
+
289
+ for frame in frames:
290
+
291
+ results = holistic.process(frame)
292
+
293
+ pose, left_hand, right_hand, face = None, None, None, None
294
+ if results.pose_landmarks is not None:
295
+ pose = protobuf_to_dict(results.pose_landmarks)['landmark']
296
+ if results.left_hand_landmarks is not None:
297
+ left_hand = protobuf_to_dict(results.left_hand_landmarks)['landmark']
298
+ if results.right_hand_landmarks is not None:
299
+ right_hand = protobuf_to_dict(results.right_hand_landmarks)['landmark']
300
+ if results.face_landmarks is not None:
301
+ face = protobuf_to_dict(results.face_landmarks)['landmark']
302
+
303
+ frame_dict = {"pose":pose, "left_hand":left_hand, "right_hand":right_hand, "face":face}
304
+
305
+ all_frame_kps.append(frame_dict)
306
+
307
+ kp_dict = {"kps":all_frame_kps, "resolution":resolution}
308
+ except Exception as e:
309
+ print("Error: ", e)
310
+ return None, "Error: Could not extract keypoints from the frames"
311
+
312
+ return kp_dict, "success"
313
+
314
+
315
+ def check_visible_gestures(kp_dict):
316
+
317
+ '''
318
+ This function checks if the gestures in the video are visible
319
+
320
+ Args:
321
+ - kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
322
+ Returns:
323
+ - msg (string) : Message to be returned
324
+ '''
325
+
326
+ keypoints = kp_dict['kps']
327
+ keypoints = np.array(keypoints)
328
+
329
+ if len(keypoints)<25:
330
+ msg = "Not enough keypoints to process! Please give a longer video as input"
331
+ return msg
332
+
333
+ pose_count, hand_count = 0, 0
334
+ for frame_kp_dict in keypoints:
335
+
336
+ pose = frame_kp_dict["pose"]
337
+ left_hand = frame_kp_dict["left_hand"]
338
+ right_hand = frame_kp_dict["right_hand"]
339
+
340
+ if pose is None:
341
+ pose_count += 1
342
+
343
+ if left_hand is None and right_hand is None:
344
+ hand_count += 1
345
+
346
+
347
+ if hand_count/len(keypoints) > 0.6 or pose_count/len(keypoints) > 0.6:
348
+ msg = "The gestures in the input video are not visible! Please give a video with visible gestures as input."
349
+ return msg
350
+
351
+ print("Successfully verified the input video - Gestures are visible!")
352
+
353
+ return "success"
354
+
355
+ def load_rgb_masked_frames(input_frames, kp_dict, asd=False, stride=1, window_frames=25, width=480, height=270):
356
+
357
+ '''
358
+ This function masks the faces using the keypoints extracted from the frames
359
+
360
+ Args:
361
+ - input_frames (list) : List of frames extracted from the video
362
+ - kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames
363
+ - stride (int) : Stride to extract the frames
364
+ - window_frames (int) : Number of frames in each window that is given as input to the model
365
+ - width (int) : Width of the frames
366
+ - height (int) : Height of the frames
367
+ Returns:
368
+ - input_frames (array) : Frame window to be given as input to the model
369
+ - num_frames (int) : Number of frames to extract
370
+ - orig_masked_frames (array) : Masked frames extracted from the video
371
+ - msg (string) : Message to be returned
372
+ '''
373
+
374
+ print("Creating masked input frames...")
375
+
376
+ input_frames_masked = []
377
+ if kp_dict is None:
378
+ for img in tqdm(input_frames):
379
+ img = cv2.resize(img, (width, height))
380
+ masked_img = cv2.rectangle(img, (0,0), (width,110), (0,0,0), -1)
381
+ input_frames_masked.append(masked_img)
382
+
383
+ else:
384
+ # Face indices to extract the face-coordinates needed for masking
385
+ face_oval_idx = [10, 21, 54, 58, 67, 93, 103, 109, 127, 132, 136, 148, 149, 150, 152, 162, 172,
386
+ 176, 234, 251, 284, 288, 297, 323, 332, 338, 356, 361, 365, 377, 378, 379, 389, 397, 400, 454]
387
+
388
+ input_keypoints, resolution = kp_dict['kps'], kp_dict['resolution']
389
+ print("Input keypoints: ", len(input_keypoints))
390
+
391
+ for i, frame_kp_dict in tqdm(enumerate(input_keypoints)):
392
+
393
+ img = input_frames[i]
394
+ face = frame_kp_dict["face"]
395
+
396
+ if face is None:
397
+ img = cv2.resize(img, (width, height))
398
+ masked_img = cv2.rectangle(img, (0,0), (width,110), (0,0,0), -1)
399
+ else:
400
+ face_kps = []
401
+ for idx in range(len(face)):
402
+ if idx in face_oval_idx:
403
+ x, y = int(face[idx]["x"]*resolution[1]), int(face[idx]["y"]*resolution[0])
404
+ face_kps.append((x,y))
405
+
406
+ face_kps = np.array(face_kps)
407
+ x1, y1 = min(face_kps[:,0]), min(face_kps[:,1])
408
+ x2, y2 = max(face_kps[:,0]), max(face_kps[:,1])
409
+ masked_img = cv2.rectangle(img, (0,0), (resolution[1],y2+15), (0,0,0), -1)
410
+
411
+ if masked_img.shape[0] != width or masked_img.shape[1] != height:
412
+ masked_img = cv2.resize(masked_img, (width, height))
413
+
414
+ input_frames_masked.append(masked_img)
415
+
416
+ orig_masked_frames = np.array(input_frames_masked)
417
+ input_frames = np.array(input_frames_masked) / 255.
418
+ if asd:
419
+ input_frames = np.pad(input_frames, ((12, 12), (0,0), (0,0), (0,0)), 'edge')
420
+ # print("Input images full: ", input_frames.shape) # num_framesx270x480x3
421
+
422
+ input_frames = np.array([input_frames[i:i+window_frames, :, :] for i in range(0,input_frames.shape[0], stride) if (i+window_frames <= input_frames.shape[0])])
423
+ # print("Input images window: ", input_frames.shape) # Tx25x270x480x3
424
+ print("Successfully created masked input frames")
425
+
426
+ num_frames = input_frames.shape[0]
427
+
428
+ if num_frames<10:
429
+ msg = "Not enough frames to process! Please give a longer video as input."
430
+ return None, None, None, msg
431
+
432
+ return input_frames, num_frames, orig_masked_frames, "success"
433
+
434
+ def load_spectrograms(wav_file, asd=False, num_frames=None, window_frames=25, stride=4):
435
+
436
+ '''
437
+ This function extracts the spectrogram from the audio file
438
+
439
+ Args:
440
+ - wav_file (string) : Path of the extracted audio file
441
+ - num_frames (int) : Number of frames to extract
442
+ - window_frames (int) : Number of frames in each window that is given as input to the model
443
+ - stride (int) : Stride to extract the audio frames
444
+ Returns:
445
+ - spec (array) : Spectrogram array window to be used as input to the model
446
+ - orig_spec (array) : Spectrogram array extracted from the audio file
447
+ - msg (string) : Message to be returned
448
+ '''
449
+
450
+ # Extract the audio from the input video file using ffmpeg
451
+ try:
452
+ wav = librosa.load(wav_file, sr=16000)[0]
453
+ except:
454
+ msg = "Oops! Could extract the spectrograms from the audio file. Please check the input and try again."
455
+ return None, None, msg
456
+
457
+ # Convert to tensor
458
+ wav = torch.FloatTensor(wav).unsqueeze(0)
459
+ mel, _, _, _ = wav2filterbanks(wav.to(device))
460
+ spec = mel.squeeze(0).cpu().numpy()
461
+ orig_spec = spec
462
+ spec = np.array([spec[i:i+(window_frames*stride), :] for i in range(0, spec.shape[0], stride) if (i+(window_frames*stride) <= spec.shape[0])])
463
+
464
+ if num_frames is not None:
465
+ if len(spec) != num_frames:
466
+ spec = spec[:num_frames]
467
+ frame_diff = np.abs(len(spec) - num_frames)
468
+ if frame_diff > 60:
469
+ print("The input video and audio length do not match - The results can be unreliable! Please check the input video.")
470
+
471
+ if asd:
472
+ pad_frames = (window_frames//2)
473
+ spec = np.pad(spec, ((pad_frames, pad_frames), (0,0), (0,0)), 'edge')
474
+
475
+ return spec, orig_spec, "success"
476
+
477
+
478
+ def calc_optimal_av_offset(vid_emb, aud_emb, num_avg_frames, model):
479
+ '''
480
+ This function calculates the audio-visual offset between the video and audio
481
+
482
+ Args:
483
+ - vid_emb (array) : Video embedding array
484
+ - aud_emb (array) : Audio embedding array
485
+ - num_avg_frames (int) : Number of frames to average the scores
486
+ - model (object) : Model object
487
+ Returns:
488
+ - offset (int) : Optimal audio-visual offset
489
+ - msg (string) : Message to be returned
490
+ '''
491
+
492
+ pos_vid_emb, all_aud_emb, pos_idx, stride, status = create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames)
493
+ if status != "success":
494
+ return None, status
495
+ scores, _ = calc_av_scores(pos_vid_emb, all_aud_emb, model)
496
+ offset = scores.argmax()*stride - pos_idx
497
+
498
+ return offset.item(), "success"
499
+
500
+ def create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames, stride=5):
501
+
502
+ '''
503
+ This function creates all possible positive and negative audio embeddings to compare and obtain the sync offset
504
+
505
+ Args:
506
+ - vid_emb (array) : Video embedding array
507
+ - aud_emb (array) : Audio embedding array
508
+ - num_avg_frames (int) : Number of frames to average the scores
509
+ - stride (int) : Stride to extract the negative windows
510
+ Returns:
511
+ - vid_emb_pos (array) : Positive video embedding array
512
+ - aud_emb_posneg (array) : All possible combinations of audio embedding array
513
+ - pos_idx_frame (int) : Positive video embedding array frame
514
+ - stride (int) : Stride used to extract the negative windows
515
+ - msg (string) : Message to be returned
516
+ '''
517
+
518
+ slice_size = num_avg_frames
519
+ aud_emb_posneg = aud_emb.squeeze(1).unfold(-1, slice_size, stride)
520
+ aud_emb_posneg = aud_emb_posneg.permute([0, 2, 1, 3])
521
+ aud_emb_posneg = aud_emb_posneg[:, :int(n_negative_samples/stride)+1]
522
+
523
+ pos_idx = (aud_emb_posneg.shape[1]//2)
524
+ pos_idx_frame = pos_idx*stride
525
+
526
+ min_offset_frames = -(pos_idx)*stride
527
+ max_offset_frames = (aud_emb_posneg.shape[1] - pos_idx - 1)*stride
528
+ print("With the current video length and the number of average frames, the model can predict the offsets in the range: [{}, {}]".format(min_offset_frames, max_offset_frames))
529
+
530
+ vid_emb_pos = vid_emb[:, :, pos_idx_frame:pos_idx_frame+slice_size]
531
+ if vid_emb_pos.shape[2] != slice_size:
532
+ msg = "Video is too short to use {} frames to average the scores. Please use a longer input video or reduce the number of average frames".format(slice_size)
533
+ return None, None, None, None, msg
534
+
535
+ return vid_emb_pos, aud_emb_posneg, pos_idx_frame, stride, "success"
536
+
537
+ def calc_av_scores(vid_emb, aud_emb, model):
538
+
539
+ '''
540
+ This function calls functions to calculate the audio-visual similarity and attention map between the video and audio embeddings
541
+
542
+ Args:
543
+ - vid_emb (array) : Video embedding array
544
+ - aud_emb (array) : Audio embedding array
545
+ - model (object) : Model object
546
+ Returns:
547
+ - scores (array) : Audio-visual similarity scores
548
+ - att_map (array) : Attention map
549
+ '''
550
+
551
+ scores = calc_att_map(vid_emb, aud_emb, model)
552
+ att_map = logsoftmax_2d(scores)
553
+ scores = scores.mean(-1)
554
+
555
+ return scores, att_map
556
+
557
+ def calc_att_map(vid_emb, aud_emb, model):
558
+
559
+ '''
560
+ This function calculates the similarity between the video and audio embeddings
561
+
562
+ Args:
563
+ - vid_emb (array) : Video embedding array
564
+ - aud_emb (array) : Audio embedding array
565
+ - model (object) : Model object
566
+ Returns:
567
+ - scores (array) : Audio-visual similarity scores
568
+ '''
569
+
570
+ vid_emb = vid_emb[:, :, None]
571
+ aud_emb = aud_emb.transpose(1, 2)
572
+
573
+ scores = run_func_in_parts(lambda x, y: (x * y).sum(1),
574
+ vid_emb,
575
+ aud_emb,
576
+ part_len=10,
577
+ dim=3,
578
+ device=device)
579
+
580
+ scores = model.logits_scale(scores[..., None]).squeeze(-1)
581
+
582
+ return scores
583
+
584
+ def generate_video(frames, audio_file, video_fname):
585
+
586
+ '''
587
+ This function generates the video from the frames and audio file
588
+
589
+ Args:
590
+ - frames (array) : Frames to be used to generate the video
591
+ - audio_file (string) : Path of the audio file
592
+ - video_fname (string) : Path of the video file
593
+ Returns:
594
+ - video_output (string) : Path of the video file
595
+ '''
596
+
597
+ fname = 'inference.avi'
598
+ video = cv2.VideoWriter(fname, cv2.VideoWriter_fourcc(*'DIVX'), 25, (frames[0].shape[1], frames[0].shape[0]))
599
+
600
+ for i in range(len(frames)):
601
+ video.write(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB))
602
+ video.release()
603
+
604
+ no_sound_video = video_fname + '_nosound.mp4'
605
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (fname, no_sound_video), shell=True)
606
+ if status != 0:
607
+ msg = "Oops! Could not generate the video. Please check the input video and try again."
608
+ return None, msg
609
+
610
+ video_output = video_fname + '.mp4'
611
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -c:v libx264 -preset veryslow -crf 18 -pix_fmt yuv420p -strict -2 -q:v 1 -shortest %s' %
612
+ (audio_file, no_sound_video, video_output), shell=True)
613
+
614
+ if status != 0:
615
+ msg = "Oops! Could not generate the video. Please check the input video and try again."
616
+ return None, msg
617
+
618
+ os.remove(fname)
619
+ os.remove(no_sound_video)
620
+
621
+ return video_output, "success"
622
+
623
+ def sync_correct_video(video_path, frames, wav_file, offset, result_folder, sample_rate=16000, fps=25):
624
+
625
+ '''
626
+ This function corrects the video and audio to sync with each other
627
+
628
+ Args:
629
+ - video_path (string) : Path of the video file
630
+ - frames (array) : Frames to be used to generate the video
631
+ - wav_file (string) : Path of the audio file
632
+ - offset (int) : Predicted sync-offset to be used to correct the video
633
+ - result_folder (string) : Path of the result folder to save the output sync-corrected video
634
+ - sample_rate (int) : Sample rate of the audio
635
+ - fps (int) : Frames per second of the video
636
+ Returns:
637
+ - video_output (string) : Path of the video file
638
+ '''
639
+
640
+ if offset == 0:
641
+ print("The input audio and video are in-sync! No need to perform sync correction.")
642
+ return video_path, "success"
643
+
644
+ print("Performing Sync Correction...")
645
+ corrected_frames = np.zeros_like(frames)
646
+ if offset > 0:
647
+ audio_offset = int(offset*(sample_rate/fps))
648
+ wav = librosa.core.load(wav_file, sr=sample_rate)[0]
649
+ corrected_wav = wav[audio_offset:]
650
+ corrected_wav_file = os.path.join(result_folder, "audio_sync_corrected.wav")
651
+ write(corrected_wav_file, sample_rate, corrected_wav)
652
+ wav_file = corrected_wav_file
653
+ corrected_frames = frames
654
+ elif offset < 0:
655
+ corrected_frames[0:len(frames)+offset] = frames[np.abs(offset):]
656
+ corrected_frames = corrected_frames[:len(frames)-np.abs(offset)]
657
+
658
+ corrected_video_path = os.path.join(result_folder, "result_sync_corrected")
659
+ video_output, status = generate_video(corrected_frames, wav_file, corrected_video_path)
660
+ if status != "success":
661
+ return None, status
662
+
663
+ return video_output, "success"
664
+
665
+
666
+ def load_masked_input_frames(test_videos, spec, wav_file, scene_num, result_folder):
667
+
668
+ '''
669
+ This function loads the masked input frames from the video
670
+
671
+ Args:
672
+ - test_videos (list) : List of videos to be processed (speaker-specific tracks)
673
+ - spec (array) : Spectrogram of the audio
674
+ - wav_file (string) : Path of the audio file
675
+ - scene_num (int) : Scene number to be used to save the input masked video
676
+ - result_folder (string) : Path of the folder to save the input masked video
677
+ Returns:
678
+ - all_frames (list) : List of masked input frames window to be used as input to the model
679
+ - all_orig_frames (list) : List of original masked input frames
680
+ '''
681
+
682
+ all_frames, all_orig_frames = [], []
683
+ for video_num, video in enumerate(test_videos):
684
+
685
+ print("Processing video: ", video)
686
+
687
+ # Load the video frames
688
+ frames, status = load_video_frames(video)
689
+ if status != "success":
690
+ return None, None, status
691
+ print("Successfully loaded the video frames")
692
+
693
+ # Extract the keypoints from the frames
694
+ kp_dict, status = get_keypoints(frames)
695
+ if status != "success":
696
+ return None, None, status
697
+ print("Successfully extracted the keypoints")
698
+
699
+ # Mask the frames using the keypoints extracted from the frames and prepare the input to the model
700
+ masked_frames, num_frames, orig_masked_frames, status = load_rgb_masked_frames(frames, kp_dict, asd=True)
701
+ if status != "success":
702
+ return None, None, status
703
+ print("Successfully loaded the masked frames")
704
+
705
+
706
+ # Check if the length of the input frames is equal to the length of the spectrogram
707
+ if spec.shape[2]!=masked_frames.shape[0]:
708
+ num_frames = spec.shape[2]
709
+ masked_frames = masked_frames[:num_frames]
710
+ orig_masked_frames = orig_masked_frames[:num_frames]
711
+ frame_diff = np.abs(spec.shape[2] - num_frames)
712
+ if frame_diff > 60:
713
+ print("The input video and audio length do not match - The results can be unreliable! Please check the input video.")
714
+
715
+ # Transpose the frames to the correct format
716
+ frames = np.transpose(masked_frames, (4, 0, 1, 2, 3))
717
+ frames = torch.FloatTensor(np.array(frames)).unsqueeze(0)
718
+ print("Successfully converted the frames to tensor")
719
+
720
+ all_frames.append(frames)
721
+ all_orig_frames.append(orig_masked_frames)
722
+
723
+
724
+ return all_frames, all_orig_frames, "success"
725
+
726
+ def extract_audio(video, result_folder):
727
+
728
+ '''
729
+ This function extracts the audio from the video file
730
+
731
+ Args:
732
+ - video (string) : Path of the video file
733
+ - result_folder (string) : Path of the folder to save the extracted audio file
734
+ Returns:
735
+ - wav_file (string) : Path of the extracted audio file
736
+ '''
737
+
738
+ wav_file = os.path.join(result_folder, "audio.wav")
739
+
740
+ status = subprocess.call('ffmpeg -hide_banner -loglevel panic -threads 1 -y -i %s -async 1 -ac 1 -vn \
741
+ -acodec pcm_s16le -ar 16000 %s' % (video, wav_file), shell=True)
742
+
743
+ if status != 0:
744
+ msg = "Oops! Could not load the audio file in the given input video. Please check the input and try again"
745
+ return None, msg
746
+
747
+ return wav_file, "success"
748
+
749
+ @spaces.GPU(duration=140)
750
+ def get_embeddings(video_sequences, audio_sequences, model, calc_aud_emb=True):
751
+
752
+ '''
753
+ This function extracts the video and audio embeddings from the input frames and audio sequences
754
+
755
+ Args:
756
+ - video_sequences (array) : Array of video frames to be used as input to the model
757
+ - audio_sequences (array) : Array of audio frames to be used as input to the model
758
+ - model (object) : Model object
759
+ - calc_aud_emb (bool) : Flag to calculate the audio embedding
760
+ Returns:
761
+ - video_emb (array) : Video embedding
762
+ - audio_emb (array) : Audio embedding
763
+ '''
764
+
765
+ batch_size = 12
766
+ video_emb = []
767
+ audio_emb = []
768
+
769
+ for i in range(0, len(video_sequences), batch_size):
770
+ video_inp = video_sequences[i:i+batch_size, ]
771
+ vid_emb = model.forward_vid(video_inp.to(device), return_feats=False)
772
+ vid_emb = torch.mean(vid_emb, axis=-1)
773
+
774
+ video_emb.append(vid_emb.detach())
775
+
776
+ if calc_aud_emb:
777
+ audio_inp = audio_sequences[i:i+batch_size, ]
778
+ aud_emb = model.forward_aud(audio_inp.to(device))
779
+ audio_emb.append(aud_emb.detach())
780
+
781
+ torch.cuda.empty_cache()
782
+
783
+ video_emb = torch.cat(video_emb, dim=0)
784
+
785
+ if calc_aud_emb:
786
+ audio_emb = torch.cat(audio_emb, dim=0)
787
+
788
+ return video_emb, audio_emb
789
+
790
+ return video_emb
791
+
792
+
793
+
794
+ def predict_active_speaker(all_video_embeddings, audio_embedding, global_score, num_avg_frames, model):
795
+
796
+ '''
797
+ This function predicts the active speaker in each frame
798
+
799
+ Args:
800
+ - all_video_embeddings (array) : Array of video embeddings of all speakers
801
+ - audio_embedding (array) : Audio embedding
802
+ - global_score (bool) : Flag to calculate the global score
803
+ Returns:
804
+ - pred_speaker (list) : List of active speakers in each frame
805
+ '''
806
+
807
+ cos = nn.CosineSimilarity(dim=1)
808
+
809
+ audio_embedding = audio_embedding.squeeze(2)
810
+
811
+ scores = []
812
+ for i in range(len(all_video_embeddings)):
813
+ video_embedding = all_video_embeddings[i]
814
+
815
+ # Compute the similarity of each speaker's video embeddings with the audio embedding
816
+ sim = cos(video_embedding, audio_embedding)
817
+
818
+ # Apply the logits scale to the similarity scores (scaling the scores)
819
+ output = model.logits_scale(sim.unsqueeze(-1)).squeeze(-1)
820
+
821
+ if global_score=="True":
822
+ score = output.mean(0)
823
+ else:
824
+ if output.shape[0]<num_avg_frames:
825
+ num_avg_frames = output.shape[0]
826
+ output_batch = output.unfold(0, num_avg_frames, 1)
827
+ score = torch.mean(output_batch, axis=-1)
828
+
829
+ scores.append(score.detach().cpu().numpy())
830
+
831
+ if global_score=="True":
832
+ print("Using global predictions")
833
+ pred_speaker = np.argmax(scores)
834
+ else:
835
+ print("Using per-frame predictions")
836
+ pred_speaker = []
837
+ num_negs = list(range(0, len(all_video_embeddings)))
838
+ for frame_idx in range(len(scores[0])):
839
+ score = [scores[i][frame_idx] for i in num_negs]
840
+ pred_idx = np.argmax(score)
841
+ pred_speaker.append(pred_idx)
842
+
843
+ return pred_speaker, num_avg_frames
844
+
845
+
846
+ def save_video(output_tracks, input_frames, wav_file, result_folder):
847
+
848
+ '''
849
+ This function saves the output video with the active speaker detections
850
+
851
+ Args:
852
+ - output_tracks (list) : List of active speakers in each frame
853
+ - input_frames (array) : Frames to be used to generate the video
854
+ - wav_file (string) : Path of the audio file
855
+ - result_folder (string) : Path of the result folder to save the output video
856
+ Returns:
857
+ - video_output (string) : Path of the output video
858
+ '''
859
+
860
+ try:
861
+ output_frames = []
862
+ for i in range(len(input_frames)):
863
+
864
+ # If the active speaker is found, draw a bounding box around the active speaker
865
+ if i in output_tracks:
866
+ bbox = output_tracks[i]
867
+ x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
868
+ out = cv2.rectangle(input_frames[i].copy(), (x1, y1), (x2, y2), color=[0, 255, 0], thickness=3)
869
+ else:
870
+ out = input_frames[i]
871
+
872
+ output_frames.append(out)
873
+
874
+ # Generate the output video
875
+ output_video_fname = os.path.join(result_folder, "result_active_speaker_det")
876
+ video_output, status = generate_video(output_frames, wav_file, output_video_fname)
877
+ if status != "success":
878
+ return None, status
879
+ except Exception as e:
880
+ return None, f"Error: {str(e)}"
881
+
882
+ return video_output, "success"
883
+
884
+ @spaces.GPU(duration=140)
885
+ def preprocess_asd(video_path, result_folder_input):
886
+
887
+ print("Pre-processing the input video...")
888
+ status = subprocess.call("python preprocess/inference_preprocess.py --data_dir={}/temp --sd_root={}/crops --work_root={}/metadata --data_root={}".format(result_folder_input, result_folder_input, result_folder_input, video_path), shell=True)
889
+ if status != 0:
890
+ msg = "Error in pre-processing the input video, please check the input video and try again..."
891
+ return msg
892
+
893
+ return "success"
894
+
895
+ def process_video_syncoffset(video_path, num_avg_frames, apply_preprocess):
896
+
897
+ try:
898
+ # Extract the video filename
899
+ video_fname = os.path.basename(video_path.split(".")[0])
900
+
901
+ # Create folders to save the inputs and results
902
+ result_folder = os.path.join("results", video_fname)
903
+ result_folder_input = os.path.join(result_folder, "input")
904
+ result_folder_output = os.path.join(result_folder, "output")
905
+
906
+ if os.path.exists(result_folder):
907
+ rmtree(result_folder)
908
+
909
+ os.makedirs(result_folder)
910
+ os.makedirs(result_folder_input)
911
+ os.makedirs(result_folder_output)
912
+
913
+
914
+ # Preprocess the video
915
+ print("Applying preprocessing: ", apply_preprocess)
916
+ wav_file, fps, vid_path_processed, status = preprocess_video(video_path, result_folder_input, apply_preprocess)
917
+ if status != "success":
918
+ return None, status
919
+ print("Successfully preprocessed the video")
920
+
921
+ # Resample the video to 25 fps if it is not already 25 fps
922
+ print("FPS of video: ", fps)
923
+ if fps!=25:
924
+ vid_path, status = resample_video(vid_path_processed, "preprocessed_video_25fps", result_folder_input)
925
+ if status != "success":
926
+ return None, status
927
+ orig_vid_path_25fps, status = resample_video(video_path, "input_video_25fps", result_folder_input)
928
+ if status != "success":
929
+ return None, status
930
+ else:
931
+ vid_path = vid_path_processed
932
+ orig_vid_path_25fps = video_path
933
+
934
+ # Load the original video frames (before pre-processing) - Needed for the final sync-correction
935
+ orig_frames, status = load_video_frames(orig_vid_path_25fps)
936
+ if status != "success":
937
+ return None, status
938
+
939
+ # Load the pre-processed video frames
940
+ frames, status = load_video_frames(vid_path)
941
+ if status != "success":
942
+ return None, status
943
+ print("Successfully extracted the video frames")
944
+
945
+ if len(frames) < num_avg_frames:
946
+ msg = "Error: The input video is too short. Please use a longer input video."
947
+ return None, msg
948
+
949
+ # Load keypoints and check if gestures are visible
950
+ kp_dict, status = get_keypoints(frames)
951
+ if status != "success":
952
+ return None, status
953
+ print("Successfully extracted the keypoints: ", len(kp_dict), len(kp_dict["kps"]))
954
+
955
+ status = check_visible_gestures(kp_dict)
956
+ if status != "success":
957
+ return None, status
958
+
959
+ # Load RGB frames
960
+ rgb_frames, num_frames, orig_masked_frames, status = load_rgb_masked_frames(frames, kp_dict, asd=False, window_frames=25, width=480, height=270)
961
+ if status != "success":
962
+ return None, status
963
+ print("Successfully loaded the RGB frames")
964
+
965
+ # Convert frames to tensor
966
+ rgb_frames = np.transpose(rgb_frames, (4, 0, 1, 2, 3))
967
+ rgb_frames = torch.FloatTensor(rgb_frames).unsqueeze(0)
968
+ B = rgb_frames.size(0)
969
+ print("Successfully converted the frames to tensor")
970
+
971
+ # Load spectrograms
972
+ spec, orig_spec, status = load_spectrograms(wav_file, asd=False, num_frames=num_frames)
973
+ if status != "success":
974
+ return None, status
975
+ spec = torch.FloatTensor(spec).unsqueeze(0).unsqueeze(0).permute(0, 1, 2, 4, 3)
976
+ print("Successfully loaded the spectrograms")
977
+
978
+ # Create input windows
979
+ video_sequences = torch.cat([rgb_frames[:, :, i] for i in range(rgb_frames.size(2))], dim=0)
980
+ audio_sequences = torch.cat([spec[:, :, i] for i in range(spec.size(2))], dim=0)
981
+
982
+ # Load the trained model
983
+ model = Transformer_RGB()
984
+ model = load_checkpoint(CHECKPOINT_PATH, model)
985
+ print("Successfully loaded the model")
986
+
987
+ video_emb, audio_emb = get_embeddings(video_sequences, audio_sequences, model, calc_aud_emb=True)
988
+
989
+ # Process in batches
990
+ # batch_size = 12
991
+ # video_emb = []
992
+ # audio_emb = []
993
+
994
+ # for i in tqdm(range(0, len(video_sequences), batch_size)):
995
+ # video_inp = video_sequences[i:i+batch_size, ]
996
+ # audio_inp = audio_sequences[i:i+batch_size, ]
997
+
998
+ # vid_emb = model.forward_vid(video_inp.to(device))
999
+ # vid_emb = torch.mean(vid_emb, axis=-1).unsqueeze(-1)
1000
+ # aud_emb = model.forward_aud(audio_inp.to(device))
1001
+
1002
+ # video_emb.append(vid_emb.detach())
1003
+ # audio_emb.append(aud_emb.detach())
1004
+
1005
+ # torch.cuda.empty_cache()
1006
+
1007
+ # audio_emb = torch.cat(audio_emb, dim=0)
1008
+ # video_emb = torch.cat(video_emb, dim=0)
1009
+
1010
+ # L2 normalize embeddings
1011
+ video_emb = torch.nn.functional.normalize(video_emb, p=2, dim=1)
1012
+ audio_emb = torch.nn.functional.normalize(audio_emb, p=2, dim=1)
1013
+
1014
+ audio_emb = torch.split(audio_emb, B, dim=0)
1015
+ audio_emb = torch.stack(audio_emb, dim=2)
1016
+ audio_emb = audio_emb.squeeze(3)
1017
+ audio_emb = audio_emb[:, None]
1018
+
1019
+ video_emb = torch.split(video_emb, B, dim=0)
1020
+ video_emb = torch.stack(video_emb, dim=2)
1021
+ video_emb = video_emb.squeeze(3)
1022
+ print("Successfully extracted GestSync embeddings")
1023
+
1024
+ # Calculate sync offset
1025
+ pred_offset, status = calc_optimal_av_offset(video_emb, audio_emb, num_avg_frames, model)
1026
+ if status != "success":
1027
+ return None, status
1028
+ print("Predicted offset: ", pred_offset)
1029
+
1030
+ # Generate sync-corrected video
1031
+ video_output, status = sync_correct_video(video_path, orig_frames, wav_file, pred_offset, result_folder_output, sample_rate=16000, fps=fps)
1032
+ if status != "success":
1033
+ return None, status
1034
+ print("Successfully generated the video:", video_output)
1035
+
1036
+ return video_output, f"Predicted offset: {pred_offset}"
1037
+
1038
+ except Exception as e:
1039
+ return None, f"Error: {str(e)}"
1040
+
1041
+
1042
+ def process_video_activespeaker(video_path, global_speaker, num_avg_frames):
1043
+ try:
1044
+ # Extract the video filename
1045
+ video_fname = os.path.basename(video_path.split(".")[0])
1046
+
1047
+ # Create folders to save the inputs and results
1048
+ result_folder = os.path.join("results", video_fname)
1049
+ result_folder_input = os.path.join(result_folder, "input")
1050
+ result_folder_output = os.path.join(result_folder, "output")
1051
+
1052
+ if os.path.exists(result_folder):
1053
+ rmtree(result_folder)
1054
+
1055
+ os.makedirs(result_folder)
1056
+ os.makedirs(result_folder_input)
1057
+ os.makedirs(result_folder_output)
1058
+
1059
+ if global_speaker=="per-frame-prediction" and num_avg_frames<25:
1060
+ msg = "Number of frames to average need to be set to a minimum of 25 frames. Atleast 1-second context is needed for the model. Please change the num_avg_frames and try again..."
1061
+ return None, msg
1062
+
1063
+ # Read the video
1064
+ try:
1065
+ vr = VideoReader(video_path, ctx=cpu(0))
1066
+ except:
1067
+ msg = "Oops! Could not load the input video file"
1068
+ return None, msg
1069
+
1070
+ # Get the FPS of the video
1071
+ fps = vr.get_avg_fps()
1072
+ print("FPS of video: ", fps)
1073
+
1074
+ # Resample the video to 25 FPS if the original video is of a different frame-rate
1075
+ if fps!=25:
1076
+ test_video_25fps, status = resample_video(video_path, video_fname, result_folder_input)
1077
+ if status != "success":
1078
+ return None, status
1079
+ else:
1080
+ test_video_25fps = video_path
1081
+
1082
+ # Load the video frames
1083
+ orig_frames, status = load_video_frames(test_video_25fps)
1084
+ if status != "success":
1085
+ return None, status
1086
+
1087
+ # Extract and save the audio file
1088
+ orig_wav_file, status = extract_audio(video_path, result_folder)
1089
+ if status != "success":
1090
+ return None, status
1091
+
1092
+ # Pre-process and extract per-speaker tracks in each scene
1093
+ print("Pre-processing the input video...")
1094
+ # status = subprocess.call("python preprocess/inference_preprocess.py --data_dir={}/temp --sd_root={}/crops --work_root={}/metadata --data_root={}".format(result_folder_input, result_folder_input, result_folder_input, video_path), shell=True)
1095
+ # if status != 0:
1096
+ # msg = "Error in pre-processing the input video, please check the input video and try again..."
1097
+ # return None, msg
1098
+ status = preprocess_asd(video_path, result_folder_input)
1099
+ if status != "success":
1100
+ return None, status
1101
+
1102
+ # Load the tracks file saved during pre-processing
1103
+ with open('{}/metadata/tracks.pckl'.format(result_folder_input), 'rb') as file:
1104
+ tracks = pickle.load(file)
1105
+
1106
+
1107
+ # Create a dictionary of all tracks found along with the bounding-boxes
1108
+ track_dict = {}
1109
+ for scene_num in range(len(tracks)):
1110
+ track_dict[scene_num] = {}
1111
+ for i in range(len(tracks[scene_num])):
1112
+ track_dict[scene_num][i] = {}
1113
+ for frame_num, bbox in zip(tracks[scene_num][i]['track']['frame'], tracks[scene_num][i]['track']['bbox']):
1114
+ track_dict[scene_num][i][frame_num] = bbox
1115
+
1116
+ # Get the total number of scenes
1117
+ test_scenes = os.listdir("{}/crops".format(result_folder_input))
1118
+ print("Total scenes found in the input video = ", len(test_scenes))
1119
+
1120
+ # Load the trained model
1121
+ model = Transformer_RGB()
1122
+ model = load_checkpoint(CHECKPOINT_PATH, model)
1123
+
1124
+ # Compute the active speaker in each scene
1125
+ output_tracks = {}
1126
+ for scene_num in tqdm(range(len(test_scenes))):
1127
+ test_videos = glob(os.path.join("{}/crops".format(result_folder_input), "scene_{}".format(str(scene_num)), "*.avi"))
1128
+ test_videos.sort(key=lambda x: int(os.path.basename(x).split('.')[0]))
1129
+ print("Scene {} -> Total video files found (speaker-specific tracks) = {}".format(scene_num, len(test_videos)))
1130
+
1131
+ if len(test_videos)<=1:
1132
+ msg = "To detect the active speaker, at least 2 visible speakers are required for each scene! Please check the input video and try again..."
1133
+ return None, msg
1134
+
1135
+ # Load the audio file
1136
+ audio_file = glob(os.path.join("{}/crops".format(result_folder_input), "scene_{}".format(str(scene_num)), "*.wav"))[0]
1137
+ spec, _, status = load_spectrograms(audio_file, asd=True)
1138
+ if status != "success":
1139
+ return None, status
1140
+ spec = torch.FloatTensor(spec).unsqueeze(0).unsqueeze(0).permute(0,1,2,4,3)
1141
+ print("Successfully loaded the spectrograms")
1142
+
1143
+ # Load the masked input frames
1144
+ all_masked_frames, all_orig_masked_frames, status = load_masked_input_frames(test_videos, spec, audio_file, scene_num, result_folder_input)
1145
+ if status != "success":
1146
+ return None, status
1147
+ print("Successfully loaded the masked input frames")
1148
+
1149
+ # Prepare the audio and video sequences for the model
1150
+ audio_sequences = torch.cat([spec[:, :, i] for i in range(spec.size(2))], dim=0)
1151
+
1152
+ print("Obtaining audio and video embeddings...")
1153
+ all_video_embs = []
1154
+ for idx in tqdm(range(len(all_masked_frames))):
1155
+ with torch.no_grad():
1156
+ video_sequences = torch.cat([all_masked_frames[idx][:, :, i] for i in range(all_masked_frames[idx].size(2))], dim=0)
1157
+
1158
+ if idx==0:
1159
+ video_emb, audio_emb = get_embeddings(video_sequences, audio_sequences, model, calc_aud_emb=True)
1160
+ else:
1161
+ video_emb = get_embeddings(video_sequences, audio_sequences, model, calc_aud_emb=False)
1162
+ all_video_embs.append(video_emb)
1163
+ print("Successfully extracted GestSync embeddings")
1164
+
1165
+ # Predict the active speaker in each scene
1166
+ if global_speaker=="per-frame-prediction":
1167
+ predictions, num_avg_frames = predict_active_speaker(all_video_embs, audio_emb, "False", num_avg_frames, model)
1168
+ else:
1169
+ predictions, _ = predict_active_speaker(all_video_embs, audio_emb, "True", num_avg_frames, model)
1170
+
1171
+ # Get the frames present in the scene
1172
+ frames_scene = tracks[scene_num][0]['track']['frame']
1173
+
1174
+ # Prepare the active speakers list to draw the bounding boxes
1175
+ if global_speaker=="global-prediction":
1176
+ print("Aggregating scores using global predictoins")
1177
+ active_speakers = [predictions]*len(frames_scene)
1178
+ start, end = 0, len(frames_scene)
1179
+ else:
1180
+ print("Aggregating scores using per-frame predictions")
1181
+ active_speakers = [0]*len(frames_scene)
1182
+ mid = num_avg_frames//2
1183
+
1184
+ if num_avg_frames%2==0:
1185
+ frame_pred = len(frames_scene)-(mid*2)+1
1186
+ start, end = mid, len(frames_scene)-mid+1
1187
+ else:
1188
+ frame_pred = len(frames_scene)-(mid*2)
1189
+ start, end = mid, len(frames_scene)-mid
1190
+
1191
+ print("Frame scene: {} | Avg frames: {} | Frame predictions: {}".format(len(frames_scene), num_avg_frames, frame_pred))
1192
+ if len(predictions) != frame_pred:
1193
+ msg = "Predicted frames {} and input video frames {} do not match!!".format(len(predictions), frame_pred)
1194
+ return None, msg
1195
+
1196
+ active_speakers[start:end] = predictions[0:]
1197
+
1198
+ # Depending on the num_avg_frames, interpolate the intial and final frame predictions to get a full video output
1199
+ initial_preds = max(set(predictions[:num_avg_frames]), key=predictions[:num_avg_frames].count)
1200
+ active_speakers[0:start] = [initial_preds] * start
1201
+
1202
+ final_preds = max(set(predictions[-num_avg_frames:]), key=predictions[-num_avg_frames:].count)
1203
+ active_speakers[end:] = [final_preds] * (len(frames_scene) - end)
1204
+ start, end = 0, len(active_speakers)
1205
+
1206
+ # Get the output tracks for each frame
1207
+ pred_idx = 0
1208
+ for frame in frames_scene[start:end]:
1209
+ label = active_speakers[pred_idx]
1210
+ pred_idx += 1
1211
+ output_tracks[frame] = track_dict[scene_num][label][frame]
1212
+
1213
+ # Save the output video
1214
+ video_output, status = save_video(output_tracks, orig_frames.copy(), orig_wav_file, result_folder_output)
1215
+ if status != "success":
1216
+ return None, status
1217
+ print("Successfully saved the output video: ", video_output)
1218
+
1219
+ return video_output, "success"
1220
+
1221
+ except Exception as e:
1222
+ return None, f"Error: {str(e)}"
1223
+
1224
+ if __name__ == "__main__":
1225
+
1226
+
1227
+ # Custom CSS and HTML
1228
+ custom_css = """
1229
+ <style>
1230
+ body {
1231
+ background-color: #ffffff;
1232
+ color: #333333; /* Default text color */
1233
+ }
1234
+ .container {
1235
+ max-width: 100% !important;
1236
+ padding-left: 0 !important;
1237
+ padding-right: 0 !important;
1238
+ }
1239
+ .header {
1240
+ background-color: #f0f0f0;
1241
+ color: #333333;
1242
+ padding: 30px;
1243
+ margin-bottom: 30px;
1244
+ text-align: center;
1245
+ font-family: 'Helvetica Neue', Arial, sans-serif;
1246
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
1247
+ }
1248
+ .header h1 {
1249
+ font-size: 36px;
1250
+ margin-bottom: 15px;
1251
+ font-weight: bold;
1252
+ color: #333333; /* Explicitly set heading color */
1253
+ }
1254
+ .header h2 {
1255
+ font-size: 24px;
1256
+ margin-bottom: 10px;
1257
+ color: #333333; /* Explicitly set subheading color */
1258
+ }
1259
+ .header p {
1260
+ font-size: 18px;
1261
+ margin: 5px 0;
1262
+ color: #666666;
1263
+ }
1264
+ .blue-text {
1265
+ color: #4a90e2;
1266
+ }
1267
+ /* Custom styles for slider container */
1268
+ .slider-container {
1269
+ background-color: white !important;
1270
+ padding-top: 0.9em;
1271
+ padding-bottom: 0.9em;
1272
+ }
1273
+ /* Add gap before examples */
1274
+ .examples-holder {
1275
+ margin-top: 2em;
1276
+ }
1277
+
1278
+ /* Set fixed size for example videos */
1279
+ .gradio-container .gradio-examples .gr-sample {
1280
+ width: 240px !important;
1281
+ height: 135px !important;
1282
+ object-fit: cover;
1283
+ display: inline-block;
1284
+ margin-right: 10px;
1285
+ }
1286
+
1287
+ .gradio-container .gradio-examples {
1288
+ display: flex;
1289
+ flex-wrap: wrap;
1290
+ gap: 10px;
1291
+ }
1292
+
1293
+ /* Ensure the parent container does not stretch */
1294
+ .gradio-container .gradio-examples {
1295
+ max-width: 100%;
1296
+ overflow: hidden;
1297
+ }
1298
+
1299
+ /* Additional styles to ensure proper sizing in Safari */
1300
+ .gradio-container .gradio-examples .gr-sample img {
1301
+ width: 240px !important;
1302
+ height: 135px !important;
1303
+ object-fit: cover;
1304
+ }
1305
+ </style>
1306
+ """
1307
+
1308
+ custom_html = custom_css + """
1309
+ <div class="header">
1310
+ <h1><span class="blue-text">GestSync:</span> Determining who is speaking without a talking head</h1>
1311
+ <h2>Synchronization and Active Speaker Detection Demo</h2>
1312
+ <p><a href='https://www.robots.ox.ac.uk/~vgg/research/gestsync/'>Project Page</a> | <a href='https://github.com/Sindhu-Hegde/gestsync'>Github</a> | <a href='https://arxiv.org/abs/2310.05304'>Paper</a></p>
1313
+ </div>
1314
+ """
1315
+
1316
+
1317
+ tips = """
1318
+ <div>
1319
+ <br><br>
1320
+ Please give us a 🌟 on <a href='https://github.com/Sindhu-Hegde/gestsync'>Github</a> if you like our work!
1321
+
1322
+ Tips to get better results:
1323
+ <ul>
1324
+ <li>Number of Average Frames: Higher the number, better the results.</li>
1325
+ <li>Clicking on "apply pre-processing" will give better results for synchornization, but this is an expensive operation and might take a while.</li>
1326
+ <li>Input videos with clearly visible gestures work better.</li>
1327
+ </ul>
1328
+
1329
+ </div>
1330
+ """
1331
+
1332
+ # Define functions
1333
+ def toggle_slider(global_speaker):
1334
+ if global_speaker == "per-frame-prediction":
1335
+ return gr.update(visible=True)
1336
+ else:
1337
+ return gr.update(visible=False)
1338
+
1339
+ def toggle_demo(demo_choice):
1340
+ if demo_choice == "Synchronization-correction":
1341
+ return (
1342
+ gr.update(value=None, visible=True), # video_input
1343
+ gr.update(value=75, visible=True), # num_avg_frames
1344
+ gr.update(value=None, visible=True), # apply_preprocess
1345
+ gr.update(value="global-prediction", visible=False), # global_speaker
1346
+ gr.update(value=None, visible=True), # output_video
1347
+ gr.update(value="", visible=True), # result_text
1348
+ gr.update(visible=True), # submit_button
1349
+ gr.update(visible=True), # clear_button
1350
+ gr.update(visible=True), # sync_examples
1351
+ gr.update(visible=False), # asd_examples
1352
+ gr.update(visible=True) # tips
1353
+ )
1354
+ else:
1355
+ return (
1356
+ gr.update(value=None, visible=True), # video_input
1357
+ gr.update(value=75, visible=True), # num_avg_frames
1358
+ gr.update(value=None, visible=False), # apply_preprocess
1359
+ gr.update(value="global-prediction", visible=True), # global_speaker
1360
+ gr.update(value=None, visible=True), # output_video
1361
+ gr.update(value="", visible=True), # result_text
1362
+ gr.update(visible=True), # submit_button
1363
+ gr.update(visible=True), # clear_button
1364
+ gr.update(visible=False), # sync_examples
1365
+ gr.update(visible=True), # asd_examples
1366
+ gr.update(visible=True) # tips
1367
+ )
1368
+
1369
+ def clear_inputs():
1370
+ return None, None, "global-prediction", 75, None, "", None
1371
+
1372
+ def process_video(video_input, demo_choice, global_speaker, num_avg_frames, apply_preprocess):
1373
+ if demo_choice == "Synchronization-correction":
1374
+ return process_video_syncoffset(video_input, num_avg_frames, apply_preprocess)
1375
+ else:
1376
+ return process_video_activespeaker(video_input, global_speaker, num_avg_frames)
1377
+
1378
+ # Define paths to sample videos
1379
+ sync_sample_videos = [
1380
+ ["samples/sync_sample_1.mp4"],
1381
+ ["samples/sync_sample_2.mp4"]
1382
+ ]
1383
+
1384
+ asd_sample_videos = [
1385
+ ["samples/asd_sample_1.mp4"],
1386
+ ["samples/asd_sample_2.mp4"]
1387
+ ]
1388
+
1389
+ # Define Gradio interface
1390
+ with gr.Blocks(css=custom_css, theme=gr.themes.Default(primary_hue=gr.themes.colors.red, secondary_hue=gr.themes.colors.pink)) as demo:
1391
+ gr.HTML(custom_html)
1392
+ demo_choice = gr.Radio(
1393
+ choices=["Synchronization-correction", "Active-speaker-detection"],
1394
+ label="Please select the task you want to perform"
1395
+ )
1396
+ with gr.Row():
1397
+ with gr.Column():
1398
+ video_input = gr.Video(label="Upload Video", height=400, visible=False)
1399
+ num_avg_frames = gr.Slider(
1400
+ minimum=50,
1401
+ maximum=150,
1402
+ step=5,
1403
+ value=75,
1404
+ label="Number of Average Frames",
1405
+ visible=False
1406
+ )
1407
+ apply_preprocess = gr.Checkbox(label="Apply Preprocessing", value=False, visible=False)
1408
+ global_speaker = gr.Radio(
1409
+ choices=["global-prediction", "per-frame-prediction"],
1410
+ value="global-prediction",
1411
+ label="Global Speaker Prediction",
1412
+ visible=False
1413
+ )
1414
+ global_speaker.change(
1415
+ fn=toggle_slider,
1416
+ inputs=global_speaker,
1417
+ outputs=num_avg_frames
1418
+ )
1419
+ with gr.Column():
1420
+ output_video = gr.Video(label="Output Video", height=400, visible=False)
1421
+ result_text = gr.Textbox(label="Result", visible=False)
1422
+
1423
+ with gr.Row():
1424
+ submit_button = gr.Button("Submit", variant="primary", visible=False)
1425
+ clear_button = gr.Button("Clear", visible=False)
1426
+
1427
+ # Add a gap before examples
1428
+ gr.HTML('<div class="examples-holder"></div>')
1429
+
1430
+
1431
+ # Add examples that only populate the video input
1432
+ sync_examples = gr.Dataset(
1433
+ samples=sync_sample_videos,
1434
+ components=[video_input],
1435
+ type="values",
1436
+ visible=False
1437
+ )
1438
+
1439
+ asd_examples = gr.Dataset(
1440
+ samples=asd_sample_videos,
1441
+ components=[video_input],
1442
+ type="values",
1443
+ visible=False
1444
+ )
1445
+
1446
+ tips = gr.Markdown(tips, visible=False)
1447
+
1448
+
1449
+ demo_choice.change(
1450
+ fn=toggle_demo,
1451
+ inputs=demo_choice,
1452
+ outputs=[video_input, num_avg_frames, apply_preprocess, global_speaker, output_video, result_text, submit_button, clear_button, sync_examples, asd_examples, tips]
1453
+ )
1454
+
1455
+ sync_examples.select(
1456
+ fn=lambda x: gr.update(value=x[0], visible=True),
1457
+ inputs=sync_examples,
1458
+ outputs=video_input
1459
+ )
1460
+
1461
+ asd_examples.select(
1462
+ fn=lambda x: gr.update(value=x[0], visible=True),
1463
+ inputs=asd_examples,
1464
+ outputs=video_input
1465
+ )
1466
+
1467
+
1468
+ submit_button.click(
1469
+ fn=process_video,
1470
+ inputs=[video_input, demo_choice, global_speaker, num_avg_frames, apply_preprocess],
1471
+ outputs=[output_video, result_text]
1472
+ )
1473
+
1474
+ clear_button.click(
1475
+ fn=clear_inputs,
1476
+ inputs=[],
1477
+ outputs=[demo_choice, video_input, global_speaker, num_avg_frames, apply_preprocess, result_text, output_video]
1478
+ )
1479
+
1480
+
1481
+ # Launch the interface
1482
+ demo.launch(allowed_paths=["."], share=True)
preprocess/inference_preprocess.py CHANGED
@@ -4,6 +4,7 @@ import sys, os, argparse, pickle, subprocess, cv2, math
4
  import numpy as np
5
  from shutil import rmtree, copy, copytree
6
  from tqdm import tqdm
 
7
 
8
  import scenedetect
9
  from scenedetect.video_manager import VideoManager
@@ -33,7 +34,7 @@ parser.add_argument('--work_root', type=str, required=True, help='Path to save m
33
  parser.add_argument('--data_root', type=str, required=True, help='Directory containing ONLY full uncropped videos')
34
  opt = parser.parse_args()
35
 
36
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
 
38
  def bb_intersection_over_union(boxA, boxB):
39
  xA = max(boxA[0], boxB[0])
@@ -181,7 +182,7 @@ def inference_video(opt, padding=0):
181
  if not success:
182
  break
183
 
184
- image_np = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
185
 
186
  # Perform person detection
187
  results = yolo_model(image_np, verbose=False)
 
4
  import numpy as np
5
  from shutil import rmtree, copy, copytree
6
  from tqdm import tqdm
7
+ import torch
8
 
9
  import scenedetect
10
  from scenedetect.video_manager import VideoManager
 
34
  parser.add_argument('--data_root', type=str, required=True, help='Directory containing ONLY full uncropped videos')
35
  opt = parser.parse_args()
36
 
37
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
 
39
  def bb_intersection_over_union(boxA, boxB):
40
  xA = max(boxA[0], boxB[0])
 
182
  if not success:
183
  break
184
 
185
+ image_np = cv2.cvtColor(image.to(device), cv2.COLOR_BGR2RGB)
186
 
187
  # Perform person detection
188
  results = yolo_model(image_np, verbose=False)