File size: 3,712 Bytes
12da6cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import pickle
from argparse import ArgumentParser
from pathlib import Path
import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from .config import FLAGS, DurationInput
from .data_loader import load_phonemes_set
from .model import AcousticModel, DurationModel
def load_lexicon(fn):
lines = open(fn, "r").readlines()
lines = [l.lower().strip().split("\t") for l in lines]
return dict(lines)
def predict_duration(tokens):
def fwd_(x):
return DurationModel(is_training=False)(x)
forward_fn = jax.jit(hk.transform_with_state(fwd_).apply)
with open(FLAGS.ckpt_dir / "duration_latest_ckpt.pickle", "rb") as f:
dic = pickle.load(f)
x = DurationInput(
np.array(tokens, dtype=np.int32)[None, :],
np.array([len(tokens)], dtype=np.int32),
None,
)
return forward_fn(dic["params"], dic["aux"], dic["rng"], x)[0]
def text2tokens(text, lexicon_fn):
phonemes = load_phonemes_set()
lexicon = load_lexicon(lexicon_fn)
words = text.strip().lower().split()
tokens = [FLAGS.sil_index]
for word in words:
if word in FLAGS.special_phonemes:
tokens.append(phonemes.index(word))
elif word in lexicon:
p = lexicon[word]
p = p.split()
p = [phonemes.index(pp) for pp in p]
tokens.extend(p)
tokens.append(FLAGS.word_end_index)
else:
for p in word:
if p in phonemes:
tokens.append(phonemes.index(p))
tokens.append(FLAGS.word_end_index)
tokens.append(FLAGS.sil_index) # silence
return tokens
def predict_mel(tokens, durations):
ckpt_fn = FLAGS.ckpt_dir / "acoustic_latest_ckpt.pickle"
with open(ckpt_fn, "rb") as f:
dic = pickle.load(f)
last_step, params, aux, rng, optim_state = (
dic["step"],
dic["params"],
dic["aux"],
dic["rng"],
dic["optim_state"],
)
@hk.transform_with_state
def forward(tokens, durations, n_frames):
net = AcousticModel(is_training=False)
return net.inference(tokens, durations, n_frames)
durations = durations * FLAGS.sample_rate / (FLAGS.n_fft // 4)
n_frames = int(jnp.sum(durations).item())
predict_fn = jax.jit(forward.apply, static_argnums=[5])
tokens = np.array(tokens, dtype=np.int32)[None, :]
return predict_fn(params, aux, rng, tokens, durations, n_frames)[0]
def text2mel(
text: str, lexicon_fn=FLAGS.data_dir / "lexicon.txt", silence_duration: float = -1.0
):
tokens = text2tokens(text, lexicon_fn)
durations = predict_duration(tokens)
durations = jnp.where(
np.array(tokens)[None, :] == FLAGS.sil_index,
jnp.clip(durations, a_min=silence_duration, a_max=None),
durations,
)
durations = jnp.where(
np.array(tokens)[None, :] == FLAGS.word_end_index, 0.0, durations
)
mels = predict_mel(tokens, durations)
if tokens[-1] == FLAGS.sil_index:
end_silence = durations[0, -1].item()
silence_frame = int(end_silence * FLAGS.sample_rate / (FLAGS.n_fft // 4))
mels = mels[:, : (mels.shape[1] - silence_frame)]
return mels
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--text", type=str, required=True)
parser.add_argument("--output", type=Path, required=True)
args = parser.parse_args()
mel = text2mel(args.text)
plt.figure(figsize=(10, 5))
plt.imshow(mel[0].T, origin="lower", aspect="auto")
plt.savefig(str(args.output))
plt.close()
mel = jax.device_get(mel)
mel.tofile("clip.mel")
|