import os import h5py import numpy as np from typing import Any, Tuple import torch import random from pytorch_lightning import LightningDataModule import torchaudio from torchaudio.functional import apply_codec from torch.utils.data import DataLoader, Dataset from typing import Any, Dict, Optional, Tuple def compute_mch_rms_dB(mch_wav, fs=16000, energy_thresh=-50): """Return the wav RMS calculated only in the active portions""" mean_square = max(1e-20, torch.mean(mch_wav ** 2)) return 10 * np.log10(mean_square) def match2(x, d): assert x.dim()==2, x.shape assert d.dim()==2, d.shape minlen = min(x.shape[-1], d.shape[-1]) x, d = x[:,0:minlen], d[:,0:minlen] Fx = torch.fft.rfft(x, dim=-1) Fd = torch.fft.rfft(d, dim=-1) Phi = Fd*Fx.conj() Phi = Phi / (Phi.abs() + 1e-3) Phi[:,0] = 0 tmp = torch.fft.irfft(Phi, dim=-1) tau = torch.argmax(tmp.abs(),dim=-1).tolist() return tau def codec_simu(wav, sr=16000, options={'bitrate':'random','compression':'random', 'complexity':'random', 'vbr':'random'}): if options['bitrate'] == 'random': options['bitrate'] = random.choice([24000, 32000, 48000, 64000, 96000, 128000]) compression = int(options['bitrate']//1000) param = {'format': "mp3", "compression": compression} wav_encdec = apply_codec(wav, sr, **param) if wav_encdec.shape[-1] >= wav.shape[-1]: wav_encdec = wav_encdec[...,:wav.shape[-1]] else: wav_encdec = torch.cat([wav_encdec, wav[..., wav_encdec.shape[-1]:]], -1) tau = match2(wav, wav_encdec) wav_encdec = torch.roll(wav_encdec, -tau[0], -1) return wav_encdec def get_wav_files(root_dir): wav_files = [] for dirpath, dirnames, filenames in os.walk(root_dir): for filename in filenames: if filename.endswith('.wav'): if "musdb18hq" in dirpath and "mixture" not in filename: wav_files.append(os.path.join(dirpath, filename)) elif "moisesdb" in dirpath: wav_files.append(os.path.join(dirpath, filename)) return wav_files class MusdbMoisesdbDataset(Dataset): def __init__( self, data_dir: str, codec_type: str, codec_options: dict, sr: int = 16000, segments: int = 10, num_stems: int = 4, snr_range: Tuple[int, int] = (-10, 10), num_samples: int = 1000, ) -> None: self.data_dir = data_dir self.codec_type = codec_type self.codec_options = codec_options self.segments = int(segments * sr) self.sr = sr self.num_stems = num_stems self.snr_range = snr_range self.num_samples = num_samples self.instruments = [ "bass", "bowed_strings", "drums", "guitar", "other", "other_keys", "other_plucked", "percussion", "piano", "vocals", "wind" ] def __len__(self) -> int: return self.num_samples def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: if random.random() > 0.5: select_stems = random.randint(1, self.num_stems) select_stems = random.choices(self.instruments, k=select_stems) ori_wav = [] for stem in select_stems: h5path = random.choice(os.listdir(os.path.join(self.data_dir, stem))) datas = h5py.File(os.path.join(self.data_dir, stem, h5path), 'r')['data'] random_index = random.randint(0, datas.shape[0]-1) music_wav = torch.FloatTensor(datas[random_index]) start = random.randint(0, music_wav.shape[-1] - self.segments) music_wav = music_wav[:, start:start+self.segments] rescale_snr = random.randint(self.snr_range[0], self.snr_range[1]) music_wav = music_wav * np.sqrt(10**(rescale_snr/10)) ori_wav.append(music_wav) ori_wav = torch.stack(ori_wav).sum(0) else: h5path = random.choice(os.listdir(os.path.join(self.data_dir, "mixture"))) datas = h5py.File(os.path.join(self.data_dir, "mixture", h5path), 'r')['data'] random_index = random.randint(0, datas.shape[0]-1) music_wav = torch.FloatTensor(datas[random_index]) start = random.randint(0, music_wav.shape[-1] - self.segments) ori_wav = music_wav[:, start:start+self.segments] codec_wav = codec_simu(ori_wav, sr=self.sr, options=self.codec_options) max_scale = max(ori_wav.abs().max(), codec_wav.abs().max()) if max_scale > 0: ori_wav = ori_wav / max_scale codec_wav = codec_wav / max_scale return ori_wav, codec_wav class MusdbMoisesdbEval(Dataset): def __init__( self, data_dir: str ) -> None: self.data_path = os.listdir(data_dir) self.data_path = [os.path.join(data_dir, i) for i in self.data_path] def __len__(self) -> int: return len(self.data_path) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: ori_wav = torchaudio.load(self.data_path[idx]+"/ori_wav.wav")[0] codec_wav = torchaudio.load(self.data_path[idx]+"/codec_wav.wav")[0] return ori_wav, codec_wav, self.data_path[idx] class MusdbMoisesdbDataModule(LightningDataModule): def __init__( self, train_dir: str, eval_dir: str, codec_type: str, codec_options: dict, sr: int = 16000, segments: int = 10, num_stems: int = 4, snr_range: Tuple[int, int] = (-10, 10), num_samples: int = 1000, batch_size: int = 32, num_workers: int = 4, ) -> None: super().__init__() self.save_hyperparameters(logger=False) self.data_train: Optional[Dataset] = None self.data_val: Optional[Dataset] = None def setup(self, stage: Optional[str] = None) -> None: """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to `self.setup()` once the data is prepared and available for use. :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. """ # load and split datasets only if not loaded already if not self.data_train and not self.data_val: self.data_train = MusdbMoisesdbDataset( data_dir=self.hparams.train_dir, codec_type=self.hparams.codec_type, codec_options=self.hparams.codec_options, sr=self.hparams.sr, segments=self.hparams.segments, num_stems=self.hparams.num_stems, snr_range=self.hparams.snr_range, num_samples=self.hparams.num_samples, ) self.data_val = MusdbMoisesdbEval( data_dir=self.hparams.eval_dir ) def train_dataloader(self) -> DataLoader: return DataLoader( self.data_train, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, shuffle=True, pin_memory=True, ) def val_dataloader(self) -> DataLoader: return DataLoader( self.data_val, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, shuffle=False, pin_memory=True, )