ProDiff / inference /ProDiff_Teacher.py
Rongjiehuang's picture
init
64e7f2f
import torch
from inference.base_tts_infer import BaseTTSInfer
from utils.ckpt_utils import load_ckpt, get_last_checkpoint
from utils.hparams import hparams
from modules.ProDiff.model.ProDiff_teacher import GaussianDiffusion
from usr.diff.net import DiffNet
import os
import numpy as np
class ProDiffTeacherInfer(BaseTTSInfer):
def build_model(self):
f0_stats_fn = f'{hparams["binary_data_dir"]}/train_f0s_mean_std.npy'
if os.path.exists(f0_stats_fn):
hparams['f0_mean'], hparams['f0_std'] = np.load(f0_stats_fn)
hparams['f0_mean'] = float(hparams['f0_mean'])
hparams['f0_std'] = float(hparams['f0_std'])
model = GaussianDiffusion(
phone_encoder=self.ph_encoder,
out_dims=80, denoise_fn=DiffNet(hparams['audio_num_mel_bins']),
timesteps=hparams['timesteps'],
loss_type=hparams['diff_loss_type'],
spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
)
model.eval()
load_ckpt(model, hparams['work_dir'], 'model')
return model
def forward_model(self, inp):
sample = self.input_to_batch(inp)
txt_tokens = sample['txt_tokens'] # [B, T_t]
with torch.no_grad():
output = self.model(txt_tokens, infer=True)
mel_out = output['mel_out']
wav_out = self.run_vocoder(mel_out)
wav_out = wav_out.squeeze().cpu().numpy()
return wav_out
if __name__ == '__main__':
ProDiffTeacherInfer.example_run()