lhzstar
initial commits
6bc94ac
from datetime import datetime
from functools import partial
from pathlib import Path
from os.path import exists
import os
import torch
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from synthesizer import audio
from synthesizer.models.tacotron import Tacotron
from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
from synthesizer.utils import ValueWindow, data_parallel_workaround
from synthesizer.utils.plot import plot_spectrogram
from synthesizer.utils.symbols import symbols
from synthesizer.utils.text import sequence_to_text
from vocoder.display import *
def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
def time_string():
return datetime.now().strftime("%Y-%m-%d %H:%M")
def sync(device: torch.device):
# For correct profiling (cuda operations are async)
if device.type == "cuda":
torch.cuda.synchronize(device)
def train(run_id: str, syn_dir: Path, models_dir: Path, save_every: int, backup_every: int, force_restart: bool, use_tb: bool,
hparams):
if use_tb:
print("Use Tensorboard")
import tensorflow as tf
import datetime
# Hide GPU from visible devices
log_dir = f"log/vc/synthesizer/tensorboard/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_summary_writer = tf.summary.create_file_writer(log_dir)
models_dir.mkdir(exist_ok=True)
model_dir = models_dir.joinpath(run_id)
plot_dir = model_dir.joinpath("plots")
wav_dir = model_dir.joinpath("wavs")
mel_output_dir = model_dir.joinpath("mel-spectrograms")
meta_folder = model_dir.joinpath("metas")
model_dir.mkdir(exist_ok=True)
plot_dir.mkdir(exist_ok=True)
wav_dir.mkdir(exist_ok=True)
mel_output_dir.mkdir(exist_ok=True)
meta_folder.mkdir(exist_ok=True)
weights_fpath = model_dir / f"synthesizer.pt"
train_metadata_fpath = syn_dir.joinpath("train/train.txt")
dev_metadata_fpath = syn_dir.joinpath("dev/dev.txt")
print("Checkpoint path: {}".format(weights_fpath))
print("Loading training data from: {}".format(train_metadata_fpath))
print("Using model: Tacotron")
# Bookkeeping
time_window = ValueWindow(100)
loss_window = ValueWindow(100)
# From WaveRNN/train_tacotron.py
if torch.cuda.is_available():
device = torch.device("cuda")
for session in hparams.tts_schedule:
_, _, _, batch_size = session
if batch_size % torch.cuda.device_count() != 0:
raise ValueError("`batch_size` must be evenly divisible by n_gpus!")
else:
device = torch.device("cpu")
print("Using device:", device)
# Instantiate Tacotron Model
print("\nInitialising Tacotron Model...\n")
model = Tacotron(embed_dims=hparams.tts_embed_dims,
num_chars=len(symbols),
encoder_dims=hparams.tts_encoder_dims,
decoder_dims=hparams.tts_decoder_dims,
n_mels=hparams.num_mels,
fft_bins=hparams.num_mels,
postnet_dims=hparams.tts_postnet_dims,
encoder_K=hparams.tts_encoder_K,
lstm_dims=hparams.tts_lstm_dims,
postnet_K=hparams.tts_postnet_K,
num_highways=hparams.tts_num_highways,
dropout=hparams.tts_dropout,
stop_threshold=hparams.tts_stop_threshold,
speaker_embedding_size=hparams.speaker_embedding_size).to(device)
# Initialize the optimizer
optimizer = optim.Adam(model.parameters())
# train_loss_file_path = "synthesizer_loss/synthesizer_train_loss.npy"
# dev_loss_file_path = "synthesizer_loss/synthesizer_dev_loss.npy"
# if not exists("synthesizer_loss"):
# import os
# os.mkdir("synthesizer_loss")
# Load the weights
if force_restart or not weights_fpath.exists():
print("\nStarting the training of Tacotron from scratch\n")
model.save(weights_fpath)
# Embeddings metadata
char_embedding_fpath = meta_folder.joinpath("CharacterEmbeddings.tsv")
with open(char_embedding_fpath, "w", encoding="utf-8") as f:
for symbol in symbols:
if symbol == " ":
symbol = "\\s" # For visual purposes, swap space with \s
f.write("{}\n".format(symbol))
# losses = []
# dev_losses = []
else:
print("\nLoading weights at %s" % weights_fpath)
model.load(weights_fpath, optimizer)
print("Tacotron weights loaded from step %d" % model.step)
# losses = list(np.load(train_loss_file_path)) if exists(train_loss_file_path) else []
# dev_losses = list(np.load(dev_loss_file_path)) if exists(dev_loss_file_path) else []
# Initialize the dataset
train_mel_dir = syn_dir.joinpath("train/mels")
train_embed_dir = syn_dir.joinpath("train/embeds")
dev_mel_dir = syn_dir.joinpath("dev/mels")
dev_embed_dir = syn_dir.joinpath("dev/embeds")
train_dataset = SynthesizerDataset(train_metadata_fpath, train_mel_dir, train_embed_dir, hparams)
dev_dataset = SynthesizerDataset(dev_metadata_fpath, dev_mel_dir, dev_embed_dir, hparams)
best_loss_file_path = "synthesizer_loss/best_loss.npy"
best_loss = np.load(best_loss_file_path)[0] if exists(best_loss_file_path) else 1000
if not exists("synthesizer_loss"):
os.makedirs("synthesizer_loss")
# profiler = Profiler(summarize_every=10, disabled=False)
for i, session in enumerate(hparams.tts_schedule):
current_step = model.get_step()
r, lr, max_step, batch_size = session
training_steps = max_step - current_step
# Do we need to change to the next session?
if current_step >= max_step:
# Are there no further sessions than the current one?
if i == len(hparams.tts_schedule) - 1:
# We have completed training. Save the model and exit
model.save(weights_fpath, optimizer)
break
else:
# There is a following session, go to it
continue
model.r = r
# Begin the training
simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
("Batch Size", batch_size),
("Learning Rate", lr),
("Outputs/Step (r)", model.r)])
for p in optimizer.param_groups:
p["lr"] = lr
collate_fn = partial(collate_synthesizer, r=r, hparams=hparams)
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn, pin_memory=True)
total_iters = len(train_dataset)
steps_per_epoch = np.ceil(total_iters / batch_size).astype(np.int32)
epochs = np.ceil(training_steps / steps_per_epoch).astype(np.int32)
for epoch in range(1, epochs+1):
for i, (texts, mels, embeds, idx) in enumerate(train_dataloader, 1):
start_time = time.time()
# profiler.tick("Blocking, waiting for batch (threaded)")
# Generate stop tokens for training
stop = torch.ones(mels.shape[0], mels.shape[2])
for j, k in enumerate(idx):
stop[j, :int(train_dataset.metadata[k][4])-1] = 0
texts = texts.to(device)
mels = mels.to(device)
embeds = embeds.to(device)
stop = stop.to(device)
# sync(device)
# profiler.tick("Data to %s" % device)
# Forward pass
# Parallelize model onto GPUS using workaround due to python bug
# if device.type == "cuda" and torch.cuda.device_count() > 1:
# m1_hat, m2_hat, attention, stop_pred = data_parallel_workaround(model, texts, mels, embeds)
# else:
m1_hat, m2_hat, attention, stop_pred = model(texts, mels, embeds)
# sync(device)
# profiler.tick("Forward pass")
# Backward pass
m1_loss = F.mse_loss(m1_hat, mels) + F.l1_loss(m1_hat, mels)
m2_loss = F.mse_loss(m2_hat, mels)
stop_loss = F.binary_cross_entropy(stop_pred, stop)
loss = m1_loss + m2_loss + stop_loss
# sync(device)
# profiler.tick("Loss")
optimizer.zero_grad()
loss.backward()
# profiler.tick("Backward pass")
if hparams.tts_clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.tts_clip_grad_norm)
if np.isnan(grad_norm.cpu()):
print("grad_norm was NaN!")
optimizer.step()
# profiler.tick("Parameter update")
time_window.append(time.time() - start_time)
loss_window.append(loss.item())
step = model.get_step()
k = step // 1000
msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Train Loss: {loss_window.average:#.4} | " \
f"{1./time_window.average:#.2} steps/s | Step: {k}k | "
stream(msg)
if use_tb:
with train_summary_writer.as_default():
tf.summary.scalar('train_loss', loss_window.average, step=step)
tf.summary.scalar('learning_rate', lr, step=step)
# Backup or save model as appropriate
# if backup_every != 0 and step % backup_every == 0 :
# backup_fpath = weights_fpath.parent / f"synthesizer_{k:06d}.pt"
# model.save(backup_fpath, optimizer)
torch.cuda.empty_cache()
if save_every != 0 and i % save_every == 0:
dev_loss = validate(dev_dataset, model, collate_fn)
msg = f"\n| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Train Loss: {loss_window.average:#.4} | " \
f"Dev Loss: {dev_loss:#.4} | {1./time_window.average:#.2} steps/s | Step: {k}k | "
print(msg)
if use_tb:
with train_summary_writer.as_default():
tf.summary.scalar('val_loss', dev_loss, step=step)
# losses.append(loss_window.average)
# np.save(train_loss_file_path, np.array(losses, dtype=float))
# dev_losses.append(dev_loss)
# np.save(dev_loss_file_path, np.array(dev_losses, dtype=float))
# Must save latest optimizer state to ensure that resuming training
# doesn't produce artifacts
if dev_loss < best_loss:
best_loss = dev_loss
np.save(best_loss_file_path, np.array([best_loss]))
model.save(weights_fpath, optimizer)
# Evaluate model to generate dev samples
# epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done
# step_eval = hparams.tts_eval_interval > 0 and i % hparams.tts_eval_interval == 0 # Every N steps
# if step_eval:
# generate train samples
# for sample_idx in range(hparams.tts_eval_num_samples):
# # At most, generate samples equal to number in the batch
# if sample_idx + 1 <= len(texts):
# # Remove padding from mels using frame length in metadata
# mel_length = int(train_dataset.metadata[idx[sample_idx]][4])
# mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length]
# target_spectrogram = np_now(mels[sample_idx]).T[:mel_length]
# attention_len = mel_length // model.r
# eval_model(attention=np_now(attention[sample_idx][:, :attention_len]),
# mel_prediction=mel_prediction,
# target_spectrogram=target_spectrogram,
# input_seq=np_now(texts[sample_idx]),
# step=step,
# plot_dir=plot_dir,
# mel_output_dir=mel_output_dir,
# wav_dir=wav_dir,
# sample_num=sample_idx + 1,
# loss=loss,
# hparams=hparams,
# if_dev="train")
# generate dev samples
# for sample_idx in range(hparams.tts_eval_num_samples):
# # At most, generate samples equal to number in the batch
# if sample_idx + 1 <= len(dev_input_texts):
# # Remove padding from mels using frame length in metadata
# mel_length = int(dev_dataset.metadata[dev_idx[sample_idx]][4])
# dev_mel_prediction = np_now(dev_m2_hat[sample_idx]).T[:mel_length]
# target_spectrogram = np_now(dev_target_mels[sample_idx]).T[:mel_length]
# attention_len = mel_length // model.r
# eval_model(attention=np_now(dev_attention[sample_idx][:, :attention_len]),
# mel_prediction=dev_mel_prediction,
# target_spectrogram=target_spectrogram,
# input_seq=np_now(dev_input_texts[sample_idx]),
# step=step,
# plot_dir=plot_dir,
# mel_output_dir=mel_output_dir,
# wav_dir=wav_dir,
# sample_num=sample_idx + 1,
# loss=dev_loss,
# hparams=hparams,
# if_dev="dev")
# Break out of loop to update training schedule
if step >= max_step:
break
# Add line break after every epoch
print("")
def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step,
plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams, if_dev = None):
# Save some results for evaluation
attention_path = str(plot_dir.joinpath("{}_attention_step_{}_sample_{}".format(if_dev, step, sample_num)))
save_attention_multiple(attention, attention_path)
# save predicted mel spectrogram to disk (debug)
mel_output_fpath = mel_output_dir.joinpath("{}-mel-prediction-step-{}_sample_{}.npy".format(if_dev, step, sample_num))
np.save(str(mel_output_fpath), mel_prediction, allow_pickle=False)
# save griffin lim inverted wav for debug (mel -> wav)
wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
wav_fpath = wav_dir.joinpath("{}-step-{}-wave-from-mel_sample_{}.wav".format(if_dev, step, sample_num))
audio.save_wav(wav, str(wav_fpath), sr=hparams.sample_rate)
# save real and predicted mel-spectrogram plot to disk (control purposes)
spec_fpath = plot_dir.joinpath("{}-step-{}-mel-spectrogram_sample_{}.png".format(if_dev, step, sample_num))
title_str = "{}, {}, step={}, {} loss={:.5f}".format("Tacotron", time_string(), step, if_dev, loss)
plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str,
target_spectrogram=target_spectrogram,
max_len=target_spectrogram.size // hparams.num_mels)
print("Input at step {}: {}".format(step, sequence_to_text(input_seq)))
def validate(dataset, model, collate_fn):
model.eval()
with torch.no_grad():
losses = []
dataloader = DataLoader(dataset, 32, num_workers=4, shuffle=False, collate_fn=collate_fn)
for i, (texts, mels, embeds, idx) in enumerate(dataloader, 1):
# Generate stop tokens for training
stop = torch.ones(mels.shape[0], mels.shape[2])
for j, k in enumerate(idx):
stop[j, :int(dataset.metadata[k][4])-1] = 0
texts = texts.cuda()
mels = mels.cuda()
embeds = embeds.cuda()
stop = stop.cuda()
# Forward pass
# Parallelize model onto GPUS using workaround due to python bug
# if device.type == "cuda" and torch.cuda.device_count() > 1:
# m1_hat, m2_hat, attention, stop_pred = data_parallel_workaround(model, texts, mels, embeds)
# else:
m1_hat, m2_hat, attention, stop_pred = model(texts, mels, embeds)
# Backward pass
m1_loss = F.mse_loss(m1_hat, mels) + F.l1_loss(m1_hat, mels)
m2_loss = F.mse_loss(m2_hat, mels)
stop_loss = F.binary_cross_entropy(stop_pred, stop)
loss = m1_loss + m2_loss + stop_loss
losses.append(loss.item())
model.train()
torch.cuda.empty_cache()
return sum(losses) / len(losses)