File size: 8,221 Bytes
96e64e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
import torch.nn.functional as F
import torch.nn as nn
from librosa.filters import mel as librosa_mel_fn
from scipy import signal
import typing
from typing import Optional, List, Union, Dict, Tuple
from collections import namedtuple
import math
import functools
# Adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py under the MIT license.
# LICENSE is in incl_licenses directory.
class MultiScaleMelSpectrogramLoss(nn.Module):
"""Compute distance between mel spectrograms. Can be used
in a multi-scale way.
Parameters
----------
n_mels : List[int]
Number of mels per STFT, by default [5, 10, 20, 40, 80, 160, 320],
window_lengths : List[int], optional
Length of each window of each STFT, by default [32, 64, 128, 256, 512, 1024, 2048]
loss_fn : typing.Callable, optional
How to compare each loss, by default nn.L1Loss()
clamp_eps : float, optional
Clamp on the log magnitude, below, by default 1e-5
mag_weight : float, optional
Weight of raw magnitude portion of loss, by default 0.0 (no ampliciation on mag part)
log_weight : float, optional
Weight of log magnitude portion of loss, by default 1.0
pow : float, optional
Power to raise magnitude to before taking log, by default 1.0
weight : float, optional
Weight of this loss, by default 1.0
match_stride : bool, optional
Whether to match the stride of convolutional layers, by default False
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
Additional code copied and modified from https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
"""
def __init__(
self,
sampling_rate: int,
n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
loss_fn: typing.Callable = nn.L1Loss(),
clamp_eps: float = 1e-5,
mag_weight: float = 0.0,
log_weight: float = 1.0,
pow: float = 1.0,
weight: float = 1.0,
match_stride: bool = False,
mel_fmin: List[float] = [0, 0, 0, 0, 0, 0, 0],
mel_fmax: List[float] = [None, None, None, None, None, None, None],
window_type: str = "hann",
):
super().__init__()
self.sampling_rate = sampling_rate
STFTParams = namedtuple(
"STFTParams",
["window_length", "hop_length", "window_type", "match_stride"],
)
self.stft_params = [
STFTParams(
window_length=w,
hop_length=w // 4,
match_stride=match_stride,
window_type=window_type,
)
for w in window_lengths
]
self.n_mels = n_mels
self.loss_fn = loss_fn
self.clamp_eps = clamp_eps
self.log_weight = log_weight
self.mag_weight = mag_weight
self.weight = weight
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.pow = pow
@staticmethod
@functools.lru_cache(None)
def get_window(
window_type,
window_length,
):
return signal.get_window(window_type, window_length)
@staticmethod
@functools.lru_cache(None)
def get_mel_filters(sr, n_fft, n_mels, fmin, fmax):
return librosa_mel_fn(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
def mel_spectrogram(
self,
wav,
n_mels,
fmin,
fmax,
window_length,
hop_length,
match_stride,
window_type,
):
"""
Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
"""
B, C, T = wav.shape
if match_stride:
assert (
hop_length == window_length // 4
), "For match_stride, hop must equal n_fft // 4"
right_pad = math.ceil(T / hop_length) * hop_length - T
pad = (window_length - hop_length) // 2
else:
right_pad = 0
pad = 0
wav = torch.nn.functional.pad(wav, (pad, pad + right_pad), mode="reflect")
window = self.get_window(window_type, window_length)
window = torch.from_numpy(window).to(wav.device).float()
stft = torch.stft(
wav.reshape(-1, T),
n_fft=window_length,
hop_length=hop_length,
window=window,
return_complex=True,
center=True,
)
_, nf, nt = stft.shape
stft = stft.reshape(B, C, nf, nt)
if match_stride:
"""
Drop first two and last two frames, which are added, because of padding. Now num_frames * hop_length = num_samples.
"""
stft = stft[..., 2:-2]
magnitude = torch.abs(stft)
nf = magnitude.shape[2]
mel_basis = self.get_mel_filters(
self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax
)
mel_basis = torch.from_numpy(mel_basis).to(wav.device)
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
mel_spectrogram = mel_spectrogram.transpose(-1, 2)
return mel_spectrogram
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Computes mel loss between an estimate and a reference
signal.
Parameters
----------
x : torch.Tensor
Estimate signal
y : torch.Tensor
Reference signal
Returns
-------
torch.Tensor
Mel loss.
"""
loss = 0.0
for n_mels, fmin, fmax, s in zip(
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
):
kwargs = {
"n_mels": n_mels,
"fmin": fmin,
"fmax": fmax,
"window_length": s.window_length,
"hop_length": s.hop_length,
"match_stride": s.match_stride,
"window_type": s.window_type,
}
x_mels = self.mel_spectrogram(x, **kwargs)
y_mels = self.mel_spectrogram(y, **kwargs)
x_logmels = torch.log(
x_mels.clamp(min=self.clamp_eps).pow(self.pow)
) / torch.log(torch.tensor(10.0))
y_logmels = torch.log(
y_mels.clamp(min=self.clamp_eps).pow(self.pow)
) / torch.log(torch.tensor(10.0))
loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
return loss
# Loss functions
def feature_loss(
fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
) -> torch.Tensor:
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
return loss * 2 # This equates to lambda=2.0 for the feature matching loss
def discriminator_loss(
disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg**2)
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_loss(
disc_outputs: List[torch.Tensor],
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses
|