tts / vietTTS /nat /gta.py
tobiccino's picture
upload
12da6cc
raw
history blame
No virus
2.39 kB
import pickle
from argparse import ArgumentParser
from pathlib import Path
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from tqdm.auto import tqdm
from vietTTS.nat.config import AcousticInput
from .config import FLAGS, AcousticInput
from .data_loader import load_textgrid_wav
from .dsp import MelFilter
from .model import AcousticModel
@hk.transform_with_state
def net(x):
return AcousticModel(is_training=True)(x)
@hk.transform_with_state
def val_net(x):
return AcousticModel(is_training=False)(x)
def forward_fn_(params, aux, rng, inputs: AcousticInput):
melfilter = MelFilter(
FLAGS.sample_rate, FLAGS.n_fft, FLAGS.mel_dim, FLAGS.fmin, FLAGS.fmax
)
mels = melfilter(inputs.wavs.astype(jnp.float32) / (2**15))
B, L, D = mels.shape
inp_mels = jnp.concatenate(
(jnp.zeros((B, 1, D), dtype=jnp.float32), mels[:, :-1, :]), axis=1
)
n_frames = inputs.durations * FLAGS.sample_rate / (FLAGS.n_fft // 4)
inputs = inputs._replace(mels=inp_mels, durations=n_frames)
(mel1_hat, mel2_hat), new_aux = val_net.apply(params, aux, rng, inputs)
return mel2_hat
forward_fn = jax.jit(forward_fn_)
def generate_gta(out_dir: Path):
out_dir.mkdir(parents=True, exist_ok=True)
data_iter = load_textgrid_wav(
FLAGS.data_dir,
FLAGS.max_phoneme_seq_len,
FLAGS.batch_size,
FLAGS.max_wave_len,
"gta",
)
ckpt_fn = FLAGS.ckpt_dir / "acoustic_latest_ckpt.pickle"
print("Resuming from latest checkpoint at", ckpt_fn)
with open(ckpt_fn, "rb") as f:
dic = pickle.load(f)
_, params, aux, rng, _ = (
dic["step"],
dic["params"],
dic["aux"],
dic["rng"],
dic["optim_state"],
)
tr = tqdm(data_iter)
for names, batch in tr:
lengths = batch.wav_lengths
predicted_mel = forward_fn(params, aux, rng, batch)
mel = jax.device_get(predicted_mel)
for idx, fn in enumerate(names):
file = out_dir / f"{fn}.npy"
tr.write(f"saving to file {file}")
l = lengths[idx] // (FLAGS.n_fft // 4)
np.save(file, mel[idx, :l].T)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("-o", "--output-dir", type=Path, default="gta")
generate_gta(parser.parse_args().output_dir)