Spaces:
No application file
No application file
import os | |
import pickle | |
import ffmpeg | |
import numpy as np | |
import torch | |
def get_frames(fn, cache_path, start_frame, end_frame, width=48, height=27): | |
''' | |
先查询cache_path路径下是否有fn的缓存文件(.pkl),若存在则直接载入, 不存在则通过ffmpeg提取获得 | |
fn: mp4文件所在路径 | |
width, height: 转换为指定宽高的视频 | |
cache_path: 缓存文件所在位置, 例如/data_share7/v_hyggewang/视频单帧数据_h48_w27, 该文件夹中包含有mp4_id.pkl文件则会导入该文件 | |
return: [总帧率, height, width, 3], np.array | |
''' | |
cache_mp4_pkl = os.path.join(cache_path, os.path.basename(fn).replace('mp4','pkl')) | |
if os.path.exists(cache_mp4_pkl) : | |
video = pickle.load(open(cache_mp4_pkl,'rb')) | |
assert (video.shape[1]==height and video.shape[2]==width), "mp4缓存文件维度与指定h,w不同" | |
# print('mp4帧文件缓存载入成功') | |
else: | |
# print('使用ffmpeg提取mp4单帧图片数据') | |
# 视频转化为np数组 | |
video_stream, err = ( | |
ffmpeg | |
.input(fn) | |
# .filter('select', 'between(n,{},{})'.format(start_frame,end_frame)) | |
.trim(start_frame=start_frame, end_frame=end_frame) | |
.setpts('PTS-STARTPTS') | |
.output('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width, height)) | |
.run(capture_stdout=True, capture_stderr=True) | |
) | |
video = np.frombuffer(video_stream, np.uint8).reshape([-1, height, width, 3]) | |
return video | |
def get_frames_label(mp4_annotations, n_frames, fps): | |
''' | |
根据mp4的标注信息,以及总帧数长度,给所有帧打标签。 | |
mp4_annotations: 列表[item...], 每个item是一个字典, 包括有该时间点类别,具体时间。 | |
fps: 该视频的帧率 | |
n_frames: 该视频总帧数 | |
return: many_hot,one_hot维度为(n_frames,)。one_hot表示只有在标注帧的地方是有标签, many_hot表示标注帧的地方前后2帧有相同标签 | |
''' | |
one_hot = np.zeros([n_frames], np.int64) | |
many_hot = np.zeros([n_frames], np.int64) | |
for item in mp4_annotations: | |
if item["class"] == "transition": | |
transition_frame_idx = int(item['timaestamp'] * fps) # 根据据标注时间点标注关键帧标签 | |
one_hot[min(transition_frame_idx, n_frames-1)] = 1 # 卡点帧标记为1 ,有些标注视频总长度低于标注点位,因此需要处理一下 | |
for i in range(min(transition_frame_idx-2, n_frames-2), min(transition_frame_idx+2, n_frames)): | |
many_hot[i] = 1 # 卡点帧周围2帧标记为1(共5帧) | |
return one_hot, many_hot | |
def get_mp4_scenes(frames, one_hot, many_hot): | |
""" | |
对视频按照100帧顺序切片,单个视频顺序切出多段,并且每个片段重叠30帧,例如: 210帧总长度,需要切成0-100帧,70-170帧,140-240帧(重复30帧) | |
frames: (T, H, W, 通道数),视频单帧数据.(np.array) | |
one_hot: 视频标签one_hot | |
many_hot: 视频标签many_hot | |
return: 按100帧切割的图片块:[x, 100, H, W, 3], one_hot_scenes:[x, 100], many_hot_scenes:[x, 100], x:表示该mp4被切成了多少块 | |
""" | |
frames = torch.from_numpy(frames) | |
# 重复最后一帧图片使得图片能够正好按规则切分 | |
if (len(frames)-30)%70 != 0: | |
repeat_n = (len(frames)-30)//70*70 + 100 - len(frames) | |
last_scenes = frames[-1, :, :, :] # 获取最后一个图像 | |
last_scenes = torch.unsqueeze(last_scenes, 0).expand([repeat_n, -1, -1, -1]) | |
frames = torch.cat([frames, last_scenes], dim=0) # 在原数据最后增加几帧图片 | |
one_hot = np.concatenate((one_hot, [0]*repeat_n)) # 标签也需要增加长度 | |
many_hot = np.concatenate((many_hot, [0]*repeat_n)) | |
one_hot = torch.from_numpy(one_hot) # 转化为tensor | |
many_hot = torch.from_numpy(many_hot) | |
scenes = [] # 用于储存分割的片段 | |
one_hot_scenes = [] # 用于储存分割片段的label | |
many_hot_scenes = [] | |
# 按规则切片 | |
for i in range((len(frames)-30)//70): | |
scenes.append(frames[i*70:(i*70+100)]) | |
one_hot_scenes.append(one_hot[i*70:(i*70+100)]) | |
many_hot_scenes.append(many_hot[i*70:(i*70+100)]) | |
return torch.stack(scenes, dim=0), torch.stack(one_hot_scenes, dim=0), torch.stack(many_hot_scenes, dim=0), int((len(frames)-30)//70) | |
def get_mp4_random_scenes(frames, one_hot, many_hot): | |
""" | |
对视频随机设定起点, 抽取100帧数据 | |
frames: (T, H, W, 通道数),视频单帧数据.(np.array) | |
one_hot: 视频标签one_hot | |
many_hot: 视频标签many_hot | |
return: 随机起点的100帧图片数据[1, 100, H, W, 3], one_hot_scenes:[1, 100], many_hot_scenes:[1, 100] | |
""" | |
number_frames = len(frames) | |
start_frames = np.random.randint(0, number_frames-100+1) # 随机获得起始帧位置 | |
end_frames = start_frames + 100 # 获得终止帧位置 | |
fal_frames = torch.from_numpy(frames[start_frames:end_frames]) | |
fal_one_hot = torch.from_numpy(one_hot[start_frames:end_frames]) | |
fal_many_hot = torch.from_numpy(many_hot[start_frames:end_frames]) | |
return torch.stack([fal_frames], dim=0), torch.stack([fal_one_hot], dim=0), torch.stack([fal_many_hot], dim=0) | |