import gradio as gr import os import torch from shutil import rmtree from torch import nn from torch.nn import functional as F import numpy as np import subprocess import cv2 import pickle import librosa from ultralytics import YOLO from decord import VideoReader from decord import cpu, gpu from utils.audio_utils import * from utils.inference_utils import * from sync_models.gestsync_models import * from shutil import rmtree, copy, copytree import scenedetect from scenedetect.video_manager import VideoManager from scenedetect.scene_manager import SceneManager from scenedetect.stats_manager import StatsManager from scenedetect.detectors import ContentDetector from scipy.interpolate import interp1d from scipy import signal from tqdm import tqdm from glob import glob from scipy.io.wavfile import write import mediapipe as mp from protobuf_to_dict import protobuf_to_dict import warnings import spaces mp_holistic = mp.solutions.holistic warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=UserWarning) # Initialize global variables CHECKPOINT_PATH = "model_rgb.pth" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") use_cuda = torch.cuda.is_available() print("Use cuda status: ", use_cuda) batch_size = 24 fps = 25 n_negative_samples = 100 facedet_scale=0.25 crop_scale=0 min_track=50 frame_rate=25 num_failed_det=25 min_frame_size=64 print("Device: ", device) # Initialize the mediapipe holistic keypoint detection model holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) def bb_intersection_over_union(boxA, boxB): xA = max(boxA[0], boxB[0]) yA = max(boxA[1], boxB[1]) xB = min(boxA[2], boxB[2]) yB = min(boxB[3], boxB[3]) interArea = max(0, xB - xA) * max(0, yB - yA) boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1]) boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1]) iou = interArea / float(boxAArea + boxBArea - interArea) return iou def track_shot(scenefaces): iouThres = 0.5 # Minimum IOU between consecutive face detections tracks = [] while True: track = [] for framefaces in scenefaces: for face in framefaces: if track == []: track.append(face) framefaces.remove(face) elif face['frame'] - track[-1]['frame'] <= num_failed_det: iou = bb_intersection_over_union(face['bbox'], track[-1]['bbox']) if iou > iouThres: track.append(face) framefaces.remove(face) continue else: break if track == []: break elif len(track) > min_track: framenum = np.array([f['frame'] for f in track]) bboxes = np.array([np.array(f['bbox']) for f in track]) frame_i = np.arange(framenum[0], framenum[-1] + 1) bboxes_i = [] for ij in range(0, 4): interpfn = interp1d(framenum, bboxes[:, ij]) bboxes_i.append(interpfn(frame_i)) bboxes_i = np.stack(bboxes_i, axis=1) if max(np.mean(bboxes_i[:, 2] - bboxes_i[:, 0]), np.mean(bboxes_i[:, 3] - bboxes_i[:, 1])) > min_frame_size: tracks.append({'frame': frame_i, 'bbox': bboxes_i}) return tracks def check_folder(folder): if os.path.exists(folder): return True return False def del_folder(folder): if os.path.exists(folder): rmtree(folder) def read_video(o, start_idx): with open(o, 'rb') as o: video_stream = VideoReader(o) if start_idx > 0: video_stream.skip_frames(start_idx) return video_stream def crop_video(avi_dir, tmp_dir, track, cropfile, tight_scale=1): fourcc = cv2.VideoWriter_fourcc(*'XVID') vOut = cv2.VideoWriter(cropfile + '.avi', fourcc, frame_rate, (480, 270)) dets = {'x': [], 'y': [], 's': [], 'bbox': track['bbox'], 'frame': track['frame']} for det in track['bbox']: # Reduce the size of the bounding box by a small factor if tighter crops are needed (default -> no reduction in size) width = (det[2] - det[0]) * tight_scale height = (det[3] - det[1]) * tight_scale center_x = (det[0] + det[2]) / 2 center_y = (det[1] + det[3]) / 2 dets['s'].append(max(height, width) / 2) dets['y'].append(center_y) # crop center y dets['x'].append(center_x) # crop center x # Smooth detections dets['s'] = signal.medfilt(dets['s'], kernel_size=13) dets['x'] = signal.medfilt(dets['x'], kernel_size=13) dets['y'] = signal.medfilt(dets['y'], kernel_size=13) videofile = os.path.join(avi_dir, 'video.avi') frame_no_to_start = track['frame'][0] video_stream = cv2.VideoCapture(videofile) video_stream.set(cv2.CAP_PROP_POS_FRAMES, frame_no_to_start) for fidx, frame in enumerate(track['frame']): cs = crop_scale bs = dets['s'][fidx] # Detection box size bsi = int(bs * (1 + 2 * cs)) # Pad videos by this amount image = video_stream.read()[1] frame = np.pad(image, ((bsi, bsi), (bsi, bsi), (0, 0)), 'constant', constant_values=(110, 110)) my = dets['y'][fidx] + bsi # BBox center Y mx = dets['x'][fidx] + bsi # BBox center X face = frame[int(my - bs):int(my + bs * (1 + 2 * cs)), int(mx - bs * (1 + cs)):int(mx + bs * (1 + cs))] vOut.write(cv2.resize(face, (480, 270))) video_stream.release() audiotmp = os.path.join(tmp_dir, 'audio.wav') audiostart = (track['frame'][0]) / frame_rate audioend = (track['frame'][-1] + 1) / frame_rate vOut.release() # ========== CROP AUDIO FILE ========== command = ("ffmpeg -hide_banner -loglevel panic -y -i %s -ss %.3f -to %.3f %s" % (os.path.join(avi_dir, 'audio.wav'), audiostart, audioend, audiotmp)) output = subprocess.call(command, shell=True, stdout=None) copy(audiotmp, cropfile + '.wav') # print('Written %s' % cropfile) # print('Mean pos: x %.2f y %.2f s %.2f' % (np.mean(dets['x']), np.mean(dets['y']), np.mean(dets['s']))) return {'track': track, 'proc_track': dets} @spaces.GPU(duration=60) def inference_video(avi_dir, work_dir, padding=0): videofile = os.path.join(avi_dir, 'video.avi') vidObj = cv2.VideoCapture(videofile) yolo_model = YOLO("yolov9m.pt") global dets, fidx dets = [] fidx = 0 print("Detecting people in the video using YOLO...") def generate_detections(): global dets, fidx while True: success, image = vidObj.read() if not success: break image_np = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Perform person detection results = yolo_model(image_np, verbose=False) detections = results[0].boxes dets.append([]) for i, det in enumerate(detections): x1, y1, x2, y2 = det.xyxy[0].detach().cpu().numpy() cls = det.cls[0].detach().cpu().numpy() conf = det.conf[0].detach().cpu().numpy() if int(cls) == 0 and conf>0.7: # Class 0 is 'person' in COCO dataset x1 = max(0, int(x1) - padding) y1 = max(0, int(y1) - padding) x2 = min(image_np.shape[1], int(x2) + padding) y2 = min(image_np.shape[0], int(y2) + padding) dets[-1].append({'frame': fidx, 'bbox': [x1, y1, x2, y2], 'conf': conf}) fidx += 1 yield return dets for _ in tqdm(generate_detections()): pass print("Successfully detected people in the video") savepath = os.path.join(work_dir, 'faces.pckl') with open(savepath, 'wb') as fil: pickle.dump(dets, fil) return dets def scene_detect(avi_dir, work_dir): video_manager = VideoManager([os.path.join(avi_dir, 'video.avi')]) stats_manager = StatsManager() scene_manager = SceneManager(stats_manager) scene_manager.add_detector(ContentDetector()) base_timecode = video_manager.get_base_timecode() video_manager.set_downscale_factor() video_manager.start() scene_manager.detect_scenes(frame_source=video_manager) scene_list = scene_manager.get_scene_list(base_timecode) savepath = os.path.join(work_dir, 'scene.pckl') if scene_list == []: scene_list = [(video_manager.get_base_timecode(), video_manager.get_current_timecode())] with open(savepath, 'wb') as fil: pickle.dump(scene_list, fil) print('%s - scenes detected %d' % (os.path.join(avi_dir, 'video.avi'), len(scene_list))) return scene_list def process_video_asd(file, sd_root, work_root, data_root, avi_dir, tmp_dir, work_dir, crop_dir, frames_dir): video_file_name = os.path.basename(file.strip()) sd_dest_folder = sd_root work_dest_folder = work_root del_folder(sd_dest_folder) del_folder(work_dest_folder) videofile = file if os.path.exists(work_dir): rmtree(work_dir) if os.path.exists(crop_dir): rmtree(crop_dir) if os.path.exists(avi_dir): rmtree(avi_dir) if os.path.exists(frames_dir): rmtree(frames_dir) if os.path.exists(tmp_dir): rmtree(tmp_dir) os.makedirs(work_dir) os.makedirs(crop_dir) os.makedirs(avi_dir) os.makedirs(frames_dir) os.makedirs(tmp_dir) command = ("ffmpeg -hide_banner -loglevel panic -y -i %s -qscale:v 2 -async 1 -r 25 %s" % (videofile, os.path.join(avi_dir, 'video.avi'))) status = subprocess.call(command, shell=True, stdout=None) if status != 0: msg = "Error in pre-processing the video, please check the input video and try again" return msg command = ("ffmpeg -hide_banner -loglevel panic -y -i %s -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (os.path.join(avi_dir, 'video.avi'), os.path.join(avi_dir, 'audio.wav'))) status = subprocess.call(command, shell=True, stdout=None) if status != 0: msg = "Error in pre-processing the video, please check the input video and try again" return msg try: faces = inference_video(avi_dir, work_dir) except: msg = "Error in pre-processing the video, please check the input video and try again" return msg print("YOLO done") print("Detecting scenes in the video...") try: scene = scene_detect(avi_dir, work_dir) except: msg = "Error in detecting the scenes in the video, please check the input video and try again" return msg print("Scene detect done") print("Tracking video...") allscenes = [] for shot in scene: if shot[1].frame_num - shot[0].frame_num >= min_track: allscenes.append(track_shot(faces[shot[0].frame_num:shot[1].frame_num])) print("Cropping video...") alltracks = [] for sc_num in range(len(allscenes)): vidtracks = [] for ii, track in enumerate(allscenes[sc_num]): os.makedirs(os.path.join(crop_dir, 'scene_'+str(sc_num)), exist_ok=True) vidtracks.append(crop_video(avi_dir, tmp_dir, track, os.path.join(crop_dir, 'scene_'+str(sc_num), '%05d' % ii))) alltracks.append(vidtracks) savepath = os.path.join(work_dir, 'tracks.pckl') with open(savepath, 'wb') as fil: pickle.dump(alltracks, fil) rmtree(tmp_dir) rmtree(avi_dir) rmtree(frames_dir) copytree(crop_dir, sd_dest_folder) copytree(work_dir, work_dest_folder) return "success" @spaces.GPU(duration=60) def get_person_detection(all_frames, frame_count, padding=20): try: # Load YOLOv9 model (pre-trained on COCO dataset) yolo_model = YOLO("yolov9s.pt") print("Loaded the YOLO model") person_videos = {} person_tracks = {} print("Processing the frames...") for frame_idx in tqdm(range(frame_count)): frame = all_frames[frame_idx] # Perform person detection results = yolo_model(frame, verbose=False) detections = results[0].boxes for i, det in enumerate(detections): x1, y1, x2, y2 = det.xyxy[0] cls = det.cls[0] if int(cls) == 0: # Class 0 is 'person' in COCO dataset x1 = max(0, int(x1) - padding) y1 = max(0, int(y1) - padding) x2 = min(frame.shape[1], int(x2) + padding) y2 = min(frame.shape[0], int(y2) + padding) if i not in person_videos: person_videos[i] = [] person_tracks[i] = [] person_videos[i].append(frame) person_tracks[i].append([x1,y1,x2,y2]) num_persons = 0 for i in person_videos.keys(): if len(person_videos[i]) >= frame_count//2: num_persons+=1 if num_persons==0: msg = "No person detected in the video! Please give a video with one person as input" return None, None, msg if num_persons>1: msg = "More than one person detected in the video! Please give a video with only one person as input" return None, None, msg except: msg = "Error in detecting person in the video, please check the input video and try again" return None, None, msg return person_videos, person_tracks, "success" def preprocess_video(path, result_folder, apply_preprocess, padding=20): ''' This function preprocesses the input video to extract the audio and crop the frames using YOLO model Args: - path (string) : Path of the input video file - result_folder (string) : Path of the folder to save the extracted audio and cropped video - padding (int) : Padding to add to the bounding box Returns: - wav_file (string) : Path of the extracted audio file - fps (int) : FPS of the input video - video_output (string) : Path of the cropped video file - msg (string) : Message to be returned ''' # Load all video frames try: vr = VideoReader(path, ctx=cpu(0)) fps = vr.get_avg_fps() frame_count = len(vr) except: msg = "Oops! Could not load the video. Please check the input video and try again." return None, None, None, msg if frame_count < 25: msg = "Not enough frames to process! Please give a longer video as input" return None, None, None, msg # Extract the audio from the input video file using ffmpeg wav_file = os.path.join(result_folder, "audio.wav") status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -async 1 -ac 1 -vn \ -acodec pcm_s16le -ar 16000 %s -y' % (path, wav_file), shell=True) if status != 0: msg = "Oops! Could not load the audio file. Please check the input video and try again." return None, None, None, msg print("Extracted the audio from the video") if apply_preprocess=="True": all_frames = [] for k in range(len(vr)): all_frames.append(vr[k].asnumpy()) all_frames = np.asarray(all_frames) print("Extracted the frames for pre-processing") person_videos, person_tracks, msg = get_person_detection(all_frames, frame_count, padding) if msg != "success": return None, None, None, msg # For the person detected, crop the frame based on the bounding box if len(person_videos[0]) > frame_count-10: crop_filename = os.path.join(result_folder, "preprocessed_video.avi") fourcc = cv2.VideoWriter_fourcc(*'DIVX') # Get bounding box coordinates based on person_tracks[i] max_x1 = min([track[0] for track in person_tracks[0]]) max_y1 = min([track[1] for track in person_tracks[0]]) max_x2 = max([track[2] for track in person_tracks[0]]) max_y2 = max([track[3] for track in person_tracks[0]]) max_width = max_x2 - max_x1 max_height = max_y2 - max_y1 out = cv2.VideoWriter(crop_filename, fourcc, fps, (max_width, max_height)) for frame in person_videos[0]: crop = frame[max_y1:max_y2, max_x1:max_x2] crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB) out.write(crop) out.release() no_sound_video = crop_filename.split('.')[0] + '_nosound.mp4' status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (crop_filename, no_sound_video), shell=True) if status != 0: msg = "Oops! Could not preprocess the video. Please check the input video and try again." return None, None, None, msg video_output = crop_filename.split('.')[0] + '.mp4' status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -strict -2 -q:v 1 %s' % (wav_file , no_sound_video, video_output), shell=True) if status != 0: msg = "Oops! Could not preprocess the video. Please check the input video and try again." return None, None, None, msg os.remove(crop_filename) os.remove(no_sound_video) print("Successfully saved the pre-processed video: ", video_output) else: msg = "Could not track the person in the full video! Please give a single-speaker video as input" return None, None, None, msg else: video_output = path return wav_file, fps, video_output, "success" def resample_video(video_file, video_fname, result_folder): ''' This function resamples the video to 25 fps Args: - video_file (string) : Path of the input video file - video_fname (string) : Name of the input video file - result_folder (string) : Path of the folder to save the resampled video Returns: - video_file_25fps (string) : Path of the resampled video file - msg (string) : Message to be returned ''' video_file_25fps = os.path.join(result_folder, '{}.mp4'.format(video_fname)) # Resample the video to 25 fps status = subprocess.call("ffmpeg -hide_banner -loglevel panic -y -i {} -c:v libx264 -preset veryslow -crf 0 -filter:v fps=25 -pix_fmt yuv420p {}".format(video_file, video_file_25fps), shell=True) if status != 0: msg = "Oops! Could not resample the video to 25 FPS. Please check the input video and try again." return None, msg print('Resampled the video to 25 fps: {}'.format(video_file_25fps)) return video_file_25fps, "success" def load_checkpoint(path, model): ''' This function loads the trained model from the checkpoint Args: - path (string) : Path of the checkpoint file - model (object) : Model object Returns: - model (object) : Model object with the weights loaded from the checkpoint ''' # Load the checkpoint checkpoint = torch.load(path, map_location="cpu") s = checkpoint["state_dict"] new_s = {} for k, v in s.items(): new_s[k.replace('module.', '')] = v model.load_state_dict(new_s) print("Loaded checkpoint from: {}".format(path)) return model.eval() def load_video_frames(video_file): ''' This function extracts the frames from the video Args: - video_file (string) : Path of the video file Returns: - frames (list) : List of frames extracted from the video - msg (string) : Message to be returned ''' # Read the video try: vr = VideoReader(video_file, ctx=cpu(0)) except: msg = "Oops! Could not load the input video file" return None, msg # Extract the frames frames = [] for k in range(len(vr)): frames.append(vr[k].asnumpy()) frames = np.asarray(frames) return frames, "success" def get_keypoints(frames): ''' This function extracts the keypoints from the frames using MediaPipe Holistic pipeline Args: - frames (list) : List of frames extracted from the video Returns: - kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames - msg (string) : Message to be returned ''' try: holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) resolution = frames[0].shape all_frame_kps = [] for frame in frames: results = holistic.process(frame) pose, left_hand, right_hand, face = None, None, None, None if results.pose_landmarks is not None: pose = protobuf_to_dict(results.pose_landmarks)['landmark'] if results.left_hand_landmarks is not None: left_hand = protobuf_to_dict(results.left_hand_landmarks)['landmark'] if results.right_hand_landmarks is not None: right_hand = protobuf_to_dict(results.right_hand_landmarks)['landmark'] if results.face_landmarks is not None: face = protobuf_to_dict(results.face_landmarks)['landmark'] frame_dict = {"pose":pose, "left_hand":left_hand, "right_hand":right_hand, "face":face} all_frame_kps.append(frame_dict) kp_dict = {"kps":all_frame_kps, "resolution":resolution} except Exception as e: print("Error: ", e) return None, "Error: Could not extract keypoints from the frames" return kp_dict, "success" def check_visible_gestures(kp_dict): ''' This function checks if the gestures in the video are visible Args: - kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames Returns: - msg (string) : Message to be returned ''' keypoints = kp_dict['kps'] keypoints = np.array(keypoints) if len(keypoints)<25: msg = "Not enough keypoints to process! Please give a longer video as input" return msg pose_count, hand_count = 0, 0 for frame_kp_dict in keypoints: pose = frame_kp_dict["pose"] left_hand = frame_kp_dict["left_hand"] right_hand = frame_kp_dict["right_hand"] if pose is None: pose_count += 1 if left_hand is None and right_hand is None: hand_count += 1 if hand_count/len(keypoints) > 0.6 or pose_count/len(keypoints) > 0.6: msg = "The gestures in the input video are not visible! Please give a video with visible gestures as input." return msg print("Successfully verified the input video - Gestures are visible!") return "success" def load_rgb_masked_frames(input_frames, kp_dict, asd=False, stride=1, window_frames=25, width=480, height=270): ''' This function masks the faces using the keypoints extracted from the frames Args: - input_frames (list) : List of frames extracted from the video - kp_dict (dict) : Dictionary containing the keypoints and the resolution of the frames - asd (bool) : Whether to use padding (needed for active speaker detection task) or not - stride (int) : Stride to extract the frames - window_frames (int) : Number of frames in each window that is given as input to the model - width (int) : Width of the frames - height (int) : Height of the frames Returns: - input_frames (array) : Frame window to be given as input to the model - num_frames (int) : Number of frames to extract - orig_masked_frames (array) : Masked frames extracted from the video - msg (string) : Message to be returned ''' print("Creating masked input frames...") input_frames_masked = [] if kp_dict is None: for img in tqdm(input_frames): img = cv2.resize(img, (width, height)) masked_img = cv2.rectangle(img, (0,0), (width,110), (0,0,0), -1) input_frames_masked.append(masked_img) else: # Face indices to extract the face-coordinates needed for masking face_oval_idx = [10, 21, 54, 58, 67, 93, 103, 109, 127, 132, 136, 148, 149, 150, 152, 162, 172, 176, 234, 251, 284, 288, 297, 323, 332, 338, 356, 361, 365, 377, 378, 379, 389, 397, 400, 454] input_keypoints, resolution = kp_dict['kps'], kp_dict['resolution'] print("Input keypoints: ", len(input_keypoints)) for i, frame_kp_dict in tqdm(enumerate(input_keypoints)): img = input_frames[i] face = frame_kp_dict["face"] if face is None: img = cv2.resize(img, (width, height)) masked_img = cv2.rectangle(img, (0,0), (width,110), (0,0,0), -1) else: face_kps = [] for idx in range(len(face)): if idx in face_oval_idx: x, y = int(face[idx]["x"]*resolution[1]), int(face[idx]["y"]*resolution[0]) face_kps.append((x,y)) face_kps = np.array(face_kps) x1, y1 = min(face_kps[:,0]), min(face_kps[:,1]) x2, y2 = max(face_kps[:,0]), max(face_kps[:,1]) masked_img = cv2.rectangle(img, (0,0), (resolution[1],y2+15), (0,0,0), -1) if masked_img.shape[0] != width or masked_img.shape[1] != height: masked_img = cv2.resize(masked_img, (width, height)) input_frames_masked.append(masked_img) orig_masked_frames = np.array(input_frames_masked) input_frames = np.array(input_frames_masked) / 255. if asd: input_frames = np.pad(input_frames, ((12, 12), (0,0), (0,0), (0,0)), 'edge') input_frames = np.array([input_frames[i:i+window_frames, :, :] for i in range(0,input_frames.shape[0], stride) if (i+window_frames <= input_frames.shape[0])]) print("Successfully created masked input frames") num_frames = input_frames.shape[0] if num_frames<10: msg = "Not enough frames to process! Please give a longer video as input." return None, None, None, msg return input_frames, num_frames, orig_masked_frames, "success" def load_spectrograms(wav_file, asd=False, num_frames=None, window_frames=25, stride=4): ''' This function extracts the spectrogram from the audio file Args: - wav_file (string) : Path of the extracted audio file - asd (bool) : Whether to use padding (needed for active speaker detection task) or not - num_frames (int) : Number of frames to extract - window_frames (int) : Number of frames in each window that is given as input to the model - stride (int) : Stride to extract the audio frames Returns: - spec (array) : Spectrogram array window to be used as input to the model - orig_spec (array) : Spectrogram array extracted from the audio file - msg (string) : Message to be returned ''' # Extract the audio from the input video file using ffmpeg try: wav = librosa.load(wav_file, sr=16000)[0] except: msg = "Oops! Could extract the spectrograms from the audio file. Please check the input and try again." return None, None, msg # Convert to tensor wav = torch.FloatTensor(wav).unsqueeze(0) mel, _, _, _ = wav2filterbanks(wav) spec = mel.squeeze(0).cpu().numpy() orig_spec = spec spec = np.array([spec[i:i+(window_frames*stride), :] for i in range(0, spec.shape[0], stride) if (i+(window_frames*stride) <= spec.shape[0])]) if num_frames is not None: if len(spec) != num_frames: spec = spec[:num_frames] frame_diff = np.abs(len(spec) - num_frames) if frame_diff > 60: print("The input video and audio length do not match - The results can be unreliable! Please check the input video.") if asd: pad_frames = (window_frames//2) spec = np.pad(spec, ((pad_frames, pad_frames), (0,0), (0,0)), 'edge') return spec, orig_spec, "success" def calc_optimal_av_offset(vid_emb, aud_emb, num_avg_frames, model): ''' This function calculates the audio-visual offset between the video and audio Args: - vid_emb (array) : Video embedding array - aud_emb (array) : Audio embedding array - num_avg_frames (int) : Number of frames to average the scores - model (object) : Model object Returns: - offset (int) : Optimal audio-visual offset - msg (string) : Message to be returned ''' pos_vid_emb, all_aud_emb, pos_idx, stride, status = create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames) if status != "success": return None, status scores, _ = calc_av_scores(pos_vid_emb, all_aud_emb, model) offset = scores.argmax()*stride - pos_idx return offset.item(), "success" def create_online_sync_negatives(vid_emb, aud_emb, num_avg_frames, stride=5): ''' This function creates all possible positive and negative audio embeddings to compare and obtain the sync offset Args: - vid_emb (array) : Video embedding array - aud_emb (array) : Audio embedding array - num_avg_frames (int) : Number of frames to average the scores - stride (int) : Stride to extract the negative windows Returns: - vid_emb_pos (array) : Positive video embedding array - aud_emb_posneg (array) : All possible combinations of audio embedding array - pos_idx_frame (int) : Positive video embedding array frame - stride (int) : Stride used to extract the negative windows - msg (string) : Message to be returned ''' slice_size = num_avg_frames aud_emb_posneg = aud_emb.squeeze(1).unfold(-1, slice_size, stride) aud_emb_posneg = aud_emb_posneg.permute([0, 2, 1, 3]) aud_emb_posneg = aud_emb_posneg[:, :int(n_negative_samples/stride)+1] pos_idx = (aud_emb_posneg.shape[1]//2) pos_idx_frame = pos_idx*stride min_offset_frames = -(pos_idx)*stride max_offset_frames = (aud_emb_posneg.shape[1] - pos_idx - 1)*stride print("With the current video length and the number of average frames, the model can predict the offsets in the range: [{}, {}]".format(min_offset_frames, max_offset_frames)) vid_emb_pos = vid_emb[:, :, pos_idx_frame:pos_idx_frame+slice_size] if vid_emb_pos.shape[2] != slice_size: msg = "Video is too short to use {} frames to average the scores. Please use a longer input video or reduce the number of average frames".format(slice_size) return None, None, None, None, msg return vid_emb_pos, aud_emb_posneg, pos_idx_frame, stride, "success" def calc_av_scores(vid_emb, aud_emb, model): ''' This function calls functions to calculate the audio-visual similarity and attention map between the video and audio embeddings Args: - vid_emb (array) : Video embedding array - aud_emb (array) : Audio embedding array - model (object) : Model object Returns: - scores (array) : Audio-visual similarity scores - att_map (array) : Attention map ''' scores = calc_att_map(vid_emb, aud_emb, model) att_map = logsoftmax_2d(torch.Tensor(scores)) scores = scores.mean(-1) return scores, att_map def calc_att_map(vid_emb, aud_emb, model): ''' This function calculates the similarity between the video and audio embeddings Args: - vid_emb (array) : Video embedding array - aud_emb (array) : Audio embedding array - model (object) : Model object Returns: - scores (array) : Audio-visual similarity scores ''' vid_emb = vid_emb[:, :, None] aud_emb = aud_emb.transpose(1, 2) scores = run_func_in_parts(lambda x, y: (x * y).sum(1), vid_emb, aud_emb, part_len=10, dim=3) scores = model.logits_scale(scores[..., None]).squeeze(-1) return scores.detach().cpu().numpy() def generate_video(frames, audio_file, video_fname): ''' This function generates the video from the frames and audio file Args: - frames (array) : Frames to be used to generate the video - audio_file (string) : Path of the audio file - video_fname (string) : Path of the video file Returns: - video_output (string) : Path of the video file - msg (string) : Message to be returned ''' fname = 'inference.avi' video = cv2.VideoWriter(fname, cv2.VideoWriter_fourcc(*'DIVX'), 25, (frames[0].shape[1], frames[0].shape[0])) for i in range(len(frames)): video.write(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB)) video.release() no_sound_video = video_fname + '_nosound.mp4' status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -c copy -an -strict -2 %s' % (fname, no_sound_video), shell=True) if status != 0: msg = "Oops! Could not generate the video. Please check the input video and try again." return None, msg video_output = video_fname + '.mp4' status = subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -c:v libx264 -preset veryslow -crf 18 -pix_fmt yuv420p -strict -2 -q:v 1 -shortest %s' % (audio_file, no_sound_video, video_output), shell=True) if status != 0: msg = "Oops! Could not generate the video. Please check the input video and try again." return None, msg os.remove(fname) os.remove(no_sound_video) return video_output, "success" def sync_correct_video(video_path, frames, wav_file, offset, result_folder, sample_rate=16000, fps=25): ''' This function corrects the video and audio to sync with each other Args: - video_path (string) : Path of the video file - frames (array) : Frames to be used to generate the video - wav_file (string) : Path of the audio file - offset (int) : Predicted sync-offset to be used to correct the video - result_folder (string) : Path of the result folder to save the output sync-corrected video - sample_rate (int) : Sample rate of the audio - fps (int) : Frames per second of the video Returns: - video_output (string) : Path of the video file - msg (string) : Message to be returned ''' if offset == 0: print("The input audio and video are in-sync! No need to perform sync correction.") return video_path, "success" print("Performing Sync Correction...") corrected_frames = np.zeros_like(frames) if offset > 0: audio_offset = int(offset*(sample_rate/fps)) wav = librosa.core.load(wav_file, sr=sample_rate)[0] corrected_wav = wav[audio_offset:] corrected_wav_file = os.path.join(result_folder, "audio_sync_corrected.wav") write(corrected_wav_file, sample_rate, corrected_wav) wav_file = corrected_wav_file corrected_frames = frames elif offset < 0: corrected_frames[0:len(frames)+offset] = frames[np.abs(offset):] corrected_frames = corrected_frames[:len(frames)-np.abs(offset)] corrected_video_path = os.path.join(result_folder, "result_sync_corrected") video_output, status = generate_video(corrected_frames, wav_file, corrected_video_path) if status != "success": return None, status return video_output, "success" def load_masked_input_frames(test_videos, spec, wav_file, scene_num, result_folder): ''' This function loads the masked input frames from the video Args: - test_videos (list) : List of videos to be processed (speaker-specific tracks) - spec (array) : Spectrogram of the audio - wav_file (string) : Path of the audio file - scene_num (int) : Scene number to be used to save the input masked video - result_folder (string) : Path of the folder to save the input masked video Returns: - all_frames (list) : List of masked input frames window to be used as input to the model - all_orig_frames (list) : List of original masked input frames ''' all_frames, all_orig_frames = [], [] for video_num, video in enumerate(test_videos): print("Processing video: ", video) # Load the video frames frames, status = load_video_frames(video) if status != "success": return None, None, status print("Successfully loaded the video frames") # Extract the keypoints from the frames # kp_dict, status = get_keypoints(frames) # if status != "success": # return None, None, status # print("Successfully extracted the keypoints") # Mask the frames using the keypoints extracted from the frames and prepare the input to the model masked_frames, num_frames, orig_masked_frames, status = load_rgb_masked_frames(frames, kp_dict=None, asd=True) if status != "success": return None, None, status print("Successfully loaded the masked frames") # Check if the length of the input frames is equal to the length of the spectrogram if spec.shape[2]!=masked_frames.shape[0]: num_frames = spec.shape[2] masked_frames = masked_frames[:num_frames] orig_masked_frames = orig_masked_frames[:num_frames] frame_diff = np.abs(spec.shape[2] - num_frames) if frame_diff > 60: print("The input video and audio length do not match - The results can be unreliable! Please check the input video.") # Transpose the frames to the correct format frames = np.transpose(masked_frames, (4, 0, 1, 2, 3)) frames = torch.FloatTensor(np.array(frames)).unsqueeze(0) print("Successfully converted the frames to tensor") all_frames.append(frames) all_orig_frames.append(orig_masked_frames) return all_frames, all_orig_frames, "success" def extract_audio(video, result_folder): ''' This function extracts the audio from the video file Args: - video (string) : Path of the video file - result_folder (string) : Path of the folder to save the extracted audio file Returns: - wav_file (string) : Path of the extracted audio file - msg (string) : Message to be returned ''' wav_file = os.path.join(result_folder, "audio.wav") status = subprocess.call('ffmpeg -hide_banner -loglevel panic -threads 1 -y -i %s -async 1 -ac 1 -vn \ -acodec pcm_s16le -ar 16000 %s' % (video, wav_file), shell=True) if status != 0: msg = "Oops! Could not load the audio file in the given input video. Please check the input and try again" return None, msg return wav_file, "success" @spaces.GPU(duration=60) def get_embeddings(video_sequences, audio_sequences, model, asd=False, calc_aud_emb=True): ''' This function extracts the video and audio embeddings from the input frames and audio sequences Args: - video_sequences (array) : Array of video frames to be used as input to the model - audio_sequences (array) : Array of audio frames to be used as input to the model - model (object) : Model object - asd (bool) : Active speaker detection task flag to return the correct dimensions for the embeddings - calc_aud_emb (bool) : Flag to calculate the audio embedding Returns: - video_emb (array) : Video embedding - audio_emb (array) : Audio embedding ''' video_emb = [] audio_emb = [] for i in range(0, len(video_sequences), batch_size): video_inp = video_sequences[i:i+batch_size, ] vid_emb = model.forward_vid(video_inp, return_feats=False) vid_emb = torch.mean(vid_emb, axis=-1) if not asd: vid_emb = vid_emb.unsqueeze(-1) video_emb.extend(vid_emb.detach().cpu().numpy()) if calc_aud_emb: audio_inp = audio_sequences[i:i+batch_size, ] aud_emb = model.forward_aud(audio_inp) audio_emb.extend(aud_emb.detach().cpu().numpy()) torch.cuda.empty_cache() video_emb = np.array(video_emb) print("Video Embedding Shape: ", video_emb.shape) if calc_aud_emb: audio_emb = np.array(audio_emb) print("Audio Embedding Shape: ", audio_emb.shape) return video_emb, audio_emb return video_emb def predict_active_speaker(all_video_embeddings, audio_embedding, global_score, num_avg_frames, model): ''' This function predicts the active speaker in each frame Args: - all_video_embeddings (array) : Array of video embeddings of all speakers - audio_embedding (array) : Audio embedding - global_score (bool) : Flag to calculate the global score - num_avg_frames (int) : Number of frames to average the scores - model (object) : Model object Returns: - pred_speaker (list) : List of active speakers in each frame - num_avg_frames (int) : Number of frames to average the scores ''' cos = nn.CosineSimilarity(dim=1) audio_embedding = torch.tensor(audio_embedding).squeeze(2) scores = [] for i in range(len(all_video_embeddings)): video_embedding = torch.tensor(all_video_embeddings[i]) # Compute the similarity of each speaker's video embeddings with the audio embedding sim = cos(video_embedding, audio_embedding) # Apply the logits scale to the similarity scores (scaling the scores) output = model.logits_scale(sim.unsqueeze(-1)).squeeze(-1) if global_score=="True": score = output.mean(0) else: if output.shape[0] Total video files found (speaker-specific tracks) = {}".format(scene_num, len(test_videos))) if len(test_videos)<=1: msg = "To detect the active speaker, at least 2 visible speakers are required for each scene! Please check the input video and try again..." return None, msg # Load the audio file audio_file = glob(os.path.join("{}/crops".format(result_folder_input), "scene_{}".format(str(scene_num)), "*.wav"))[0] spec, _, status = load_spectrograms(audio_file, asd=True) if status != "success": return None, status spec = torch.FloatTensor(spec).unsqueeze(0).unsqueeze(0).permute(0,1,2,4,3) print("Successfully loaded the spectrograms") # Load the masked input frames all_masked_frames, all_orig_masked_frames, status = load_masked_input_frames(test_videos, spec, audio_file, scene_num, result_folder_input) if status != "success": return None, status print("Successfully loaded the masked input frames") # Prepare the audio and video sequences for the model audio_sequences = torch.cat([spec[:, :, i] for i in range(spec.size(2))], dim=0) print("Obtaining audio and video embeddings...") all_video_embs = [] for idx in tqdm(range(len(all_masked_frames))): with torch.no_grad(): video_sequences = torch.cat([all_masked_frames[idx][:, :, i] for i in range(all_masked_frames[idx].size(2))], dim=0) if idx==0: video_emb, audio_emb = get_embeddings(video_sequences, audio_sequences, model, asd=True, calc_aud_emb=True) else: video_emb = get_embeddings(video_sequences, audio_sequences, model, asd=True, calc_aud_emb=False) all_video_embs.append(video_emb) print("Successfully extracted GestSync embeddings") # Predict the active speaker in each scene if global_speaker=="per-frame-prediction": predictions, num_avg_frames = predict_active_speaker(all_video_embs, audio_emb, "False", num_avg_frames, model) else: predictions, _ = predict_active_speaker(all_video_embs, audio_emb, "True", num_avg_frames, model) # Get the frames present in the scene frames_scene = tracks[scene_num][0]['track']['frame'] # Prepare the active speakers list to draw the bounding boxes if global_speaker=="global-prediction": print("Aggregating scores using global predictions") active_speakers = [predictions]*len(frames_scene) start, end = 0, len(frames_scene) else: print("Aggregating scores using per-frame predictions") active_speakers = [0]*len(frames_scene) mid = num_avg_frames//2 if num_avg_frames%2==0: frame_pred = len(frames_scene)-(mid*2)+1 start, end = mid, len(frames_scene)-mid+1 else: frame_pred = len(frames_scene)-(mid*2) start, end = mid, len(frames_scene)-mid print("Frame scene: {} | Avg frames: {} | Frame predictions: {}".format(len(frames_scene), num_avg_frames, frame_pred)) if len(predictions) != frame_pred: msg = "Predicted frames {} and input video frames {} do not match!!".format(len(predictions), frame_pred) return None, msg active_speakers[start:end] = predictions[0:] # Depending on the num_avg_frames, interpolate the intial and final frame predictions to get a full video output initial_preds = max(set(predictions[:num_avg_frames]), key=predictions[:num_avg_frames].count) active_speakers[0:start] = [initial_preds] * start final_preds = max(set(predictions[-num_avg_frames:]), key=predictions[-num_avg_frames:].count) active_speakers[end:] = [final_preds] * (len(frames_scene) - end) start, end = 0, len(active_speakers) # Get the output tracks for each frame pred_idx = 0 for frame in frames_scene[start:end]: label = active_speakers[pred_idx] pred_idx += 1 output_tracks[frame] = track_dict[scene_num][label][frame] # Save the output video video_output, status = save_video(output_tracks, orig_frames.copy(), orig_wav_file, result_folder_output) if status != "success": return None, status print("Successfully saved the output video: ", video_output) return video_output, "success" except Exception as e: return None, f"Error: {str(e)}" if __name__ == "__main__": # Custom CSS and HTML custom_css = """ """ custom_html = custom_css + """

GestSync: Determining who is speaking without a talking head

Synchronization and Active Speaker Detection Demo

Project Page | Github | Paper

""" tips = """


Please give us a 🌟 on Github if you like our work! Tips to get better results: Inference time: Note: Occasionally, there may be a delay in acquiring a GPU, as the model runs on a free community GPU from ZeroGPU.
""" # Define functions def toggle_slider(global_speaker): if global_speaker == "per-frame-prediction": return gr.update(visible=True) else: return gr.update(visible=False) def toggle_demo(demo_choice): if demo_choice == "Synchronization-correction": return ( gr.update(value=None, visible=True), # video_input gr.update(value=75, visible=True), # num_avg_frames gr.update(value=None, visible=True), # apply_preprocess gr.update(value="global-prediction", visible=False), # global_speaker gr.update(value=None, visible=True), # output_video gr.update(value="", visible=True), # result_text gr.update(visible=True), # submit_button gr.update(visible=True), # clear_button gr.update(visible=True), # sync_examples gr.update(visible=False), # asd_examples gr.update(visible=True) # tips ) else: return ( gr.update(value=None, visible=True), # video_input gr.update(value=75, visible=True), # num_avg_frames gr.update(value=None, visible=False), # apply_preprocess gr.update(value="global-prediction", visible=True), # global_speaker gr.update(value=None, visible=True), # output_video gr.update(value="", visible=True), # result_text gr.update(visible=True), # submit_button gr.update(visible=True), # clear_button gr.update(visible=False), # sync_examples gr.update(visible=True), # asd_examples gr.update(visible=True) # tips ) def clear_inputs(): return None, None, "global-prediction", 75, None, "", None def process_video(video_input, demo_choice, global_speaker, num_avg_frames, apply_preprocess): if demo_choice == "Synchronization-correction": return process_video_syncoffset(video_input, num_avg_frames, apply_preprocess) else: return process_video_activespeaker(video_input, global_speaker, num_avg_frames) # Define paths to sample videos sync_sample_videos = [ ["samples/sync_sample_1.mp4"], ["samples/sync_sample_2.mp4"] ] asd_sample_videos = [ ["samples/asd_sample_1.mp4"], ["samples/asd_sample_2.mp4"] ] # Define Gradio interface with gr.Blocks(css=custom_css, theme=gr.themes.Default(primary_hue=gr.themes.colors.red, secondary_hue=gr.themes.colors.pink)) as demo: gr.HTML(custom_html) demo_choice = gr.Radio( choices=["Synchronization-correction", "Active-speaker-detection"], label="Please select the task you want to perform" ) with gr.Row(): with gr.Column(): video_input = gr.Video(label="Upload Video", height=400, visible=False) num_avg_frames = gr.Slider( minimum=50, maximum=150, step=5, value=75, label="Number of Average Frames", visible=False ) apply_preprocess = gr.Checkbox(label="Apply Preprocessing", value=False, visible=False) global_speaker = gr.Radio( choices=["global-prediction", "per-frame-prediction"], value="global-prediction", label="Global Speaker Prediction", visible=False ) global_speaker.change( fn=toggle_slider, inputs=global_speaker, outputs=num_avg_frames ) with gr.Column(): output_video = gr.Video(label="Output Video", height=400, visible=False) result_text = gr.Textbox(label="Result", visible=False) with gr.Row(): submit_button = gr.Button("Submit", variant="primary", visible=False) clear_button = gr.Button("Clear", visible=False) # Add a gap before examples gr.HTML('
') # Add examples that only populate the video input sync_examples = gr.Dataset( samples=sync_sample_videos, components=[video_input], type="values", visible=False ) asd_examples = gr.Dataset( samples=asd_sample_videos, components=[video_input], type="values", visible=False ) tips = gr.Markdown(tips, visible=False) demo_choice.change( fn=toggle_demo, inputs=demo_choice, outputs=[video_input, num_avg_frames, apply_preprocess, global_speaker, output_video, result_text, submit_button, clear_button, sync_examples, asd_examples, tips] ) sync_examples.select( fn=lambda x: gr.update(value=x[0], visible=True), inputs=sync_examples, outputs=video_input ) asd_examples.select( fn=lambda x: gr.update(value=x[0], visible=True), inputs=asd_examples, outputs=video_input ) submit_button.click( fn=process_video, inputs=[video_input, demo_choice, global_speaker, num_avg_frames, apply_preprocess], outputs=[output_video, result_text] ) clear_button.click( fn=clear_inputs, inputs=[], outputs=[demo_choice, video_input, global_speaker, num_avg_frames, apply_preprocess, result_text, output_video] ) # Launch the interface demo.launch(allowed_paths=["."], share=True)