from math import sqrt import torch from torch import nn from Encoder import Encoder from Decoder import Decoder from Postnet import Postnet from GST import GST from utils import to_gpu, get_mask_from_lengths from fp16_optimizer import fp32_to_fp16, fp16_to_fp32 torch.manual_seed(1234) class tacotron_2(nn.Module): def __init__(self, tacotron_hyperparams): super(tacotron_2, self).__init__() self.mask_padding = tacotron_hyperparams['mask_padding'] self.fp16_run = tacotron_hyperparams['fp16_run'] self.n_mel_channels = tacotron_hyperparams['n_mel_channels'] self.n_frames_per_step = tacotron_hyperparams['number_frames_step'] self.embedding = nn.Embedding( tacotron_hyperparams['n_symbols'], tacotron_hyperparams['symbols_embedding_length']) # CHECK THIS OUT!!! std = sqrt(2.0 / (tacotron_hyperparams['n_symbols'] + tacotron_hyperparams['symbols_embedding_length'])) val = sqrt(3.0) * std self.embedding.weight.data.uniform_(-val, val) self.encoder = Encoder(tacotron_hyperparams) self.decoder = Decoder(tacotron_hyperparams) self.postnet = Postnet(tacotron_hyperparams) self.gst = GST(tacotron_hyperparams) def parse_batch(self, batch): # GST I add the new tensor from prosody features to train GST tokens: text_padded, input_lengths, mel_padded, gate_padded, output_lengths, prosody_padded = batch text_padded = to_gpu(text_padded).long() max_len = int(torch.max(input_lengths.data).item()) # With item() you get the pure value (not in a tensor) input_lengths = to_gpu(input_lengths).long() mel_padded = to_gpu(mel_padded).float() gate_padded = to_gpu(gate_padded).float() output_lengths = to_gpu(output_lengths).long() prosody_padded = to_gpu(prosody_padded).float() return ( (text_padded, input_lengths, mel_padded, max_len, output_lengths, prosody_padded), (mel_padded, gate_padded)) def parse_input(self, inputs): inputs = fp32_to_fp16(inputs) if self.fp16_run else inputs return inputs def parse_output(self, outputs, output_lengths=None): if self.mask_padding and output_lengths is not None: mask = ~get_mask_from_lengths(output_lengths) mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) outputs[0].data.masked_fill_(mask, 0.0) outputs[1].data.masked_fill_(mask, 0.0) outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs return outputs def forward(self, inputs): inputs, input_lengths, targets, max_len, output_lengths, gst_prosody_padded = self.parse_input(inputs) input_lengths, output_lengths = input_lengths.data, output_lengths.data embedded_inputs = self.embedding(inputs).transpose(1, 2) encoder_outputs = self.encoder(embedded_inputs, input_lengths) # GST style embedding plus embedded_inputs before entering the decoder # bin_locations = gst_prosody_padded[:, 0, :] # pitch_intensities = gst_prosody_padded[:, 1:, :] # bin_locations = bin_locations.unsqueeze(2) gst_style_embedding, gst_scores = self.gst(gst_prosody_padded, output_lengths) # [N, 512] gst_style_embedding = gst_style_embedding.expand_as(encoder_outputs) encoder_outputs = encoder_outputs + gst_style_embedding mel_outputs, gate_outputs, alignments = self.decoder( encoder_outputs, targets, memory_lengths=input_lengths) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet return self.parse_output( [mel_outputs, mel_outputs_postnet, gate_outputs, alignments, gst_scores], output_lengths) def inference(self, inputs, gst_scores): # gst_scores must be a torch tensor inputs = self.parse_input(inputs) embedded_inputs = self.embedding(inputs).transpose(1, 2) encoder_outputs = self.encoder.inference(embedded_inputs) # GST inference: gst_style_embedding = self.gst.inference(gst_scores) gst_style_embedding = gst_style_embedding.expand_as(encoder_outputs) encoder_outputs = encoder_outputs + gst_style_embedding mel_outputs, gate_outputs, alignments = self.decoder.inference( encoder_outputs) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet outputs = self.parse_output( [mel_outputs, mel_outputs_postnet, gate_outputs, alignments]) return outputs