|
import math |
|
from typing import Tuple |
|
import torch |
|
import torch.nn as nn |
|
from torchaudio.transforms import SpecAugment |
|
from torch import Tensor |
|
from torchvision.transforms import functional as F |
|
|
|
|
|
class AugmentLayer(nn.Module): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
self.cfg = cfg |
|
|
|
|
|
self.mixup = MixUp( |
|
alpha=cfg.augment.mixup_alpha, |
|
num_classes=cfg.num_classes, |
|
p=cfg.augment.mixup_p, |
|
inplace=True, |
|
) |
|
|
|
|
|
self.time_freq_mask = SpecAugment( |
|
n_time_masks=cfg.augment.n_time_masks, |
|
time_mask_param=cfg.augment.time_mask_param, |
|
n_freq_masks=cfg.augment.n_freq_masks, |
|
freq_mask_param=cfg.augment.freq_mask_param, |
|
p=cfg.augment.time_freq_mask_p, |
|
zero_masking=True, |
|
) |
|
|
|
def forward(self, spec, y=None): |
|
|
|
if y is not None: |
|
|
|
spec, y = self.mixup(spec, y) |
|
|
|
|
|
|
|
spec = self.time_freq_mask(spec) |
|
return spec, y |
|
|
|
|
|
class MixUp(torch.nn.Module): |
|
"""Randomly apply MixUp to the provided batch and targets. |
|
The class implements the data augmentations as described in the paper |
|
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_. |
|
|
|
Args: |
|
num_classes (int): number of classes used for one-hot encoding. |
|
p (float): probability of the batch being transformed. Default value is 0.5. |
|
alpha (float): hyperparameter of the Beta distribution used for mixup. |
|
Default value is 1.0. |
|
inplace (bool): boolean to make this transform inplace. Default set to False. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_classes: int, |
|
p: float = 0.5, |
|
alpha: float = 1.0, |
|
inplace: bool = False, |
|
) -> None: |
|
super().__init__() |
|
|
|
if num_classes < 1: |
|
raise ValueError( |
|
f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}" |
|
) |
|
|
|
if alpha <= 0: |
|
raise ValueError("Alpha param can't be zero.") |
|
|
|
self.num_classes = num_classes |
|
self.p = p |
|
self.alpha = alpha |
|
self.inplace = inplace |
|
|
|
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: |
|
""" |
|
Args: |
|
batch (Tensor): Float tensor of size (B, C, H, W) |
|
target (Tensor): Integer tensor of size (B, ) |
|
|
|
Returns: |
|
Tensor: Randomly transformed batch. |
|
""" |
|
if batch.ndim != 3 and batch.ndim != 2: |
|
raise ValueError( |
|
f"Batch ndim should be 3 (b, f, t) or 2 (b, n). Got {batch.ndim}" |
|
) |
|
if target.ndim != 1: |
|
raise ValueError(f"Target ndim should be 1. Got {target.ndim}") |
|
if not batch.is_floating_point(): |
|
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") |
|
if target.dtype != torch.int64 and self.num_classes > 1: |
|
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") |
|
|
|
if not self.inplace: |
|
batch = batch.clone() |
|
target = target.clone() |
|
|
|
if target.ndim == 1 and self.num_classes > 1: |
|
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes) |
|
|
|
target = target.to(dtype=batch.dtype) |
|
|
|
if torch.rand(1).item() >= self.p: |
|
return batch, target |
|
|
|
|
|
batch_rolled = batch.roll(1, 0) |
|
target_rolled = target.roll(1, 0) |
|
|
|
|
|
lambda_param = float( |
|
torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] |
|
) |
|
batch_rolled.mul_(1.0 - lambda_param) |
|
batch.mul_(lambda_param).add_(batch_rolled) |
|
|
|
target_rolled.mul_(1.0 - lambda_param) |
|
target.mul_(lambda_param).add_(target_rolled) |
|
|
|
return batch, target |
|
|
|
def __repr__(self) -> str: |
|
s = ( |
|
f"{self.__class__.__name__}(" |
|
f"num_classes={self.num_classes}" |
|
f", p={self.p}" |
|
f", alpha={self.alpha}" |
|
f", inplace={self.inplace}" |
|
f")" |
|
) |
|
return s |
|
|
|
|
|
|
|
class CutMix(torch.nn.Module): |
|
"""Randomly apply CutMix to the provided batch and targets. |
|
The class implements the data augmentations as described in the paper |
|
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" |
|
<https://arxiv.org/abs/1905.04899>`_. |
|
|
|
Args: |
|
num_classes (int): number of classes used for one-hot encoding. |
|
p (float): probability of the batch being transformed. Default value is 0.5. |
|
alpha (float): hyperparameter of the Beta distribution used for cutmix. |
|
Default value is 1.0. |
|
inplace (bool): boolean to make this transform inplace. Default set to False. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_classes: int, |
|
p: float = 0.5, |
|
alpha: float = 1.0, |
|
inplace: bool = False, |
|
) -> None: |
|
super().__init__() |
|
if num_classes < 1: |
|
raise ValueError( |
|
"Please provide a valid positive value for the num_classes." |
|
) |
|
if alpha <= 0: |
|
raise ValueError("Alpha param can't be zero.") |
|
|
|
self.num_classes = num_classes |
|
self.p = p |
|
self.alpha = alpha |
|
self.inplace = inplace |
|
|
|
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: |
|
""" |
|
Args: |
|
batch (Tensor): Float tensor of size (B, C, H, W) |
|
target (Tensor): Integer tensor of size (B, ) |
|
|
|
Returns: |
|
Tensor: Randomly transformed batch. |
|
""" |
|
if batch.ndim != 4: |
|
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") |
|
if target.ndim != 1: |
|
raise ValueError(f"Target ndim should be 1. Got {target.ndim}") |
|
if not batch.is_floating_point(): |
|
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") |
|
if target.dtype != torch.int64 and self.num_classes > 1: |
|
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") |
|
|
|
if not self.inplace: |
|
batch = batch.clone() |
|
target = target.clone() |
|
|
|
if target.ndim == 1 and self.num_classes > 1: |
|
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes) |
|
|
|
target = target.to(dtype=batch.dtype) |
|
|
|
if torch.rand(1).item() >= self.p: |
|
return batch, target |
|
|
|
|
|
batch_rolled = batch.roll(1, 0) |
|
target_rolled = target.roll(1, 0) |
|
|
|
|
|
lambda_param = float( |
|
torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] |
|
) |
|
_, H, W = F.get_dimensions(batch) |
|
|
|
r_x = torch.randint(W, (1,)) |
|
r_y = torch.randint(H, (1,)) |
|
|
|
r = 0.5 * math.sqrt(1.0 - lambda_param) |
|
r_w_half = int(r * W) |
|
r_h_half = int(r * H) |
|
|
|
x1 = int(torch.clamp(r_x - r_w_half, min=0)) |
|
y1 = int(torch.clamp(r_y - r_h_half, min=0)) |
|
x2 = int(torch.clamp(r_x + r_w_half, max=W)) |
|
y2 = int(torch.clamp(r_y + r_h_half, max=H)) |
|
|
|
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] |
|
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) |
|
|
|
target_rolled.mul_(1.0 - lambda_param) |
|
target.mul_(lambda_param).add_(target_rolled) |
|
|
|
return batch, target |
|
|
|
def __repr__(self) -> str: |
|
s = ( |
|
f"{self.__class__.__name__}(" |
|
f"num_classes={self.num_classes}" |
|
f", p={self.p}" |
|
f", alpha={self.alpha}" |
|
f", inplace={self.inplace}" |
|
f")" |
|
) |
|
return s |
|
|