|
import os.path as osp |
|
import random |
|
from glob import glob |
|
|
|
from torchvision import transforms |
|
import numpy as np |
|
import torch |
|
import torch.utils.data as data |
|
import torch.nn.functional as F |
|
from torchvision.transforms import Lambda |
|
|
|
from ....dataset.transform import ToTensorVideo, CenterCropVideo |
|
from ....utils.dataset_utils import DecordInit |
|
|
|
def TemporalRandomCrop(total_frames, size): |
|
""" |
|
Performs a random temporal crop on a video sequence. |
|
|
|
This function randomly selects a continuous frame sequence of length `size` from a video sequence. |
|
`total_frames` indicates the total number of frames in the video sequence, and `size` represents the length of the frame sequence to be cropped. |
|
|
|
Parameters: |
|
- total_frames (int): The total number of frames in the video sequence. |
|
- size (int): The length of the frame sequence to be cropped. |
|
|
|
Returns: |
|
- (int, int): A tuple containing two integers. The first integer is the starting frame index of the cropped sequence, |
|
and the second integer is the ending frame index (inclusive) of the cropped sequence. |
|
""" |
|
rand_end = max(0, total_frames - size - 1) |
|
begin_index = random.randint(0, rand_end) |
|
end_index = min(begin_index + size, total_frames) |
|
return begin_index, end_index |
|
|
|
def resize(x, resolution): |
|
height, width = x.shape[-2:] |
|
resolution = min(2 * resolution, height, width) |
|
aspect_ratio = width / height |
|
if width <= height: |
|
new_width = resolution |
|
new_height = int(resolution / aspect_ratio) |
|
else: |
|
new_height = resolution |
|
new_width = int(resolution * aspect_ratio) |
|
resized_x = F.interpolate(x, size=(new_height, new_width), mode='bilinear', align_corners=True, antialias=True) |
|
return resized_x |
|
|
|
class VideoDataset(data.Dataset): |
|
""" Generic dataset for videos files stored in folders |
|
Returns BCTHW videos in the range [-0.5, 0.5] """ |
|
video_exts = ['avi', 'mp4', 'webm'] |
|
def __init__(self, video_folder, sequence_length, image_folder=None, train=True, resolution=64, sample_rate=1, dynamic_sample=True): |
|
|
|
self.train = train |
|
self.sequence_length = sequence_length |
|
self.sample_rate = sample_rate |
|
self.resolution = resolution |
|
self.v_decoder = DecordInit() |
|
self.video_folder = video_folder |
|
self.dynamic_sample = dynamic_sample |
|
|
|
self.transform = transforms.Compose([ |
|
ToTensorVideo(), |
|
|
|
CenterCropVideo(self.resolution), |
|
Lambda(lambda x: 2.0 * x - 1.0) |
|
]) |
|
print('Building datasets...') |
|
self.samples = self._make_dataset() |
|
|
|
def _make_dataset(self): |
|
samples = [] |
|
samples += sum([glob(osp.join(self.video_folder, '**', f'*.{ext}'), recursive=True) |
|
for ext in self.video_exts], []) |
|
return samples |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|
|
def __getitem__(self, idx): |
|
video_path = self.samples[idx] |
|
try: |
|
video = self.decord_read(video_path) |
|
video = self.transform(video) |
|
video = video.transpose(0, 1) |
|
return dict(video=video, label="") |
|
except Exception as e: |
|
print(f'Error with {e}, {video_path}') |
|
return self.__getitem__(random.randint(0, self.__len__()-1)) |
|
|
|
def decord_read(self, path): |
|
decord_vr = self.v_decoder(path) |
|
total_frames = len(decord_vr) |
|
|
|
if self.dynamic_sample: |
|
sample_rate = random.randint(1, self.sample_rate) |
|
else: |
|
sample_rate = self.sample_rate |
|
size = self.sequence_length * sample_rate |
|
start_frame_ind, end_frame_ind = TemporalRandomCrop(total_frames, size) |
|
|
|
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.sequence_length, dtype=int) |
|
|
|
video_data = decord_vr.get_batch(frame_indice).asnumpy() |
|
video_data = torch.from_numpy(video_data) |
|
video_data = video_data.permute(0, 3, 1, 2) |
|
return video_data |