awsaf49's picture
Initial Commit
3f50570
raw
history blame
8.44 kB
import math
from typing import Tuple
import torch
import torch.nn as nn
from torchaudio.transforms import SpecAugment
from torch import Tensor
from torchvision.transforms import functional as F
class AugmentLayer(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
# Initialize MixUp
self.mixup = MixUp(
alpha=cfg.augment.mixup_alpha,
num_classes=cfg.num_classes,
p=cfg.augment.mixup_p,
inplace=True,
)
# Initialize other augmentations
self.time_freq_mask = SpecAugment(
n_time_masks=cfg.augment.n_time_masks,
time_mask_param=cfg.augment.time_mask_param,
n_freq_masks=cfg.augment.n_freq_masks,
freq_mask_param=cfg.augment.freq_mask_param,
p=cfg.augment.time_freq_mask_p,
zero_masking=True,
)
def forward(self, spec, y=None):
# Apply MixUp or CutMix with RandomChoice
if y is not None:
# img = spec.unsqueeze(1) # shape: (batch_size, 1, n_mels, n_frames)
spec, y = self.mixup(spec, y)
# spec = img.squeeze(1) # shape: (batch_size, n_mels, n_frames)
# Apply TimeMasking and FrequencyMasking
spec = self.time_freq_mask(spec)
return spec, y
class MixUp(torch.nn.Module):
"""Randomly apply MixUp to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 0.5.
alpha (float): hyperparameter of the Beta distribution used for mixup.
Default value is 1.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""
def __init__(
self,
num_classes: int,
p: float = 0.5,
alpha: float = 1.0,
inplace: bool = False,
) -> None:
super().__init__()
if num_classes < 1:
raise ValueError(
f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
)
if alpha <= 0:
raise ValueError("Alpha param can't be zero.")
self.num_classes = num_classes
self.p = p
self.alpha = alpha
self.inplace = inplace
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )
Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 3 and batch.ndim != 2:
raise ValueError(
f"Batch ndim should be 3 (b, f, t) or 2 (b, n). Got {batch.ndim}"
)
if target.ndim != 1:
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
if not batch.is_floating_point():
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
if target.dtype != torch.int64 and self.num_classes > 1:
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
if not self.inplace:
batch = batch.clone()
target = target.clone()
if target.ndim == 1 and self.num_classes > 1:
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes)
target = target.to(dtype=batch.dtype)
if torch.rand(1).item() >= self.p:
return batch, target
# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1, 0)
# Implemented as on mixup paper, page 3.
lambda_param = float(
torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]
)
batch_rolled.mul_(1.0 - lambda_param)
batch.mul_(lambda_param).add_(batch_rolled)
target_rolled.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_rolled)
return batch, target
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"num_classes={self.num_classes}"
f", p={self.p}"
f", alpha={self.alpha}"
f", inplace={self.inplace}"
f")"
)
return s
# Todo: height of spec should be 1, adjust it for audio input (bs, n_samples)
class CutMix(torch.nn.Module):
"""Randomly apply CutMix to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
<https://arxiv.org/abs/1905.04899>`_.
Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 0.5.
alpha (float): hyperparameter of the Beta distribution used for cutmix.
Default value is 1.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""
def __init__(
self,
num_classes: int,
p: float = 0.5,
alpha: float = 1.0,
inplace: bool = False,
) -> None:
super().__init__()
if num_classes < 1:
raise ValueError(
"Please provide a valid positive value for the num_classes."
)
if alpha <= 0:
raise ValueError("Alpha param can't be zero.")
self.num_classes = num_classes
self.p = p
self.alpha = alpha
self.inplace = inplace
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )
Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
if target.ndim != 1:
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
if not batch.is_floating_point():
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
if target.dtype != torch.int64 and self.num_classes > 1:
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
if not self.inplace:
batch = batch.clone()
target = target.clone()
if target.ndim == 1 and self.num_classes > 1:
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes)
target = target.to(dtype=batch.dtype)
if torch.rand(1).item() >= self.p:
return batch, target
# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1, 0)
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
lambda_param = float(
torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]
)
_, H, W = F.get_dimensions(batch)
r_x = torch.randint(W, (1,))
r_y = torch.randint(H, (1,))
r = 0.5 * math.sqrt(1.0 - lambda_param)
r_w_half = int(r * W)
r_h_half = int(r * H)
x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
target_rolled.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_rolled)
return batch, target
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"num_classes={self.num_classes}"
f", p={self.p}"
f", alpha={self.alpha}"
f", inplace={self.inplace}"
f")"
)
return s