ProDiff / tasks /tts /tts.py
Rongjiehuang's picture
init
64e7f2f
raw
history blame
4.71 kB
from multiprocessing.pool import Pool
import matplotlib
from utils.pl_utils import data_loader
from utils.training_utils import RSQRTSchedule
from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder
from modules.fastspeech.pe import PitchExtractor
matplotlib.use('Agg')
import os
import numpy as np
from tqdm import tqdm
import torch.distributed as dist
from tasks.base_task import BaseTask
from utils.hparams import hparams
from utils.text_encoder import TokenTextEncoder
import json
import torch
import torch.optim
import torch.utils.data
import utils
class TtsTask(BaseTask):
def __init__(self, *args, **kwargs):
self.vocoder = None
self.phone_encoder = self.build_phone_encoder(hparams['binary_data_dir'])
self.padding_idx = self.phone_encoder.pad()
self.eos_idx = self.phone_encoder.eos()
self.seg_idx = self.phone_encoder.seg()
self.saving_result_pool = None
self.saving_results_futures = None
self.stats = {}
super().__init__(*args, **kwargs)
def build_scheduler(self, optimizer):
return RSQRTSchedule(optimizer)
def build_optimizer(self, model):
self.optimizer = optimizer = torch.optim.AdamW(
model.parameters(),
lr=hparams['lr'])
return optimizer
def build_dataloader(self, dataset, shuffle, max_tokens=None, max_sentences=None,
required_batch_size_multiple=-1, endless=False, batch_by_size=True):
devices_cnt = torch.cuda.device_count()
if devices_cnt == 0:
devices_cnt = 1
if required_batch_size_multiple == -1:
required_batch_size_multiple = devices_cnt
def shuffle_batches(batches):
np.random.shuffle(batches)
return batches
if max_tokens is not None:
max_tokens *= devices_cnt
if max_sentences is not None:
max_sentences *= devices_cnt
indices = dataset.ordered_indices()
if batch_by_size:
batch_sampler = utils.batch_by_size(
indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
else:
batch_sampler = []
for i in range(0, len(indices), max_sentences):
batch_sampler.append(indices[i:i + max_sentences])
if shuffle:
batches = shuffle_batches(list(batch_sampler))
if endless:
batches = [b for _ in range(1000) for b in shuffle_batches(list(batch_sampler))]
else:
batches = batch_sampler
if endless:
batches = [b for _ in range(1000) for b in batches]
num_workers = dataset.num_workers
if self.trainer.use_ddp:
num_replicas = dist.get_world_size()
rank = dist.get_rank()
batches = [x[rank::num_replicas] for x in batches if len(x) % num_replicas == 0]
return torch.utils.data.DataLoader(dataset,
collate_fn=dataset.collater,
batch_sampler=batches,
num_workers=num_workers,
pin_memory=False)
def build_phone_encoder(self, data_dir):
phone_list_file = os.path.join(data_dir, 'phone_set.json')
phone_list = json.load(open(phone_list_file))
return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',')
def build_optimizer(self, model):
self.optimizer = optimizer = torch.optim.AdamW(
model.parameters(),
lr=hparams['lr'])
return optimizer
def test_start(self):
self.saving_result_pool = Pool(8)
self.saving_results_futures = []
self.vocoder: BaseVocoder = get_vocoder_cls(hparams)()
if hparams.get('pe_enable') is not None and hparams['pe_enable']:
self.pe = PitchExtractor().cuda()
utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True)
self.pe.eval()
def test_end(self, outputs):
self.saving_result_pool.close()
[f.get() for f in tqdm(self.saving_results_futures)]
self.saving_result_pool.join()
return {}
##########
# utils
##########
def weights_nonzero_speech(self, target):
# target : B x T x mel
# Assign weight 1.0 to all labels except for padding (id=0).
dim = target.size(-1)
return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim)
if __name__ == '__main__':
TtsTask.start()