File size: 5,022 Bytes
8c70653
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import random

import torch
from torch.utils.data import Dataset

from TTS.encoder.utils.generic_utils import AugmentWAV


class EncoderDataset(Dataset):
    def __init__(
        self,
        config,
        ap,
        meta_data,
        voice_len=1.6,
        num_classes_in_batch=64,
        num_utter_per_class=10,
        verbose=False,
        augmentation_config=None,
        use_torch_spec=None,
    ):
        """
        Args:
            ap (TTS.tts.utils.AudioProcessor): audio processor object.
            meta_data (list): list of dataset instances.
            seq_len (int): voice segment length in seconds.
            verbose (bool): print diagnostic information.
        """
        super().__init__()
        self.config = config
        self.items = meta_data
        self.sample_rate = ap.sample_rate
        self.seq_len = int(voice_len * self.sample_rate)
        self.num_utter_per_class = num_utter_per_class
        self.ap = ap
        self.verbose = verbose
        self.use_torch_spec = use_torch_spec
        self.classes, self.items = self.__parse_items()

        self.classname_to_classid = {key: i for i, key in enumerate(self.classes)}

        # Data Augmentation
        self.augmentator = None
        self.gaussian_augmentation_config = None
        if augmentation_config:
            self.data_augmentation_p = augmentation_config["p"]
            if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config):
                self.augmentator = AugmentWAV(ap, augmentation_config)

            if "gaussian" in augmentation_config.keys():
                self.gaussian_augmentation_config = augmentation_config["gaussian"]

        if self.verbose:
            print("\n > DataLoader initialization")
            print(f" | > Classes per Batch: {num_classes_in_batch}")
            print(f" | > Number of instances : {len(self.items)}")
            print(f" | > Sequence length: {self.seq_len}")
            print(f" | > Num Classes: {len(self.classes)}")
            print(f" | > Classes: {self.classes}")

    def load_wav(self, filename):
        audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
        return audio

    def __parse_items(self):
        class_to_utters = {}
        for item in self.items:
            path_ = item["audio_file"]
            class_name = item[self.config.class_name_key]
            if class_name in class_to_utters.keys():
                class_to_utters[class_name].append(path_)
            else:
                class_to_utters[class_name] = [
                    path_,
                ]

        # skip classes with number of samples >= self.num_utter_per_class
        class_to_utters = {k: v for (k, v) in class_to_utters.items() if len(v) >= self.num_utter_per_class}

        classes = list(class_to_utters.keys())
        classes.sort()

        new_items = []
        for item in self.items:
            path_ = item["audio_file"]
            class_name = item["emotion_name"] if self.config.model == "emotion_encoder" else item["speaker_name"]
            # ignore filtered classes
            if class_name not in classes:
                continue
            # ignore small audios
            if self.load_wav(path_).shape[0] - self.seq_len <= 0:
                continue

            new_items.append({"wav_file_path": path_, "class_name": class_name})

        return classes, new_items

    def __len__(self):
        return len(self.items)

    def get_num_classes(self):
        return len(self.classes)

    def get_class_list(self):
        return self.classes

    def set_classes(self, classes):
        self.classes = classes
        self.classname_to_classid = {key: i for i, key in enumerate(self.classes)}

    def get_map_classid_to_classname(self):
        return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items())

    def __getitem__(self, idx):
        return self.items[idx]

    def collate_fn(self, batch):
        # get the batch class_ids
        labels = []
        feats = []
        for item in batch:
            utter_path = item["wav_file_path"]
            class_name = item["class_name"]

            # get classid
            class_id = self.classname_to_classid[class_name]
            # load wav file
            wav = self.load_wav(utter_path)
            offset = random.randint(0, wav.shape[0] - self.seq_len)
            wav = wav[offset : offset + self.seq_len]

            if self.augmentator is not None and self.data_augmentation_p:
                if random.random() < self.data_augmentation_p:
                    wav = self.augmentator.apply_one(wav)

            if not self.use_torch_spec:
                mel = self.ap.melspectrogram(wav)
                feats.append(torch.FloatTensor(mel))
            else:
                feats.append(torch.FloatTensor(wav))

            labels.append(class_id)

        feats = torch.stack(feats)
        labels = torch.LongTensor(labels)

        return feats, labels