MuseV-test / mmcm /vision /transition /TransNetV2 /transnetv2_predictor.py
kevinwang676's picture
Upload folder using huggingface_hub
6755a2d verified
raw
history blame
10.7 kB
# -*- coding: utf-8 -*-
from typing import Dict, List, Union
import os
import json
import traceback
import argparse
import hashlib
import librosa
import soundfile as sf
import numpy as np
import torch
from moviepy.editor import *
from .TransNetmodels import TransNetV2
from ...video_map.video_meta_info import VideoMetaInfo
from ...video_map.video_map import VideoMap
from ...video_map.video_clip import VideoClipSeq
from ...black_border import det_video_black_border
from ...utils.path_util import get_video_signature
from ...data.video_dataset import MoviepyVideoDataset, SequentialDataset
def predict(
model,
video_path,
threshold=0.3,
sample_fps=25,
content_box=None,
single_frame_ratio=1,
map_path: str = None,
ignored_keys: List = None,
) -> VideoMap:
video_hash_code, video_path = get_video_signature(path=video_path, rename=True)
basename = os.path.basename(video_path)
filename, ext = os.path.splitext(video_path)
with torch.no_grad():
(
single_frame_result,
all_frame_result,
fps,
total_frames,
duration,
height,
width,
) = model.predict_video(
video_path,
cache_path="",
c_box=content_box,
width=48,
height=27,
input_frames=10000,
overlap=100,
sample_fps=sample_fps,
)
# pred_label = single_frame_ratio * single_frame_result + (1 - single_frame_ratio) * all_frame_result
pred_label = np.array([single_frame_result, all_frame_result])
pred_label = pred_label.max(axis=0)
transition_index = np.where(pred_label > threshold)[0] # 转场帧位置
transition_index = transition_index.astype(np.float)
# 对返回结果做后处理合并相邻帧
result_transition = []
for i, transition in enumerate(transition_index):
if i == 0:
result_transition.append([transition])
else:
if abs(result_transition[-1][-1] - transition) <= 4:
result_transition[-1].append(transition)
else:
result_transition.append([transition])
result = [[0]]
for item in result_transition:
start_idx = int(item[0])
end_idx = int(item[-1])
if len(item) > 3:
if max(pred_label[start_idx : end_idx + 1]) > 0.3:
result.append([item[0], item[-1]])
elif len(item) > 1:
if max(pred_label[start_idx : end_idx + 1]) > 0.4:
result.append([item[0], item[-1]])
else:
if pred_label[start_idx] > 0.45:
result.append(item)
result.append([pred_label.shape[0]])
video_meta_info_dct = {
"video_name": filename,
"video_path": video_path,
"video_file_hash_code": video_hash_code,
"fps": fps,
"frame_num": total_frames,
"duration": duration,
"height": height,
"width": width,
"content_box": content_box,
"sample_fps": sample_fps,
}
video_meta_info = VideoMetaInfo.from_video_path(video_path)
video_meta_info.__dict__.update(video_meta_info_dct)
video_clipseq = []
slice_id = 0
for i in range(len(result) - 1):
if len(result[i]) == 1:
vidoe_clip = {
"time_start": round(result[i][0] / sample_fps, 4), # 开始时间
"duration": round(
result[i + 1][0] / sample_fps - result[i][0] / sample_fps,
4,
), # 片段持续时间
"frame_start": result[i][0],
"frame_end": result[i + 1][0],
"clipid": slice_id, # 片段序号,
"cliptype": "body",
}
video_clipseq.append(vidoe_clip)
slice_id += 1
elif len(result[i]) == 2:
vidoe_clip = {
"time_start": round(result[i][0] / sample_fps, 4), # 开始时间
"duration": round(
result[i][1] / sample_fps - result[i][0] / sample_fps,
4,
), # 片段持续时间
"frame_start": result[i][0],
"frame_end": result[i][1],
"clipid": slice_id, # 片段序号,
"cliptype": "transition",
}
video_clipseq.append(vidoe_clip)
slice_id += 1
vidoe_clip = {
"time_start": round(result[i][1] / sample_fps, 4), # 开始时间
"duration": round(
result[i + 1][0] / sample_fps - result[i][1] / sample_fps,
4,
), # 片段持续时间
"frame_start": result[i][1],
"frame_end": result[i + 1][0],
"clipid": slice_id, # 片段序号,
"cliptype": "body",
}
video_clipseq.append(vidoe_clip)
slice_id += 1
video_clipseq = VideoClipSeq.from_data(video_clipseq)
video_map = VideoMap(meta_info=video_meta_info, clipseq=video_clipseq)
if map_path is not None:
with open(map_path, "w") as f:
json.dump(video_map.to_dct(ignored_keys=ignored_keys), f, indent=4)
return video_map, single_frame_result, all_frame_result
class TransNetV2Predictor(object):
def __init__(self, model_path: str, device: str) -> None:
# 模型初始化和参数载入
self.model = TransNetV2()
checkpoint = torch.load(model_path) # 载入模型参数
self.model.load_state_dict(
{k.replace("model.", ""): v for k, v in checkpoint.items()}
)
# model.load_state_dict(checkpoint['state_dict'])
self.model.eval().to(device)
self.device = device
def __call__(self, video_path, map_path, content_box) -> Dict:
return predict(
self.model, video_path, map_path=map_path, content_box=content_box
)
# TODO: is writing
def predict_video_write(
self,
video_dataset: Union[str, SequentialDataset],
c_box=None,
width=48,
height=27,
input_frames=100,
overlap=30,
sample_fps=30,
threshold=0.3,
drop_last=False,
):
# check parameters
assert overlap % 2 == 0
assert input_frames > overlap
# prepare video_dataset
if isinstance(video_dataset, str):
video_dataset = MoviepyVideoDataset(video_dataset)
step = input_frames - overlap
if (
video_dataset.step != step
or video_dataset.time_size != input_frames
or video_dataset.drop_last != drop_last
):
video_dataset.generate_sample_idxs(
time_size=input_frames, step=step, drop_last=drop_last
)
fps = video_dataset.fps
duration = video_dataset.duration
total_frames = video_dataset.total_frames
w, h = video_dataset.size
if c_box:
video_dataset.cap.crop(*c_box)
single_frame_pred_lst, all_frame_pred_lst, index_lst = [], [], []
for i, batch in enumerate(video_dataset):
data, data_index = batch.data, batch.index
data = data.to(self.device)
# shape: batch dim x video frames x frame height x frame width x RGB (not BGR) channels
single_frame_pred, all_frame_pred = self.forward(data.unsqueeze(0)) # 前向推理
# single_frame_pred = F.softmax(single_frame_pred, dim=-1) # 获得每一帧对应的类别概率
# single_frame_pred = torch.argmax(single_frame_pred, dim=-1).reshape(-1)
single_frame_pred = torch.sigmoid(single_frame_pred).reshape(-1)
all_frame_pred = torch.sigmoid(all_frame_pred).reshape(-1)
# single_frame_pred = (single_frame_pred>threshold)*1
if total_frames > data_index[-1]:
if i == 0:
single_frame_pred_label = single_frame_pred[: -overlap // 2]
all_frame_pred_label = all_frame_pred[: -overlap // 2]
else:
single_frame_pred_label = single_frame_pred[
overlap // 2 : -overlap // 2
]
all_frame_pred_label = all_frame_pred[overlap // 2 : -overlap // 2]
else:
if i == 0:
single_frame_pred_label = single_frame_pred
all_frame_pred_label = all_frame_pred
else:
single_frame_pred_label = single_frame_pred[overlap // 2 :]
all_frame_pred_label = all_frame_pred[overlap // 2 :]
single_frame_pred_lst.append(single_frame_pred_label)
all_frame_pred_lst.append(all_frame_pred_label)
index_lst.extent(data_index)
single_frame_pred_label = torch.concat(single_frame_pred_lst, dim=0)
all_frame_pred_label = torch.concat(all_frame_pred_lst, dim=0)
single_frame_pred_label = single_frame_pred_label.cpu().numpy()
all_frame_pred_label = all_frame_pred_label.cpu().numpy()
# 对返回结果做后处理合并相邻帧
pred_label = np.array([single_frame_pred_label, all_frame_pred_label])
pred_label = pred_label.max(axis=0)
transition_index = np.where(pred_label > threshold)[0] # 转场帧位置
transition_index = transition_index.astype(np.float)
result_transition = []
for i, transition in enumerate(transition_index):
if i == 0:
result_transition.append([transition])
else:
if abs(result_transition[-1][-1] - transition) <= 4:
result_transition[-1].append(transition)
else:
result_transition.append([transition])
result = [[0]]
for item in result_transition:
start_idx = int(item[0])
end_idx = int(item[-1])
if len(item) > 3:
if max(pred_label[start_idx : end_idx + 1]) > 0.3:
result.append([item[0], item[-1]])
elif len(item) > 1:
if max(pred_label[start_idx : end_idx + 1]) > 0.4:
result.append([item[0], item[-1]])
else:
if pred_label[start_idx] > 0.45:
result.append(item)
result.append([pred_label.shape[0]])
return (
single_frame_pred_label,
all_frame_pred_label,
fps,
total_frames,
duration,
h,
w,
)