tts / TTS /vocoder /datasets /wavernn_dataset.py
tobiccino's picture
update ui tacotron
8c70653
raw
history blame
4.43 kB
import numpy as np
import torch
from torch.utils.data import Dataset
class WaveRNNDataset(Dataset):
"""
WaveRNN Dataset searchs for all the wav files under root path
and converts them to acoustic features on the fly.
"""
def __init__(
self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, verbose=False, return_segments=True
):
super().__init__()
self.ap = ap
self.compute_feat = not isinstance(items[0], (tuple, list))
self.item_list = items
self.seq_len = seq_len
self.hop_len = hop_len
self.mel_len = seq_len // hop_len
self.pad = pad
self.mode = mode
self.mulaw = mulaw
self.is_training = is_training
self.verbose = verbose
self.return_segments = return_segments
assert self.seq_len % self.hop_len == 0
def __len__(self):
return len(self.item_list)
def __getitem__(self, index):
item = self.load_item(index)
return item
def load_test_samples(self, num_samples):
samples = []
return_segments = self.return_segments
self.return_segments = False
for idx in range(num_samples):
mel, audio, _ = self.load_item(idx)
samples.append([mel, audio])
self.return_segments = return_segments
return samples
def load_item(self, index):
"""
load (audio, feat) couple if feature_path is set
else compute it on the fly
"""
if self.compute_feat:
wavpath = self.item_list[index]
audio = self.ap.load_wav(wavpath)
if self.return_segments:
min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len)
else:
min_audio_len = audio.shape[0] + (2 * self.pad * self.hop_len)
if audio.shape[0] < min_audio_len:
print(" [!] Instance is too short! : {}".format(wavpath))
audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len])
mel = self.ap.melspectrogram(audio)
if self.mode in ["gauss", "mold"]:
x_input = audio
elif isinstance(self.mode, int):
x_input = (
self.ap.mulaw_encode(audio, qc=self.mode) if self.mulaw else self.ap.quantize(audio, bits=self.mode)
)
else:
raise RuntimeError("Unknown dataset mode - ", self.mode)
else:
wavpath, feat_path = self.item_list[index]
mel = np.load(feat_path.replace("/quant/", "/mel/"))
if mel.shape[-1] < self.mel_len + 2 * self.pad:
print(" [!] Instance is too short! : {}".format(wavpath))
self.item_list[index] = self.item_list[index + 1]
feat_path = self.item_list[index]
mel = np.load(feat_path.replace("/quant/", "/mel/"))
if self.mode in ["gauss", "mold"]:
x_input = self.ap.load_wav(wavpath)
elif isinstance(self.mode, int):
x_input = np.load(feat_path.replace("/mel/", "/quant/"))
else:
raise RuntimeError("Unknown dataset mode - ", self.mode)
return mel, x_input, wavpath
def collate(self, batch):
mel_win = self.seq_len // self.hop_len + 2 * self.pad
max_offsets = [x[0].shape[-1] - (mel_win + 2 * self.pad) for x in batch]
mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
sig_offsets = [(offset + self.pad) * self.hop_len for offset in mel_offsets]
mels = [x[0][:, mel_offsets[i] : mel_offsets[i] + mel_win] for i, x in enumerate(batch)]
coarse = [x[1][sig_offsets[i] : sig_offsets[i] + self.seq_len + 1] for i, x in enumerate(batch)]
mels = np.stack(mels).astype(np.float32)
if self.mode in ["gauss", "mold"]:
coarse = np.stack(coarse).astype(np.float32)
coarse = torch.FloatTensor(coarse)
x_input = coarse[:, : self.seq_len]
elif isinstance(self.mode, int):
coarse = np.stack(coarse).astype(np.int64)
coarse = torch.LongTensor(coarse)
x_input = 2 * coarse[:, : self.seq_len].float() / (2**self.mode - 1.0) - 1.0
y_coarse = coarse[:, 1:]
mels = torch.FloatTensor(mels)
return x_input, mels, y_coarse