File size: 17,079 Bytes
a03c9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
"""slakh Dataset Loader

.. admonition:: Dataset Info
    :class: dropdown

    • This code is modified to use the Slakh2100 dataset converted into 16k.
    • Unlike slakh, this version treats drum tracks as pitched instruments (80 notes appears).
      See Line 243, 356. 

    The Synthesized Lakh (Slakh) Dataset is a dataset of multi-track audio and aligned 
    MIDI for music source separation and multi-instrument automatic transcription. 
    Individual MIDI tracks are synthesized from the Lakh MIDI Dataset v0.1 using 
    professional-grade sample-based virtual instruments, and the resulting audio is 
    mixed together to make musical mixtures. 
    
    The original release of Slakh, called Slakh2100, 
    contains 2100 automatically mixed tracks and accompanying, aligned MIDI files, 
    synthesized from 187 instrument patches categorized into 34 classes, totaling 
    145 hours of mixture data.

    This loader supports two versions of Slakh:
    - Slakh2100-redux: a deduplicated version of slakh2100 containing 1710 multitracks
    - baby-slakh: a mini version with 16k wav audio and only the first 20 tracks

    This dataset was created at Mitsubishi Electric Research Labl (MERL) and
    Interactive Audio Lab at Northwestern University by Ethan Manilow, 
    Gordon Wichern, Prem Seetharaman, and Jonathan Le Roux.

    For more information see http://www.slakh.com/

"""
import os
from typing import BinaryIO, Optional, Tuple

from deprecated.sphinx import deprecated
import librosa
import numpy as np
import pretty_midi
from smart_open import open
import yaml

from mirdata import io, download_utils, jams_utils, core, annotations

BIBTEX = """
@inproceedings{manilow2019cutting,
  title={Cutting Music Source Separation Some {Slakh}: A Dataset to Study the Impact of Training Data Quality and Quantity},
  author={Manilow, Ethan and Wichern, Gordon and Seetharaman, Prem and Le Roux, Jonathan},
  booktitle={Proc. IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (WASPAA)},
  year={2019},
  organization={IEEE}
}
"""

INDEXES = {
    "default":
        "2100-yourmt3-16k",
    "test":
        "baby",
    "2100-yourmt3-16k":
        core.Index(
            filename="slakh_index_2100-yourmt3-16k.json",
            url="https://zenodo.org/record/7717249/files/slakh_index_2100-yourmt3-16k.json?download=1",
            checksum="fab898bd82827ddc4c3e4dbd7b7fcbd9",
            partial_download=["2100-yourmt3-16k"]),
    "2100-redux":
        core.Index(filename="slakh_index_2100-redux.json", partial_download=["2100-redux"]),
    "baby":
        core.Index(filename="slakh_index_baby.json", partial_download=["baby"]),
}

REMOTES = {
    "2100-yourmt3-16k":
        download_utils.RemoteFileMetadata(
            filename="slakh2100_yourmt3_16k.tar.gz",
            url="https://zenodo.org/record/7717249/files/slakh2100_yourmt3_16k.tar.gz?download=1",
            checksum="c44f9bcba07b3c6ddeaf604f45dc61c5",
        ),
    "2100-redux":
        download_utils.RemoteFileMetadata(
            filename="slakh2100_flac_redux.tar.gz",
            url="https://zenodo.org/record/4599666/files/slakh2100_flac_redux.tar.gz?download=1",
            checksum="f4b71b6c45ac9b506f59788456b3f0c4",
        ),
    "baby":
        download_utils.RemoteFileMetadata(
            filename="babyslakh_16k.tar.gz",
            url="https://zenodo.org/record/4603870/files/babyslakh_16k.tar.gz?download=1",
            checksum="311096dc2bde7d61c97e930edbfc7f78",
        ),
}

LICENSE_INFO = """
Creative Commons Attribution 4.0 International
"""

SPLITS = ["train", "validation", "test", "omitted"]
SPLITS_16K = ["train", "validation", "test"]

#: Mixing group to program number mapping
MIXING_GROUPS = {
    "piano": [0, 1, 2, 3, 4, 5, 6, 7],
    "guitar": [24, 25, 26, 27, 28, 29, 30, 31],
    "bass": [32, 33, 34, 35, 36, 37, 38, 39],
    "drums": [128],
}


class Track(core.Track):
    """slakh Track class, for individual stems

    Attributes:
        audio_path (str or None): path to the track's audio file. For some unusual tracks,
            such as sound effects, there is no audio and this attribute is None.
        split (str or None): one of 'train', 'validation', 'test', or 'omitted'.
            'omitted' tracks are part of slakh2100-redux which were found to be
            duplicates in the original slakh2011.
            In baby slakh there are no splits, so this attribute is None.
        data_split (str or None): equivalent to split (deprecated in 0.3.6)
        metadata_path (str): path to the multitrack's metadata file
        midi_path (str or None): path to the track's midi file. For some unusual tracks,
            such as sound effects, there is no midi and this attribute is None.
        mtrack_id (str): the track's multitrack id
        track_id (str): track id
        instrument (str): MIDI instrument class, see link for details:
            https://en.wikipedia.org/wiki/General_MIDI#Program_change_events
        integrated_loudness (float): integrated loudness (dB) of this track
            as calculated by the ITU-R BS.1770-4 spec
        is_drum (bool): whether the "drum" flag is true for this MIDI track
        midi_program_name (str): MIDI instrument program name
        plugin_name (str): patch/plugin name that rendered the audio file
        mixing_group (str): which mixing group the track belongs to.
            One of MIXING_GROUPS.
        program_number (int): MIDI instrument program number

    Cached Properties:
        midi (PrettyMIDI): midi data used to generate the audio
        notes (NoteData or None): note representation of the midi data.
            If there are no notes in the midi file, returns None.
        multif0 (MultiF0Data or None): multif0 representaation of the midi data.
            If there are no notes in the midi file, returns None.

    """

    def __init__(self, track_id, data_home, dataset_name, index, metadata):

        super().__init__(
            track_id,
            data_home,
            dataset_name=dataset_name,
            index=index,
            metadata=metadata,
        )

        self.mtrack_id = self.track_id.split("-")[0]
        self.audio_path = self.get_path("audio")
        self.midi_path = self.get_path("midi")
        self.metadata_path = self.get_path("metadata")

        # split (train/validation/test/omitted) is part of the relative filepath in the index
        self.split = None  # for baby_slakh, there are no data splits - set to None
        # if index["version"] == "2100-redux":
        if "2100-redux" in index["version"]:
            self.split = self._track_paths["metadata"][0].split(os.sep)[1]
            assert (self.split in SPLITS), "{} not a valid split - should be one of {}.".format(
                self.split, SPLITS)
        elif "2100-yourmt3" in index["version"]:
            self.split = self._track_paths["metadata"][0].split(os.sep)[1]
            assert (self.split in SPLITS_16K), "{} not a valid split - should be one of {}.".format(
                self.split, SPLITS_16K)

        self.data_split = self.split  # deprecated in 0.3.6

    @core.cached_property
    def _track_metadata(self) -> dict:
        try:
            with open(self.metadata_path, "r") as fhandle:
                metadata = yaml.safe_load(fhandle)
        except FileNotFoundError:
            raise FileNotFoundError(
                f"track metadata for {self.track_id} not found. Did you run .download()?")
        return metadata["stems"][self.track_id.split("-")[1]]

    @property
    def instrument(self) -> Optional[str]:
        return self._track_metadata.get("inst_class")

    @property
    def integrated_loudness(self) -> Optional[float]:
        return self._track_metadata.get("integrated_loudness")

    @property
    def is_drum(self) -> Optional[bool]:
        return self._track_metadata.get("is_drum")

    @property
    def midi_program_name(self) -> Optional[str]:
        return self._track_metadata.get("midi_program_name")

    @property
    def plugin_name(self) -> Optional[str]:
        return self._track_metadata.get("plugin_name")

    @property
    def program_number(self) -> Optional[int]:
        return self._track_metadata.get("program_num")

    @property
    def mixing_group(self) -> Optional[str]:
        group = [k for k, v in MIXING_GROUPS.items() if self.program_number in v]
        if len(group) == 0:
            return None
        return group[0]

    @core.cached_property
    def midi(self) -> Optional[pretty_midi.PrettyMIDI]:
        return io.load_midi(self.midi_path)

    @core.cached_property
    def notes(self) -> Optional[annotations.NoteData]:
        return io.load_notes_from_midi(self.midi_path, self.midi, skip_drums=False)

    @core.cached_property
    def multif0(self) -> Optional[annotations.MultiF0Data]:
        return io.load_multif0_from_midi(
            self.midi_path, self.midi, skip_drums=True, pitch_bend=False)

    @property
    def audio(self) -> Optional[Tuple[np.ndarray, float]]:
        """The track's audio

        Returns:
            * np.ndarray - audio signal
            * float - sample rate

        """
        return load_audio(self.audio_path)

    def to_jams(self):
        """Jams: the track's data in jams format"""
        return jams_utils.jams_converter(
            audio_path=self.audio_path,
            note_data=[(self.notes, "Notes")],
        )


class MultiTrack(core.MultiTrack):
    """slakh multitrack class, containing information about the mix and
    the set of associated stems

    Attributes:
        mtrack_id (str): track id
        tracks (dict): {track_id: Track}
        track_audio_property (str): the name of the attribute of Track which
            returns the audio to be mixed
        mix_path (str): path to the multitrack mix audio
        midi_path (str): path to the full midi data used to generate the mixture
        metadata_path (str): path to the multitrack metadata file
        split (str or None): one of 'train', 'validation', 'test', or 'omitted'.
            'omitted' tracks are part of slakh2100-redux which were found to be
            duplicates in the original slakh2011.
        data_split (str or None): equivalent to split (deprecated in 0.3.6)
        uuid (str): File name of the original MIDI file from Lakh, sans extension
        lakh_midi_dir (str): Path to the original MIDI file from a fresh download of Lakh
        normalized (bool): whether the mix and stems were normalized according to the ITU-R BS.1770-4 spec
        overall_gain (float): gain applied to every stem to make sure mixture does not clip when stems are summed

    Cached Properties:
        midi (PrettyMIDI): midi data used to generate the mixture audio
        notes (NoteData): note representation of the midi data
        multif0 (MultiF0Data): multif0 representation of the midi data

    """

    def __init__(self, mtrack_id, data_home, dataset_name, index, track_class, metadata):
        super().__init__(
            mtrack_id=mtrack_id,
            data_home=data_home,
            dataset_name=dataset_name,
            index=index,
            track_class=track_class,
            metadata=metadata,
        )
        self.mix_path = self.get_path("mix")
        self.midi_path = self.get_path("midi")
        self.metadata_path = self.get_path("metadata")

        # split (train/validation/test) is determined by the relative filepath in the index
        self.split = None  # for baby_slakh, there are no data splits - set to None
        # if index["version"] == "2100-redux":
        if "2100-redux" in index["version"]:
            self.split = self._multitrack_paths["mix"][0].split(os.sep)[1]
            assert self.split in SPLITS, "{} not in SPLITS".format(self.split)
        elif "2100-yourmt3" in index["version"]:
            self.split = self._multitrack_paths["mix"][0].split(os.sep)[1]
            assert self.split in SPLITS_16K, "{} not in SPLITS".format(self.split)

        self.data_split = self.split  # deprecated in 0.3.6

    @property
    def track_audio_property(self) -> str:
        return "audio"

    @core.cached_property
    def _multitrack_metadata(self) -> dict:
        try:
            with open(self.metadata_path, "r") as fhandle:
                metadata = yaml.safe_load(fhandle)
        except FileNotFoundError:
            raise FileNotFoundError("Metadata not found. Did you run .download()?")
        return metadata

    @property
    def uuid(self) -> Optional[str]:
        return self._multitrack_metadata.get("UUID")

    @property
    def lakh_midi_dir(self) -> Optional[str]:
        return self._multitrack_metadata.get("lmd_midi_dir")

    @property
    def normalized(self) -> Optional[bool]:
        return self._multitrack_metadata.get("normalized")

    @property
    def overall_gain(self) -> Optional[float]:
        return self._multitrack_metadata.get("overall_gain")

    @core.cached_property
    def midi(self) -> Optional[pretty_midi.PrettyMIDI]:
        return io.load_midi(self.midi_path)

    @core.cached_property
    def notes(self) -> Optional[annotations.NoteData]:
        return io.load_notes_from_midi(self.midi_path, self.midi, skip_drums=False)

    @core.cached_property
    def multif0(self) -> Optional[annotations.MultiF0Data]:
        # TODO: setting pitch_bend to False by default, but there are some
        # patches that render pitch bend in the audio.
        return io.load_multif0_from_midi(
            self.midi_path, self.midi, skip_drums=False, pitch_bend=False)

    @property
    def audio(self) -> Optional[Tuple[np.ndarray, float]]:
        """The track's audio

        Returns:
            * np.ndarray - audio signal
            * float - sample rate

        """
        return load_audio(self.mix_path)

    def to_jams(self):
        """Jams: the track's data in jams format"""
        return jams_utils.jams_converter(
            audio_path=self.mix_path,
            note_data=[(self.notes, "Notes")],
        )

    def get_submix_by_group(self, target_groups):
        """Create submixes grouped by instrument type. Creates one submix
        per target group, plus one additional "other" group for any remaining sources.
        Only tracks with available audio are mixed.

        Args:
            target_groups (list): List of target groups. Elements should be one of
                MIXING_GROUPS, e.g. ["bass", "guitar"]

        Returns:
            * submixes (dict): {group: audio_signal} of submixes
            * groups (dict): {group: list of track ids} of submixes

        """
        groups = {}
        submixes = {}
        tracks_with_audio = [track for track in self.tracks.values() if track.audio_path]
        in_group = []
        for group in target_groups:
            groups[group] = [
                track.track_id for track in tracks_with_audio if track.mixing_group == group
            ]
            in_group.extend(groups[group])

            submixes[group] = (None if len(groups[group]) == 0 else self.get_target(groups[group]))

        groups["other"] = [
            track.track_id for track in tracks_with_audio if track.track_id not in in_group
        ]
        submixes["other"] = (None
                             if len(groups["other"]) == 0 else self.get_target(groups["other"]))
        return submixes, groups


@io.coerce_to_bytes_io
def load_audio(fhandle: BinaryIO) -> Tuple[np.ndarray, float]:
    """Load a slakh audio file.

    Args:
        fhandle (str or file-like): path or file-like object pointing to an audio file

    Returns:
        * np.ndarray - the audio signal
        * float - The sample rate of the audio file

    """
    return librosa.load(fhandle, sr=None, mono=False)


@core.docstring_inherit(core.Dataset)
class Dataset(core.Dataset):
    """
    The slakh dataset
    """

    def __init__(self, data_home=None, version="default"):
        super().__init__(
            data_home,
            version,
            name="slakh",
            track_class=Track,
            multitrack_class=MultiTrack,
            bibtex=BIBTEX,
            indexes=INDEXES,
            remotes=REMOTES,
            license_info=LICENSE_INFO,
        )

    @deprecated(
        reason="Use mirdata.datasets.slakh.load_audio",
        version="0.3.4",
    )
    def load_audio(self, *args, **kwargs):
        return load_audio(*args, **kwargs)

    @deprecated(
        reason="Use mirdata.datasets.slakh.load_midi",
        version="0.3.4",
    )
    def load_midi(self, *args, **kwargs):
        return io.load_midi(*args, **kwargs)

    @deprecated(
        reason="Use mirdata.io.load_notes_from_midi",
        version="0.3.4",
    )
    def load_notes_from_midi(self, *args, **kwargs):
        return io.load_notes_from_midi(*args, **kwargs)

    @deprecated(
        reason="Use mirdata.io.load_multif0_from_midi",
        version="0.3.4",
    )
    def load_multif0_from_midi(self, *args, **kwargs):
        return io.load_multif0_from_midi(*args, **kwargs)