tts / vietTTS /nat /duration_trainer.py
tobiccino's picture
upload
12da6cc
raw
history blame
No virus
4.72 kB
from functools import partial
from typing import Deque
import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from tqdm.auto import tqdm
from vietTTS.nat.config import DurationInput
from .config import FLAGS
from .data_loader import textgrid_data_loader
from .model import DurationModel
from .utils import load_latest_ckpt, print_flags, save_ckpt
def loss_fn(params, aux, rng, x: DurationInput, is_training=True):
"""return the l1 loss"""
@hk.transform_with_state
def net(x):
return DurationModel(is_training=is_training)(x)
if is_training:
# randomly mask tokens with [WORD END] token
# during training to avoid overfitting
m_rng, rng = jax.random.split(rng, 2)
m = jax.random.bernoulli(m_rng, FLAGS.token_mask_prob, x.phonemes.shape)
x = x._replace(phonemes=jnp.where(m, FLAGS.word_end_index, x.phonemes))
durations, aux = net.apply(params, aux, rng, x)
mask = jnp.arange(0, x.phonemes.shape[1])[None, :] < x.lengths[:, None]
# NOT predict [WORD END] token
mask = jnp.where(x.phonemes == FLAGS.word_end_index, False, mask)
masked_loss = jnp.abs(durations - x.durations) * mask
loss = jnp.sum(masked_loss) / jnp.sum(mask)
return loss, aux
forward_fn = jax.jit(
hk.transform_with_state(lambda x: DurationModel(is_training=False)(x)).apply
)
def predict_duration(params, aux, rng, x: DurationInput):
d, _ = forward_fn(params, aux, rng, x)
return d, x.durations
val_loss_fn = jax.jit(partial(loss_fn, is_training=False))
loss_vag = jax.value_and_grad(loss_fn, has_aux=True)
optimizer = optax.chain(
optax.clip_by_global_norm(FLAGS.max_grad_norm),
optax.adamw(FLAGS.duration_learning_rate, weight_decay=FLAGS.weight_decay),
)
@jax.jit
def update(params, aux, rng, optim_state, inputs: DurationInput):
rng, new_rng = jax.random.split(rng)
(loss, new_aux), grads = loss_vag(params, aux, rng, inputs)
updates, new_optim_state = optimizer.update(grads, optim_state, params)
new_params = optax.apply_updates(params, updates)
return loss, (new_params, new_aux, new_rng, new_optim_state)
def initial_state(batch):
rng = jax.random.PRNGKey(42)
params, aux = hk.transform_with_state(lambda x: DurationModel(True)(x)).init(
rng, batch
)
optim_state = optimizer.init(params)
return params, aux, rng, optim_state
def plot_val_duration(step: int, batch, params, aux, rng):
fn = FLAGS.ckpt_dir / f"duration_{step:06d}.png"
predicted_dur, gt_dur = predict_duration(params, aux, rng, batch)
L = batch.lengths[0]
x = np.arange(0, L) * 3
plt.plot(predicted_dur[0, :L])
plt.plot(gt_dur[0, :L])
plt.legend(["predicted", "gt"])
plt.title("Phoneme durations")
plt.savefig(fn)
plt.close()
def train():
train_data_iter = textgrid_data_loader(
FLAGS.data_dir, FLAGS.max_phoneme_seq_len, FLAGS.batch_size, mode="train"
)
val_data_iter = textgrid_data_loader(
FLAGS.data_dir, FLAGS.max_phoneme_seq_len, FLAGS.batch_size, mode="val"
)
losses = Deque(maxlen=1000)
val_losses = Deque(maxlen=100)
latest_ckpt = load_latest_ckpt(FLAGS.ckpt_dir)
if latest_ckpt is not None:
last_step, params, aux, rng, optim_state = latest_ckpt
else:
last_step = -1
print("Generate random initial states...")
params, aux, rng, optim_state = initial_state(next(train_data_iter))
tr = tqdm(
range(last_step + 1, 1 + FLAGS.num_training_steps),
total=1 + FLAGS.num_training_steps,
initial=last_step + 1,
ncols=80,
desc="training",
)
for step in tr:
batch = next(train_data_iter)
loss, (params, aux, rng, optim_state) = update(
params, aux, rng, optim_state, batch
)
losses.append(loss)
if step % 10 == 0:
val_loss, _ = val_loss_fn(params, aux, rng, next(val_data_iter))
val_losses.append(val_loss)
if step % 1000 == 0:
loss = sum(losses).item() / len(losses)
val_loss = sum(val_losses).item() / len(val_losses)
plot_val_duration(step, next(val_data_iter), params, aux, rng)
tr.write(
f" {step:>6d}/{FLAGS.num_training_steps:>6d} | train loss {loss:.5f} | val loss {val_loss:.5f}"
)
save_ckpt(step, params, aux, rng, optim_state, ckpt_dir=FLAGS.ckpt_dir)
if __name__ == "__main__":
print_flags(FLAGS.__dict__)
if not FLAGS.ckpt_dir.exists():
print("Create checkpoint dir at", FLAGS.ckpt_dir)
FLAGS.ckpt_dir.mkdir(parents=True, exist_ok=True)
train()