Spaces:
Runtime error
Runtime error
from encoder.data_objects.random_cycler import RandomCycler | |
from encoder.data_objects.speaker_batch import SpeakerBatch | |
from encoder.data_objects.utterance_batch import UtteranceBatch | |
from encoder.data_objects.speaker import Speaker | |
from encoder.params_data import partials_n_frames | |
from torch.utils.data import Dataset, DataLoader | |
from pathlib import Path | |
from os import listdir | |
from os.path import isfile | |
import numpy as np | |
# TODO: improve with a pool of speakers for data efficiency | |
class Train_Dataset(Dataset): | |
def __init__(self, datasets_root: Path): | |
self.root = datasets_root | |
speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()] | |
if len(speaker_dirs) == 0: | |
raise Exception("No speakers found. Make sure you are pointing to the directory " | |
"containing all preprocessed speaker directories.") | |
self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs] | |
self.speaker_cycler = RandomCycler(self.speakers) | |
def __len__(self): | |
return int(1e8) | |
def __getitem__(self, index): | |
return next(self.speaker_cycler) | |
def get_logs(self): | |
log_string = "" | |
for log_fpath in self.root.glob("*.txt"): | |
with log_fpath.open("r") as log_file: | |
log_string += "".join(log_file.readlines()) | |
return log_string | |
class Dev_Dataset(Dataset): | |
def __init__(self, datasets_root: Path): | |
self.root = datasets_root | |
speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()] | |
if len(speaker_dirs) == 0: | |
raise Exception("No speakers found. Make sure you are pointing to the directory " | |
"containing all preprocessed speaker directories.") | |
self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs] | |
self.speaker_cycler = RandomCycler(self.speakers) | |
def __len__(self): | |
return len(self.speakers) | |
def __getitem__(self, index): | |
return next(self.speaker_cycler) | |
class DataLoader(DataLoader): | |
def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, shuffle, sampler=None, | |
batch_sampler=None, num_workers=0, pin_memory=False, timeout=0, | |
worker_init_fn=None): | |
self.utterances_per_speaker = utterances_per_speaker | |
super().__init__( | |
dataset=dataset, | |
batch_size=speakers_per_batch, | |
shuffle=shuffle, | |
sampler=sampler, | |
batch_sampler=batch_sampler, | |
num_workers=num_workers, | |
collate_fn=self.collate, | |
pin_memory=pin_memory, | |
drop_last=False, | |
timeout=timeout, | |
worker_init_fn=worker_init_fn | |
) | |
def collate(self, speakers): | |
return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames) | |