Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import utils | |
from modules.FastDiff.module.FastDiff_model import FastDiff | |
from tasks.vocoder.vocoder_base import VocoderBaseTask | |
from utils import audio | |
from utils.hparams import hparams | |
from modules.FastDiff.module.util import theta_timestep_loss, compute_hyperparams_given_schedule, sampling_given_noise_schedule | |
class FastDiffTask(VocoderBaseTask): | |
def __init__(self): | |
super(FastDiffTask, self).__init__() | |
def build_model(self): | |
self.model = FastDiff(audio_channels=hparams['audio_channels'], | |
inner_channels=hparams['inner_channels'], | |
cond_channels=hparams['cond_channels'], | |
upsample_ratios=hparams['upsample_ratios'], | |
lvc_layers_each_block=hparams['lvc_layers_each_block'], | |
lvc_kernel_size=hparams['lvc_kernel_size'], | |
kpnet_hidden_channels=hparams['kpnet_hidden_channels'], | |
kpnet_conv_size=hparams['kpnet_conv_size'], | |
dropout=hparams['dropout'], | |
diffusion_step_embed_dim_in=hparams['diffusion_step_embed_dim_in'], | |
diffusion_step_embed_dim_mid=hparams['diffusion_step_embed_dim_mid'], | |
diffusion_step_embed_dim_out=hparams['diffusion_step_embed_dim_out'], | |
use_weight_norm=hparams['use_weight_norm']) | |
utils.print_arch(self.model) | |
# Init hyperparameters by linear schedule | |
noise_schedule = torch.linspace(float(hparams["beta_0"]), float(hparams["beta_T"]), int(hparams["T"])).cuda() | |
diffusion_hyperparams = compute_hyperparams_given_schedule(noise_schedule) | |
# map diffusion hyperparameters to gpu | |
for key in diffusion_hyperparams: | |
if key in ["beta", "alpha", "sigma"]: | |
diffusion_hyperparams[key] = diffusion_hyperparams[key].cuda() | |
self.diffusion_hyperparams = diffusion_hyperparams | |
return self.model | |
def _training_step(self, sample, batch_idx, optimizer_idx): | |
mels = sample['mels'] | |
y = sample['wavs'] | |
X = (mels, y) | |
loss = theta_timestep_loss(self.model, X, self.diffusion_hyperparams) | |
return loss, {'loss': loss} | |
def validation_step(self, sample, batch_idx): | |
mels = sample['mels'] | |
y = sample['wavs'] | |
X = (mels, y) | |
loss = theta_timestep_loss(self.model, X, self.diffusion_hyperparams) | |
return loss, {'loss': loss} | |
def test_step(self, sample, batch_idx): | |
mels = sample['mels'] | |
y = sample['wavs'] | |
loss_output = {} | |
if hparams['noise_schedule'] != '': | |
noise_schedule = hparams['noise_schedule'] | |
if isinstance(noise_schedule, list): | |
noise_schedule = torch.FloatTensor(noise_schedule).cuda() | |
else: | |
# Select Schedule | |
try: | |
reverse_step = int(hparams.get('N')) | |
except: | |
print('Please specify $N (the number of revere iterations) in config file. Now denoise with 4 iterations.') | |
reverse_step = 4 | |
if reverse_step == 1000: | |
noise_schedule = torch.linspace(0.000001, 0.01, 1000).cuda() | |
elif reverse_step == 200: | |
noise_schedule = torch.linspace(0.0001, 0.02, 200).cuda() | |
# Below are schedules derived by Noise Predictor. | |
# We will release codes of noise predictor training process & noise scheduling process soon. Please Stay Tuned! | |
elif reverse_step == 8: | |
noise_schedule = [6.689325005027058e-07, 1.0033881153503899e-05, 0.00015496854030061513, | |
0.002387222135439515, 0.035597629845142365, 0.3681158423423767, 0.4735414385795593, 0.5] | |
elif reverse_step == 6: | |
noise_schedule = [1.7838445955931093e-06, 2.7984189728158526e-05, 0.00043231004383414984, | |
0.006634317338466644, 0.09357017278671265, 0.6000000238418579] | |
elif reverse_step == 4: | |
noise_schedule = [3.2176e-04, 2.5743e-03, 2.5376e-02, 7.0414e-01] | |
elif reverse_step == 3: | |
noise_schedule = [9.0000e-05, 9.0000e-03, 6.0000e-01] | |
else: | |
raise NotImplementedError | |
if isinstance(noise_schedule, list): | |
noise_schedule = torch.FloatTensor(noise_schedule).cuda() | |
audio_length = mels.shape[-1] * hparams["hop_size"] | |
# generate using DDPM reverse process | |
y_ = sampling_given_noise_schedule( | |
self.model, (1, 1, audio_length), self.diffusion_hyperparams, noise_schedule, | |
condition=mels, ddim=False, return_sequence=False) | |
gen_dir = os.path.join(hparams['work_dir'], f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') | |
os.makedirs(gen_dir, exist_ok=True) | |
if len(y) == 0: | |
# Inference from mel | |
for idx, (wav_pred, item_name) in enumerate(zip(y_, sample["item_name"])): | |
wav_pred = wav_pred / wav_pred.abs().max() | |
audio.save_wav(wav_pred.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_pred.wav', | |
hparams['audio_sample_rate']) | |
else: | |
for idx, (wav_pred, wav_gt, item_name) in enumerate(zip(y_, y, sample["item_name"])): | |
wav_gt = wav_gt / wav_gt.abs().max() | |
wav_pred = wav_pred / wav_pred.abs().max() | |
audio.save_wav(wav_gt.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_gt.wav', hparams['audio_sample_rate']) | |
audio.save_wav(wav_pred.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_pred.wav', hparams['audio_sample_rate']) | |
return loss_output | |
def build_optimizer(self, model): | |
self.optimizer = optimizer = torch.optim.AdamW( | |
self.model.parameters(), | |
lr=float(hparams['lr']), weight_decay=float(hparams['weight_decay'])) | |
return optimizer | |
def compute_rtf(self, sample, generation_time, sample_rate=22050): | |
""" | |
Computes RTF for a given sample. | |
""" | |
total_length = sample.shape[-1] | |
return float(generation_time * sample_rate / total_length) |