Spaces:
Runtime error
Runtime error
File size: 7,010 Bytes
64e7f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import os
import torch
from tasks.tts.dataset_utils import FastSpeechWordDataset
from tasks.tts.tts_utils import load_data_preprocessor
import numpy as np
from modules.FastDiff.module.util import compute_hyperparams_given_schedule, sampling_given_noise_schedule
import os
import torch
from modules.FastDiff.module.FastDiff_model import FastDiff
from utils.ckpt_utils import load_ckpt
from utils.hparams import set_hparams
class BaseTTSInfer:
def __init__(self, hparams, device=None):
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.hparams = hparams
self.device = device
self.data_dir = hparams['binary_data_dir']
self.preprocessor, self.preprocess_args = load_data_preprocessor()
self.ph_encoder = self.preprocessor.load_dict(self.data_dir)
self.spk_map = self.preprocessor.load_spk_map(self.data_dir)
self.ds_cls = FastSpeechWordDataset
self.model = self.build_model()
self.model.eval()
self.model.to(self.device)
self.vocoder, self.diffusion_hyperparams, self.noise_schedule = self.build_vocoder()
self.vocoder.eval()
self.vocoder.to(self.device)
def build_model(self):
raise NotImplementedError
def forward_model(self, inp):
raise NotImplementedError
def build_vocoder(self):
base_dir = self.hparams['vocoder_ckpt']
config_path = f'{base_dir}/config.yaml'
config = set_hparams(config_path, global_hparams=False)
vocoder = FastDiff(audio_channels=config['audio_channels'],
inner_channels=config['inner_channels'],
cond_channels=config['cond_channels'],
upsample_ratios=config['upsample_ratios'],
lvc_layers_each_block=config['lvc_layers_each_block'],
lvc_kernel_size=config['lvc_kernel_size'],
kpnet_hidden_channels=config['kpnet_hidden_channels'],
kpnet_conv_size=config['kpnet_conv_size'],
dropout=config['dropout'],
diffusion_step_embed_dim_in=config['diffusion_step_embed_dim_in'],
diffusion_step_embed_dim_mid=config['diffusion_step_embed_dim_mid'],
diffusion_step_embed_dim_out=config['diffusion_step_embed_dim_out'],
use_weight_norm=config['use_weight_norm'])
load_ckpt(vocoder, base_dir, 'model')
# Init hyperparameters by linear schedule
noise_schedule = torch.linspace(float(config["beta_0"]), float(config["beta_T"]), int(config["T"]))
diffusion_hyperparams = compute_hyperparams_given_schedule(noise_schedule)
if config['noise_schedule'] != '':
noise_schedule = config['noise_schedule']
if isinstance(noise_schedule, list):
noise_schedule = torch.FloatTensor(noise_schedule)
else:
# Select Schedule
try:
reverse_step = int(self.hparams.get('N'))
except:
print(
'Please specify $N (the number of revere iterations) in config file. Now denoise with 4 iterations.')
reverse_step = 4
if reverse_step == 1000:
noise_schedule = torch.linspace(0.000001, 0.01, 1000)
elif reverse_step == 200:
noise_schedule = torch.linspace(0.0001, 0.02, 200)
# Below are schedules derived by Noise Predictor.
# We will release codes of noise predictor training process & noise scheduling process soon. Please Stay Tuned!
elif reverse_step == 8:
noise_schedule = [6.689325005027058e-07, 1.0033881153503899e-05, 0.00015496854030061513,
0.002387222135439515, 0.035597629845142365, 0.3681158423423767, 0.4735414385795593,
0.5]
elif reverse_step == 6:
noise_schedule = [1.7838445955931093e-06, 2.7984189728158526e-05, 0.00043231004383414984,
0.006634317338466644, 0.09357017278671265, 0.6000000238418579]
elif reverse_step == 4:
noise_schedule = [3.2176e-04, 2.5743e-03, 2.5376e-02, 7.0414e-01]
elif reverse_step == 3:
noise_schedule = [9.0000e-05, 9.0000e-03, 6.0000e-01]
else:
raise NotImplementedError
if isinstance(noise_schedule, list):
noise_schedule = torch.FloatTensor(noise_schedule)
return vocoder, diffusion_hyperparams, noise_schedule
def run_vocoder(self, c):
c = c.transpose(2, 1)
audio_length = c.shape[-1] * self.hparams["hop_size"]
y = sampling_given_noise_schedule(
self.vocoder, (1, 1, audio_length), self.diffusion_hyperparams, self.noise_schedule, condition=c, ddim=False, return_sequence=False)
return y
def preprocess_input(self, inp):
"""
:param inp: {'text': str, 'item_name': (str, optional), 'spk_name': (str, optional)}
:return:
"""
preprocessor, preprocess_args = self.preprocessor, self.preprocess_args
text_raw = inp['text']
item_name = inp.get('item_name', '<ITEM_NAME>')
spk_name = inp.get('spk_name', 'SPK1')
ph, txt = preprocessor.txt_to_ph(
preprocessor.txt_processor, text_raw, preprocess_args)
ph_token = self.ph_encoder.encode(ph)
spk_id = self.spk_map[spk_name]
item = {'item_name': item_name, 'text': txt, 'ph': ph, 'spk_id': spk_id, 'ph_token': ph_token}
item['ph_len'] = len(item['ph_token'])
return item
def input_to_batch(self, item):
item_names = [item['item_name']]
text = [item['text']]
ph = [item['ph']]
txt_tokens = torch.LongTensor(item['ph_token'])[None, :].to(self.device)
txt_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device)
spk_ids = torch.LongTensor(item['spk_id'])[None, :].to(self.device)
batch = {
'item_name': item_names,
'text': text,
'ph': ph,
'txt_tokens': txt_tokens,
'txt_lengths': txt_lengths,
'spk_ids': spk_ids,
}
return batch
def postprocess_output(self, output):
return output
def infer_once(self, inp):
inp = self.preprocess_input(inp)
output = self.forward_model(inp)
output = self.postprocess_output(output)
return output
@classmethod
def example_run(cls):
from utils.hparams import set_hparams
from utils.hparams import hparams as hp
from utils.audio import save_wav
set_hparams()
inp = {
'text': hp['text']
}
infer_ins = cls(hp)
out = infer_ins.infer_once(inp)
os.makedirs('infer_out', exist_ok=True)
save_wav(out, f'infer_out/{hp["text"]}.wav', hp['audio_sample_rate'])
|