MuseV-test / mmcm /vision /transition /scene_transition_predictor.py
kevinwang676's picture
Upload folder using huggingface_hub
6755a2d verified
raw
history blame
18.2 kB
from __future__ import print_function
import traceback
from typing import Dict
from moviepy.editor import VideoFileClip
import hashlib
import json
import numpy as np
import os
import time
import copy
import os.path as osp
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import librosa
from ...utils.util import load_dct_from_file
# from lgss.utilis.package import *
normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transformer = transforms.Compose(
[
# transforms.Resize(256),
# transforms.CenterCrop(224),
transforms.ToTensor(),
normalizer,
]
)
def wav2stft(data):
# normalize
mean = (data.max() + data.min()) / 2
span = (data.max() - data.min()) / 2
if span < 1e-6:
span = 1
data = (data - mean) / span # range: [-1,1]
D = librosa.core.stft(data, n_fft=512)
freq = np.abs(D)
freq = librosa.core.amplitude_to_db(freq)
span = 80
thr = 4 * span
if freq.shape[1] <= thr:
copy_ = freq.copy()
while freq.shape[1] < thr:
tmp = copy_.copy()
freq = np.concatenate((freq, tmp), axis=1)
freq = freq[:, :thr]
else:
# sample
n = freq.shape[1]
stft_img = []
stft_img.append(freq[:, : 2 * span])
# stft_img.append(freq[:, n//2 - span : n//2 + span])
stft_img.append(freq[:, -2 * span :])
freq = np.concatenate(stft_img, axis=1)
return freq
def test(
model,
data_place,
data_cast=None,
data_act=None,
data_aud=None,
last_image_overlap_feat=None,
last_aud_overlap_feat=None,
):
with torch.no_grad():
# data_place = data_place.cuda() if data_place is not None else []
data_cast = data_cast.cuda() if data_cast is not None else []
data_act = data_act.cuda() if data_act is not None else []
data_aud = data_aud.cuda() if data_aud is not None else []
(
img_output,
aud_output,
image_overlap_feat,
audio_overlap_feat,
shot_dynamic_list,
) = model(
data_place,
data_cast,
data_act,
data_aud,
last_image_overlap_feat,
last_aud_overlap_feat,
)
img_output = img_output.view(-1, 2)
img_output = F.softmax(img_output, dim=1)
img_prob = img_output[:, 1]
img_prob = img_prob.cpu()
aud_output = aud_output.view(-1, 2)
aud_output = F.softmax(aud_output, dim=1)
aud_prob = aud_output[:, 1]
aud_prob = aud_prob.cpu()
return img_prob, aud_prob, image_overlap_feat, audio_overlap_feat, shot_dynamic_list
def predict(
model,
cfg,
video_path,
save_path,
map_path,
seq_len=120,
shot_num=4,
overlap=21,
shot_frame_max_num=60,
):
assert overlap % 2 == 1
video_name = ".".join(video_path.split("/")[-1].split(".")[:-1])
if not os.path.exists(save_path):
os.makedirs(save_path)
# video_hash_code = (os.popen('md5sum {}'.format(video_path))).readlines()[0].split(' ')[0]
with open(video_path, "rb") as fd:
data = fd.read()
video_hash_code = hashlib.md5(data).hexdigest()
save_path = os.path.join(
save_path, "{}_{}.json".format(video_name, video_hash_code[:8])
)
if os.path.exists(save_path) and not args.overwrite:
video_map = json.load(open(save_path), encoding="UTF-8")
valid_clips = []
for clip in video_map["clips"]:
if clip["cliptype"] == "body" and clip["duration"] > 0.25:
valid_clips.append(clip)
# Capture video
if (
video_map["content_box"][2] - video_map["content_box"][0]
> video_map["content_box"][3] - video_map["content_box"][1]
):
target_resolution = (
256
* video_map["height"]
/ (video_map["content_box"][3] - video_map["content_box"][1]),
None,
)
else:
target_resolution = (
None,
256
* video_map["width"]
/ (video_map["content_box"][2] - video_map["content_box"][0]),
)
video = VideoFileClip(
video_path,
target_resolution=target_resolution,
resize_algorithm="bilinear",
audio_fps=16000,
)
# video = video.crop(*video_map["content_box"])
x1 = video_map["content_box"][0] * video.size[0] // video_map["width"]
y1 = video_map["content_box"][1] * video.size[1] // video_map["height"]
x2 = video_map["content_box"][2] * video.size[0] // video_map["width"]
y2 = video_map["content_box"][3] * video.size[1] // video_map["height"]
video = video.crop(
width=(x2 - x1) * 224 / 256,
height=224,
x_center=(x1 + x2) // 2,
y_center=(y1 + y2) // 2,
)
print("exists " + save_path)
else:
map_path = os.path.join(
map_path, "{}_{}.json".format(video_name, video_hash_code[:8])
)
if not os.path.exists(map_path):
print("map not exist: ", map_path)
return
video_map = json.load(open(map_path), encoding="UTF-8")
assert video_hash_code == video_map["video_file_hash_code"]
# Capture video
if (
video_map["content_box"][2] - video_map["content_box"][0]
> video_map["content_box"][3] - video_map["content_box"][1]
):
target_resolution = (
256
* video_map["height"]
/ (video_map["content_box"][3] - video_map["content_box"][1]),
None,
)
else:
target_resolution = (
None,
256
* video_map["width"]
/ (video_map["content_box"][2] - video_map["content_box"][0]),
)
video = VideoFileClip(
video_path,
target_resolution=target_resolution,
resize_algorithm="bilinear",
audio_fps=16000,
)
# video = video.crop(*video_map["content_box"])
x1 = video_map["content_box"][0] * video.size[0] // video_map["width"]
y1 = video_map["content_box"][1] * video.size[1] // video_map["height"]
x2 = video_map["content_box"][2] * video.size[0] // video_map["width"]
y2 = video_map["content_box"][3] * video.size[1] // video_map["height"]
video = video.crop(
width=(x2 - x1) * 224 / 256,
height=224,
x_center=(x1 + x2) // 2,
y_center=(y1 + y2) // 2,
)
fps = video.fps
duration = video.duration
total_frames = int(duration * fps)
width, height = video.size
print("fps, frame_count, width, height:", fps, total_frames, width, height)
valid_clips = []
for clip in video_map["clips"]:
if clip["cliptype"] == "body" and clip["duration"] > 0.25:
valid_clips.append(clip)
# valid_clips = valid_clips[:150]
total_shot_num = len(valid_clips)
last_image_overlap_feat = None
last_aud_overlap_feat = None
truncate_time = 0.1
all_shot_dynamic_list = []
for i in range(total_shot_num // (seq_len - overlap) + 1):
shot_frame_list = []
shot_audio_list = []
start_shot = i * (seq_len - overlap)
end_shot = min(start_shot + seq_len, total_shot_num)
if i != 0:
start_shot += overlap
print(start_shot, end_shot)
if start_shot >= end_shot:
break
for clip in valid_clips[start_shot:end_shot]:
time_start = clip["time_start"]
time_end = clip["time_start"] + clip["duration"]
truncate_time = min(clip["duration"] / 10, 0.1)
time_start += truncate_time
time_end -= truncate_time
time_start = max(time_start, (time_end + time_start) / 2 - 3)
time_end = min(time_end, (time_end + time_start) / 2 + 3)
duration = time_end - time_start
t0 = time.time()
video_subclip = video.subclip(time_start, time_end)
# video_save_path = os.path.join(args.video_save_path, 'shot_{:04d}.mp4'.format(clip["clipid"]))
# video_subclip.write_videofile(video_save_path, threads=8, codec='libx264')
if "image" in cfg.dataset["mode"]:
frame_iter = video_subclip.iter_frames(fps=10)
shot_frame = []
for frame in frame_iter:
frame = transformer(frame)
shot_frame.append(frame)
if len(shot_frame) > shot_frame_max_num:
break
shot_frame = torch.stack(shot_frame)
shot_frame = shot_frame.cuda()
shot_frame_list.append(shot_frame)
t5 = time.time()
if "aud" in cfg.dataset["mode"]:
try:
sub_audio = video.audio.subclip(
clip["time_start"], clip["time_start"] + clip["duration"]
)
sub_audio = sub_audio.to_soundarray(
fps=16000, quantize=True, buffersize=20000
)
sub_audio = sub_audio.mean(axis=1)
except:
sub_audio = np.zeros((16000 * 4), np.float32)
sub_audio = wav2stft(sub_audio)
sub_audio = torch.from_numpy(sub_audio).float()
sub_audio = sub_audio.unsqueeze(dim=0)
shot_audio_list.append(sub_audio)
t6 = time.time()
print(clip["clipid"], t5 - t0, t6 - t5)
data_place = data_aud = None
if len(shot_frame_list) > 0:
# data_place = torch.stack(shot_frame_list)
data_place = shot_frame_list
if len(shot_audio_list) > 0:
data_aud = torch.stack(shot_audio_list)
data_aud = data_aud.unsqueeze(dim=0)
(
img_preds,
aud_preds,
last_image_overlap_feat,
last_aud_overlap_feat,
shot_dynamic_list,
) = test(
model,
data_place=data_place,
data_aud=data_aud,
last_image_overlap_feat=last_image_overlap_feat,
last_aud_overlap_feat=last_aud_overlap_feat,
)
print(shot_dynamic_list)
all_shot_dynamic_list.extend(shot_dynamic_list)
if total_shot_num > end_shot:
if i == 0:
img_preds_all = img_preds[: -(overlap - shot_num + 1) // 2]
aud_preds_all = aud_preds[: -(overlap - shot_num + 1) // 2]
else:
img_preds_all = torch.cat(
(
img_preds_all,
img_preds[
(overlap - shot_num + 1)
// 2 : -(overlap - shot_num + 1)
// 2
],
),
dim=0,
)
aud_preds_all = torch.cat(
(
aud_preds_all,
aud_preds[
(overlap - shot_num + 1)
// 2 : -(overlap - shot_num + 1)
// 2
],
),
dim=0,
)
else:
if i == 0:
img_preds_all = img_preds
aud_preds_all = aud_preds
else:
img_preds_all = torch.cat(
(img_preds_all, img_preds[(overlap - shot_num + 1) // 2 :]),
dim=0,
)
aud_preds_all = torch.cat(
(aud_preds_all, aud_preds[(overlap - shot_num + 1) // 2 :]),
dim=0,
)
print(
img_preds_all.shape[0],
total_shot_num - shot_num + 1,
len(all_shot_dynamic_list),
)
assert img_preds_all.shape[0] == total_shot_num - shot_num + 1
assert len(all_shot_dynamic_list) == total_shot_num
print("img_preds_all: ", img_preds_all)
print("aud_preds_all: ", aud_preds_all)
video_map["scenes_img_preds"] = img_preds_all.tolist()
video_map["scenes_aud_preds"] = aud_preds_all.tolist()
for clip, dynamic in zip(valid_clips, all_shot_dynamic_list):
clip["dynamic"] = None
if dynamic is not None:
clip["dynamic"] = round(np.clip(dynamic, 0, 1), 5)
preds_all = cfg.model.ratio[0] * np.array(
video_map["scenes_img_preds"]
) + cfg.model.ratio[3] * np.array(video_map["scenes_aud_preds"])
video_map["scenes_preds"] = preds_all.tolist()
scene_boundary = np.where(preds_all > args.threshold)[0]
video_map["scenes"] = []
scene = {
"sceneid": 0,
"clip_start": valid_clips[0]["clipid"],
"clip_end": valid_clips[0]["clipid"],
"time_start": valid_clips[0]["time_start"],
"time_end": valid_clips[0]["time_start"] + valid_clips[0]["duration"],
}
for i in scene_boundary:
scene["clip_end"] = valid_clips[i + shot_num // 2 - 1]["clipid"]
scene["time_end"] = (
valid_clips[i + shot_num // 2 - 1]["time_start"]
+ valid_clips[i + shot_num // 2 - 1]["duration"]
)
scene["roles"] = {}
scene["dynamic"] = None
dynamic_num = 0
dynamic = 0
for clip in video_map["clips"][scene["clip_start"] : scene["clip_end"] + 1]:
for roleid in clip["roles"].keys():
if roleid not in scene["roles"]:
scene["roles"][roleid] = {
"name": clip["roles"][roleid]["name"]
if "name" in clip["roles"][roleid]
else ""
}
if "dynamic" in clip and clip["dynamic"] != None:
dynamic += clip["dynamic"]
dynamic_num += 1
if dynamic_num > 0:
scene["dynamic"] = dynamic / dynamic_num
for clip in video_map["clips"][scene["clip_start"] : scene["clip_end"] + 1]:
clip["scene_roles"] = scene["roles"]
clip["scene_dynamic"] = scene["dynamic"]
clip["sceneid"] = scene["sceneid"]
video_map["scenes"].append(copy.deepcopy(scene))
scene["sceneid"] += 1
scene["clip_start"] = scene["clip_end"] = valid_clips[i + shot_num // 2][
"clipid"
]
scene["time_start"] = valid_clips[i + shot_num // 2]["time_start"]
scene["time_end"] = (
valid_clips[i + shot_num // 2]["time_start"]
+ valid_clips[i + shot_num // 2]["duration"]
)
scene["clip_end"] = valid_clips[-1]["clipid"]
scene["time_end"] = valid_clips[-1]["time_start"] + valid_clips[-1]["duration"]
scene["roles"] = {}
scene["dynamic"] = None
dynamic_num = 0
dynamic = 0
for clip in video_map["clips"][scene["clip_start"] : scene["clip_end"] + 1]:
for roleid in clip["roles"].keys():
if roleid not in scene["roles"]:
scene["roles"][roleid] = {
"name": clip["roles"][roleid]["name"]
if "name" in clip["roles"][roleid]
else ""
}
if "dynamic" in clip and clip["dynamic"] != None:
dynamic += clip["dynamic"]
dynamic_num += 1
if dynamic_num > 0:
scene["dynamic"] = dynamic / dynamic_num
for clip in video_map["clips"][scene["clip_start"] : scene["clip_end"] + 1]:
clip["scene_roles"] = scene["roles"]
clip["scene_dynamic"] = scene["dynamic"]
clip["sceneid"] = scene["sceneid"]
video_map["scenes"].append(scene)
return video_map
class SceneTransitionPredictor(object):
def __init__(self, config_path, overlap=41, model_path=None) -> None:
from mmcv import Config
from lgss.utilis import load_checkpoint
import lgss.src.models as models
self.config_path = config_path
cfg = Config.fromfile(config_path)
# cfg = load_dct_from_file(config_path)
self.cfg = cfg
self.model = models.__dict__[cfg.model.name](cfg, overlap).cuda()
self.model = nn.DataParallel(self.model)
checkpoint = load_checkpoint(
osp.join(cfg.logger.logs_dir, "model_best.pth.tar")
)
paras = {}
for key, value in checkpoint["state_dict"].items():
if key in self.model.state_dict():
paras[key] = value
if "aud" in cfg.dataset["mode"]:
c_logs_dir = cfg.logger.logs_dir.replace("image50", "aud")
checkpoint = load_checkpoint(osp.join(c_logs_dir, "model_best.pth.tar"))
for key, value in checkpoint["state_dict"].items():
if key in self.model.state_dict():
paras[key] = value
print(list(paras.keys()))
self.model.load_state_dict(paras)
self.model.eval()
def __call__(
self,
video_path,
video_map,
) -> Dict:
video_info = predict(
self.model,
self.cfg,
video_path,
video_map,
overlap=self.overlap,
)
return video_info