|
import pickle |
|
from argparse import ArgumentParser |
|
from pathlib import Path |
|
|
|
import haiku as hk |
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
from tqdm.auto import tqdm |
|
from vietTTS.nat.config import AcousticInput |
|
|
|
from .config import FLAGS, AcousticInput |
|
from .data_loader import load_textgrid_wav |
|
from .dsp import MelFilter |
|
from .model import AcousticModel |
|
|
|
|
|
@hk.transform_with_state |
|
def net(x): |
|
return AcousticModel(is_training=True)(x) |
|
|
|
|
|
@hk.transform_with_state |
|
def val_net(x): |
|
return AcousticModel(is_training=False)(x) |
|
|
|
|
|
def forward_fn_(params, aux, rng, inputs: AcousticInput): |
|
melfilter = MelFilter( |
|
FLAGS.sample_rate, FLAGS.n_fft, FLAGS.mel_dim, FLAGS.fmin, FLAGS.fmax |
|
) |
|
mels = melfilter(inputs.wavs.astype(jnp.float32) / (2**15)) |
|
B, L, D = mels.shape |
|
inp_mels = jnp.concatenate( |
|
(jnp.zeros((B, 1, D), dtype=jnp.float32), mels[:, :-1, :]), axis=1 |
|
) |
|
n_frames = inputs.durations * FLAGS.sample_rate / (FLAGS.n_fft // 4) |
|
inputs = inputs._replace(mels=inp_mels, durations=n_frames) |
|
(mel1_hat, mel2_hat), new_aux = val_net.apply(params, aux, rng, inputs) |
|
return mel2_hat |
|
|
|
|
|
forward_fn = jax.jit(forward_fn_) |
|
|
|
|
|
def generate_gta(out_dir: Path): |
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
data_iter = load_textgrid_wav( |
|
FLAGS.data_dir, |
|
FLAGS.max_phoneme_seq_len, |
|
FLAGS.batch_size, |
|
FLAGS.max_wave_len, |
|
"gta", |
|
) |
|
ckpt_fn = FLAGS.ckpt_dir / "acoustic_latest_ckpt.pickle" |
|
print("Resuming from latest checkpoint at", ckpt_fn) |
|
with open(ckpt_fn, "rb") as f: |
|
dic = pickle.load(f) |
|
_, params, aux, rng, _ = ( |
|
dic["step"], |
|
dic["params"], |
|
dic["aux"], |
|
dic["rng"], |
|
dic["optim_state"], |
|
) |
|
|
|
tr = tqdm(data_iter) |
|
for names, batch in tr: |
|
lengths = batch.wav_lengths |
|
predicted_mel = forward_fn(params, aux, rng, batch) |
|
mel = jax.device_get(predicted_mel) |
|
for idx, fn in enumerate(names): |
|
file = out_dir / f"{fn}.npy" |
|
tr.write(f"saving to file {file}") |
|
l = lengths[idx] // (FLAGS.n_fft // 4) |
|
np.save(file, mel[idx, :l].T) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = ArgumentParser() |
|
parser.add_argument("-o", "--output-dir", type=Path, default="gta") |
|
generate_gta(parser.parse_args().output_dir) |
|
|