tts / vietTTS /nat /text2mel.py
tobiccino's picture
upload
12da6cc
raw
history blame
3.71 kB
import pickle
from argparse import ArgumentParser
from pathlib import Path
import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from .config import FLAGS, DurationInput
from .data_loader import load_phonemes_set
from .model import AcousticModel, DurationModel
def load_lexicon(fn):
lines = open(fn, "r").readlines()
lines = [l.lower().strip().split("\t") for l in lines]
return dict(lines)
def predict_duration(tokens):
def fwd_(x):
return DurationModel(is_training=False)(x)
forward_fn = jax.jit(hk.transform_with_state(fwd_).apply)
with open(FLAGS.ckpt_dir / "duration_latest_ckpt.pickle", "rb") as f:
dic = pickle.load(f)
x = DurationInput(
np.array(tokens, dtype=np.int32)[None, :],
np.array([len(tokens)], dtype=np.int32),
None,
)
return forward_fn(dic["params"], dic["aux"], dic["rng"], x)[0]
def text2tokens(text, lexicon_fn):
phonemes = load_phonemes_set()
lexicon = load_lexicon(lexicon_fn)
words = text.strip().lower().split()
tokens = [FLAGS.sil_index]
for word in words:
if word in FLAGS.special_phonemes:
tokens.append(phonemes.index(word))
elif word in lexicon:
p = lexicon[word]
p = p.split()
p = [phonemes.index(pp) for pp in p]
tokens.extend(p)
tokens.append(FLAGS.word_end_index)
else:
for p in word:
if p in phonemes:
tokens.append(phonemes.index(p))
tokens.append(FLAGS.word_end_index)
tokens.append(FLAGS.sil_index) # silence
return tokens
def predict_mel(tokens, durations):
ckpt_fn = FLAGS.ckpt_dir / "acoustic_latest_ckpt.pickle"
with open(ckpt_fn, "rb") as f:
dic = pickle.load(f)
last_step, params, aux, rng, optim_state = (
dic["step"],
dic["params"],
dic["aux"],
dic["rng"],
dic["optim_state"],
)
@hk.transform_with_state
def forward(tokens, durations, n_frames):
net = AcousticModel(is_training=False)
return net.inference(tokens, durations, n_frames)
durations = durations * FLAGS.sample_rate / (FLAGS.n_fft // 4)
n_frames = int(jnp.sum(durations).item())
predict_fn = jax.jit(forward.apply, static_argnums=[5])
tokens = np.array(tokens, dtype=np.int32)[None, :]
return predict_fn(params, aux, rng, tokens, durations, n_frames)[0]
def text2mel(
text: str, lexicon_fn=FLAGS.data_dir / "lexicon.txt", silence_duration: float = -1.0
):
tokens = text2tokens(text, lexicon_fn)
durations = predict_duration(tokens)
durations = jnp.where(
np.array(tokens)[None, :] == FLAGS.sil_index,
jnp.clip(durations, a_min=silence_duration, a_max=None),
durations,
)
durations = jnp.where(
np.array(tokens)[None, :] == FLAGS.word_end_index, 0.0, durations
)
mels = predict_mel(tokens, durations)
if tokens[-1] == FLAGS.sil_index:
end_silence = durations[0, -1].item()
silence_frame = int(end_silence * FLAGS.sample_rate / (FLAGS.n_fft // 4))
mels = mels[:, : (mels.shape[1] - silence_frame)]
return mels
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--text", type=str, required=True)
parser.add_argument("--output", type=Path, required=True)
args = parser.parse_args()
mel = text2mel(args.text)
plt.figure(figsize=(10, 5))
plt.imshow(mel[0].T, origin="lower", aspect="auto")
plt.savefig(str(args.output))
plt.close()
mel = jax.device_get(mel)
mel.tofile("clip.mel")