import haiku as hk import jax import jax.numpy as jnp from jax.numpy import ndarray from .config import FLAGS, AcousticInput, DurationInput class TokenEncoder(hk.Module): """Encode phonemes/text to vector""" def __init__(self, vocab_size, lstm_dim, dropout_rate, is_training=True): super().__init__() self.is_training = is_training self.embed = hk.Embed(vocab_size, lstm_dim) self.conv1 = hk.Conv1D(lstm_dim, 3, padding="SAME") self.conv2 = hk.Conv1D(lstm_dim, 3, padding="SAME") self.conv3 = hk.Conv1D(lstm_dim, 3, padding="SAME") self.bn1 = hk.BatchNorm(True, True, 0.9) self.bn2 = hk.BatchNorm(True, True, 0.9) self.bn3 = hk.BatchNorm(True, True, 0.9) self.lstm_fwd = hk.LSTM(lstm_dim) self.lstm_bwd = hk.ResetCore(hk.LSTM(lstm_dim)) self.dropout_rate = dropout_rate def __call__(self, x, lengths): x = self.embed(x) x = jax.nn.relu(self.bn1(self.conv1(x), is_training=self.is_training)) if self.is_training: x = hk.dropout(hk.next_rng_key(), self.dropout_rate, x) x = jax.nn.relu(self.bn2(self.conv2(x), is_training=self.is_training)) if self.is_training: x = hk.dropout(hk.next_rng_key(), self.dropout_rate, x) x = jax.nn.relu(self.bn3(self.conv3(x), is_training=self.is_training)) if self.is_training: x = hk.dropout(hk.next_rng_key(), self.dropout_rate, x) B, L, _ = x.shape mask = jnp.arange(0, L)[None, :] >= (lengths[:, None] - 1) h0c0_fwd = self.lstm_fwd.initial_state(B) new_hx_fwd, _ = hk.dynamic_unroll(self.lstm_fwd, x, h0c0_fwd, time_major=False) x_bwd, mask_bwd = jax.tree_map(lambda x: jnp.flip(x, axis=1), (x, mask)) h0c0_bwd = self.lstm_bwd.initial_state(B) new_hx_bwd, _ = hk.dynamic_unroll( self.lstm_bwd, (x_bwd, mask_bwd), h0c0_bwd, time_major=False ) x = jnp.concatenate((new_hx_fwd, jnp.flip(new_hx_bwd, axis=1)), axis=-1) return x class DurationModel(hk.Module): """Duration model of phonemes.""" def __init__(self, is_training=True): super().__init__() self.is_training = is_training self.encoder = TokenEncoder( FLAGS.vocab_size, FLAGS.duration_lstm_dim, FLAGS.duration_embed_dropout_rate, is_training, ) self.projection = hk.Sequential( [hk.Linear(FLAGS.duration_lstm_dim), jax.nn.gelu, hk.Linear(1)] ) def __call__(self, inputs: DurationInput): x = self.encoder(inputs.phonemes, inputs.lengths) x = jnp.squeeze(self.projection(x), axis=-1) x = jax.nn.softplus(x) return x class AcousticModel(hk.Module): """Predict melspectrogram from aligned phonemes""" def __init__(self, is_training=True): super().__init__() self.is_training = is_training self.encoder = TokenEncoder( FLAGS.vocab_size, FLAGS.acoustic_encoder_dim, 0.5, is_training ) self.decoder = hk.deep_rnn_with_skip_connections( [hk.LSTM(FLAGS.acoustic_decoder_dim), hk.LSTM(FLAGS.acoustic_decoder_dim)] ) self.projection = hk.Linear(FLAGS.mel_dim) # prenet self.prenet_fc1 = hk.Linear(256, with_bias=False) self.prenet_fc2 = hk.Linear(256, with_bias=False) # posnet self.postnet_convs = [hk.Conv1D(FLAGS.postnet_dim, 5) for _ in range(4)] self.postnet_convs.append(hk.Conv1D(FLAGS.mel_dim, 5)) self.postnet_bns = [hk.BatchNorm(True, True, 0.9) for _ in range(4)] + [None] def prenet(self, x, dropout=0.5): x = jax.nn.relu(self.prenet_fc1(x)) x = hk.dropout(hk.next_rng_key(), dropout, x) x = jax.nn.relu(self.prenet_fc2(x)) x = hk.dropout(hk.next_rng_key(), dropout, x) return x def upsample(self, x, durations, L): ruler = jnp.arange(0, L)[None, :] # B, L end_pos = jnp.cumsum(durations, axis=1) mid_pos = end_pos - durations / 2 # B, T d2 = jnp.square((mid_pos[:, None, :] - ruler[:, :, None])) / 10.0 w = jax.nn.softmax(-d2, axis=-1) hk.set_state("attn", w[0]) x = jnp.einsum("BLT,BTD->BLD", w, x) return x def postnet(self, mel: ndarray) -> ndarray: x = mel for conv, bn in zip(self.postnet_convs, self.postnet_bns): x = conv(x) if bn is not None: x = bn(x, is_training=self.is_training) x = jnp.tanh(x) x = hk.dropout(hk.next_rng_key(), 0.5, x) if self.is_training else x return x def inference(self, tokens, durations, n_frames): B, L = tokens.shape lengths = jnp.array([L], dtype=jnp.int32) x = self.encoder(tokens, lengths) x = self.upsample(x, durations, n_frames) def loop_fn(inputs, state): cond = inputs prev_mel, hxcx = state prev_mel = self.prenet(prev_mel) x = jnp.concatenate((cond, prev_mel), axis=-1) x, new_hxcx = self.decoder(x, hxcx) x = self.projection(x) return x, (x, new_hxcx) state = ( jnp.zeros((B, FLAGS.mel_dim), dtype=jnp.float32), self.decoder.initial_state(B), ) x, _ = hk.dynamic_unroll(loop_fn, x, state, time_major=False) residual = self.postnet(x) return x + residual def __call__(self, inputs: AcousticInput): x = self.encoder(inputs.phonemes, inputs.lengths) x = self.upsample(x, inputs.durations, inputs.mels.shape[1]) mels = self.prenet(inputs.mels) x = jnp.concatenate((x, mels), axis=-1) B, L, _ = x.shape hx = self.decoder.initial_state(B) def zoneout_decoder(inputs, prev_state): x, mask = inputs x, state = self.decoder(x, prev_state) state = jax.tree_map( lambda m, s1, s2: s1 * m + s2 * (1 - m), mask, prev_state, state ) return x, state mask = jax.tree_map( lambda x: jax.random.bernoulli(hk.next_rng_key(), 0.1, (B, L, x.shape[-1])), hx, ) x, _ = hk.dynamic_unroll(zoneout_decoder, (x, mask), hx, time_major=False) x = self.projection(x) residual = self.postnet(x) return x, x + residual