File size: 4,719 Bytes
12da6cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from functools import partial
from typing import Deque

import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from tqdm.auto import tqdm
from vietTTS.nat.config import DurationInput

from .config import FLAGS
from .data_loader import textgrid_data_loader
from .model import DurationModel
from .utils import load_latest_ckpt, print_flags, save_ckpt


def loss_fn(params, aux, rng, x: DurationInput, is_training=True):
    """return the l1 loss"""

    @hk.transform_with_state
    def net(x):
        return DurationModel(is_training=is_training)(x)

    if is_training:
        # randomly mask tokens with [WORD END] token
        # during training to avoid overfitting
        m_rng, rng = jax.random.split(rng, 2)
        m = jax.random.bernoulli(m_rng, FLAGS.token_mask_prob, x.phonemes.shape)
        x = x._replace(phonemes=jnp.where(m, FLAGS.word_end_index, x.phonemes))
    durations, aux = net.apply(params, aux, rng, x)
    mask = jnp.arange(0, x.phonemes.shape[1])[None, :] < x.lengths[:, None]
    # NOT predict [WORD END] token
    mask = jnp.where(x.phonemes == FLAGS.word_end_index, False, mask)
    masked_loss = jnp.abs(durations - x.durations) * mask
    loss = jnp.sum(masked_loss) / jnp.sum(mask)
    return loss, aux


forward_fn = jax.jit(
    hk.transform_with_state(lambda x: DurationModel(is_training=False)(x)).apply
)


def predict_duration(params, aux, rng, x: DurationInput):
    d, _ = forward_fn(params, aux, rng, x)
    return d, x.durations


val_loss_fn = jax.jit(partial(loss_fn, is_training=False))

loss_vag = jax.value_and_grad(loss_fn, has_aux=True)

optimizer = optax.chain(
    optax.clip_by_global_norm(FLAGS.max_grad_norm),
    optax.adamw(FLAGS.duration_learning_rate, weight_decay=FLAGS.weight_decay),
)


@jax.jit
def update(params, aux, rng, optim_state, inputs: DurationInput):
    rng, new_rng = jax.random.split(rng)
    (loss, new_aux), grads = loss_vag(params, aux, rng, inputs)
    updates, new_optim_state = optimizer.update(grads, optim_state, params)
    new_params = optax.apply_updates(params, updates)
    return loss, (new_params, new_aux, new_rng, new_optim_state)


def initial_state(batch):
    rng = jax.random.PRNGKey(42)
    params, aux = hk.transform_with_state(lambda x: DurationModel(True)(x)).init(
        rng, batch
    )
    optim_state = optimizer.init(params)
    return params, aux, rng, optim_state


def plot_val_duration(step: int, batch, params, aux, rng):
    fn = FLAGS.ckpt_dir / f"duration_{step:06d}.png"
    predicted_dur, gt_dur = predict_duration(params, aux, rng, batch)
    L = batch.lengths[0]
    x = np.arange(0, L) * 3
    plt.plot(predicted_dur[0, :L])
    plt.plot(gt_dur[0, :L])
    plt.legend(["predicted", "gt"])
    plt.title("Phoneme durations")
    plt.savefig(fn)
    plt.close()


def train():
    train_data_iter = textgrid_data_loader(
        FLAGS.data_dir, FLAGS.max_phoneme_seq_len, FLAGS.batch_size, mode="train"
    )
    val_data_iter = textgrid_data_loader(
        FLAGS.data_dir, FLAGS.max_phoneme_seq_len, FLAGS.batch_size, mode="val"
    )
    losses = Deque(maxlen=1000)
    val_losses = Deque(maxlen=100)
    latest_ckpt = load_latest_ckpt(FLAGS.ckpt_dir)
    if latest_ckpt is not None:
        last_step, params, aux, rng, optim_state = latest_ckpt
    else:
        last_step = -1
        print("Generate random initial states...")
        params, aux, rng, optim_state = initial_state(next(train_data_iter))

    tr = tqdm(
        range(last_step + 1, 1 + FLAGS.num_training_steps),
        total=1 + FLAGS.num_training_steps,
        initial=last_step + 1,
        ncols=80,
        desc="training",
    )
    for step in tr:
        batch = next(train_data_iter)
        loss, (params, aux, rng, optim_state) = update(
            params, aux, rng, optim_state, batch
        )
        losses.append(loss)

        if step % 10 == 0:
            val_loss, _ = val_loss_fn(params, aux, rng, next(val_data_iter))
            val_losses.append(val_loss)

        if step % 1000 == 0:
            loss = sum(losses).item() / len(losses)
            val_loss = sum(val_losses).item() / len(val_losses)
            plot_val_duration(step, next(val_data_iter), params, aux, rng)
            tr.write(
                f" {step:>6d}/{FLAGS.num_training_steps:>6d} | train loss {loss:.5f} | val loss {val_loss:.5f}"
            )
            save_ckpt(step, params, aux, rng, optim_state, ckpt_dir=FLAGS.ckpt_dir)


if __name__ == "__main__":
    print_flags(FLAGS.__dict__)
    if not FLAGS.ckpt_dir.exists():
        print("Create checkpoint dir at", FLAGS.ckpt_dir)
        FLAGS.ckpt_dir.mkdir(parents=True, exist_ok=True)
    train()