import numpy as np import pandas as pd import torch from utmosv2.dataset._utils import ( extend_audio, get_dataset_map, load_audio, select_random_start, ) class SSLDataset(torch.utils.data.Dataset): def __init__(self, cfg, data: pd.DataFrame, phase: str): self.cfg = cfg self.data = data self.phase = phase def __len__(self): return len(self.data) def __getitem__(self, idx): row = self.data.iloc[idx] file = row["file_path"] y = load_audio(self.cfg, file) length = int(self.cfg.dataset.ssl.duration * self.cfg.sr) y = extend_audio(self.cfg, y, length, type="tile") y = select_random_start(y, length) target = row["mos"] target = torch.tensor(target, dtype=torch.float32) return y, target class SSLExtDataset(SSLDataset): def __init__(self, cfg, data: pd.DataFrame, phase: str): super().__init__(cfg, data, phase) self.dataset_map = get_dataset_map(cfg) def __getitem__(self, idx): y, target = super().__getitem__(idx) d = np.zeros(len(self.dataset_map)) d[self.dataset_map[self.data.iloc[idx]["dataset"]]] = 1 d = torch.tensor(d, dtype=torch.float32) return y, d, target