import librosa import numpy as np import torch import torch.nn as nn import logging import math import random CONSTANT = 1e-5 def normalize_batch(x, seq_len, normalize_type): x_mean = None x_std = None if normalize_type == "per_feature": x_mean = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device) x_std = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device) for i in range(x.shape[0]): if x[i, :, : seq_len[i]].shape[1] == 1: raise ValueError( "normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result " "in torch.std() returning nan. Make sure your audio length has enough samples for a single " "feature (ex. at least `hop_length` for Mel Spectrograms)." ) x_mean[i, :] = x[i, :, : seq_len[i]].mean(dim=1) x_std[i, :] = x[i, :, : seq_len[i]].std(dim=1) # make sure x_std is not zero x_std += CONSTANT return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2), x_mean, x_std elif normalize_type == "all_features": x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device) x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device) for i in range(x.shape[0]): x_mean[i] = x[i, :, : seq_len[i].item()].mean() x_std[i] = x[i, :, : seq_len[i].item()].std() # make sure x_std is not zero x_std += CONSTANT return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1), x_mean, x_std elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type: x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device) x_std = torch.tensor(normalize_type["fixed_std"], device=x.device) return ( (x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2)) / x_std.view(x.shape[0], x.shape[1]).unsqueeze(2), x_mean, x_std, ) else: return x, x_mean, x_std def splice_frames(x, frame_splicing): """ Stacks frames together across feature dim input is batch_size, feature_dim, num_frames output is batch_size, feature_dim*frame_splicing, num_frames """ seq = [x] for n in range(1, frame_splicing): seq.append(torch.cat([x[:, :, :n], x[:, :, n:]], dim=2)) return torch.cat(seq, dim=1) class FilterbankFeatures(nn.Module): """Featurizer that converts wavs to Mel Spectrograms. See AudioToMelSpectrogramPreprocessor for args. "normalize": "per_feature", "window_size": 0.025, "sample_rate": 16000, "window_stride": 0.01, "window": "hann", "features": 80, "n_fft": 512, "frame_splicing": 1, "dither": 1e-05 n_window_size=window_size * sample_rate, n_window_stride = window_stride * sample_rate, """ def __init__( self, sample_rate=16000, n_window_size=400, n_window_stride=160, window="hann", normalize="per_feature", n_fft=512, preemph=0.97, nfilt=80, lowfreq=0, highfreq=None, log=True, log_zero_guard_type="add", log_zero_guard_value=2 ** -24, dither=CONSTANT, pad_to=16, max_duration=16.7, frame_splicing=1, exact_pad=False, pad_value=0, mag_power=2.0, use_grads=False, rng=None, nb_augmentation_prob=0.0, nb_max_freq=4000, stft_exact_pad=False, # Deprecated arguments; kept for config compatibility stft_conv=False, # Deprecated arguments; kept for config compatibility ): super().__init__() if stft_conv or stft_exact_pad: logging.warning( "Using torch_stft is deprecated and has been removed. The values have been forcibly set to False " "for FilterbankFeatures and AudioToMelSpectrogramPreprocessor. Please set exact_pad to True " "as needed." ) if exact_pad and n_window_stride % 2 == 1: raise NotImplementedError( f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the " "returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size." ) self.log_zero_guard_value = log_zero_guard_value if ( n_window_size is None or n_window_stride is None or not isinstance(n_window_size, int) or not isinstance(n_window_stride, int) or n_window_size <= 0 or n_window_stride <= 0 ): raise ValueError( f"{self} got an invalid value for either n_window_size or " f"n_window_stride. Both must be positive ints." ) logging.info(f"PADDING: {pad_to}") self.win_length = n_window_size self.hop_length = n_window_stride self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length)) self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None if exact_pad: logging.info("STFT using exact pad") torch_windows = { 'hann': torch.hann_window, 'hamming': torch.hamming_window, 'blackman': torch.blackman_window, 'bartlett': torch.bartlett_window, 'none': None, } window_fn = torch_windows.get(window, None) window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None self.register_buffer("window", window_tensor) self.stft = lambda x: torch.stft( x, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, center=False if exact_pad else True, window=self.window.to(dtype=torch.float), return_complex=True, ) self.normalize = normalize self.log = log self.dither = dither self.frame_splicing = frame_splicing self.nfilt = nfilt self.preemph = preemph self.pad_to = pad_to highfreq = highfreq or sample_rate / 2 filterbanks = torch.tensor( librosa.filters.mel(sr=sample_rate, n_fft=self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq), dtype=torch.float, ).unsqueeze(0) self.register_buffer("fb", filterbanks) # Calculate maximum sequence length max_length = self.get_seq_len(torch.tensor(max_duration * sample_rate, dtype=torch.float)) max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0 self.max_length = max_length + max_pad self.pad_value = pad_value self.mag_power = mag_power # We want to avoid taking the log of zero # There are two options: either adding or clamping to a small value if log_zero_guard_type not in ["add", "clamp"]: raise ValueError( f"{self} received {log_zero_guard_type} for the " f"log_zero_guard_type parameter. It must be either 'add' or " f"'clamp'." ) self.use_grads = use_grads if not use_grads: self.forward = torch.no_grad()(self.forward) self._rng = random.Random() if rng is None else rng self.nb_augmentation_prob = nb_augmentation_prob if self.nb_augmentation_prob > 0.0: if nb_max_freq >= sample_rate / 2: self.nb_augmentation_prob = 0.0 else: self._nb_max_fft_bin = int((nb_max_freq / sample_rate) * n_fft) # log_zero_guard_value is the the small we want to use, we support # an actual number, or "tiny", or "eps" self.log_zero_guard_type = log_zero_guard_type logging.debug(f"sr: {sample_rate}") logging.debug(f"n_fft: {self.n_fft}") logging.debug(f"win_length: {self.win_length}") logging.debug(f"hop_length: {self.hop_length}") logging.debug(f"n_mels: {nfilt}") logging.debug(f"fmin: {lowfreq}") logging.debug(f"fmax: {highfreq}") logging.debug(f"using grads: {use_grads}") logging.debug(f"nb_augmentation_prob: {nb_augmentation_prob}") def log_zero_guard_value_fn(self, x): if isinstance(self.log_zero_guard_value, str): if self.log_zero_guard_value == "tiny": return torch.finfo(x.dtype).tiny elif self.log_zero_guard_value == "eps": return torch.finfo(x.dtype).eps else: raise ValueError( f"{self} received {self.log_zero_guard_value} for the " f"log_zero_guard_type parameter. It must be either a " f"number, 'tiny', or 'eps'" ) else: return self.log_zero_guard_value def get_seq_len(self, seq_len): # Assuming that center is True is stft_pad_amount = 0 pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2 seq_len = torch.floor((seq_len + pad_amount - self.n_fft) / self.hop_length) + 1 return seq_len.to(dtype=torch.long) @property def filter_banks(self): return self.fb def forward(self, x, seq_len, linear_spec=False): seq_len = self.get_seq_len(seq_len.float()) if self.stft_pad_amount is not None: x = torch.nn.functional.pad( x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "reflect" ).squeeze(1) # dither (only in training mode for eval determinism) if self.training and self.dither > 0: x += self.dither * torch.randn_like(x) # do preemphasis if self.preemph is not None: x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1) # disable autocast to get full range of stft values with torch.cuda.amp.autocast(enabled=False): x = self.stft(x) # torch stft returns complex tensor (of shape [B,N,T]); so convert to magnitude # guard is needed for sqrt if grads are passed through guard = 0 if not self.use_grads else CONSTANT x = torch.view_as_real(x) x = torch.sqrt(x.pow(2).sum(-1) + guard) if self.training and self.nb_augmentation_prob > 0.0: for idx in range(x.shape[0]): if self._rng.random() < self.nb_augmentation_prob: x[idx, self._nb_max_fft_bin :, :] = 0.0 # get power spectrum if self.mag_power != 1.0: x = x.pow(self.mag_power) # return plain spectrogram if required if linear_spec: return x, seq_len # dot with filterbank energies x = torch.matmul(self.fb.to(x.dtype), x) # log features if required if self.log: if self.log_zero_guard_type == "add": x = torch.log(x + self.log_zero_guard_value_fn(x)) elif self.log_zero_guard_type == "clamp": x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x))) else: raise ValueError("log_zero_guard_type was not understood") # frame splicing if required if self.frame_splicing > 1: x = splice_frames(x, self.frame_splicing) # normalize if required if self.normalize: x, _, _ = normalize_batch(x, seq_len, normalize_type=self.normalize) # mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency) max_len = x.size(-1) mask = torch.arange(max_len).to(x.device) mask = mask.repeat(x.size(0), 1) >= seq_len.unsqueeze(1) x = x.masked_fill(mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value) del mask pad_to = self.pad_to if pad_to == "max": x = nn.functional.pad(x, (0, self.max_length - x.size(-1)), value=self.pad_value) elif pad_to > 0: pad_amt = x.size(-1) % pad_to if pad_amt != 0: x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value) return x, seq_len