Spaces:
Running
Running
import torch | |
from torch.autograd import Variable | |
from torch import nn | |
from torch.nn import functional as F | |
from nn_layers import linear_module, location_layer | |
from utils import get_mask_from_lengths | |
torch.manual_seed(1234) | |
class AttentionNet(nn.Module): | |
# 1024, 512, 128, 32, 31 | |
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, | |
attention_location_n_filters, attention_location_kernel_size): | |
super(AttentionNet, self).__init__() | |
self.query_layer = linear_module(attention_rnn_dim, attention_dim, | |
bias=False, w_init_gain='tanh') | |
# Projecting inputs into 128-D hidden representation | |
self.memory_layer = linear_module(embedding_dim, attention_dim, bias=False, | |
w_init_gain='tanh') | |
# Projecting into 1-D scalar value | |
self.v = linear_module(attention_dim, 1, bias=False) | |
# Convolutional layers to obtain location features and projecting them into 128-D hidden representation | |
self.location_layer = location_layer(attention_location_n_filters, | |
attention_location_kernel_size, | |
attention_dim) | |
self.score_mask_value = -float("inf") | |
def get_alignment_energies(self, query, processed_memory, | |
attention_weights_cat): | |
""" | |
PARAMS | |
------ | |
query: decoder output (batch, n_mel_channels * n_frames_per_step) | |
processed_memory: processed encoder outputs (B, T_in, attention_dim) | |
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) | |
RETURNS | |
------- | |
alignment (batch, max_time) | |
""" | |
processed_query = self.query_layer(query.unsqueeze(1)) | |
processed_attention_weights = self.location_layer(attention_weights_cat) | |
energies = self.v(torch.tanh( | |
processed_query + processed_attention_weights + processed_memory)) | |
energies = energies.squeeze(-1) # eliminates the third dimension of the tensor, which is 1. | |
return energies | |
def forward(self, attention_hidden_state, memory, processed_memory, | |
attention_weights_cat, mask): | |
""" | |
PARAMS | |
------ | |
attention_hidden_state: attention rnn last output | |
memory: encoder outputs | |
processed_memory: processed encoder outputs | |
attention_weights_cat: previous and cummulative attention weights | |
mask: binary mask for padded data | |
""" | |
alignment = self.get_alignment_energies( | |
attention_hidden_state, processed_memory, attention_weights_cat) | |
if mask is not None: | |
alignment.data.masked_fill_(mask, self.score_mask_value) | |
attention_weights = F.softmax(alignment, dim=1) | |
# I think attention_weights is a [BxNUMENCINPUTS] so with unsequeeze(1): [Bx1xNUMENCINPUTS] and memory is | |
# [BxNUMENCINPUTSx512] | |
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) | |
attention_context = attention_context.squeeze(1) | |
return attention_context, attention_weights | |
class Prenet(nn.Module): | |
def __init__(self, in_dim, sizes): | |
super(Prenet, self).__init__() | |
in_sizes = [in_dim] + sizes[:-1] # all list values but the last one. The result is a list of the in_dim element | |
# concatenated with sizes of layers (i.e. [80, 256]) | |
self.layers = nn.ModuleList( | |
[linear_module(in_size, out_size, bias=False) | |
for (in_size, out_size) in zip(in_sizes, sizes)]) | |
def forward(self, x): | |
for linear in self.layers: | |
x = F.dropout(F.relu(linear(x)), p=0.5, training=True) | |
return x | |
class Decoder(nn.Module): | |
def __init__(self, tacotron_hyperparams): | |
super(Decoder, self).__init__() | |
self.n_mel_channels = tacotron_hyperparams['n_mel_channels'] | |
self.n_frames_per_step = tacotron_hyperparams['number_frames_step'] | |
self.encoder_embedding_dim = tacotron_hyperparams['encoder_embedding_dim'] | |
self.attention_rnn_dim = tacotron_hyperparams['attention_rnn_dim'] # 1024 | |
self.decoder_rnn_dim = tacotron_hyperparams['decoder_rnn_dim'] # 1024 | |
self.prenet_dim = tacotron_hyperparams['prenet_dim'] | |
self.max_decoder_steps = tacotron_hyperparams['max_decoder_steps'] | |
# The threshold to decide whether stop or not stop decoding? | |
self.gate_threshold = tacotron_hyperparams['gate_threshold'] | |
self.p_attention_dropout = tacotron_hyperparams['p_attention_dropout'] | |
self.p_decoder_dropout = tacotron_hyperparams['p_decoder_dropout'] | |
# Define the prenet: there is only one frame per step, so input dim is the number of mel channels. | |
# There are two fully connected layers: | |
self.prenet = Prenet( | |
tacotron_hyperparams['n_mel_channels'] * tacotron_hyperparams['number_frames_step'], | |
[tacotron_hyperparams['prenet_dim'], tacotron_hyperparams['prenet_dim']]) | |
# input_size: 1024 + 512 (output of first LSTM cell + attention_context) / hidden_size: 1024 | |
self.attention_rnn = nn.LSTMCell( | |
tacotron_hyperparams['prenet_dim'] + tacotron_hyperparams['encoder_embedding_dim'], | |
tacotron_hyperparams['attention_rnn_dim']) | |
# return attention_weights and attention_context. Does the alignments. | |
self.attention_layer = AttentionNet( | |
tacotron_hyperparams['attention_rnn_dim'], tacotron_hyperparams['encoder_embedding_dim'], | |
tacotron_hyperparams['attention_dim'], tacotron_hyperparams['attention_location_n_filters'], | |
tacotron_hyperparams['attention_location_kernel_size']) | |
# input_size: 256 + 512 (attention_context + prenet_info), hidden_size: 1024 | |
self.decoder_rnn = nn.LSTMCell( | |
tacotron_hyperparams['attention_rnn_dim'] + tacotron_hyperparams['encoder_embedding_dim'], | |
tacotron_hyperparams['decoder_rnn_dim'], 1) | |
# (LSTM output)1024 + (attention_context)512, out_dim: number of mel channels. Last linear projection that | |
# generates an output decoder spectral frame. | |
self.linear_projection = linear_module( | |
tacotron_hyperparams['decoder_rnn_dim'] + tacotron_hyperparams['encoder_embedding_dim'], | |
tacotron_hyperparams['n_mel_channels']*tacotron_hyperparams['number_frames_step']) | |
# decision whether to continue decoding. | |
self.gate_layer = linear_module( | |
tacotron_hyperparams['decoder_rnn_dim'] + tacotron_hyperparams['encoder_embedding_dim'], 1, | |
bias=True, w_init_gain='sigmoid') | |
def get_go_frame(self, memory): | |
""" Gets all zeros frames to use as first decoder input | |
PARAMS | |
------ | |
memory: decoder outputs | |
RETURNS | |
------- | |
decoder_input: all zeros frames | |
""" | |
B = memory.size(0) | |
decoder_input = Variable(memory.data.new( | |
B, self.n_mel_channels * self.n_frames_per_step).zero_()) | |
return decoder_input | |
def initialize_decoder_states(self, memory, mask): | |
""" Initializes attention rnn states, decoder rnn states, attention | |
weights, attention cumulative weights, attention context, stores memory | |
and stores processed memory | |
PARAMS | |
------ | |
memory: Encoder outputs | |
mask: Mask for padded data if training, expects None for inference | |
""" | |
B = memory.size(0) | |
MAX_TIME = memory.size(1) | |
self.attention_hidden = Variable(memory.data.new( | |
B, self.attention_rnn_dim).zero_()) | |
self.attention_cell = Variable(memory.data.new( | |
B, self.attention_rnn_dim).zero_()) | |
self.decoder_hidden = Variable(memory.data.new( | |
B, self.decoder_rnn_dim).zero_()) | |
self.decoder_cell = Variable(memory.data.new( | |
B, self.decoder_rnn_dim).zero_()) | |
self.attention_weights = Variable(memory.data.new( | |
B, MAX_TIME).zero_()) | |
self.attention_weights_cum = Variable(memory.data.new( | |
B, MAX_TIME).zero_()) | |
self.attention_context = Variable(memory.data.new( | |
B, self.encoder_embedding_dim).zero_()) | |
self.memory = memory | |
self.processed_memory = self.attention_layer.memory_layer(memory) | |
self.mask = mask | |
def parse_decoder_inputs(self, decoder_inputs): | |
""" Prepares decoder inputs, i.e. mel outputs | |
PARAMS | |
------ | |
decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs | |
RETURNS | |
------- | |
inputs: processed decoder inputs | |
""" | |
# (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels) | |
decoder_inputs = decoder_inputs.transpose(1, 2) | |
# reshape decoder inputs in case we want to work with more than 1 frame per step (chunks). Otherwise, this next | |
# line does not just do anything | |
decoder_inputs = decoder_inputs.view( | |
decoder_inputs.size(0), | |
int(decoder_inputs.size(1)/self.n_frames_per_step), -1) | |
# (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels) | |
decoder_inputs = decoder_inputs.transpose(0, 1) | |
return decoder_inputs | |
def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments): | |
""" Prepares decoder outputs for output | |
PARAMS | |
------ | |
mel_outputs: | |
gate_outputs: gate output energies | |
alignments: | |
RETURNS | |
------- | |
mel_outputs: | |
gate_outpust: gate output energies | |
alignments: | |
""" | |
# (T_out, B) -> (B, T_out) | |
alignments = torch.stack(alignments).transpose(0, 1) | |
# (T_out, B) -> (B, T_out) | |
gate_outputs = torch.stack(gate_outputs).transpose(0, 1) | |
gate_outputs = gate_outputs.contiguous() | |
# (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels) | |
mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous() | |
# decouple frames per step | |
mel_outputs = mel_outputs.view( | |
mel_outputs.size(0), -1, self.n_mel_channels) | |
# (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out) | |
mel_outputs = mel_outputs.transpose(1, 2) | |
return mel_outputs, gate_outputs, alignments | |
def decode(self, decoder_input): | |
""" Decoder step using stored states, attention and memory | |
PARAMS | |
------ | |
decoder_input: previous mel output | |
RETURNS | |
------- | |
mel_output: | |
gate_output: gate output energies | |
attention_weights: | |
""" | |
# concatenates [Bx1024] and [Bx512]. All dimensions match except 1 (torch.cat -1) | |
# concatenate the i-th decoder hidden state together with the i-th attention context | |
cell_input = torch.cat((decoder_input, self.attention_context), -1) | |
# the previous input is for the following LSTM cell, initialized with zeroes the hidden states and the cell | |
# state. | |
# compute the (i+1)th attention hidden state based on the i-th decoder hidden state and attention context. | |
self.attention_hidden, self.attention_cell = self.attention_rnn( | |
cell_input, (self.attention_hidden, self.attention_cell)) | |
self.attention_hidden = F.dropout(self.attention_hidden, self.p_attention_dropout, self.training) | |
self.attention_cell = F.dropout(self.attention_cell, self.p_attention_dropout, self.training) | |
# concatenate the i-th state attention weights together with the cumulated from previous states to compute | |
# (i+1)th state | |
attention_weights_cat = torch.cat( | |
(self.attention_weights.unsqueeze(1), | |
self.attention_weights_cum.unsqueeze(1)), dim=1) | |
# compute (i+1)th attention context and provide (i+1)th attention weights based on the (i+1)th attention hidden | |
# state and (i)th and prev. weights | |
self.attention_context, self.attention_weights = self.attention_layer( | |
self.attention_hidden, self.memory, self.processed_memory, | |
attention_weights_cat, self.mask) | |
# cumulate attention_weights adding the (i+1)th to compute (i+2)th state | |
self.attention_weights_cum += self.attention_weights | |
decoder_input = torch.cat((self.attention_hidden, self.attention_context), -1) | |
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(decoder_input, | |
(self.decoder_hidden, self.decoder_cell)) | |
self.decoder_hidden = F.dropout(self.decoder_hidden, self.p_decoder_dropout, self.training) | |
self.decoder_cell = F.dropout(self.decoder_cell, self.p_decoder_dropout, self.training) | |
decoder_hidden_attention_context = torch.cat((self.decoder_hidden, self.attention_context), dim=1) | |
decoder_output = self.linear_projection(decoder_hidden_attention_context) | |
gate_prediction = self.gate_layer(decoder_hidden_attention_context) | |
return decoder_output, gate_prediction, self.attention_weights | |
""" | |
# the decoder_output from ith step passes through the pre-net to compute new decoder hidden state and attention_ | |
# context (i+1)th | |
prenet_output = self.prenet(decoder_input) | |
# the decoder_input now is the concatenation of the pre-net output and the new (i+1)th attention_context | |
decoder_input = torch.cat((prenet_output, self.attention_context), -1) | |
# another LSTM Cell to compute the decoder hidden (i+1)th state from the decoder_input | |
self.decoder_hidden, self.decoder_cell = self.decoder_rnn( | |
decoder_input, (self.decoder_hidden, self.decoder_cell)) | |
# with new attention_context we concatenate again with the new (i+1)th decoder_hidden state. | |
decoder_hidden_attention_context = torch.cat( | |
(self.decoder_hidden, self.attention_context), dim=1) | |
# the (i+1)th output is a linear projection of the decoder hidden state with a weight matrix plus bias. | |
decoder_output = self.linear_projection( | |
decoder_hidden_attention_context) | |
# check whether (i+1)th state is the last of the sequence | |
gate_prediction = self.gate_layer(decoder_hidden_attention_context) | |
return decoder_output, gate_prediction, self.attention_weights""" | |
def forward(self, memory, decoder_inputs, memory_lengths): | |
""" Decoder forward pass for training | |
PARAMS | |
------ | |
memory: Encoder outputs | |
decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs | |
memory_lengths: Encoder output lengths for attention masking. | |
RETURNS | |
------- | |
mel_outputs: mel outputs from the decoder | |
gate_outputs: gate outputs from the decoder | |
alignments: sequence of attention weights from the decoder | |
""" | |
decoder_input = self.get_go_frame(memory).unsqueeze(0) | |
decoder_inputs = self.parse_decoder_inputs(decoder_inputs) | |
decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) | |
decoder_inputs = self.prenet(decoder_inputs) | |
self.initialize_decoder_states( | |
memory, mask=~get_mask_from_lengths(memory_lengths)) | |
mel_outputs, gate_outputs, alignments = [], [], [] | |
while len(mel_outputs) < decoder_inputs.size(0) - 1: | |
decoder_input = decoder_inputs[len(mel_outputs)] | |
mel_output, gate_output, attention_weights = self.decode( | |
decoder_input) | |
# a class list, when += means concatenation of vectors | |
mel_outputs += [mel_output.squeeze(1)] | |
gate_outputs += [gate_output.squeeze()] | |
alignments += [attention_weights] | |
# getting the frame indexing from reference mel frames to pass it as the new input of the next decoding | |
# step: Teacher Forcing! | |
# Takes each time_step of sequences of all mini-batch samples (i.e. [48, 80] as the decoder_inputs is | |
# parsed as [189, 48, 80]). | |
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( | |
mel_outputs, gate_outputs, alignments) | |
return mel_outputs, gate_outputs, alignments | |
def inference(self, memory): | |
""" Decoder inference | |
PARAMS | |
------ | |
memory: Encoder outputs | |
RETURNS | |
------- | |
mel_outputs: mel outputs from the decoder | |
gate_outputs: gate outputs from the decoder | |
alignments: sequence of attention weights from the decoder | |
""" | |
decoder_input = self.get_go_frame(memory) | |
self.initialize_decoder_states(memory, mask=None) | |
mel_outputs, gate_outputs, alignments = [], [], [] | |
while True: | |
decoder_input = self.prenet(decoder_input) | |
mel_output, gate_output, alignment = self.decode(decoder_input) | |
mel_outputs += [mel_output.squeeze(1)] | |
gate_outputs += [gate_output] | |
alignments += [alignment] | |
if torch.sigmoid(gate_output.data) > self.gate_threshold: | |
break | |
elif len(mel_outputs) == self.max_decoder_steps: | |
print("Warning! Reached max decoder steps") | |
break | |
decoder_input = mel_output | |
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( | |
mel_outputs, gate_outputs, alignments) | |
return mel_outputs, gate_outputs, alignments | |