File size: 13,758 Bytes
8c70653 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 |
from dataclasses import dataclass, field
from typing import Dict, List, Tuple
import numpy as np
import torch
from coqpit import Coqpit
from torch import nn
from torch.nn.utils import weight_norm
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.utils.io import load_fsspec
from TTS.vocoder.datasets import WaveGradDataset
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
from TTS.vocoder.models.base_vocoder import BaseVocoder
from TTS.vocoder.utils.generic_utils import plot_results
@dataclass
class WavegradArgs(Coqpit):
in_channels: int = 80
out_channels: int = 1
use_weight_norm: bool = False
y_conv_channels: int = 32
x_conv_channels: int = 768
dblock_out_channels: List[int] = field(default_factory=lambda: [128, 128, 256, 512])
ublock_out_channels: List[int] = field(default_factory=lambda: [512, 512, 256, 128, 128])
upsample_factors: List[int] = field(default_factory=lambda: [4, 4, 4, 2, 2])
upsample_dilations: List[List[int]] = field(
default_factory=lambda: [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]]
)
class Wavegrad(BaseVocoder):
"""πΈ π WaveGrad π model.
Paper - https://arxiv.org/abs/2009.00713
Examples:
Initializing the model.
>>> from TTS.vocoder.configs import WavegradConfig
>>> config = WavegradConfig()
>>> model = Wavegrad(config)
Paper Abstract:
This paper introduces WaveGrad, a conditional model for waveform generation which estimates gradients of the
data density. The model is built on prior work on score matching and diffusion probabilistic models. It starts
from a Gaussian white noise signal and iteratively refines the signal via a gradient-based sampler conditioned
on the mel-spectrogram. WaveGrad offers a natural way to trade inference speed for sample quality by adjusting
the number of refinement steps, and bridges the gap between non-autoregressive and autoregressive models in
terms of audio quality. We find that it can generate high fidelity audio samples using as few as six iterations.
Experiments reveal WaveGrad to generate high fidelity audio, outperforming adversarial non-autoregressive
baselines and matching a strong likelihood-based autoregressive baseline using fewer sequential operations.
Audio samples are available at this https URL.
"""
# pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit):
super().__init__(config)
self.config = config
self.use_weight_norm = config.model_params.use_weight_norm
self.hop_len = np.prod(config.model_params.upsample_factors)
self.noise_level = None
self.num_steps = None
self.beta = None
self.alpha = None
self.alpha_hat = None
self.c1 = None
self.c2 = None
self.sigma = None
# dblocks
self.y_conv = Conv1d(1, config.model_params.y_conv_channels, 5, padding=2)
self.dblocks = nn.ModuleList([])
ic = config.model_params.y_conv_channels
for oc, df in zip(config.model_params.dblock_out_channels, reversed(config.model_params.upsample_factors)):
self.dblocks.append(DBlock(ic, oc, df))
ic = oc
# film
self.film = nn.ModuleList([])
ic = config.model_params.y_conv_channels
for oc in reversed(config.model_params.ublock_out_channels):
self.film.append(FiLM(ic, oc))
ic = oc
# ublocksn
self.ublocks = nn.ModuleList([])
ic = config.model_params.x_conv_channels
for oc, uf, ud in zip(
config.model_params.ublock_out_channels,
config.model_params.upsample_factors,
config.model_params.upsample_dilations,
):
self.ublocks.append(UBlock(ic, oc, uf, ud))
ic = oc
self.x_conv = Conv1d(config.model_params.in_channels, config.model_params.x_conv_channels, 3, padding=1)
self.out_conv = Conv1d(oc, config.model_params.out_channels, 3, padding=1)
if config.model_params.use_weight_norm:
self.apply_weight_norm()
def forward(self, x, spectrogram, noise_scale):
shift_and_scale = []
x = self.y_conv(x)
shift_and_scale.append(self.film[0](x, noise_scale))
for film, layer in zip(self.film[1:], self.dblocks):
x = layer(x)
shift_and_scale.append(film(x, noise_scale))
x = self.x_conv(spectrogram)
for layer, (film_shift, film_scale) in zip(self.ublocks, reversed(shift_and_scale)):
x = layer(x, film_shift, film_scale)
x = self.out_conv(x)
return x
def load_noise_schedule(self, path):
beta = np.load(path, allow_pickle=True).item()["beta"] # pylint: disable=unexpected-keyword-arg
self.compute_noise_level(beta)
@torch.no_grad()
def inference(self, x, y_n=None):
"""
Shapes:
x: :math:`[B, C , T]`
y_n: :math:`[B, 1, T]`
"""
if y_n is None:
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1])
else:
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0)
y_n = y_n.type_as(x)
sqrt_alpha_hat = self.noise_level.to(x)
for n in range(len(self.alpha) - 1, -1, -1):
y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0])))
if n > 0:
z = torch.randn_like(y_n)
y_n += self.sigma[n - 1] * z
y_n.clamp_(-1.0, 1.0)
return y_n
def compute_y_n(self, y_0):
"""Compute noisy audio based on noise schedule"""
self.noise_level = self.noise_level.to(y_0)
if len(y_0.shape) == 3:
y_0 = y_0.squeeze(1)
s = torch.randint(0, self.num_steps - 1, [y_0.shape[0]])
l_a, l_b = self.noise_level[s], self.noise_level[s + 1]
noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a)
noise_scale = noise_scale.unsqueeze(1)
noise = torch.randn_like(y_0)
noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2) ** 0.5 * noise
return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0]
def compute_noise_level(self, beta):
"""Compute noise schedule parameters"""
self.num_steps = len(beta)
alpha = 1 - beta
alpha_hat = np.cumprod(alpha)
noise_level = np.concatenate([[1.0], alpha_hat**0.5], axis=0)
noise_level = alpha_hat**0.5
# pylint: disable=not-callable
self.beta = torch.tensor(beta.astype(np.float32))
self.alpha = torch.tensor(alpha.astype(np.float32))
self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32))
self.noise_level = torch.tensor(noise_level.astype(np.float32))
self.c1 = 1 / self.alpha**0.5
self.c2 = (1 - self.alpha) / (1 - self.alpha_hat) ** 0.5
self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:]) ** 0.5
def remove_weight_norm(self):
for _, layer in enumerate(self.dblocks):
if len(layer.state_dict()) != 0:
try:
nn.utils.remove_weight_norm(layer)
except ValueError:
layer.remove_weight_norm()
for _, layer in enumerate(self.film):
if len(layer.state_dict()) != 0:
try:
nn.utils.remove_weight_norm(layer)
except ValueError:
layer.remove_weight_norm()
for _, layer in enumerate(self.ublocks):
if len(layer.state_dict()) != 0:
try:
nn.utils.remove_weight_norm(layer)
except ValueError:
layer.remove_weight_norm()
nn.utils.remove_weight_norm(self.x_conv)
nn.utils.remove_weight_norm(self.out_conv)
nn.utils.remove_weight_norm(self.y_conv)
def apply_weight_norm(self):
for _, layer in enumerate(self.dblocks):
if len(layer.state_dict()) != 0:
layer.apply_weight_norm()
for _, layer in enumerate(self.film):
if len(layer.state_dict()) != 0:
layer.apply_weight_norm()
for _, layer in enumerate(self.ublocks):
if len(layer.state_dict()) != 0:
layer.apply_weight_norm()
self.x_conv = weight_norm(self.x_conv)
self.out_conv = weight_norm(self.out_conv)
self.y_conv = weight_norm(self.y_conv)
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training
if self.config.model_params.use_weight_norm:
self.remove_weight_norm()
betas = np.linspace(
config["test_noise_schedule"]["min_val"],
config["test_noise_schedule"]["max_val"],
config["test_noise_schedule"]["num_steps"],
)
self.compute_noise_level(betas)
else:
betas = np.linspace(
config["train_noise_schedule"]["min_val"],
config["train_noise_schedule"]["max_val"],
config["train_noise_schedule"]["num_steps"],
)
self.compute_noise_level(betas)
def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
# format data
x = batch["input"]
y = batch["waveform"]
# set noise scale
noise, x_noisy, noise_scale = self.compute_y_n(y)
# forward pass
noise_hat = self.forward(x_noisy, x, noise_scale)
# compute losses
loss = criterion(noise, noise_hat)
return {"model_output": noise_hat}, {"loss": loss}
def train_log( # pylint: disable=no-self-use
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
pass
@torch.no_grad()
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
return self.train_step(batch, criterion)
def eval_log( # pylint: disable=no-self-use
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
) -> None:
pass
def test(self, assets: Dict, test_loader: "DataLoader", outputs=None): # pylint: disable=unused-argument
# setup noise schedule and inference
ap = assets["audio_processor"]
noise_schedule = self.config["test_noise_schedule"]
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
self.compute_noise_level(betas)
samples = test_loader.dataset.load_test_samples(1)
for sample in samples:
x = sample[0]
x = x[None, :, :].to(next(self.parameters()).device)
y = sample[1]
y = y[None, :]
# compute voice
y_pred = self.inference(x)
# compute spectrograms
figures = plot_results(y_pred, y, ap, "test")
# Sample audio
sample_voice = y_pred[0].squeeze(0).detach().cpu().numpy()
return figures, {"test/audio": sample_voice}
def get_optimizer(self):
return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self)
def get_scheduler(self, optimizer):
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer)
@staticmethod
def get_criterion():
return torch.nn.L1Loss()
@staticmethod
def format_batch(batch: Dict) -> Dict:
# return a whole audio segment
m, y = batch[0], batch[1]
y = y.unsqueeze(1)
return {"input": m, "waveform": y}
def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int):
ap = assets["audio_processor"]
dataset = WaveGradDataset(
ap=ap,
items=samples,
seq_len=self.config.seq_len,
hop_len=ap.hop_length,
pad_short=self.config.pad_short,
conv_pad=self.config.conv_pad,
is_training=not is_eval,
return_segments=True,
use_noise_augment=False,
use_cache=config.use_cache,
verbose=verbose,
)
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(
dataset,
batch_size=self.config.batch_size,
shuffle=num_gpus <= 1,
drop_last=False,
sampler=sampler,
num_workers=self.config.num_eval_loader_workers if is_eval else self.config.num_loader_workers,
pin_memory=False,
)
return loader
def on_epoch_start(self, trainer): # pylint: disable=unused-argument
noise_schedule = self.config["train_noise_schedule"]
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
self.compute_noise_level(betas)
@staticmethod
def init_from_config(config: "WavegradConfig"):
return Wavegrad(config)
|