# -*- coding: utf-8 -*- from typing import Dict, List, Union import os import json import traceback import argparse import hashlib import librosa import soundfile as sf import numpy as np import torch from moviepy.editor import * from .TransNetmodels import TransNetV2 from ...video_map.video_meta_info import VideoMetaInfo from ...video_map.video_map import VideoMap from ...video_map.video_clip import VideoClipSeq from ...black_border import det_video_black_border from ...utils.path_util import get_video_signature from ...data.video_dataset import MoviepyVideoDataset, SequentialDataset def predict( model, video_path, threshold=0.3, sample_fps=25, content_box=None, single_frame_ratio=1, map_path: str = None, ignored_keys: List = None, ) -> VideoMap: video_hash_code, video_path = get_video_signature(path=video_path, rename=True) basename = os.path.basename(video_path) filename, ext = os.path.splitext(video_path) with torch.no_grad(): ( single_frame_result, all_frame_result, fps, total_frames, duration, height, width, ) = model.predict_video( video_path, cache_path="", c_box=content_box, width=48, height=27, input_frames=10000, overlap=100, sample_fps=sample_fps, ) # pred_label = single_frame_ratio * single_frame_result + (1 - single_frame_ratio) * all_frame_result pred_label = np.array([single_frame_result, all_frame_result]) pred_label = pred_label.max(axis=0) transition_index = np.where(pred_label > threshold)[0] # 转场帧位置 transition_index = transition_index.astype(np.float) # 对返回结果做后处理合并相邻帧 result_transition = [] for i, transition in enumerate(transition_index): if i == 0: result_transition.append([transition]) else: if abs(result_transition[-1][-1] - transition) <= 4: result_transition[-1].append(transition) else: result_transition.append([transition]) result = [[0]] for item in result_transition: start_idx = int(item[0]) end_idx = int(item[-1]) if len(item) > 3: if max(pred_label[start_idx : end_idx + 1]) > 0.3: result.append([item[0], item[-1]]) elif len(item) > 1: if max(pred_label[start_idx : end_idx + 1]) > 0.4: result.append([item[0], item[-1]]) else: if pred_label[start_idx] > 0.45: result.append(item) result.append([pred_label.shape[0]]) video_meta_info_dct = { "video_name": filename, "video_path": video_path, "video_file_hash_code": video_hash_code, "fps": fps, "frame_num": total_frames, "duration": duration, "height": height, "width": width, "content_box": content_box, "sample_fps": sample_fps, } video_meta_info = VideoMetaInfo.from_video_path(video_path) video_meta_info.__dict__.update(video_meta_info_dct) video_clipseq = [] slice_id = 0 for i in range(len(result) - 1): if len(result[i]) == 1: vidoe_clip = { "time_start": round(result[i][0] / sample_fps, 4), # 开始时间 "duration": round( result[i + 1][0] / sample_fps - result[i][0] / sample_fps, 4, ), # 片段持续时间 "frame_start": result[i][0], "frame_end": result[i + 1][0], "clipid": slice_id, # 片段序号, "cliptype": "body", } video_clipseq.append(vidoe_clip) slice_id += 1 elif len(result[i]) == 2: vidoe_clip = { "time_start": round(result[i][0] / sample_fps, 4), # 开始时间 "duration": round( result[i][1] / sample_fps - result[i][0] / sample_fps, 4, ), # 片段持续时间 "frame_start": result[i][0], "frame_end": result[i][1], "clipid": slice_id, # 片段序号, "cliptype": "transition", } video_clipseq.append(vidoe_clip) slice_id += 1 vidoe_clip = { "time_start": round(result[i][1] / sample_fps, 4), # 开始时间 "duration": round( result[i + 1][0] / sample_fps - result[i][1] / sample_fps, 4, ), # 片段持续时间 "frame_start": result[i][1], "frame_end": result[i + 1][0], "clipid": slice_id, # 片段序号, "cliptype": "body", } video_clipseq.append(vidoe_clip) slice_id += 1 video_clipseq = VideoClipSeq.from_data(video_clipseq) video_map = VideoMap(meta_info=video_meta_info, clipseq=video_clipseq) if map_path is not None: with open(map_path, "w") as f: json.dump(video_map.to_dct(ignored_keys=ignored_keys), f, indent=4) return video_map, single_frame_result, all_frame_result class TransNetV2Predictor(object): def __init__(self, model_path: str, device: str) -> None: # 模型初始化和参数载入 self.model = TransNetV2() checkpoint = torch.load(model_path) # 载入模型参数 self.model.load_state_dict( {k.replace("model.", ""): v for k, v in checkpoint.items()} ) # model.load_state_dict(checkpoint['state_dict']) self.model.eval().to(device) self.device = device def __call__(self, video_path, map_path, content_box) -> Dict: return predict( self.model, video_path, map_path=map_path, content_box=content_box ) # TODO: is writing def predict_video_write( self, video_dataset: Union[str, SequentialDataset], c_box=None, width=48, height=27, input_frames=100, overlap=30, sample_fps=30, threshold=0.3, drop_last=False, ): # check parameters assert overlap % 2 == 0 assert input_frames > overlap # prepare video_dataset if isinstance(video_dataset, str): video_dataset = MoviepyVideoDataset(video_dataset) step = input_frames - overlap if ( video_dataset.step != step or video_dataset.time_size != input_frames or video_dataset.drop_last != drop_last ): video_dataset.generate_sample_idxs( time_size=input_frames, step=step, drop_last=drop_last ) fps = video_dataset.fps duration = video_dataset.duration total_frames = video_dataset.total_frames w, h = video_dataset.size if c_box: video_dataset.cap.crop(*c_box) single_frame_pred_lst, all_frame_pred_lst, index_lst = [], [], [] for i, batch in enumerate(video_dataset): data, data_index = batch.data, batch.index data = data.to(self.device) # shape: batch dim x video frames x frame height x frame width x RGB (not BGR) channels single_frame_pred, all_frame_pred = self.forward(data.unsqueeze(0)) # 前向推理 # single_frame_pred = F.softmax(single_frame_pred, dim=-1) # 获得每一帧对应的类别概率 # single_frame_pred = torch.argmax(single_frame_pred, dim=-1).reshape(-1) single_frame_pred = torch.sigmoid(single_frame_pred).reshape(-1) all_frame_pred = torch.sigmoid(all_frame_pred).reshape(-1) # single_frame_pred = (single_frame_pred>threshold)*1 if total_frames > data_index[-1]: if i == 0: single_frame_pred_label = single_frame_pred[: -overlap // 2] all_frame_pred_label = all_frame_pred[: -overlap // 2] else: single_frame_pred_label = single_frame_pred[ overlap // 2 : -overlap // 2 ] all_frame_pred_label = all_frame_pred[overlap // 2 : -overlap // 2] else: if i == 0: single_frame_pred_label = single_frame_pred all_frame_pred_label = all_frame_pred else: single_frame_pred_label = single_frame_pred[overlap // 2 :] all_frame_pred_label = all_frame_pred[overlap // 2 :] single_frame_pred_lst.append(single_frame_pred_label) all_frame_pred_lst.append(all_frame_pred_label) index_lst.extent(data_index) single_frame_pred_label = torch.concat(single_frame_pred_lst, dim=0) all_frame_pred_label = torch.concat(all_frame_pred_lst, dim=0) single_frame_pred_label = single_frame_pred_label.cpu().numpy() all_frame_pred_label = all_frame_pred_label.cpu().numpy() # 对返回结果做后处理合并相邻帧 pred_label = np.array([single_frame_pred_label, all_frame_pred_label]) pred_label = pred_label.max(axis=0) transition_index = np.where(pred_label > threshold)[0] # 转场帧位置 transition_index = transition_index.astype(np.float) result_transition = [] for i, transition in enumerate(transition_index): if i == 0: result_transition.append([transition]) else: if abs(result_transition[-1][-1] - transition) <= 4: result_transition[-1].append(transition) else: result_transition.append([transition]) result = [[0]] for item in result_transition: start_idx = int(item[0]) end_idx = int(item[-1]) if len(item) > 3: if max(pred_label[start_idx : end_idx + 1]) > 0.3: result.append([item[0], item[-1]]) elif len(item) > 1: if max(pred_label[start_idx : end_idx + 1]) > 0.4: result.append([item[0], item[-1]]) else: if pred_label[start_idx] > 0.45: result.append(item) result.append([pred_label.shape[0]]) return ( single_frame_pred_label, all_frame_pred_label, fps, total_frames, duration, h, w, )