Spaces:
Runtime error
Runtime error
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import argparse | |
import os | |
import tqdm | |
from statistics import fmean | |
from eval.syncnet import SyncNetEval | |
from eval.syncnet_detect import SyncNetDetector | |
from latentsync.utils.util import red_text | |
import torch | |
def syncnet_eval(syncnet, syncnet_detector, video_path, temp_dir, detect_results_dir="detect_results"): | |
syncnet_detector(video_path=video_path, min_track=50) | |
crop_videos = os.listdir(os.path.join(detect_results_dir, "crop")) | |
if crop_videos == []: | |
raise Exception(red_text(f"Face not detected in {video_path}")) | |
av_offset_list = [] | |
conf_list = [] | |
for video in crop_videos: | |
av_offset, _, conf = syncnet.evaluate( | |
video_path=os.path.join(detect_results_dir, "crop", video), temp_dir=temp_dir | |
) | |
av_offset_list.append(av_offset) | |
conf_list.append(conf) | |
av_offset = int(fmean(av_offset_list)) | |
conf = fmean(conf_list) | |
print(f"Input video: {video_path}\nSyncNet confidence: {conf:.2f}\nAV offset: {av_offset}") | |
return av_offset, conf | |
def main(): | |
parser = argparse.ArgumentParser(description="SyncNet") | |
parser.add_argument("--initial_model", type=str, default="checkpoints/auxiliary/syncnet_v2.model", help="") | |
parser.add_argument("--video_path", type=str, default=None, help="") | |
parser.add_argument("--videos_dir", type=str, default="/root/processed") | |
parser.add_argument("--temp_dir", type=str, default="temp", help="") | |
args = parser.parse_args() | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
syncnet = SyncNetEval(device=device) | |
syncnet.loadParameters(args.initial_model) | |
syncnet_detector = SyncNetDetector(device=device, detect_results_dir="detect_results") | |
if args.video_path is not None: | |
syncnet_eval(syncnet, syncnet_detector, args.video_path, args.temp_dir) | |
else: | |
sync_conf_list = [] | |
video_names = sorted([f for f in os.listdir(args.videos_dir) if f.endswith(".mp4")]) | |
for video_name in tqdm.tqdm(video_names): | |
try: | |
_, conf = syncnet_eval( | |
syncnet, syncnet_detector, os.path.join(args.videos_dir, video_name), args.temp_dir | |
) | |
sync_conf_list.append(conf) | |
except Exception as e: | |
print(e) | |
print(f"The average sync confidence is {fmean(sync_conf_list):.02f}") | |
if __name__ == "__main__": | |
main() | |