Higobeatz's picture
Initial commit
0a97d6c
import torch
import torch.nn.functional as F
import torchaudio
import torchaudio.transforms as transforms
class LogMelSpectrogram(torch.nn.Module):
def __init__(self, sr=24000, frame_length=1920, hop_length=480, n_mel=128, f_min=0, f_max=12000,):
super().__init__()
self.frame_length = frame_length
self.hop_length = hop_length
self.mel = transforms.MelSpectrogram(
sample_rate=sr,
n_fft=frame_length,
win_length=frame_length,
hop_length=hop_length,
center=False,
power=1.0,
norm="slaney",
n_mels=n_mel,
mel_scale="slaney",
f_min=f_min,
f_max=f_max
)
@torch.no_grad()
def forward(self, x, target_length=None):
x = F.pad(x, ((self.frame_length - self.hop_length) // 2,
(self.frame_length - self.hop_length) // 2), "reflect")
mel = self.mel(x)
target_length = mel.shape[-1] if target_length is None else target_length
logmel = torch.zeros(mel.shape[0], mel.shape[1], target_length).to(mel.device)
logmel[:, :, :mel.shape[2]] = mel
logmel = torch.log(torch.clamp(logmel, min=1e-5))
return logmel