Spaces:
Runtime error
Runtime error
File size: 2,104 Bytes
64e7f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import torch
from inference.base_tts_infer import BaseTTSInfer
from utils.ckpt_utils import load_ckpt, get_last_checkpoint
from utils.hparams import hparams
from modules.ProDiff.model.ProDiff import GaussianDiffusion
from usr.diff.net import DiffNet
import os
import numpy as np
from functools import partial
class ProDiffInfer(BaseTTSInfer):
def build_model(self):
f0_stats_fn = f'{hparams["binary_data_dir"]}/train_f0s_mean_std.npy'
if os.path.exists(f0_stats_fn):
hparams['f0_mean'], hparams['f0_std'] = np.load(f0_stats_fn)
hparams['f0_mean'] = float(hparams['f0_mean'])
hparams['f0_std'] = float(hparams['f0_std'])
model = GaussianDiffusion(
phone_encoder=self.ph_encoder,
out_dims=80, denoise_fn=DiffNet(hparams['audio_num_mel_bins']),
timesteps=hparams['timesteps'],
loss_type=hparams['diff_loss_type'],
spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
)
checkpoint = torch.load(hparams['teacher_ckpt'], map_location='cpu')["state_dict"]['model']
teacher_timesteps = int(checkpoint['timesteps'].item())
teacher_timescales = int(checkpoint['timescale'].item())
student_timesteps = teacher_timesteps // 2
student_timescales = teacher_timescales * 2
to_torch = partial(torch.tensor, dtype=torch.float32)
model.register_buffer('timesteps', to_torch(student_timesteps)) # beta
model.register_buffer('timescale', to_torch(student_timescales)) # beta
model.eval()
load_ckpt(model, hparams['work_dir'], 'model')
return model
def forward_model(self, inp):
sample = self.input_to_batch(inp)
txt_tokens = sample['txt_tokens'] # [B, T_t]
with torch.no_grad():
output = self.model(txt_tokens, infer=True)
mel_out = output['mel_out']
wav_out = self.run_vocoder(mel_out)
wav_out = wav_out.squeeze().cpu().numpy()
return wav_out
if __name__ == '__main__':
ProDiffInfer.example_run()
|