|
import easing_functions as ef |
|
import random |
|
import torch |
|
from torchvision import transforms |
|
from torchvision.transforms import functional as F |
|
|
|
|
|
class MotionAugmentation: |
|
def __init__(self, |
|
size, |
|
prob_fgr_affine, |
|
prob_bgr_affine, |
|
prob_noise, |
|
prob_color_jitter, |
|
prob_grayscale, |
|
prob_sharpness, |
|
prob_blur, |
|
prob_hflip, |
|
prob_pause, |
|
static_affine=True, |
|
aspect_ratio_range=(0.9, 1.1)): |
|
self.size = size |
|
self.prob_fgr_affine = prob_fgr_affine |
|
self.prob_bgr_affine = prob_bgr_affine |
|
self.prob_noise = prob_noise |
|
self.prob_color_jitter = prob_color_jitter |
|
self.prob_grayscale = prob_grayscale |
|
self.prob_sharpness = prob_sharpness |
|
self.prob_blur = prob_blur |
|
self.prob_hflip = prob_hflip |
|
self.prob_pause = prob_pause |
|
self.static_affine = static_affine |
|
self.aspect_ratio_range = aspect_ratio_range |
|
|
|
def __call__(self, fgrs, phas, bgrs): |
|
|
|
if random.random() < self.prob_fgr_affine: |
|
fgrs, phas = self._motion_affine(fgrs, phas) |
|
|
|
|
|
if random.random() < self.prob_bgr_affine / 2: |
|
bgrs = self._motion_affine(bgrs) |
|
if random.random() < self.prob_bgr_affine / 2: |
|
fgrs, phas, bgrs = self._motion_affine(fgrs, phas, bgrs) |
|
|
|
|
|
if self.static_affine: |
|
fgrs, phas = self._static_affine(fgrs, phas, scale_ranges=(0.5, 1)) |
|
bgrs = self._static_affine(bgrs, scale_ranges=(1, 1.5)) |
|
|
|
|
|
fgrs = torch.stack([F.to_tensor(fgr) for fgr in fgrs]) |
|
phas = torch.stack([F.to_tensor(pha) for pha in phas]) |
|
bgrs = torch.stack([F.to_tensor(bgr) for bgr in bgrs]) |
|
|
|
|
|
params = transforms.RandomResizedCrop.get_params(fgrs, scale=(1, 1), ratio=self.aspect_ratio_range) |
|
fgrs = F.resized_crop(fgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) |
|
phas = F.resized_crop(phas, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) |
|
params = transforms.RandomResizedCrop.get_params(bgrs, scale=(1, 1), ratio=self.aspect_ratio_range) |
|
bgrs = F.resized_crop(bgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) |
|
|
|
|
|
if random.random() < self.prob_hflip: |
|
fgrs = F.hflip(fgrs) |
|
phas = F.hflip(phas) |
|
if random.random() < self.prob_hflip: |
|
bgrs = F.hflip(bgrs) |
|
|
|
|
|
if random.random() < self.prob_noise: |
|
fgrs, bgrs = self._motion_noise(fgrs, bgrs) |
|
|
|
|
|
if random.random() < self.prob_color_jitter: |
|
fgrs = self._motion_color_jitter(fgrs) |
|
if random.random() < self.prob_color_jitter: |
|
bgrs = self._motion_color_jitter(bgrs) |
|
|
|
|
|
if random.random() < self.prob_grayscale: |
|
fgrs = F.rgb_to_grayscale(fgrs, num_output_channels=3).contiguous() |
|
bgrs = F.rgb_to_grayscale(bgrs, num_output_channels=3).contiguous() |
|
|
|
|
|
if random.random() < self.prob_sharpness: |
|
sharpness = random.random() * 8 |
|
fgrs = F.adjust_sharpness(fgrs, sharpness) |
|
phas = F.adjust_sharpness(phas, sharpness) |
|
bgrs = F.adjust_sharpness(bgrs, sharpness) |
|
|
|
|
|
if random.random() < self.prob_blur / 3: |
|
fgrs, phas = self._motion_blur(fgrs, phas) |
|
if random.random() < self.prob_blur / 3: |
|
bgrs = self._motion_blur(bgrs) |
|
if random.random() < self.prob_blur / 3: |
|
fgrs, phas, bgrs = self._motion_blur(fgrs, phas, bgrs) |
|
|
|
|
|
if random.random() < self.prob_pause: |
|
fgrs, phas, bgrs = self._motion_pause(fgrs, phas, bgrs) |
|
|
|
return fgrs, phas, bgrs |
|
|
|
def _static_affine(self, *imgs, scale_ranges): |
|
params = transforms.RandomAffine.get_params( |
|
degrees=(-10, 10), translate=(0.1, 0.1), scale_ranges=scale_ranges, |
|
shears=(-5, 5), img_size=imgs[0][0].size) |
|
imgs = [[F.affine(t, *params, F.InterpolationMode.BILINEAR) for t in img] for img in imgs] |
|
return imgs if len(imgs) > 1 else imgs[0] |
|
|
|
def _motion_affine(self, *imgs): |
|
config = dict(degrees=(-10, 10), translate=(0.1, 0.1), |
|
scale_ranges=(0.9, 1.1), shears=(-5, 5), img_size=imgs[0][0].size) |
|
angleA, (transXA, transYA), scaleA, (shearXA, shearYA) = transforms.RandomAffine.get_params(**config) |
|
angleB, (transXB, transYB), scaleB, (shearXB, shearYB) = transforms.RandomAffine.get_params(**config) |
|
|
|
T = len(imgs[0]) |
|
easing = random_easing_fn() |
|
for t in range(T): |
|
percentage = easing(t / (T - 1)) |
|
angle = lerp(angleA, angleB, percentage) |
|
transX = lerp(transXA, transXB, percentage) |
|
transY = lerp(transYA, transYB, percentage) |
|
scale = lerp(scaleA, scaleB, percentage) |
|
shearX = lerp(shearXA, shearXB, percentage) |
|
shearY = lerp(shearYA, shearYB, percentage) |
|
for img in imgs: |
|
img[t] = F.affine(img[t], angle, (transX, transY), scale, (shearX, shearY), F.InterpolationMode.BILINEAR) |
|
return imgs if len(imgs) > 1 else imgs[0] |
|
|
|
def _motion_noise(self, *imgs): |
|
grain_size = random.random() * 3 + 1 |
|
monochrome = random.random() < 0.5 |
|
for img in imgs: |
|
T, C, H, W = img.shape |
|
noise = torch.randn((T, 1 if monochrome else C, round(H / grain_size), round(W / grain_size))) |
|
noise.mul_(random.random() * 0.2 / grain_size) |
|
if grain_size != 1: |
|
noise = F.resize(noise, (H, W)) |
|
img.add_(noise).clamp_(0, 1) |
|
return imgs if len(imgs) > 1 else imgs[0] |
|
|
|
def _motion_color_jitter(self, *imgs): |
|
brightnessA, brightnessB, contrastA, contrastB, saturationA, saturationB, hueA, hueB \ |
|
= torch.randn(8).mul(0.1).tolist() |
|
strength = random.random() * 0.2 |
|
easing = random_easing_fn() |
|
T = len(imgs[0]) |
|
for t in range(T): |
|
percentage = easing(t / (T - 1)) * strength |
|
for img in imgs: |
|
img[t] = F.adjust_brightness(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1)) |
|
img[t] = F.adjust_contrast(img[t], max(1 + lerp(contrastA, contrastB, percentage), 0.1)) |
|
img[t] = F.adjust_saturation(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1)) |
|
img[t] = F.adjust_hue(img[t], min(0.5, max(-0.5, lerp(hueA, hueB, percentage) * 0.1))) |
|
return imgs if len(imgs) > 1 else imgs[0] |
|
|
|
def _motion_blur(self, *imgs): |
|
blurA = random.random() * 10 |
|
blurB = random.random() * 10 |
|
|
|
T = len(imgs[0]) |
|
easing = random_easing_fn() |
|
for t in range(T): |
|
percentage = easing(t / (T - 1)) |
|
blur = max(lerp(blurA, blurB, percentage), 0) |
|
if blur != 0: |
|
kernel_size = int(blur * 2) |
|
if kernel_size % 2 == 0: |
|
kernel_size += 1 |
|
for img in imgs: |
|
img[t] = F.gaussian_blur(img[t], kernel_size, sigma=blur) |
|
|
|
return imgs if len(imgs) > 1 else imgs[0] |
|
|
|
def _motion_pause(self, *imgs): |
|
T = len(imgs[0]) |
|
pause_frame = random.choice(range(T - 1)) |
|
pause_length = random.choice(range(T - pause_frame)) |
|
for img in imgs: |
|
img[pause_frame + 1 : pause_frame + pause_length] = img[pause_frame] |
|
return imgs if len(imgs) > 1 else imgs[0] |
|
|
|
|
|
def lerp(a, b, percentage): |
|
return a * (1 - percentage) + b * percentage |
|
|
|
|
|
def random_easing_fn(): |
|
if random.random() < 0.2: |
|
return ef.LinearInOut() |
|
else: |
|
return random.choice([ |
|
ef.BackEaseIn, |
|
ef.BackEaseOut, |
|
ef.BackEaseInOut, |
|
ef.BounceEaseIn, |
|
ef.BounceEaseOut, |
|
ef.BounceEaseInOut, |
|
ef.CircularEaseIn, |
|
ef.CircularEaseOut, |
|
ef.CircularEaseInOut, |
|
ef.CubicEaseIn, |
|
ef.CubicEaseOut, |
|
ef.CubicEaseInOut, |
|
ef.ExponentialEaseIn, |
|
ef.ExponentialEaseOut, |
|
ef.ExponentialEaseInOut, |
|
ef.ElasticEaseIn, |
|
ef.ElasticEaseOut, |
|
ef.ElasticEaseInOut, |
|
ef.QuadEaseIn, |
|
ef.QuadEaseOut, |
|
ef.QuadEaseInOut, |
|
ef.QuarticEaseIn, |
|
ef.QuarticEaseOut, |
|
ef.QuarticEaseInOut, |
|
ef.QuinticEaseIn, |
|
ef.QuinticEaseOut, |
|
ef.QuinticEaseInOut, |
|
ef.SineEaseIn, |
|
ef.SineEaseOut, |
|
ef.SineEaseInOut, |
|
Step, |
|
])() |
|
|
|
class Step: |
|
def __call__(self, value): |
|
return 0 if value < 0.5 else 1 |
|
|
|
|
|
|
|
|
|
|
|
class TrainFrameSampler: |
|
def __init__(self, speed=[0.5, 1, 2, 3, 4, 5]): |
|
self.speed = speed |
|
|
|
def __call__(self, seq_length): |
|
frames = list(range(seq_length)) |
|
|
|
|
|
speed = random.choice(self.speed) |
|
frames = [int(f * speed) for f in frames] |
|
|
|
|
|
shift = random.choice(range(seq_length)) |
|
frames = [f + shift for f in frames] |
|
|
|
|
|
if random.random() < 0.5: |
|
frames = frames[::-1] |
|
|
|
return frames |
|
|
|
class ValidFrameSampler: |
|
def __call__(self, seq_length): |
|
return range(seq_length) |
|
|