Spaces:
No application file
No application file
from __future__ import print_function | |
import traceback | |
from typing import Dict | |
from moviepy.editor import VideoFileClip | |
import hashlib | |
import json | |
import numpy as np | |
import os | |
import time | |
import copy | |
import os.path as osp | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import transforms | |
import librosa | |
from ...utils.util import load_dct_from_file | |
# from lgss.utilis.package import * | |
normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
transformer = transforms.Compose( | |
[ | |
# transforms.Resize(256), | |
# transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
normalizer, | |
] | |
) | |
def wav2stft(data): | |
# normalize | |
mean = (data.max() + data.min()) / 2 | |
span = (data.max() - data.min()) / 2 | |
if span < 1e-6: | |
span = 1 | |
data = (data - mean) / span # range: [-1,1] | |
D = librosa.core.stft(data, n_fft=512) | |
freq = np.abs(D) | |
freq = librosa.core.amplitude_to_db(freq) | |
span = 80 | |
thr = 4 * span | |
if freq.shape[1] <= thr: | |
copy_ = freq.copy() | |
while freq.shape[1] < thr: | |
tmp = copy_.copy() | |
freq = np.concatenate((freq, tmp), axis=1) | |
freq = freq[:, :thr] | |
else: | |
# sample | |
n = freq.shape[1] | |
stft_img = [] | |
stft_img.append(freq[:, : 2 * span]) | |
# stft_img.append(freq[:, n//2 - span : n//2 + span]) | |
stft_img.append(freq[:, -2 * span :]) | |
freq = np.concatenate(stft_img, axis=1) | |
return freq | |
def test( | |
model, | |
data_place, | |
data_cast=None, | |
data_act=None, | |
data_aud=None, | |
last_image_overlap_feat=None, | |
last_aud_overlap_feat=None, | |
): | |
with torch.no_grad(): | |
# data_place = data_place.cuda() if data_place is not None else [] | |
data_cast = data_cast.cuda() if data_cast is not None else [] | |
data_act = data_act.cuda() if data_act is not None else [] | |
data_aud = data_aud.cuda() if data_aud is not None else [] | |
( | |
img_output, | |
aud_output, | |
image_overlap_feat, | |
audio_overlap_feat, | |
shot_dynamic_list, | |
) = model( | |
data_place, | |
data_cast, | |
data_act, | |
data_aud, | |
last_image_overlap_feat, | |
last_aud_overlap_feat, | |
) | |
img_output = img_output.view(-1, 2) | |
img_output = F.softmax(img_output, dim=1) | |
img_prob = img_output[:, 1] | |
img_prob = img_prob.cpu() | |
aud_output = aud_output.view(-1, 2) | |
aud_output = F.softmax(aud_output, dim=1) | |
aud_prob = aud_output[:, 1] | |
aud_prob = aud_prob.cpu() | |
return img_prob, aud_prob, image_overlap_feat, audio_overlap_feat, shot_dynamic_list | |
def predict( | |
model, | |
cfg, | |
video_path, | |
save_path, | |
map_path, | |
seq_len=120, | |
shot_num=4, | |
overlap=21, | |
shot_frame_max_num=60, | |
): | |
assert overlap % 2 == 1 | |
video_name = ".".join(video_path.split("/")[-1].split(".")[:-1]) | |
if not os.path.exists(save_path): | |
os.makedirs(save_path) | |
# video_hash_code = (os.popen('md5sum {}'.format(video_path))).readlines()[0].split(' ')[0] | |
with open(video_path, "rb") as fd: | |
data = fd.read() | |
video_hash_code = hashlib.md5(data).hexdigest() | |
save_path = os.path.join( | |
save_path, "{}_{}.json".format(video_name, video_hash_code[:8]) | |
) | |
if os.path.exists(save_path) and not args.overwrite: | |
video_map = json.load(open(save_path), encoding="UTF-8") | |
valid_clips = [] | |
for clip in video_map["clips"]: | |
if clip["cliptype"] == "body" and clip["duration"] > 0.25: | |
valid_clips.append(clip) | |
# Capture video | |
if ( | |
video_map["content_box"][2] - video_map["content_box"][0] | |
> video_map["content_box"][3] - video_map["content_box"][1] | |
): | |
target_resolution = ( | |
256 | |
* video_map["height"] | |
/ (video_map["content_box"][3] - video_map["content_box"][1]), | |
None, | |
) | |
else: | |
target_resolution = ( | |
None, | |
256 | |
* video_map["width"] | |
/ (video_map["content_box"][2] - video_map["content_box"][0]), | |
) | |
video = VideoFileClip( | |
video_path, | |
target_resolution=target_resolution, | |
resize_algorithm="bilinear", | |
audio_fps=16000, | |
) | |
# video = video.crop(*video_map["content_box"]) | |
x1 = video_map["content_box"][0] * video.size[0] // video_map["width"] | |
y1 = video_map["content_box"][1] * video.size[1] // video_map["height"] | |
x2 = video_map["content_box"][2] * video.size[0] // video_map["width"] | |
y2 = video_map["content_box"][3] * video.size[1] // video_map["height"] | |
video = video.crop( | |
width=(x2 - x1) * 224 / 256, | |
height=224, | |
x_center=(x1 + x2) // 2, | |
y_center=(y1 + y2) // 2, | |
) | |
print("exists " + save_path) | |
else: | |
map_path = os.path.join( | |
map_path, "{}_{}.json".format(video_name, video_hash_code[:8]) | |
) | |
if not os.path.exists(map_path): | |
print("map not exist: ", map_path) | |
return | |
video_map = json.load(open(map_path), encoding="UTF-8") | |
assert video_hash_code == video_map["video_file_hash_code"] | |
# Capture video | |
if ( | |
video_map["content_box"][2] - video_map["content_box"][0] | |
> video_map["content_box"][3] - video_map["content_box"][1] | |
): | |
target_resolution = ( | |
256 | |
* video_map["height"] | |
/ (video_map["content_box"][3] - video_map["content_box"][1]), | |
None, | |
) | |
else: | |
target_resolution = ( | |
None, | |
256 | |
* video_map["width"] | |
/ (video_map["content_box"][2] - video_map["content_box"][0]), | |
) | |
video = VideoFileClip( | |
video_path, | |
target_resolution=target_resolution, | |
resize_algorithm="bilinear", | |
audio_fps=16000, | |
) | |
# video = video.crop(*video_map["content_box"]) | |
x1 = video_map["content_box"][0] * video.size[0] // video_map["width"] | |
y1 = video_map["content_box"][1] * video.size[1] // video_map["height"] | |
x2 = video_map["content_box"][2] * video.size[0] // video_map["width"] | |
y2 = video_map["content_box"][3] * video.size[1] // video_map["height"] | |
video = video.crop( | |
width=(x2 - x1) * 224 / 256, | |
height=224, | |
x_center=(x1 + x2) // 2, | |
y_center=(y1 + y2) // 2, | |
) | |
fps = video.fps | |
duration = video.duration | |
total_frames = int(duration * fps) | |
width, height = video.size | |
print("fps, frame_count, width, height:", fps, total_frames, width, height) | |
valid_clips = [] | |
for clip in video_map["clips"]: | |
if clip["cliptype"] == "body" and clip["duration"] > 0.25: | |
valid_clips.append(clip) | |
# valid_clips = valid_clips[:150] | |
total_shot_num = len(valid_clips) | |
last_image_overlap_feat = None | |
last_aud_overlap_feat = None | |
truncate_time = 0.1 | |
all_shot_dynamic_list = [] | |
for i in range(total_shot_num // (seq_len - overlap) + 1): | |
shot_frame_list = [] | |
shot_audio_list = [] | |
start_shot = i * (seq_len - overlap) | |
end_shot = min(start_shot + seq_len, total_shot_num) | |
if i != 0: | |
start_shot += overlap | |
print(start_shot, end_shot) | |
if start_shot >= end_shot: | |
break | |
for clip in valid_clips[start_shot:end_shot]: | |
time_start = clip["time_start"] | |
time_end = clip["time_start"] + clip["duration"] | |
truncate_time = min(clip["duration"] / 10, 0.1) | |
time_start += truncate_time | |
time_end -= truncate_time | |
time_start = max(time_start, (time_end + time_start) / 2 - 3) | |
time_end = min(time_end, (time_end + time_start) / 2 + 3) | |
duration = time_end - time_start | |
t0 = time.time() | |
video_subclip = video.subclip(time_start, time_end) | |
# video_save_path = os.path.join(args.video_save_path, 'shot_{:04d}.mp4'.format(clip["clipid"])) | |
# video_subclip.write_videofile(video_save_path, threads=8, codec='libx264') | |
if "image" in cfg.dataset["mode"]: | |
frame_iter = video_subclip.iter_frames(fps=10) | |
shot_frame = [] | |
for frame in frame_iter: | |
frame = transformer(frame) | |
shot_frame.append(frame) | |
if len(shot_frame) > shot_frame_max_num: | |
break | |
shot_frame = torch.stack(shot_frame) | |
shot_frame = shot_frame.cuda() | |
shot_frame_list.append(shot_frame) | |
t5 = time.time() | |
if "aud" in cfg.dataset["mode"]: | |
try: | |
sub_audio = video.audio.subclip( | |
clip["time_start"], clip["time_start"] + clip["duration"] | |
) | |
sub_audio = sub_audio.to_soundarray( | |
fps=16000, quantize=True, buffersize=20000 | |
) | |
sub_audio = sub_audio.mean(axis=1) | |
except: | |
sub_audio = np.zeros((16000 * 4), np.float32) | |
sub_audio = wav2stft(sub_audio) | |
sub_audio = torch.from_numpy(sub_audio).float() | |
sub_audio = sub_audio.unsqueeze(dim=0) | |
shot_audio_list.append(sub_audio) | |
t6 = time.time() | |
print(clip["clipid"], t5 - t0, t6 - t5) | |
data_place = data_aud = None | |
if len(shot_frame_list) > 0: | |
# data_place = torch.stack(shot_frame_list) | |
data_place = shot_frame_list | |
if len(shot_audio_list) > 0: | |
data_aud = torch.stack(shot_audio_list) | |
data_aud = data_aud.unsqueeze(dim=0) | |
( | |
img_preds, | |
aud_preds, | |
last_image_overlap_feat, | |
last_aud_overlap_feat, | |
shot_dynamic_list, | |
) = test( | |
model, | |
data_place=data_place, | |
data_aud=data_aud, | |
last_image_overlap_feat=last_image_overlap_feat, | |
last_aud_overlap_feat=last_aud_overlap_feat, | |
) | |
print(shot_dynamic_list) | |
all_shot_dynamic_list.extend(shot_dynamic_list) | |
if total_shot_num > end_shot: | |
if i == 0: | |
img_preds_all = img_preds[: -(overlap - shot_num + 1) // 2] | |
aud_preds_all = aud_preds[: -(overlap - shot_num + 1) // 2] | |
else: | |
img_preds_all = torch.cat( | |
( | |
img_preds_all, | |
img_preds[ | |
(overlap - shot_num + 1) | |
// 2 : -(overlap - shot_num + 1) | |
// 2 | |
], | |
), | |
dim=0, | |
) | |
aud_preds_all = torch.cat( | |
( | |
aud_preds_all, | |
aud_preds[ | |
(overlap - shot_num + 1) | |
// 2 : -(overlap - shot_num + 1) | |
// 2 | |
], | |
), | |
dim=0, | |
) | |
else: | |
if i == 0: | |
img_preds_all = img_preds | |
aud_preds_all = aud_preds | |
else: | |
img_preds_all = torch.cat( | |
(img_preds_all, img_preds[(overlap - shot_num + 1) // 2 :]), | |
dim=0, | |
) | |
aud_preds_all = torch.cat( | |
(aud_preds_all, aud_preds[(overlap - shot_num + 1) // 2 :]), | |
dim=0, | |
) | |
print( | |
img_preds_all.shape[0], | |
total_shot_num - shot_num + 1, | |
len(all_shot_dynamic_list), | |
) | |
assert img_preds_all.shape[0] == total_shot_num - shot_num + 1 | |
assert len(all_shot_dynamic_list) == total_shot_num | |
print("img_preds_all: ", img_preds_all) | |
print("aud_preds_all: ", aud_preds_all) | |
video_map["scenes_img_preds"] = img_preds_all.tolist() | |
video_map["scenes_aud_preds"] = aud_preds_all.tolist() | |
for clip, dynamic in zip(valid_clips, all_shot_dynamic_list): | |
clip["dynamic"] = None | |
if dynamic is not None: | |
clip["dynamic"] = round(np.clip(dynamic, 0, 1), 5) | |
preds_all = cfg.model.ratio[0] * np.array( | |
video_map["scenes_img_preds"] | |
) + cfg.model.ratio[3] * np.array(video_map["scenes_aud_preds"]) | |
video_map["scenes_preds"] = preds_all.tolist() | |
scene_boundary = np.where(preds_all > args.threshold)[0] | |
video_map["scenes"] = [] | |
scene = { | |
"sceneid": 0, | |
"clip_start": valid_clips[0]["clipid"], | |
"clip_end": valid_clips[0]["clipid"], | |
"time_start": valid_clips[0]["time_start"], | |
"time_end": valid_clips[0]["time_start"] + valid_clips[0]["duration"], | |
} | |
for i in scene_boundary: | |
scene["clip_end"] = valid_clips[i + shot_num // 2 - 1]["clipid"] | |
scene["time_end"] = ( | |
valid_clips[i + shot_num // 2 - 1]["time_start"] | |
+ valid_clips[i + shot_num // 2 - 1]["duration"] | |
) | |
scene["roles"] = {} | |
scene["dynamic"] = None | |
dynamic_num = 0 | |
dynamic = 0 | |
for clip in video_map["clips"][scene["clip_start"] : scene["clip_end"] + 1]: | |
for roleid in clip["roles"].keys(): | |
if roleid not in scene["roles"]: | |
scene["roles"][roleid] = { | |
"name": clip["roles"][roleid]["name"] | |
if "name" in clip["roles"][roleid] | |
else "" | |
} | |
if "dynamic" in clip and clip["dynamic"] != None: | |
dynamic += clip["dynamic"] | |
dynamic_num += 1 | |
if dynamic_num > 0: | |
scene["dynamic"] = dynamic / dynamic_num | |
for clip in video_map["clips"][scene["clip_start"] : scene["clip_end"] + 1]: | |
clip["scene_roles"] = scene["roles"] | |
clip["scene_dynamic"] = scene["dynamic"] | |
clip["sceneid"] = scene["sceneid"] | |
video_map["scenes"].append(copy.deepcopy(scene)) | |
scene["sceneid"] += 1 | |
scene["clip_start"] = scene["clip_end"] = valid_clips[i + shot_num // 2][ | |
"clipid" | |
] | |
scene["time_start"] = valid_clips[i + shot_num // 2]["time_start"] | |
scene["time_end"] = ( | |
valid_clips[i + shot_num // 2]["time_start"] | |
+ valid_clips[i + shot_num // 2]["duration"] | |
) | |
scene["clip_end"] = valid_clips[-1]["clipid"] | |
scene["time_end"] = valid_clips[-1]["time_start"] + valid_clips[-1]["duration"] | |
scene["roles"] = {} | |
scene["dynamic"] = None | |
dynamic_num = 0 | |
dynamic = 0 | |
for clip in video_map["clips"][scene["clip_start"] : scene["clip_end"] + 1]: | |
for roleid in clip["roles"].keys(): | |
if roleid not in scene["roles"]: | |
scene["roles"][roleid] = { | |
"name": clip["roles"][roleid]["name"] | |
if "name" in clip["roles"][roleid] | |
else "" | |
} | |
if "dynamic" in clip and clip["dynamic"] != None: | |
dynamic += clip["dynamic"] | |
dynamic_num += 1 | |
if dynamic_num > 0: | |
scene["dynamic"] = dynamic / dynamic_num | |
for clip in video_map["clips"][scene["clip_start"] : scene["clip_end"] + 1]: | |
clip["scene_roles"] = scene["roles"] | |
clip["scene_dynamic"] = scene["dynamic"] | |
clip["sceneid"] = scene["sceneid"] | |
video_map["scenes"].append(scene) | |
return video_map | |
class SceneTransitionPredictor(object): | |
def __init__(self, config_path, overlap=41, model_path=None) -> None: | |
from mmcv import Config | |
from lgss.utilis import load_checkpoint | |
import lgss.src.models as models | |
self.config_path = config_path | |
cfg = Config.fromfile(config_path) | |
# cfg = load_dct_from_file(config_path) | |
self.cfg = cfg | |
self.model = models.__dict__[cfg.model.name](cfg, overlap).cuda() | |
self.model = nn.DataParallel(self.model) | |
checkpoint = load_checkpoint( | |
osp.join(cfg.logger.logs_dir, "model_best.pth.tar") | |
) | |
paras = {} | |
for key, value in checkpoint["state_dict"].items(): | |
if key in self.model.state_dict(): | |
paras[key] = value | |
if "aud" in cfg.dataset["mode"]: | |
c_logs_dir = cfg.logger.logs_dir.replace("image50", "aud") | |
checkpoint = load_checkpoint(osp.join(c_logs_dir, "model_best.pth.tar")) | |
for key, value in checkpoint["state_dict"].items(): | |
if key in self.model.state_dict(): | |
paras[key] = value | |
print(list(paras.keys())) | |
self.model.load_state_dict(paras) | |
self.model.eval() | |
def __call__( | |
self, | |
video_path, | |
video_map, | |
) -> Dict: | |
video_info = predict( | |
self.model, | |
self.cfg, | |
video_path, | |
video_map, | |
overlap=self.overlap, | |
) | |
return video_info | |