File size: 6,261 Bytes
64e7f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os

import torch
import utils
from modules.FastDiff.module.FastDiff_model import FastDiff
from tasks.vocoder.vocoder_base import VocoderBaseTask
from utils import audio
from utils.hparams import hparams
from modules.FastDiff.module.util import theta_timestep_loss, compute_hyperparams_given_schedule, sampling_given_noise_schedule


class FastDiffTask(VocoderBaseTask):
    def __init__(self):
        super(FastDiffTask, self).__init__()

    def build_model(self):
        self.model = FastDiff(audio_channels=hparams['audio_channels'],
                 inner_channels=hparams['inner_channels'],
                 cond_channels=hparams['cond_channels'],
                 upsample_ratios=hparams['upsample_ratios'],
                 lvc_layers_each_block=hparams['lvc_layers_each_block'],
                 lvc_kernel_size=hparams['lvc_kernel_size'],
                 kpnet_hidden_channels=hparams['kpnet_hidden_channels'],
                 kpnet_conv_size=hparams['kpnet_conv_size'],
                 dropout=hparams['dropout'],
                 diffusion_step_embed_dim_in=hparams['diffusion_step_embed_dim_in'],
                 diffusion_step_embed_dim_mid=hparams['diffusion_step_embed_dim_mid'],
                 diffusion_step_embed_dim_out=hparams['diffusion_step_embed_dim_out'],
                 use_weight_norm=hparams['use_weight_norm'])
        utils.print_arch(self.model)

        # Init hyperparameters by linear schedule
        noise_schedule = torch.linspace(float(hparams["beta_0"]), float(hparams["beta_T"]), int(hparams["T"])).cuda()
        diffusion_hyperparams = compute_hyperparams_given_schedule(noise_schedule)

        # map diffusion hyperparameters to gpu
        for key in diffusion_hyperparams:
            if key in ["beta", "alpha", "sigma"]:
                diffusion_hyperparams[key] = diffusion_hyperparams[key].cuda()
        self.diffusion_hyperparams = diffusion_hyperparams

        return self.model

    def _training_step(self, sample, batch_idx, optimizer_idx):
        mels = sample['mels']
        y = sample['wavs']
        X = (mels, y)
        loss = theta_timestep_loss(self.model, X, self.diffusion_hyperparams)
        return loss, {'loss': loss}


    def validation_step(self, sample, batch_idx):
        mels = sample['mels']
        y = sample['wavs']
        X = (mels, y)
        loss = theta_timestep_loss(self.model, X, self.diffusion_hyperparams)
        return loss, {'loss': loss}


    def test_step(self, sample, batch_idx):
        mels = sample['mels']
        y = sample['wavs']
        loss_output = {}

        if hparams['noise_schedule'] != '':
            noise_schedule = hparams['noise_schedule']
            if isinstance(noise_schedule, list):
                noise_schedule = torch.FloatTensor(noise_schedule).cuda()
        else:
            # Select Schedule
            try:
                reverse_step = int(hparams.get('N'))
            except:
                print('Please specify $N (the number of revere iterations) in config file. Now denoise with 4 iterations.')
                reverse_step = 4
            if reverse_step == 1000:
                noise_schedule = torch.linspace(0.000001, 0.01, 1000).cuda()
            elif reverse_step == 200:
                noise_schedule = torch.linspace(0.0001, 0.02, 200).cuda()

            # Below are schedules derived by Noise Predictor.
            # We will release codes of noise predictor training process & noise scheduling process soon. Please Stay Tuned!
            elif reverse_step == 8:
                noise_schedule = [6.689325005027058e-07, 1.0033881153503899e-05, 0.00015496854030061513,
                                 0.002387222135439515, 0.035597629845142365, 0.3681158423423767, 0.4735414385795593, 0.5]
            elif reverse_step == 6:
                noise_schedule = [1.7838445955931093e-06, 2.7984189728158526e-05, 0.00043231004383414984,
                                  0.006634317338466644, 0.09357017278671265, 0.6000000238418579]
            elif reverse_step == 4:
                noise_schedule = [3.2176e-04, 2.5743e-03, 2.5376e-02, 7.0414e-01]
            elif reverse_step == 3:
                noise_schedule = [9.0000e-05, 9.0000e-03, 6.0000e-01]
            else:
                raise NotImplementedError

        if isinstance(noise_schedule, list):
            noise_schedule = torch.FloatTensor(noise_schedule).cuda()

        audio_length = mels.shape[-1] * hparams["hop_size"]
        # generate using DDPM reverse process

        y_ = sampling_given_noise_schedule(
            self.model, (1, 1, audio_length), self.diffusion_hyperparams, noise_schedule,
            condition=mels, ddim=False, return_sequence=False)
        gen_dir = os.path.join(hparams['work_dir'], f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}')
        os.makedirs(gen_dir, exist_ok=True)

        if len(y) == 0:
            # Inference from mel
            for idx, (wav_pred, item_name) in enumerate(zip(y_, sample["item_name"])):
                wav_pred = wav_pred / wav_pred.abs().max()
                audio.save_wav(wav_pred.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_pred.wav',
                               hparams['audio_sample_rate'])
        else:
            for idx, (wav_pred, wav_gt, item_name) in enumerate(zip(y_, y, sample["item_name"])):
                wav_gt = wav_gt / wav_gt.abs().max()
                wav_pred = wav_pred / wav_pred.abs().max()
                audio.save_wav(wav_gt.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_gt.wav', hparams['audio_sample_rate'])
                audio.save_wav(wav_pred.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_pred.wav', hparams['audio_sample_rate'])
        return loss_output
        
    def build_optimizer(self, model):
        self.optimizer = optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=float(hparams['lr']), weight_decay=float(hparams['weight_decay']))
        return optimizer

    def compute_rtf(self, sample, generation_time, sample_rate=22050):
        """
        Computes RTF for a given sample.
        """
        total_length = sample.shape[-1]
        return float(generation_time * sample_rate / total_length)