awsaf49's picture
Initial Commit
3f50570
import math
from typing import Tuple
import torch
import torch.nn as nn
from torchaudio.transforms import SpecAugment
from torch import Tensor
from torchvision.transforms import functional as F
class AugmentLayer(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
# Initialize MixUp
self.mixup = MixUp(
alpha=cfg.augment.mixup_alpha,
num_classes=cfg.num_classes,
p=cfg.augment.mixup_p,
inplace=True,
)
# Initialize other augmentations
self.time_freq_mask = SpecAugment(
n_time_masks=cfg.augment.n_time_masks,
time_mask_param=cfg.augment.time_mask_param,
n_freq_masks=cfg.augment.n_freq_masks,
freq_mask_param=cfg.augment.freq_mask_param,
p=cfg.augment.time_freq_mask_p,
zero_masking=True,
)
def forward(self, spec, y=None):
# Apply MixUp or CutMix with RandomChoice
if y is not None:
# img = spec.unsqueeze(1) # shape: (batch_size, 1, n_mels, n_frames)
spec, y = self.mixup(spec, y)
# spec = img.squeeze(1) # shape: (batch_size, n_mels, n_frames)
# Apply TimeMasking and FrequencyMasking
spec = self.time_freq_mask(spec)
return spec, y
class MixUp(torch.nn.Module):
"""Randomly apply MixUp to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 0.5.
alpha (float): hyperparameter of the Beta distribution used for mixup.
Default value is 1.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""
def __init__(
self,
num_classes: int,
p: float = 0.5,
alpha: float = 1.0,
inplace: bool = False,
) -> None:
super().__init__()
if num_classes < 1:
raise ValueError(
f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
)
if alpha <= 0:
raise ValueError("Alpha param can't be zero.")
self.num_classes = num_classes
self.p = p
self.alpha = alpha
self.inplace = inplace
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )
Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 3 and batch.ndim != 2:
raise ValueError(
f"Batch ndim should be 3 (b, f, t) or 2 (b, n). Got {batch.ndim}"
)
if target.ndim != 1:
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
if not batch.is_floating_point():
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
if target.dtype != torch.int64 and self.num_classes > 1:
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
if not self.inplace:
batch = batch.clone()
target = target.clone()
if target.ndim == 1 and self.num_classes > 1:
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes)
target = target.to(dtype=batch.dtype)
if torch.rand(1).item() >= self.p:
return batch, target
# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1, 0)
# Implemented as on mixup paper, page 3.
lambda_param = float(
torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]
)
batch_rolled.mul_(1.0 - lambda_param)
batch.mul_(lambda_param).add_(batch_rolled)
target_rolled.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_rolled)
return batch, target
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"num_classes={self.num_classes}"
f", p={self.p}"
f", alpha={self.alpha}"
f", inplace={self.inplace}"
f")"
)
return s
# Todo: height of spec should be 1, adjust it for audio input (bs, n_samples)
class CutMix(torch.nn.Module):
"""Randomly apply CutMix to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
<https://arxiv.org/abs/1905.04899>`_.
Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 0.5.
alpha (float): hyperparameter of the Beta distribution used for cutmix.
Default value is 1.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""
def __init__(
self,
num_classes: int,
p: float = 0.5,
alpha: float = 1.0,
inplace: bool = False,
) -> None:
super().__init__()
if num_classes < 1:
raise ValueError(
"Please provide a valid positive value for the num_classes."
)
if alpha <= 0:
raise ValueError("Alpha param can't be zero.")
self.num_classes = num_classes
self.p = p
self.alpha = alpha
self.inplace = inplace
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )
Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
if target.ndim != 1:
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
if not batch.is_floating_point():
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
if target.dtype != torch.int64 and self.num_classes > 1:
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
if not self.inplace:
batch = batch.clone()
target = target.clone()
if target.ndim == 1 and self.num_classes > 1:
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes)
target = target.to(dtype=batch.dtype)
if torch.rand(1).item() >= self.p:
return batch, target
# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1, 0)
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
lambda_param = float(
torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]
)
_, H, W = F.get_dimensions(batch)
r_x = torch.randint(W, (1,))
r_y = torch.randint(H, (1,))
r = 0.5 * math.sqrt(1.0 - lambda_param)
r_w_half = int(r * W)
r_h_half = int(r * H)
x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
target_rolled.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_rolled)
return batch, target
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"num_classes={self.num_classes}"
f", p={self.p}"
f", alpha={self.alpha}"
f", inplace={self.inplace}"
f")"
)
return s