import os import pickle from functools import partial from typing import Deque import fire import jax import jax.numpy as jnp import jax.tools.colab_tpu import matplotlib.pyplot as plt import optax from tqdm.auto import tqdm from .acoustic_trainer import initial_state, loss_vag, val_loss_fn from .config import FLAGS from .data_loader import load_textgrid_wav from .dsp import MelFilter from .utils import print_flags def setup_colab_tpu(): jax.tools.colab_tpu.setup_tpu() def train( batch_size: int = 32, steps_per_update: int = 10, learning_rate: float = 1024e-6, ): """Train acoustic model on multiple cores (TPU).""" lr_schedule = optax.exponential_decay(learning_rate, 50_000, 0.5, staircase=True) optimizer = optax.chain( optax.clip_by_global_norm(1.0), optax.adamw(lr_schedule, weight_decay=FLAGS.weight_decay), ) def update_step(prev_state, inputs): params, aux, rng, optim_state = prev_state rng, new_rng = jax.random.split(rng) (loss, new_aux), grads = loss_vag(params, aux, rng, inputs) grads = jax.lax.pmean(grads, axis_name="i") updates, new_optim_state = optimizer.update(grads, optim_state, params) new_params = optax.apply_updates(params, updates) next_state = (new_params, new_aux, new_rng, new_optim_state) return next_state, loss @partial(jax.pmap, axis_name="i") def update(params, aux, rng, optim_state, inputs): states, losses = jax.lax.scan( update_step, (params, aux, rng, optim_state), inputs ) return states, jnp.mean(losses) print(jax.devices()) num_devices = jax.device_count() train_data_iter = load_textgrid_wav( FLAGS.data_dir, FLAGS.max_phoneme_seq_len, batch_size * num_devices * steps_per_update, FLAGS.max_wave_len, "train", ) val_data_iter = load_textgrid_wav( FLAGS.data_dir, FLAGS.max_phoneme_seq_len, batch_size, FLAGS.max_wave_len, "val", ) melfilter = MelFilter( FLAGS.sample_rate, FLAGS.n_fft, FLAGS.mel_dim, FLAGS.fmin, FLAGS.fmax, ) batch = next(train_data_iter) batch = jax.tree_map(lambda x: x[:1], batch) batch = batch._replace(mels=melfilter(batch.wavs.astype(jnp.float32) / (2**15))) params, aux, rng, optim_state = initial_state(optimizer, batch) losses = Deque(maxlen=1000) val_losses = Deque(maxlen=100) last_step = -steps_per_update # loading latest checkpoint ckpt_fn = FLAGS.ckpt_dir / "acoustic_latest_ckpt.pickle" if ckpt_fn.exists(): print("Resuming from latest checkpoint at", ckpt_fn) with open(ckpt_fn, "rb") as f: dic = pickle.load(f) last_step, params, aux, rng, optim_state = ( dic["step"], dic["params"], dic["aux"], dic["rng"], dic["optim_state"], ) tr = tqdm( range( last_step + steps_per_update, FLAGS.num_training_steps + 1, steps_per_update ), desc="training", total=FLAGS.num_training_steps // steps_per_update + 1, initial=last_step // steps_per_update + 1, ) params, aux, rng, optim_state = jax.device_put_replicated( (params, aux, rng, optim_state), jax.devices() ) def batch_reshape(batch): return jax.tree_map( lambda x: jnp.reshape(x, (num_devices, steps_per_update, -1) + x.shape[1:]), batch, ) for step in tr: batch = next(train_data_iter) batch = batch_reshape(batch) (params, aux, rng, optim_state), loss = update( params, aux, rng, optim_state, batch ) losses.append(loss) if step % 10 == 0: val_batch = next(val_data_iter) val_loss, val_aux, predicted_mel, gt_mel = val_loss_fn( *jax.tree_map(lambda x: x[0], (params, aux, rng)), val_batch ) val_losses.append(val_loss) attn = jax.device_get(val_aux["acoustic_model"]["attn"]) predicted_mel = jax.device_get(predicted_mel[0]) gt_mel = jax.device_get(gt_mel[0]) if step % 1000 == 0: loss = jnp.mean(sum(losses)).item() / len(losses) val_loss = sum(val_losses).item() / len(val_losses) tr.write(f"step {step} train loss {loss:.3f} val loss {val_loss:.3f}") # saving predicted mels plt.figure(figsize=(10, 10)) plt.subplot(3, 1, 1) plt.imshow(predicted_mel.T, origin="lower", aspect="auto") plt.subplot(3, 1, 2) plt.imshow(gt_mel.T, origin="lower", aspect="auto") plt.subplot(3, 1, 3) plt.imshow(attn.T, origin="lower", aspect="auto") plt.tight_layout() plt.savefig(FLAGS.ckpt_dir / f"mel_{step:06d}.png") plt.close() # saving checkpoint with open(ckpt_fn, "wb") as f: params_, aux_, rng_, optim_state_ = jax.tree_map( lambda x: x[0], (params, aux, rng, optim_state) ) pickle.dump( { "step": step, "params": params_, "aux": aux_, "rng": rng_, "optim_state": optim_state_, }, f, ) if __name__ == "__main__": # we don't use these flags. del FLAGS.batch_size del FLAGS.learning_rate del FLAGS.duration_learning_rate del FLAGS.duration_lstm_dim del FLAGS.duration_embed_dropout_rate print_flags(FLAGS.__dict__) if "COLAB_TPU_ADDR" in os.environ: setup_colab_tpu() if not FLAGS.ckpt_dir.exists(): print("Create checkpoint dir at", FLAGS.ckpt_dir) FLAGS.ckpt_dir.mkdir(parents=True, exist_ok=True) fire.Fire(train)