from torch.utils.data import Dataset from torch.utils.data import DataLoader import numpy as np import torch import librosa class AudioDataset(Dataset): def __init__( self, filepaths, labels, skip_times=None, num_classes=1, normalize="std", max_len=32000, random_sampling=True, train=False, **kwargs ): super().__init__(**kwargs) self.filepaths = filepaths self.labels = labels self.skip_times = skip_times self.num_classes = num_classes self.random_sampling = random_sampling self.normalize = normalize self.max_len = max_len self.train = train if not self.train: assert ( not self.random_sampling ), "Ensure random_sampling is disabled for val" def __len__(self): return len(self.filepaths) def crop_or_pad(self, audio, max_len, random_sampling=True): audio_len = audio.shape[0] if random_sampling: diff_len = abs(max_len - audio_len) if audio_len < max_len: pad1 = np.random.randint(0, diff_len) pad2 = diff_len - pad1 audio = np.pad(audio, (pad1, pad2), mode="constant") elif audio_len > max_len: idx = np.random.randint(0, diff_len) audio = audio[idx : (idx + max_len)] else: if audio_len < max_len: audio = np.pad(audio, (0, max_len - audio_len), mode="constant") elif audio_len > max_len: # Crop from the beginning # audio = audio[:max_len] # Crop from 3/4 of the audio # eq: l = (3x + t + x) => idx = 3x = (l - t) / 4 * 3 idx = int((audio_len - max_len) / 4 * 3) audio = audio[idx : (idx + max_len)] return audio def __getitem__(self, idx): # Load audio audio, sr = librosa.load(self.filepaths[idx], sr=None) target = np.array([self.labels[idx]]) # Trim start of audio (torchaudio.transforms.vad) if self.skip_times is not None: skip_time = self.skip_times[idx] audio = audio[int(skip_time*sr):] # Ensure fixed length audio = self.crop_or_pad(audio, self.max_len, self.random_sampling) if self.normalize == "std": audio /= np.maximum(np.std(audio), 1e-6) elif self.normalize == "minmax": audio -= np.min(audio) audio /= np.maximum(np.max(audio), 1e-6) audio = torch.from_numpy(audio).float() target = torch.from_numpy(target).float().squeeze() return { "audio": audio, "target": target, } def get_dataloader( filepaths, labels, skip_times=None, batch_size=8, num_classes=1, max_len=32000, random_sampling=True, normalize="std", train=False, # drop_last=False, pin_memory=True, worker_init_fn=None, collate_fn=None, num_workers=0, distributed=False, ): dataset = AudioDataset( filepaths, labels, skip_times=skip_times, num_classes=num_classes, max_len=max_len, random_sampling=random_sampling, normalize=normalize, train=train, ) if distributed: # drop_last is set to True to validate properly # Ref: https://discuss.pytorch.org/t/how-do-i-validate-with-pytorch-distributeddataparallel/172269/8 sampler = torch.utils.data.distributed.DistributedSampler( dataset, shuffle=train, drop_last=not train ) else: sampler = None dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=(sampler is None) and train, # drop_last=drop_last, num_workers=num_workers, pin_memory=pin_memory, worker_init_fn=worker_init_fn, collate_fn=collate_fn, sampler=sampler, ) return dataloader