tts / vietTTS /nat /model.py
tobiccino's picture
upload
12da6cc
raw
history blame
No virus
6.45 kB
import haiku as hk
import jax
import jax.numpy as jnp
from jax.numpy import ndarray
from .config import FLAGS, AcousticInput, DurationInput
class TokenEncoder(hk.Module):
"""Encode phonemes/text to vector"""
def __init__(self, vocab_size, lstm_dim, dropout_rate, is_training=True):
super().__init__()
self.is_training = is_training
self.embed = hk.Embed(vocab_size, lstm_dim)
self.conv1 = hk.Conv1D(lstm_dim, 3, padding="SAME")
self.conv2 = hk.Conv1D(lstm_dim, 3, padding="SAME")
self.conv3 = hk.Conv1D(lstm_dim, 3, padding="SAME")
self.bn1 = hk.BatchNorm(True, True, 0.9)
self.bn2 = hk.BatchNorm(True, True, 0.9)
self.bn3 = hk.BatchNorm(True, True, 0.9)
self.lstm_fwd = hk.LSTM(lstm_dim)
self.lstm_bwd = hk.ResetCore(hk.LSTM(lstm_dim))
self.dropout_rate = dropout_rate
def __call__(self, x, lengths):
x = self.embed(x)
x = jax.nn.relu(self.bn1(self.conv1(x), is_training=self.is_training))
if self.is_training:
x = hk.dropout(hk.next_rng_key(), self.dropout_rate, x)
x = jax.nn.relu(self.bn2(self.conv2(x), is_training=self.is_training))
if self.is_training:
x = hk.dropout(hk.next_rng_key(), self.dropout_rate, x)
x = jax.nn.relu(self.bn3(self.conv3(x), is_training=self.is_training))
if self.is_training:
x = hk.dropout(hk.next_rng_key(), self.dropout_rate, x)
B, L, _ = x.shape
mask = jnp.arange(0, L)[None, :] >= (lengths[:, None] - 1)
h0c0_fwd = self.lstm_fwd.initial_state(B)
new_hx_fwd, _ = hk.dynamic_unroll(self.lstm_fwd, x, h0c0_fwd, time_major=False)
x_bwd, mask_bwd = jax.tree_map(lambda x: jnp.flip(x, axis=1), (x, mask))
h0c0_bwd = self.lstm_bwd.initial_state(B)
new_hx_bwd, _ = hk.dynamic_unroll(
self.lstm_bwd, (x_bwd, mask_bwd), h0c0_bwd, time_major=False
)
x = jnp.concatenate((new_hx_fwd, jnp.flip(new_hx_bwd, axis=1)), axis=-1)
return x
class DurationModel(hk.Module):
"""Duration model of phonemes."""
def __init__(self, is_training=True):
super().__init__()
self.is_training = is_training
self.encoder = TokenEncoder(
FLAGS.vocab_size,
FLAGS.duration_lstm_dim,
FLAGS.duration_embed_dropout_rate,
is_training,
)
self.projection = hk.Sequential(
[hk.Linear(FLAGS.duration_lstm_dim), jax.nn.gelu, hk.Linear(1)]
)
def __call__(self, inputs: DurationInput):
x = self.encoder(inputs.phonemes, inputs.lengths)
x = jnp.squeeze(self.projection(x), axis=-1)
x = jax.nn.softplus(x)
return x
class AcousticModel(hk.Module):
"""Predict melspectrogram from aligned phonemes"""
def __init__(self, is_training=True):
super().__init__()
self.is_training = is_training
self.encoder = TokenEncoder(
FLAGS.vocab_size, FLAGS.acoustic_encoder_dim, 0.5, is_training
)
self.decoder = hk.deep_rnn_with_skip_connections(
[hk.LSTM(FLAGS.acoustic_decoder_dim), hk.LSTM(FLAGS.acoustic_decoder_dim)]
)
self.projection = hk.Linear(FLAGS.mel_dim)
# prenet
self.prenet_fc1 = hk.Linear(256, with_bias=False)
self.prenet_fc2 = hk.Linear(256, with_bias=False)
# posnet
self.postnet_convs = [hk.Conv1D(FLAGS.postnet_dim, 5) for _ in range(4)]
self.postnet_convs.append(hk.Conv1D(FLAGS.mel_dim, 5))
self.postnet_bns = [hk.BatchNorm(True, True, 0.9) for _ in range(4)] + [None]
def prenet(self, x, dropout=0.5):
x = jax.nn.relu(self.prenet_fc1(x))
x = hk.dropout(hk.next_rng_key(), dropout, x)
x = jax.nn.relu(self.prenet_fc2(x))
x = hk.dropout(hk.next_rng_key(), dropout, x)
return x
def upsample(self, x, durations, L):
ruler = jnp.arange(0, L)[None, :] # B, L
end_pos = jnp.cumsum(durations, axis=1)
mid_pos = end_pos - durations / 2 # B, T
d2 = jnp.square((mid_pos[:, None, :] - ruler[:, :, None])) / 10.0
w = jax.nn.softmax(-d2, axis=-1)
hk.set_state("attn", w[0])
x = jnp.einsum("BLT,BTD->BLD", w, x)
return x
def postnet(self, mel: ndarray) -> ndarray:
x = mel
for conv, bn in zip(self.postnet_convs, self.postnet_bns):
x = conv(x)
if bn is not None:
x = bn(x, is_training=self.is_training)
x = jnp.tanh(x)
x = hk.dropout(hk.next_rng_key(), 0.5, x) if self.is_training else x
return x
def inference(self, tokens, durations, n_frames):
B, L = tokens.shape
lengths = jnp.array([L], dtype=jnp.int32)
x = self.encoder(tokens, lengths)
x = self.upsample(x, durations, n_frames)
def loop_fn(inputs, state):
cond = inputs
prev_mel, hxcx = state
prev_mel = self.prenet(prev_mel)
x = jnp.concatenate((cond, prev_mel), axis=-1)
x, new_hxcx = self.decoder(x, hxcx)
x = self.projection(x)
return x, (x, new_hxcx)
state = (
jnp.zeros((B, FLAGS.mel_dim), dtype=jnp.float32),
self.decoder.initial_state(B),
)
x, _ = hk.dynamic_unroll(loop_fn, x, state, time_major=False)
residual = self.postnet(x)
return x + residual
def __call__(self, inputs: AcousticInput):
x = self.encoder(inputs.phonemes, inputs.lengths)
x = self.upsample(x, inputs.durations, inputs.mels.shape[1])
mels = self.prenet(inputs.mels)
x = jnp.concatenate((x, mels), axis=-1)
B, L, _ = x.shape
hx = self.decoder.initial_state(B)
def zoneout_decoder(inputs, prev_state):
x, mask = inputs
x, state = self.decoder(x, prev_state)
state = jax.tree_map(
lambda m, s1, s2: s1 * m + s2 * (1 - m), mask, prev_state, state
)
return x, state
mask = jax.tree_map(
lambda x: jax.random.bernoulli(hk.next_rng_key(), 0.1, (B, L, x.shape[-1])),
hx,
)
x, _ = hk.dynamic_unroll(zoneout_decoder, (x, mask), hx, time_major=False)
x = self.projection(x)
residual = self.postnet(x)
return x, x + residual