|
import os |
|
import random |
|
from torch.utils.data import Dataset |
|
from PIL import Image |
|
|
|
from .augmentation import MotionAugmentation |
|
|
|
|
|
class VideoMatteDataset(Dataset): |
|
def __init__(self, |
|
videomatte_dir, |
|
background_image_dir, |
|
background_video_dir, |
|
size, |
|
seq_length, |
|
seq_sampler, |
|
transform=None): |
|
self.background_image_dir = background_image_dir |
|
self.background_image_files = os.listdir(background_image_dir) |
|
self.background_video_dir = background_video_dir |
|
self.background_video_clips = sorted(os.listdir(background_video_dir)) |
|
self.background_video_frames = [sorted(os.listdir(os.path.join(background_video_dir, clip))) |
|
for clip in self.background_video_clips] |
|
|
|
self.videomatte_dir = videomatte_dir |
|
self.videomatte_clips = sorted(os.listdir(os.path.join(videomatte_dir, 'fgr'))) |
|
self.videomatte_frames = [sorted(os.listdir(os.path.join(videomatte_dir, 'fgr', clip))) |
|
for clip in self.videomatte_clips] |
|
self.videomatte_idx = [(clip_idx, frame_idx) |
|
for clip_idx in range(len(self.videomatte_clips)) |
|
for frame_idx in range(0, len(self.videomatte_frames[clip_idx]), seq_length)] |
|
self.size = size |
|
self.seq_length = seq_length |
|
self.seq_sampler = seq_sampler |
|
self.transform = transform |
|
|
|
def __len__(self): |
|
return len(self.videomatte_idx) |
|
|
|
def __getitem__(self, idx): |
|
if random.random() < 0.5: |
|
bgrs = self._get_random_image_background() |
|
else: |
|
bgrs = self._get_random_video_background() |
|
|
|
fgrs, phas = self._get_videomatte(idx) |
|
|
|
if self.transform is not None: |
|
return self.transform(fgrs, phas, bgrs) |
|
|
|
return fgrs, phas, bgrs |
|
|
|
def _get_random_image_background(self): |
|
with Image.open(os.path.join(self.background_image_dir, random.choice(self.background_image_files))) as bgr: |
|
bgr = self._downsample_if_needed(bgr.convert('RGB')) |
|
bgrs = [bgr] * self.seq_length |
|
return bgrs |
|
|
|
def _get_random_video_background(self): |
|
clip_idx = random.choice(range(len(self.background_video_clips))) |
|
frame_count = len(self.background_video_frames[clip_idx]) |
|
frame_idx = random.choice(range(max(1, frame_count - self.seq_length))) |
|
clip = self.background_video_clips[clip_idx] |
|
bgrs = [] |
|
for i in self.seq_sampler(self.seq_length): |
|
frame_idx_t = frame_idx + i |
|
frame = self.background_video_frames[clip_idx][frame_idx_t % frame_count] |
|
with Image.open(os.path.join(self.background_video_dir, clip, frame)) as bgr: |
|
bgr = self._downsample_if_needed(bgr.convert('RGB')) |
|
bgrs.append(bgr) |
|
return bgrs |
|
|
|
def _get_videomatte(self, idx): |
|
clip_idx, frame_idx = self.videomatte_idx[idx] |
|
clip = self.videomatte_clips[clip_idx] |
|
frame_count = len(self.videomatte_frames[clip_idx]) |
|
fgrs, phas = [], [] |
|
for i in self.seq_sampler(self.seq_length): |
|
frame = self.videomatte_frames[clip_idx][(frame_idx + i) % frame_count] |
|
with Image.open(os.path.join(self.videomatte_dir, 'fgr', clip, frame)) as fgr, \ |
|
Image.open(os.path.join(self.videomatte_dir, 'pha', clip, frame)) as pha: |
|
fgr = self._downsample_if_needed(fgr.convert('RGB')) |
|
pha = self._downsample_if_needed(pha.convert('L')) |
|
fgrs.append(fgr) |
|
phas.append(pha) |
|
return fgrs, phas |
|
|
|
def _downsample_if_needed(self, img): |
|
w, h = img.size |
|
if min(w, h) > self.size: |
|
scale = self.size / min(w, h) |
|
w = int(scale * w) |
|
h = int(scale * h) |
|
img = img.resize((w, h)) |
|
return img |
|
|
|
class VideoMatteTrainAugmentation(MotionAugmentation): |
|
def __init__(self, size): |
|
super().__init__( |
|
size=size, |
|
prob_fgr_affine=0.3, |
|
prob_bgr_affine=0.3, |
|
prob_noise=0.1, |
|
prob_color_jitter=0.3, |
|
prob_grayscale=0.02, |
|
prob_sharpness=0.1, |
|
prob_blur=0.02, |
|
prob_hflip=0.5, |
|
prob_pause=0.03, |
|
) |
|
|
|
class VideoMatteValidAugmentation(MotionAugmentation): |
|
def __init__(self, size): |
|
super().__init__( |
|
size=size, |
|
prob_fgr_affine=0, |
|
prob_bgr_affine=0, |
|
prob_noise=0, |
|
prob_color_jitter=0, |
|
prob_grayscale=0, |
|
prob_sharpness=0, |
|
prob_blur=0, |
|
prob_hflip=0, |
|
prob_pause=0, |
|
) |
|
|