import random import torch from torch import nn import torch.nn.functional as F import ffmpeg import numpy as np import cv2 from moviepy.editor import VideoFileClip from .utils import get_frames class TransNetV2(nn.Module): def __init__(self, F=16, L=3, S=2, D=1024): super(TransNetV2, self).__init__() self.SDDCNN = nn.ModuleList( [ StackedDDCNNV2( in_filters=3, n_blocks=S, filters=F, stochastic_depth_drop_prob=0.0 ) ] + [ StackedDDCNNV2( in_filters=(F * 2 ** (i - 1)) * 4, n_blocks=S, filters=F * 2**i ) for i in range(1, L) ] ) # 帧相似网络 self.frame_sim_layer = FrameSimilarity( sum([(F * 2**i) * 4 for i in range(L)]), lookup_window=101, output_dim=128, similarity_dim=128, use_bias=True, ) # 颜色相似网络 self.color_hist_layer = ColorHistograms(lookup_window=101, output_dim=128) # dropout self.dropout = nn.Dropout(0.5) output_dim = ((F * 2 ** (L - 1)) * 4) * 3 * 6 # output_dim = output_dim + 128 # 使用了帧相似网络, 维度需要加128 output_dim = output_dim + 128 # 使用了颜色相似网络, 维度需要再加128 self.fc1 = nn.Linear(output_dim, D) self.cls_layer1 = nn.Linear(D, 1) self.cls_layer2 = nn.Linear(D, 1) def forward(self, inputs): # 输入必须为torch.uint8, (h,w)=(27,48)的图片batch样本 # assert isinstance(inputs, torch.Tensor) and list(inputs.shape[2:]) == [27, 48, 3] and inputs.dtype == torch.uint8, "incorrect input type and/or shape" # uint8 of shape [B, T, H, W, 3] to float of shape [B, 3, T, H, W] with torch.autograd.set_detect_anomaly(True): x = inputs.permute([0, 4, 1, 2, 3]).float() x = x.div_(255.0) # 收集每一层的SDDCNN特征图 block_features = [] for block in self.SDDCNN: x = block(x) block_features.append(x) x = x.permute(0, 2, 3, 4, 1) # 把维度从[B, 通道数, T, H, W] 转化为 [B, T, H, W, 通道数] x = x.reshape(x.shape[0], x.shape[1], -1) x = torch.cat( [self.frame_sim_layer(block_features), x], 2 ) # 在最后一维度cat上block_features输出的特征 x = torch.cat( [self.color_hist_layer(inputs), x], 2 ) # 在最后一维度cat上color_hist_layer输出的特征 x = F.relu(self.fc1(x)) x = self.dropout(x) one_hot = self.cls_layer1(x) many_hot = self.cls_layer2(x) return one_hot, many_hot # 预测MP4文件转换帧,并给出对应帧位置 def predict_video( self, mp4_file, cache_path="", c_box=None, width=48, height=27, input_frames=100, overlap=30, sample_fps=30, threshold=0.3, ): """ mp4_file: ~/6712566330782010632.mp4 cache_path: ~/视频单帧数据_h48_w27 return: [x,x,...] 点位时间 """ assert overlap % 2 == 0 assert input_frames > overlap # fps = eval(ffmpeg.probe(mp4_file)['streams'][0]['r_frame_rate']) # 获取视频的视频帧率 # total_frames = int(ffmpeg.probe(mp4_file)['streams'][0]['nb_frames']) # 获取视频的总帧数 # duration = float(ffmpeg.probe(mp4_file)['streams'][0]['duration']) # 获取视频的总时长 video = VideoFileClip(mp4_file) # video = video.subclip(0, 60 * 10) fps = video.fps duration = video.duration total_frames = int(duration * fps) w, h = video.size print(fps, duration, total_frames, w, h) if c_box: video.crop(*c_box) frame_iter = video.iter_frames(fps=sample_fps) sample_total_frames = int(sample_fps * duration) frame_list = [] for i in range(sample_total_frames // (input_frames - overlap) + 1): # if i==1: # break frame_list = frame_list[-overlap:] start_frame = i * (input_frames - overlap) end_frame = min(start_frame + input_frames, sample_total_frames) print("start_frame & end_frame: ", start_frame, end_frame) for frame in frame_iter: frame = cv2.resize(frame, (width, height)) frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) frame_list.append(frame) if len(frame_list) == end_frame - start_frame: break frames = torch.Tensor(frame_list) # 获得帧 if frames.shape[0] < end_frame - start_frame: # 原视频的视频时长比音频时长短,体现出来的是原视频最后有声音没画面 print( "total_frames is wrong: ", total_frames, "-->", start_frame + frames.shape[0], ) # sample_total_frames = start_frame + frames.shape[0] # fps = total_frames / duration frames = frames.cuda() # single_frame_pred和all_frame_pred都是输出window_size长的是否转场概率, single_frame_pred, all_frame_pred = self.forward( frames.unsqueeze(0) ) # 前向推理 # single_frame_pred = F.softmax(single_frame_pred, dim=-1) # 获得每一帧对应的类别概率 # single_frame_pred = torch.argmax(single_frame_pred, dim=-1).reshape(-1) single_frame_pred = torch.sigmoid(single_frame_pred).reshape(-1) all_frame_pred = torch.sigmoid(all_frame_pred).reshape(-1) # single_frame_pred = (single_frame_pred>threshold)*1 if total_frames > end_frame: if i == 0: single_frame_pred_label = single_frame_pred[: -overlap // 2] all_frame_pred_label = all_frame_pred[: -overlap // 2] else: single_frame_pred_label = torch.cat( ( single_frame_pred_label, single_frame_pred[overlap // 2 : -overlap // 2], ), dim=0, ) all_frame_pred_label = torch.cat( ( all_frame_pred_label, all_frame_pred[overlap // 2 : -overlap // 2], ), dim=0, ) else: if i == 0: single_frame_pred_label = single_frame_pred all_frame_pred_label = all_frame_pred else: single_frame_pred_label = torch.cat( (single_frame_pred_label, single_frame_pred[overlap // 2 :]), dim=0, ) all_frame_pred_label = torch.cat( (all_frame_pred_label, all_frame_pred[overlap // 2 :]), dim=0 ) break single_frame_pred_label = single_frame_pred_label.cpu().numpy() all_frame_pred_label = all_frame_pred_label.cpu().numpy() return ( single_frame_pred_label, all_frame_pred_label, fps, total_frames, duration, h, w, ) # transition_index = torch.where(pred_label==1)[0].cpu().numpy() # 转场帧位置 # transition_index = transition_index.astype(np.float) # # 对返回结果做后处理合并相邻帧 # result_transition = [] # for i, transition in enumerate(transition_index): # if i == 0: # result_transition.append([transition]) # else: # if abs(result_transition[-1][-1]-transition) == 1: # result_transition[-1].append(transition) # else: # result_transition.append([transition]) # # result_transition = [[0]] + [[item[0], item[-1]] if len(item)>1 else [item[0]] for item in result_transition] + [[total_frames]] # # return result_transition, fps, total_frames, duration, h, w def predict_video_2( self, mp4_file, cache_path="", c_box=None, width=48, height=27, input_frames=100, overlap=30, sample_fps=30, threshold=0.3, ): """ mp4_file: ~/6712566330782010632.mp4 cache_path: ~/视频单帧数据_h48_w27 return: [x,x,...] 点位时间 """ assert overlap % 2 == 0 assert input_frames > overlap # fps = eval(ffmpeg.probe(mp4_file)['streams'][0]['r_frame_rate']) # 获取视频的视频帧率 # total_frames = int(ffmpeg.probe(mp4_file)['streams'][0]['nb_frames']) # 获取视频的总帧数 # duration = float(ffmpeg.probe(mp4_file)['streams'][0]['duration']) # 获取视频的总时长 video = VideoFileClip(mp4_file) # video = video.subclip(0, 60 * 10) fps = video.fps duration = video.duration total_frames = int(duration * fps) w, h = video.size print(fps, duration, total_frames, w, h) if c_box: video.crop(*c_box) frame_iter = video.iter_frames(fps=sample_fps) sample_total_frames = int(sample_fps * duration) frame_list = [] for i in range(sample_total_frames // (input_frames - overlap) + 1): # if i==1: # break frame_list = frame_list[-overlap:] start_frame = i * (input_frames - overlap) end_frame = min(start_frame + input_frames, sample_total_frames) print("start_frame & end_frame: ", start_frame, end_frame) for frame in frame_iter: frame = cv2.resize(frame, (width, height)) frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) frame_list.append(frame) if len(frame_list) == end_frame - start_frame: break frames = torch.Tensor(frame_list) # 获得帧 if frames.shape[0] < end_frame - start_frame: # 原视频的视频时长比音频时长短,体现出来的是原视频最后有声音没画面 print( "total_frames is wrong: ", total_frames, "-->", start_frame + frames.shape[0], ) # sample_total_frames = start_frame + frames.shape[0] # fps = total_frames / duration frames = frames.cuda() single_frame_pred, all_frame_pred = self.forward( frames.unsqueeze(0) ) # 前向推理 # single_frame_pred = F.softmax(single_frame_pred, dim=-1) # 获得每一帧对应的类别概率 # single_frame_pred = torch.argmax(single_frame_pred, dim=-1).reshape(-1) single_frame_pred = torch.sigmoid(single_frame_pred).reshape(-1) all_frame_pred = torch.sigmoid(all_frame_pred).reshape(-1) # single_frame_pred = (single_frame_pred>threshold)*1 if total_frames > end_frame: if i == 0: single_frame_pred_label = single_frame_pred[: -overlap // 2] all_frame_pred_label = all_frame_pred[: -overlap // 2] else: single_frame_pred_label = torch.cat( ( single_frame_pred_label, single_frame_pred[overlap // 2 : -overlap // 2], ), dim=0, ) all_frame_pred_label = torch.cat( ( all_frame_pred_label, all_frame_pred[overlap // 2 : -overlap // 2], ), dim=0, ) else: if i == 0: single_frame_pred_label = single_frame_pred all_frame_pred_label = all_frame_pred else: single_frame_pred_label = torch.cat( (single_frame_pred_label, single_frame_pred[overlap // 2 :]), dim=0, ) all_frame_pred_label = torch.cat( (all_frame_pred_label, all_frame_pred[overlap // 2 :]), dim=0 ) break single_frame_pred_label = single_frame_pred_label.cpu().numpy() all_frame_pred_label = all_frame_pred_label.cpu().numpy() return ( single_frame_pred_label, all_frame_pred_label, fps, total_frames, duration, h, w, ) class StackedDDCNNV2(nn.Module): def __init__( self, in_filters, n_blocks, filters, shortcut=True, pool_type="avg", stochastic_depth_drop_prob=0.0, ): super(StackedDDCNNV2, self).__init__() self.shortcut = shortcut # 定义DDCNN层 self.DDCNN = nn.ModuleList( [ DilatedDCNNV2( in_filters if i == 1 else filters * 4, filters, activation=F.relu if i != n_blocks else None, ) for i in range(1, n_blocks + 1) ] ) # 有n_blocks层数量的DilateDCNNV2模块 # 定义pool层 self.pool = ( nn.MaxPool3d(kernel_size=(1, 2, 2)) if pool_type == "max" else nn.AvgPool3d(kernel_size=(1, 2, 2)) ) self.stochastic_depth_drop_prob = stochastic_depth_drop_prob def forward(self, inputs): x = inputs shortcut = None # DDCNN层前向传播 for block in self.DDCNN: x = block(x) if shortcut is None: # 记录第一层的结果作为残差连接 shortcut = x x = F.relu(x) if self.shortcut is not None: if self.stochastic_depth_drop_prob != 0.0: if self.training: if random.random() < self.stochastic_depth_drop_prob: x = shortcut else: x = x + shortcut else: x = (1 - self.stochastic_depth_drop_prob) * x + shortcut else: x = x + shortcut x = self.pool(x) return x class DilatedDCNNV2(nn.Module): def __init__(self, in_filters, filters, batch_norm=True, activation=None): super(DilatedDCNNV2, self).__init__() self.Conv3D_1 = Conv3DConfigurable( in_filters, filters, 1, use_bias=not batch_norm ) self.Conv3D_2 = Conv3DConfigurable( in_filters, filters, 2, use_bias=not batch_norm ) self.Conv3D_4 = Conv3DConfigurable( in_filters, filters, 4, use_bias=not batch_norm ) self.Conv3D_8 = Conv3DConfigurable( in_filters, filters, 8, use_bias=not batch_norm ) self.bn = nn.BatchNorm3d(filters * 4, eps=1e-3) if batch_norm else None self.activation = activation # 激活函数定义 def forward(self, inputs): conv1 = self.Conv3D_1(inputs) conv2 = self.Conv3D_2(inputs) conv3 = self.Conv3D_4(inputs) conv4 = self.Conv3D_8(inputs) x = torch.cat([conv1, conv2, conv3, conv4], dim=1) if self.bn is not None: x = self.bn(x) if self.activation is not None: x = self.activation(x) return x class Conv3DConfigurable(nn.Module): def __init__( self, in_filters, filters, dilation_rate, separable=True, use_bias=True ): super(Conv3DConfigurable, self).__init__() if separable: # (2+1)D convolution https://arxiv.org/pdf/1711.11248.pdf conv1 = nn.Conv3d( in_filters, 2 * filters, kernel_size=(1, 3, 3), dilation=(1, 1, 1), padding=(0, 1, 1), bias=False, ) conv2 = nn.Conv3d( 2 * filters, filters, kernel_size=(3, 1, 1), dilation=(dilation_rate, 1, 1), padding=(dilation_rate, 0, 0), bias=use_bias, ) self.layers = nn.ModuleList([conv1, conv2]) else: conv = nn.Conv3d( in_filters, filters, kernel_size=3, dilation=(dilation_rate, 1, 1), padding=(dilation_rate, 1, 1), bias=use_bias, ) self.layers = nn.ModuleList([conv]) def forward(self, inputs): x = inputs for layer in self.layers: x = layer(x) return x # 帧相似网络构建 class FrameSimilarity(nn.Module): def __init__( self, in_filters, similarity_dim=128, lookup_window=101, output_dim=128, use_bias=False, ): super(FrameSimilarity, self).__init__() self.projection = nn.Linear(in_filters, similarity_dim, bias=use_bias) self.fc = nn.Linear(lookup_window, output_dim) self.lookup_window = lookup_window assert lookup_window % 2 == 1, "`lookup_window` must be odd integer" def forward(self, inputs): x = torch.cat([torch.mean(x, dim=[3, 4]) for x in inputs], dim=1) x = torch.transpose(x, 1, 2) x = self.projection(x) x = F.normalize(x, p=2, dim=2) batch_size, time_window = x.shape[0], x.shape[1] similarities = torch.bmm( x, x.transpose(1, 2) ) # [batch_size, time_window, time_window]余弦相似度 similarities_padded = F.pad( similarities, [(self.lookup_window - 1) // 2, (self.lookup_window - 1) // 2] ) batch_indices = ( torch.arange(0, batch_size, device=x.device) .view([batch_size, 1, 1]) .repeat([1, time_window, self.lookup_window]) ) time_indices = ( torch.arange(0, time_window, device=x.device) .view([1, time_window, 1]) .repeat([batch_size, 1, self.lookup_window]) ) lookup_indices = ( torch.arange(0, self.lookup_window, device=x.device) .view([1, 1, self.lookup_window]) .repeat([batch_size, time_window, 1]) + time_indices ) similarities = similarities_padded[batch_indices, time_indices, lookup_indices] return F.relu(self.fc(similarities)) # 颜色相似网络 class ColorHistograms(nn.Module): def __init__(self, lookup_window=101, output_dim=None): super(ColorHistograms, self).__init__() self.fc = ( nn.Linear(lookup_window, output_dim) if output_dim is not None else None ) self.lookup_window = lookup_window assert lookup_window % 2 == 1, "`lookup_window` must be odd integer" @staticmethod def compute_color_histograms(frames): frames = frames.int() def get_bin(frames): # returns 0 .. 511 R, G, B = frames[:, :, 0], frames[:, :, 1], frames[:, :, 2] R, G, B = R >> 5, G >> 5, B >> 5 return (R << 6) + (G << 3) + B batch_size, time_window, height, width, no_channels = frames.shape assert no_channels == 3 frames_flatten = frames.view(batch_size * time_window, height * width, 3) binned_values = get_bin(frames_flatten) frame_bin_prefix = ( torch.arange(0, batch_size * time_window, device=frames.device) << 9 ).view(-1, 1) binned_values = (binned_values + frame_bin_prefix).view(-1) histograms = torch.zeros( batch_size * time_window * 512, dtype=torch.int32, device=frames.device ) histograms.scatter_add_( 0, binned_values, torch.ones(len(binned_values), dtype=torch.int32, device=frames.device), ) histograms = histograms.view(batch_size, time_window, 512).float() histograms_normalized = F.normalize(histograms, p=2, dim=2) return histograms_normalized def forward(self, inputs): x = self.compute_color_histograms(inputs) batch_size, time_window = x.shape[0], x.shape[1] similarities = torch.bmm( x, x.transpose(1, 2) ) # [batch_size, time_window, time_window] similarities_padded = F.pad( similarities, [(self.lookup_window - 1) // 2, (self.lookup_window - 1) // 2] ) batch_indices = ( torch.arange(0, batch_size, device=x.device) .view([batch_size, 1, 1]) .repeat([1, time_window, self.lookup_window]) ) time_indices = ( torch.arange(0, time_window, device=x.device) .view([1, time_window, 1]) .repeat([batch_size, 1, self.lookup_window]) ) lookup_indices = ( torch.arange(0, self.lookup_window, device=x.device) .view([1, 1, self.lookup_window]) .repeat([batch_size, time_window, 1]) + time_indices ) similarities = similarities_padded[batch_indices, time_indices, lookup_indices] if self.fc is not None: return F.relu(self.fc(similarities)) return similarities