Upload 8 files
Browse files- hparams/train.yaml +35 -0
- hparams/train_with_wav2vec.yaml +112 -0
- model.pth +3 -0
- prepare.py +52 -0
- results/train_with_wav2vec2/1993/test.json +1 -0
- results/train_with_wav2vec2/1993/train.json +1 -0
- results/train_with_wav2vec2/1993/valid.json +1 -0
- train_with_wav2vec.py +302 -0
hparams/train.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
seed: 1993
|
2 |
+
__set_seed: !apply:torch.manual_seed [!ref <seed>]
|
3 |
+
|
4 |
+
# Dataset will be downloaded to the `data_original`
|
5 |
+
data_original: D:/voice-emo/dat/
|
6 |
+
output_folder: results/train_with_wav2vec2/1993
|
7 |
+
save_folder: !ref <output_folder>/save
|
8 |
+
train_log: !ref <output_folder>/train_log.txt
|
9 |
+
|
10 |
+
# URL for the wav2vec2 model
|
11 |
+
wav2vec2_hub: facebook/wav2vec2-base
|
12 |
+
wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
|
13 |
+
|
14 |
+
# Path where data manifest files will be stored
|
15 |
+
train_annotation: !ref <output_folder>/train.json
|
16 |
+
valid_annotation: !ref <output_folder>/valid.json
|
17 |
+
test_annotation: !ref <output_folder>/test.json
|
18 |
+
split_ratio: [80, 10, 10]
|
19 |
+
skip_prep: False
|
20 |
+
|
21 |
+
number_of_epochs: 5
|
22 |
+
batch_size: 4
|
23 |
+
lr: 0.0001
|
24 |
+
lr_wav2vec2: 0.00001
|
25 |
+
|
26 |
+
dataloader_options:
|
27 |
+
batch_size: !ref <batch_size>
|
28 |
+
shuffle: True
|
29 |
+
num_workers: 0
|
30 |
+
drop_last: False
|
31 |
+
|
32 |
+
encoder_dim: 768
|
33 |
+
|
34 |
+
# Number of emotions
|
35 |
+
out_n_neurons: 7 # (anger, disgust, fear, happy, neutral, sad, surprise)
|
hparams/train_with_wav2vec.yaml
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
seed: 1993
|
2 |
+
__set_seed: !apply:torch.manual_seed [!ref <seed>]
|
3 |
+
|
4 |
+
# Dataset will be downloaded to the `data_original`
|
5 |
+
data_original: D:/voice-emo/dat/
|
6 |
+
output_folder: !ref results/train_with_wav2vec2/<seed>
|
7 |
+
save_folder: !ref <output_folder>/save
|
8 |
+
train_log: !ref <output_folder>/train_log.txt
|
9 |
+
|
10 |
+
# URL for the wav2vec2 model, you can change to benchmark different models
|
11 |
+
# Important: we use wav2vec2 base and not the fine-tuned one with ASR task
|
12 |
+
# This allows you to have ~4% improvement
|
13 |
+
wav2vec2_hub: facebook/wav2vec2-base
|
14 |
+
wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
|
15 |
+
|
16 |
+
# Path where data manifest files will be stored
|
17 |
+
train_annotation: !ref <output_folder>/train.json
|
18 |
+
valid_annotation: !ref <output_folder>/valid.json
|
19 |
+
test_annotation: !ref <output_folder>/test.json
|
20 |
+
split_ratio: [80, 10, 10]
|
21 |
+
skip_prep: False
|
22 |
+
|
23 |
+
# The train logger writes training statistics to a file, as well as stdout.
|
24 |
+
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
|
25 |
+
save_file: !ref <train_log>
|
26 |
+
|
27 |
+
ckpt_interval_minutes: 15 # save checkpoint every N min
|
28 |
+
|
29 |
+
####################### Training Parameters ####################################
|
30 |
+
number_of_epochs: 30
|
31 |
+
batch_size: 4
|
32 |
+
lr: 0.0001
|
33 |
+
lr_wav2vec2: 0.00001
|
34 |
+
|
35 |
+
# Freeze all wav2vec2
|
36 |
+
freeze_wav2vec2: False
|
37 |
+
# Set to true to freeze the CONV part of the wav2vec2 model
|
38 |
+
# We see an improvement of 2% with freezing CNNs
|
39 |
+
freeze_wav2vec2_conv: True
|
40 |
+
|
41 |
+
####################### Model Parameters #######################################
|
42 |
+
encoder_dim: 768
|
43 |
+
|
44 |
+
# Number of emotions
|
45 |
+
out_n_neurons: 7 # (anger, disgust, fear, happy, neutral, sad, suprise )
|
46 |
+
|
47 |
+
dataloader_options:
|
48 |
+
batch_size: !ref <batch_size>
|
49 |
+
shuffle: True
|
50 |
+
num_workers: 2 # 2 on Linux but 0 works on Windows
|
51 |
+
drop_last: False
|
52 |
+
|
53 |
+
# Wav2vec2 encoder
|
54 |
+
wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
|
55 |
+
source: !ref <wav2vec2_hub>
|
56 |
+
output_norm: True
|
57 |
+
freeze: !ref <freeze_wav2vec2>
|
58 |
+
freeze_feature_extractor: !ref <freeze_wav2vec2_conv>
|
59 |
+
save_path: !ref <wav2vec2_folder>
|
60 |
+
|
61 |
+
avg_pool: !new:speechbrain.nnet.pooling.StatisticsPooling
|
62 |
+
return_std: False
|
63 |
+
|
64 |
+
output_mlp: !new:speechbrain.nnet.linear.Linear
|
65 |
+
input_size: !ref <encoder_dim>
|
66 |
+
n_neurons: !ref <out_n_neurons>
|
67 |
+
bias: False
|
68 |
+
|
69 |
+
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
|
70 |
+
limit: !ref <number_of_epochs>
|
71 |
+
|
72 |
+
modules:
|
73 |
+
wav2vec2: !ref <wav2vec2>
|
74 |
+
output_mlp: !ref <output_mlp>
|
75 |
+
|
76 |
+
model: !new:torch.nn.ModuleList
|
77 |
+
- [!ref <output_mlp>]
|
78 |
+
|
79 |
+
log_softmax: !new:speechbrain.nnet.activations.Softmax
|
80 |
+
apply_log: True
|
81 |
+
|
82 |
+
compute_cost: !name:speechbrain.nnet.losses.nll_loss
|
83 |
+
|
84 |
+
error_stats: !name:speechbrain.utils.metric_stats.MetricStats
|
85 |
+
metric: !name:speechbrain.nnet.losses.classification_error
|
86 |
+
reduction: batch
|
87 |
+
|
88 |
+
opt_class: !name:torch.optim.Adam
|
89 |
+
lr: !ref <lr>
|
90 |
+
|
91 |
+
wav2vec2_opt_class: !name:torch.optim.Adam
|
92 |
+
lr: !ref <lr_wav2vec2>
|
93 |
+
|
94 |
+
lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
|
95 |
+
initial_value: !ref <lr>
|
96 |
+
improvement_threshold: 0.0025
|
97 |
+
annealing_factor: 0.9
|
98 |
+
patient: 0
|
99 |
+
|
100 |
+
lr_annealing_wav2vec2: !new:speechbrain.nnet.schedulers.NewBobScheduler
|
101 |
+
initial_value: !ref <lr_wav2vec2>
|
102 |
+
improvement_threshold: 0.0025
|
103 |
+
annealing_factor: 0.9
|
104 |
+
|
105 |
+
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
|
106 |
+
checkpoints_dir: !ref <save_folder>
|
107 |
+
recoverables:
|
108 |
+
model: !ref <model>
|
109 |
+
wav2vec2: !ref <wav2vec2>
|
110 |
+
lr_annealing_output: !ref <lr_annealing>
|
111 |
+
lr_annealing_wav2vec2: !ref <lr_annealing_wav2vec2>
|
112 |
+
counter: !ref <epoch_counter>
|
model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6dd3e3ab14987cd5124407a58a263504ee5b7540727dadbdb04f6481f3775b1f
|
3 |
+
size 755087318
|
prepare.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
import logging
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
def prepare_data(data_original, save_json_train, save_json_valid, save_json_test, split_ratio=[80, 10, 10], seed=12):
|
9 |
+
# Setting seeds for reproducible code.
|
10 |
+
random.seed(seed)
|
11 |
+
|
12 |
+
# Check if data preparation has already been done (skip if files exist)
|
13 |
+
if os.path.exists(save_json_train) and os.path.exists(save_json_valid) and os.path.exists(save_json_test):
|
14 |
+
logger.info("Preparation completed in previous run, skipping.")
|
15 |
+
return
|
16 |
+
|
17 |
+
# Collect audio files and labels
|
18 |
+
wav_list = []
|
19 |
+
labels = os.listdir(data_original)
|
20 |
+
|
21 |
+
for label in labels:
|
22 |
+
label_dir = os.path.join(data_original, label)
|
23 |
+
if os.path.isdir(label_dir):
|
24 |
+
for audio_file in os.listdir(label_dir):
|
25 |
+
if audio_file.endswith('.wav'):
|
26 |
+
wav_file = os.path.join(label_dir, audio_file)
|
27 |
+
if os.path.isfile(wav_file):
|
28 |
+
wav_list.append((wav_file, label))
|
29 |
+
else:
|
30 |
+
logger.warning(f"Skipping invalid audio file: {wav_file}")
|
31 |
+
|
32 |
+
# Shuffle and split the data
|
33 |
+
random.shuffle(wav_list)
|
34 |
+
n_total = len(wav_list)
|
35 |
+
n_train = n_total * split_ratio[0] // 100
|
36 |
+
n_valid = n_total * split_ratio[1] // 100
|
37 |
+
|
38 |
+
train_set = wav_list[:n_train]
|
39 |
+
valid_set = wav_list[n_train:n_train + n_valid]
|
40 |
+
test_set = wav_list[n_train + n_valid:]
|
41 |
+
|
42 |
+
# Create JSON files for train, valid, and test sets
|
43 |
+
create_json(train_set, save_json_train)
|
44 |
+
create_json(valid_set, save_json_valid)
|
45 |
+
create_json(test_set, save_json_test)
|
46 |
+
|
47 |
+
logger.info(f"Created {save_json_train}, {save_json_valid}, and {save_json_test}")
|
48 |
+
|
49 |
+
def create_json(data, json_file):
|
50 |
+
data_dict = {str(idx): {'wav': wav, 'label': label} for idx, (wav, label) in enumerate(data)}
|
51 |
+
with open(json_file, 'w') as f:
|
52 |
+
json.dump(data_dict, f)
|
results/train_with_wav2vec2/1993/test.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"0": {"wav": "D:/voice-emo/dat/surprise\\JK_su14.wav", "label": "surprise"}, "1": {"wav": "D:/voice-emo/dat/disgust\\1001_IWW_DIS_XX_stretch_augmented.wav", "label": "disgust"}, "2": {"wav": "D:/voice-emo/dat/happy\\1001_IWL_HAP_XX.wav", "label": "happy"}, "3": {"wav": "D:/voice-emo/dat/disgust\\1001_IEO_DIS_HI_noise_augmented.wav", "label": "disgust"}, "4": {"wav": "D:/voice-emo/dat/sad\\1001_IEO_SAD_LO.wav", "label": "sad"}, "5": {"wav": "D:/voice-emo/dat/sad\\1001_IEO_SAD_LO_noise_augmented.wav", "label": "sad"}, "6": {"wav": "D:/voice-emo/dat/fear\\1001_IOM_FEA_XX_pitch_augmented.wav", "label": "fear"}, "7": {"wav": "D:/voice-emo/dat/neutral\\1001_IEO_NEU_XX.wav", "label": "neutral"}, "8": {"wav": "D:/voice-emo/dat/disgust\\1001_TAI_DIS_XX_stretch_augmented.wav", "label": "disgust"}, "9": {"wav": "D:/voice-emo/dat/fear\\1001_WSI_FEA_XX_noise_augmented.wav", "label": "fear"}, "10": {"wav": "D:/voice-emo/dat/angry\\1001_IOM_ANG_XX.wav", "label": "angry"}, "11": {"wav": "D:/voice-emo/dat/angry\\1001_WSI_ANG_XX_stretch_augmented.wav", "label": "angry"}, "12": {"wav": "D:/voice-emo/dat/disgust\\1001_IOM_DIS_XX_stretch_augmented.wav", "label": "disgust"}, "13": {"wav": "D:/voice-emo/dat/angry\\1001_MTI_ANG_XX.wav", "label": "angry"}, "14": {"wav": "D:/voice-emo/dat/surprise\\DC_su01_noise_augmented.wav", "label": "surprise"}, "15": {"wav": "D:/voice-emo/dat/fear\\1001_IEO_FEA_HI_stretch_augmented.wav", "label": "fear"}, "16": {"wav": "D:/voice-emo/dat/neutral\\1001_TIE_NEU_XX_noise_augmented.wav", "label": "neutral"}, "17": {"wav": "D:/voice-emo/dat/disgust\\1001_TAI_DIS_XX_noise_augmented.wav", "label": "disgust"}, "18": {"wav": "D:/voice-emo/dat/angry\\1001_DFA_ANG_XX_stretch_augmented.wav", "label": "angry"}, "19": {"wav": "D:/voice-emo/dat/neutral\\1001_DFA_NEU_XX.wav", "label": "neutral"}, "20": {"wav": "D:/voice-emo/dat/surprise\\DC_su03.wav", "label": "surprise"}, "21": {"wav": "D:/voice-emo/dat/fear\\1001_TIE_FEA_XX_stretch_augmented.wav", "label": "fear"}, "22": {"wav": "D:/voice-emo/dat/fear\\1001_IWL_FEA_XX_noise_augmented.wav", "label": "fear"}, "23": {"wav": "D:/voice-emo/dat/fear\\1001_IWL_FEA_XX_stretch_augmented.wav", "label": "fear"}, "24": {"wav": "D:/voice-emo/dat/happy\\1001_MTI_HAP_XX_pitch_augmented.wav", "label": "happy"}, "25": {"wav": "D:/voice-emo/dat/happy\\1001_IEO_HAP_MD_noise_augmented.wav", "label": "happy"}, "26": {"wav": "D:/voice-emo/dat/sad\\1001_ITH_SAD_XX_stretch_augmented.wav", "label": "sad"}, "27": {"wav": "D:/voice-emo/dat/happy\\1001_MTI_HAP_XX.wav", "label": "happy"}, "28": {"wav": "D:/voice-emo/dat/disgust\\1001_ITH_DIS_XX_stretch_augmented.wav", "label": "disgust"}, "29": {"wav": "D:/voice-emo/dat/happy\\1001_IEO_HAP_LO_stretch_augmented.wav", "label": "happy"}, "30": {"wav": "D:/voice-emo/dat/happy\\1001_DFA_HAP_XX_pitch_augmented.wav", "label": "happy"}, "31": {"wav": "D:/voice-emo/dat/disgust\\1001_IWL_DIS_XX_noise_augmented.wav", "label": "disgust"}, "32": {"wav": "D:/voice-emo/dat/disgust\\1001_DFA_DIS_XX_stretch_augmented.wav", "label": "disgust"}, "33": {"wav": "D:/voice-emo/dat/angry\\1001_IOM_ANG_XX_pitch_augmented.wav", "label": "angry"}, "34": {"wav": "D:/voice-emo/dat/neutral\\1001_IEO_NEU_XX_noise_augmented.wav", "label": "neutral"}, "35": {"wav": "D:/voice-emo/dat/fear\\1001_ITH_FEA_XX_pitch_augmented.wav", "label": "fear"}, "36": {"wav": "D:/voice-emo/dat/sad\\1001_TIE_SAD_XX_noise_augmented.wav", "label": "sad"}, "37": {"wav": "D:/voice-emo/dat/neutral\\1001_IWL_NEU_XX_pitch_augmented.wav", "label": "neutral"}}
|
results/train_with_wav2vec2/1993/train.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"0": {"wav": "D:/voice-emo/dat/angry\\1001_IEO_ANG_HI_noise_augmented.wav", "label": "angry"}, "1": {"wav": "D:/voice-emo/dat/angry\\1001_DFA_ANG_XX_pitch_augmented.wav", "label": "angry"}, "2": {"wav": "D:/voice-emo/dat/angry\\1001_ITH_ANG_XX.wav", "label": "angry"}, "3": {"wav": "D:/voice-emo/dat/fear\\1001_IWW_FEA_XX_noise_augmented.wav", "label": "fear"}, "4": {"wav": "D:/voice-emo/dat/neutral\\1001_MTI_NEU_XX.wav", "label": "neutral"}, "5": {"wav": "D:/voice-emo/dat/surprise\\DC_su08_noise_augmented.wav", "label": "surprise"}, "6": {"wav": "D:/voice-emo/dat/surprise\\DC_su04.wav", "label": "surprise"}, "7": {"wav": "D:/voice-emo/dat/surprise\\DC_su01_pitch_augmented.wav", "label": "surprise"}, "8": {"wav": "D:/voice-emo/dat/angry\\1001_IEO_ANG_MD_pitch_augmented.wav", "label": "angry"}, "9": {"wav": "D:/voice-emo/dat/fear\\1001_DFA_FEA_XX_stretch_augmented.wav", "label": "fear"}, "10": {"wav": "D:/voice-emo/dat/disgust\\1001_IEO_DIS_HI_stretch_augmented.wav", "label": "disgust"}, "11": {"wav": "D:/voice-emo/dat/surprise\\DC_su07_noise_augmented.wav", "label": "surprise"}, "12": {"wav": "D:/voice-emo/dat/neutral\\1001_ITH_NEU_XX_pitch_augmented.wav", "label": "neutral"}, "13": {"wav": "D:/voice-emo/dat/fear\\1001_TIE_FEA_XX.wav", "label": "fear"}, "14": {"wav": "D:/voice-emo/dat/happy\\1001_TAI_HAP_XX_noise_augmented.wav", "label": "happy"}, "15": {"wav": "D:/voice-emo/dat/happy\\1001_IEO_HAP_LO_pitch_augmented.wav", "label": "happy"}, "16": {"wav": "D:/voice-emo/dat/neutral\\1001_ITH_NEU_XX_stretch_augmented.wav", "label": "neutral"}, "17": {"wav": "D:/voice-emo/dat/sad\\1001_ITS_SAD_XX_noise_augmented.wav", "label": "sad"}, "18": {"wav": "D:/voice-emo/dat/neutral\\1001_ITS_NEU_XX_stretch_augmented.wav", "label": "neutral"}, "19": {"wav": "D:/voice-emo/dat/sad\\1001_TAI_SAD_XX_pitch_augmented.wav", "label": "sad"}, "20": {"wav": "D:/voice-emo/dat/sad\\1001_ITH_SAD_XX_pitch_augmented.wav", "label": "sad"}, "21": {"wav": "D:/voice-emo/dat/angry\\1001_ITS_ANG_XX.wav", "label": "angry"}, "22": {"wav": "D:/voice-emo/dat/happy\\1001_IEO_HAP_HI.wav", "label": "happy"}, "23": {"wav": "D:/voice-emo/dat/sad\\1001_IOM_SAD_XX_pitch_augmented.wav", "label": "sad"}, "24": {"wav": "D:/voice-emo/dat/sad\\1001_TSI_SAD_XX_stretch_augmented.wav", "label": "sad"}, "25": {"wav": "D:/voice-emo/dat/surprise\\DC_su08.wav", "label": "surprise"}, "26": {"wav": "D:/voice-emo/dat/sad\\1001_TSI_SAD_XX.wav", "label": "sad"}, "27": {"wav": "D:/voice-emo/dat/sad\\1001_TSI_SAD_XX_noise_augmented.wav", "label": "sad"}, "28": {"wav": "D:/voice-emo/dat/neutral\\1001_TSI_NEU_XX_stretch_augmented.wav", "label": "neutral"}, "29": {"wav": "D:/voice-emo/dat/angry\\1001_TSI_ANG_XX_pitch_augmented.wav", "label": "angry"}, "30": {"wav": "D:/voice-emo/dat/fear\\1001_ITH_FEA_XX_noise_augmented.wav", "label": "fear"}, "31": {"wav": "D:/voice-emo/dat/neutral\\1001_ITS_NEU_XX_noise_augmented.wav", "label": "neutral"}, "32": {"wav": "D:/voice-emo/dat/neutral\\1001_TAI_NEU_XX_stretch_augmented.wav", "label": "neutral"}, "33": {"wav": "D:/voice-emo/dat/angry\\1001_TSI_ANG_XX_noise_augmented.wav", "label": "angry"}, "34": {"wav": "D:/voice-emo/dat/disgust\\1001_ITS_DIS_XX_stretch_augmented.wav", "label": "disgust"}, "35": {"wav": "D:/voice-emo/dat/surprise\\DC_su04_noise_augmented.wav", "label": "surprise"}, "36": {"wav": "D:/voice-emo/dat/sad\\1001_IEO_SAD_MD_noise_augmented.wav", "label": "sad"}, "37": {"wav": "D:/voice-emo/dat/surprise\\DC_su08_pitch_augmented.wav", "label": "surprise"}, "38": {"wav": "D:/voice-emo/dat/sad\\1001_IWL_SAD_XX_pitch_augmented.wav", "label": "sad"}, "39": {"wav": "D:/voice-emo/dat/fear\\1001_DFA_FEA_XX.wav", "label": "fear"}, "40": {"wav": "D:/voice-emo/dat/happy\\1001_ITH_HAP_XX.wav", "label": "happy"}, "41": {"wav": "D:/voice-emo/dat/disgust\\1001_MTI_DIS_XX.wav", "label": "disgust"}, "42": {"wav": "D:/voice-emo/dat/happy\\1001_IEO_HAP_LO_noise_augmented.wav", "label": "happy"}, "43": {"wav": "D:/voice-emo/dat/surprise\\DC_su03_stretch_augmented.wav", "label": "surprise"}, "44": {"wav": "D:/voice-emo/dat/surprise\\DC_su04_pitch_augmented.wav", "label": "surprise"}, "45": {"wav": "D:/voice-emo/dat/surprise\\DC_su07_pitch_augmented.wav", "label": "surprise"}, "46": {"wav": "D:/voice-emo/dat/happy\\1001_IWW_HAP_XX_stretch_augmented.wav", "label": "happy"}, "47": {"wav": "D:/voice-emo/dat/fear\\1001_IEO_FEA_MD.wav", "label": "fear"}, "48": {"wav": "D:/voice-emo/dat/angry\\1001_TSI_ANG_XX_stretch_augmented.wav", "label": "angry"}, "49": {"wav": "D:/voice-emo/dat/neutral\\1001_IOM_NEU_XX_stretch_augmented.wav", "label": "neutral"}, "50": {"wav": "D:/voice-emo/dat/fear\\1001_IOM_FEA_XX_noise_augmented.wav", "label": "fear"}, "51": {"wav": "D:/voice-emo/dat/disgust\\1001_IEO_DIS_LO_noise_augmented.wav", "label": "disgust"}, "52": {"wav": "D:/voice-emo/dat/disgust\\1001_IEO_DIS_MD_noise_augmented.wav", "label": "disgust"}, "53": {"wav": "D:/voice-emo/dat/disgust\\1001_IEO_DIS_MD_stretch_augmented.wav", "label": "disgust"}, "54": {"wav": "D:/voice-emo/dat/fear\\1001_IEO_FEA_LO_pitch_augmented.wav", "label": "fear"}, "55": {"wav": "D:/voice-emo/dat/disgust\\1001_IEO_DIS_LO_pitch_augmented.wav", "label": "disgust"}, "56": {"wav": "D:/voice-emo/dat/disgust\\1001_IWW_DIS_XX.wav", "label": "disgust"}, "57": {"wav": "D:/voice-emo/dat/angry\\1001_IWL_ANG_XX_noise_augmented.wav", "label": "angry"}, "58": {"wav": "D:/voice-emo/dat/happy\\1001_TAI_HAP_XX_stretch_augmented.wav", "label": "happy"}, "59": {"wav": "D:/voice-emo/dat/neutral\\1001_ITH_NEU_XX_noise_augmented.wav", "label": "neutral"}, "60": {"wav": "D:/voice-emo/dat/happy\\1001_IOM_HAP_XX_stretch_augmented.wav", "label": "happy"}, "61": {"wav": "D:/voice-emo/dat/sad\\1001_IEO_SAD_HI_pitch_augmented.wav", "label": "sad"}, "62": {"wav": "D:/voice-emo/dat/sad\\1001_TIE_SAD_XX.wav", "label": "sad"}, "63": {"wav": "D:/voice-emo/dat/angry\\1001_MTI_ANG_XX_stretch_augmented.wav", "label": "angry"}, "64": {"wav": "D:/voice-emo/dat/disgust\\1001_ITH_DIS_XX_noise_augmented.wav", "label": "disgust"}, "65": {"wav": "D:/voice-emo/dat/neutral\\1001_TAI_NEU_XX_pitch_augmented.wav", "label": "neutral"}, "66": {"wav": "D:/voice-emo/dat/fear\\1001_ITH_FEA_XX.wav", "label": "fear"}, "67": {"wav": "D:/voice-emo/dat/surprise\\DC_su09.wav", "label": "surprise"}, "68": {"wav": "D:/voice-emo/dat/sad\\1001_DFA_SAD_XX_pitch_augmented.wav", "label": "sad"}, "69": {"wav": "D:/voice-emo/dat/surprise\\DC_su05_stretch_augmented.wav", "label": "surprise"}, "70": {"wav": "D:/voice-emo/dat/neutral\\1001_IWL_NEU_XX_noise_augmented.wav", "label": "neutral"}, "71": {"wav": "D:/voice-emo/dat/disgust\\1001_IOM_DIS_XX_pitch_augmented.wav", "label": "disgust"}, "72": {"wav": "D:/voice-emo/dat/angry\\1001_IEO_ANG_LO_noise_augmented.wav", "label": "angry"}, "73": {"wav": "D:/voice-emo/dat/sad\\1001_IWW_SAD_XX_pitch_augmented.wav", "label": "sad"}, "74": {"wav": "D:/voice-emo/dat/angry\\1001_ITS_ANG_XX_stretch_augmented.wav", "label": "angry"}, "75": {"wav": "D:/voice-emo/dat/sad\\1001_WSI_SAD_XX_stretch_augmented.wav", "label": "sad"}, "76": {"wav": "D:/voice-emo/dat/neutral\\1001_IEO_NEU_XX_stretch_augmented.wav", "label": "neutral"}, "77": {"wav": "D:/voice-emo/dat/neutral\\1001_WSI_NEU_XX.wav", "label": "neutral"}, "78": {"wav": "D:/voice-emo/dat/disgust\\1001_MTI_DIS_XX_stretch_augmented.wav", "label": "disgust"}, "79": {"wav": "D:/voice-emo/dat/disgust\\1001_WSI_DIS_XX_noise_augmented.wav", "label": "disgust"}, "80": {"wav": "D:/voice-emo/dat/neutral\\1001_IOM_NEU_XX.wav", "label": "neutral"}, "81": {"wav": "D:/voice-emo/dat/surprise\\DC_su04_stretch_augmented.wav", "label": "surprise"}, "82": {"wav": "D:/voice-emo/dat/sad\\1001_IEO_SAD_LO_pitch_augmented.wav", "label": "sad"}, "83": {"wav": "D:/voice-emo/dat/fear\\1001_WSI_FEA_XX.wav", "label": "fear"}, "84": {"wav": "D:/voice-emo/dat/disgust\\1001_IEO_DIS_LO_stretch_augmented.wav", "label": "disgust"}, "85": {"wav": "D:/voice-emo/dat/fear\\1001_IEO_FEA_HI_pitch_augmented.wav", "label": "fear"}, "86": {"wav": "D:/voice-emo/dat/neutral\\1001_IWL_NEU_XX.wav", "label": "neutral"}, "87": {"wav": "D:/voice-emo/dat/neutral\\1001_IWW_NEU_XX_stretch_augmented.wav", "label": "neutral"}, "88": {"wav": "D:/voice-emo/dat/angry\\1001_WSI_ANG_XX_pitch_augmented.wav", "label": "angry"}, "89": {"wav": "D:/voice-emo/dat/angry\\1001_ITH_ANG_XX_pitch_augmented.wav", "label": "angry"}, "90": {"wav": "D:/voice-emo/dat/happy\\1001_TIE_HAP_XX_pitch_augmented.wav", "label": "happy"}, "91": {"wav": "D:/voice-emo/dat/neutral\\1001_TAI_NEU_XX.wav", "label": "neutral"}, "92": {"wav": "D:/voice-emo/dat/disgust\\1001_IOM_DIS_XX_noise_augmented.wav", "label": "disgust"}, "93": {"wav": "D:/voice-emo/dat/angry\\1001_IWL_ANG_XX_stretch_augmented.wav", "label": "angry"}, "94": {"wav": "D:/voice-emo/dat/fear\\1001_ITS_FEA_XX_stretch_augmented.wav", "label": "fear"}, "95": {"wav": "D:/voice-emo/dat/surprise\\DC_su06_pitch_augmented.wav", "label": "surprise"}, "96": {"wav": "D:/voice-emo/dat/sad\\1001_TAI_SAD_XX.wav", "label": "sad"}, "97": {"wav": "D:/voice-emo/dat/sad\\1001_IEO_SAD_HI_noise_augmented.wav", "label": "sad"}, "98": {"wav": "D:/voice-emo/dat/surprise\\DC_su06_stretch_augmented.wav", "label": "surprise"}, "99": {"wav": "D:/voice-emo/dat/angry\\1001_TSI_ANG_XX.wav", "label": "angry"}, "100": {"wav": "D:/voice-emo/dat/neutral\\1001_TIE_NEU_XX.wav", "label": "neutral"}, "101": {"wav": "D:/voice-emo/dat/sad\\1001_MTI_SAD_XX_pitch_augmented.wav", "label": "sad"}, "102": {"wav": "D:/voice-emo/dat/angry\\1001_IEO_ANG_MD_stretch_augmented.wav", "label": "angry"}, "103": {"wav": "D:/voice-emo/dat/surprise\\DC_su09_pitch_augmented.wav", "label": "surprise"}, "104": {"wav": "D:/voice-emo/dat/fear\\1001_IEO_FEA_LO_stretch_augmented.wav", "label": "fear"}, "105": {"wav": "D:/voice-emo/dat/fear\\1001_IWL_FEA_XX_pitch_augmented.wav", "label": "fear"}, "106": {"wav": "D:/voice-emo/dat/angry\\1001_ITH_ANG_XX_noise_augmented.wav", "label": "angry"}, "107": {"wav": "D:/voice-emo/dat/sad\\1001_IWL_SAD_XX.wav", "label": "sad"}, "108": {"wav": "D:/voice-emo/dat/disgust\\1001_TIE_DIS_XX.wav", "label": "disgust"}, "109": {"wav": "D:/voice-emo/dat/fear\\1001_TAI_FEA_XX_stretch_augmented.wav", "label": "fear"}, "110": {"wav": "D:/voice-emo/dat/angry\\1001_IEO_ANG_MD_noise_augmented.wav", "label": "angry"}, "111": {"wav": "D:/voice-emo/dat/neutral\\1001_MTI_NEU_XX_pitch_augmented.wav", "label": "neutral"}, "112": {"wav": "D:/voice-emo/dat/disgust\\1001_TAI_DIS_XX.wav", "label": "disgust"}, "113": {"wav": "D:/voice-emo/dat/fear\\1001_IEO_FEA_HI_noise_augmented.wav", "label": "fear"}, "114": {"wav": "D:/voice-emo/dat/disgust\\1001_IOM_DIS_XX.wav", "label": "disgust"}, "115": {"wav": "D:/voice-emo/dat/sad\\1001_TIE_SAD_XX_pitch_augmented.wav", "label": "sad"}, "116": {"wav": "D:/voice-emo/dat/disgust\\1001_TSI_DIS_XX_stretch_augmented.wav", "label": "disgust"}, "117": {"wav": "D:/voice-emo/dat/disgust\\1001_TSI_DIS_XX_pitch_augmented.wav", "label": "disgust"}, "118": {"wav": "D:/voice-emo/dat/disgust\\1001_IWW_DIS_XX_pitch_augmented.wav", "label": "disgust"}, "119": {"wav": "D:/voice-emo/dat/angry\\1001_MTI_ANG_XX_pitch_augmented.wav", "label": "angry"}, "120": {"wav": "D:/voice-emo/dat/sad\\1001_TAI_SAD_XX_noise_augmented.wav", "label": "sad"}, "121": {"wav": "D:/voice-emo/dat/disgust\\1001_TIE_DIS_XX_pitch_augmented.wav", "label": "disgust"}, "122": {"wav": "D:/voice-emo/dat/surprise\\DC_su02_stretch_augmented.wav", "label": "surprise"}, "123": {"wav": "D:/voice-emo/dat/surprise\\DC_su10_noise_augmented.wav", "label": "surprise"}, "124": {"wav": "D:/voice-emo/dat/disgust\\1001_WSI_DIS_XX.wav", "label": "disgust"}, "125": {"wav": "D:/voice-emo/dat/happy\\1001_WSI_HAP_XX_noise_augmented.wav", "label": "happy"}, "126": {"wav": "D:/voice-emo/dat/angry\\1001_DFA_ANG_XX_noise_augmented.wav", "label": "angry"}, "127": {"wav": "D:/voice-emo/dat/fear\\1001_ITS_FEA_XX_noise_augmented.wav", "label": "fear"}, "128": {"wav": "D:/voice-emo/dat/happy\\1001_TSI_HAP_XX_pitch_augmented.wav", "label": "happy"}, "129": {"wav": "D:/voice-emo/dat/happy\\1001_ITS_HAP_XX_stretch_augmented.wav", "label": "happy"}, "130": {"wav": "D:/voice-emo/dat/sad\\1001_WSI_SAD_XX.wav", "label": "sad"}, "131": {"wav": "D:/voice-emo/dat/fear\\1001_TAI_FEA_XX_noise_augmented.wav", "label": "fear"}, "132": {"wav": "D:/voice-emo/dat/angry\\1001_DFA_ANG_XX.wav", "label": "angry"}, "133": {"wav": "D:/voice-emo/dat/sad\\1001_WSI_SAD_XX_pitch_augmented.wav", "label": "sad"}, "134": {"wav": "D:/voice-emo/dat/angry\\1001_MTI_ANG_XX_noise_augmented.wav", "label": "angry"}, "135": {"wav": "D:/voice-emo/dat/disgust\\1001_IEO_DIS_MD.wav", "label": "disgust"}, "136": {"wav": "D:/voice-emo/dat/sad\\1001_MTI_SAD_XX_noise_augmented.wav", "label": "sad"}, "137": {"wav": "D:/voice-emo/dat/neutral\\1001_DFA_NEU_XX_noise_augmented.wav", "label": "neutral"}, "138": {"wav": "D:/voice-emo/dat/fear\\1001_MTI_FEA_XX.wav", "label": "fear"}, "139": {"wav": "D:/voice-emo/dat/sad\\1001_TSI_SAD_XX_pitch_augmented.wav", "label": "sad"}, "140": {"wav": "D:/voice-emo/dat/disgust\\1001_ITS_DIS_XX.wav", "label": "disgust"}, "141": {"wav": "D:/voice-emo/dat/neutral\\1001_WSI_NEU_XX_pitch_augmented.wav", "label": "neutral"}, "142": {"wav": "D:/voice-emo/dat/fear\\1001_TAI_FEA_XX_pitch_augmented.wav", "label": "fear"}, "143": {"wav": "D:/voice-emo/dat/sad\\1001_IEO_SAD_MD_stretch_augmented.wav", "label": "sad"}, "144": {"wav": "D:/voice-emo/dat/angry\\1001_ITS_ANG_XX_pitch_augmented.wav", "label": "angry"}, "145": {"wav": "D:/voice-emo/dat/sad\\1001_IEO_SAD_HI_stretch_augmented.wav", "label": "sad"}, "146": {"wav": "D:/voice-emo/dat/happy\\1001_ITS_HAP_XX_noise_augmented.wav", "label": "happy"}, "147": {"wav": "D:/voice-emo/dat/angry\\1001_TIE_ANG_XX_stretch_augmented.wav", "label": "angry"}, "148": {"wav": "D:/voice-emo/dat/happy\\1001_TIE_HAP_XX.wav", "label": "happy"}, "149": {"wav": "D:/voice-emo/dat/fear\\1001_WSI_FEA_XX_stretch_augmented.wav", "label": "fear"}, "150": {"wav": "D:/voice-emo/dat/disgust\\1001_IEO_DIS_LO.wav", "label": "disgust"}, "151": {"wav": "D:/voice-emo/dat/sad\\1001_IOM_SAD_XX.wav", "label": "sad"}, "152": {"wav": "D:/voice-emo/dat/sad\\1001_MTI_SAD_XX_stretch_augmented.wav", "label": "sad"}, "153": {"wav": "D:/voice-emo/dat/happy\\1001_IOM_HAP_XX_pitch_augmented.wav", "label": "happy"}, "154": {"wav": "D:/voice-emo/dat/happy\\1001_IOM_HAP_XX_noise_augmented.wav", "label": "happy"}, "155": {"wav": "D:/voice-emo/dat/sad\\1001_MTI_SAD_XX.wav", "label": "sad"}, "156": {"wav": "D:/voice-emo/dat/fear\\1001_IEO_FEA_MD_noise_augmented.wav", "label": "fear"}, "157": {"wav": "D:/voice-emo/dat/sad\\1001_ITS_SAD_XX_stretch_augmented.wav", "label": "sad"}, "158": {"wav": "D:/voice-emo/dat/sad\\1001_IWL_SAD_XX_noise_augmented.wav", "label": "sad"}, "159": {"wav": "D:/voice-emo/dat/neutral\\1001_IWW_NEU_XX_pitch_augmented.wav", "label": "neutral"}, "160": {"wav": "D:/voice-emo/dat/angry\\1001_ITH_ANG_XX_stretch_augmented.wav", "label": "angry"}, "161": {"wav": "D:/voice-emo/dat/happy\\1001_MTI_HAP_XX_noise_augmented.wav", "label": "happy"}, "162": {"wav": "D:/voice-emo/dat/angry\\1001_WSI_ANG_XX.wav", "label": "angry"}, "163": {"wav": "D:/voice-emo/dat/neutral\\1001_TIE_NEU_XX_pitch_augmented.wav", "label": "neutral"}, "164": {"wav": "D:/voice-emo/dat/sad\\1001_WSI_SAD_XX_noise_augmented.wav", "label": "sad"}, "165": {"wav": "D:/voice-emo/dat/angry\\1001_IWW_ANG_XX_pitch_augmented.wav", "label": "angry"}, "166": {"wav": "D:/voice-emo/dat/happy\\1001_IWW_HAP_XX.wav", "label": "happy"}, "167": {"wav": "D:/voice-emo/dat/angry\\1001_IEO_ANG_HI_pitch_augmented.wav", "label": "angry"}, "168": {"wav": "D:/voice-emo/dat/happy\\1001_IEO_HAP_MD_stretch_augmented.wav", "label": "happy"}, "169": {"wav": "D:/voice-emo/dat/sad\\1001_ITH_SAD_XX.wav", "label": "sad"}, "170": {"wav": "D:/voice-emo/dat/happy\\1001_TAI_HAP_XX.wav", "label": "happy"}, "171": {"wav": "D:/voice-emo/dat/fear\\1001_IWW_FEA_XX_pitch_augmented.wav", "label": "fear"}, "172": {"wav": "D:/voice-emo/dat/sad\\1001_ITS_SAD_XX.wav", "label": "sad"}, "173": {"wav": "D:/voice-emo/dat/angry\\1001_TIE_ANG_XX.wav", "label": "angry"}, "174": {"wav": "D:/voice-emo/dat/disgust\\1001_MTI_DIS_XX_noise_augmented.wav", "label": "disgust"}, "175": {"wav": "D:/voice-emo/dat/surprise\\DC_su06.wav", "label": "surprise"}, "176": {"wav": "D:/voice-emo/dat/angry\\1001_ITS_ANG_XX_noise_augmented.wav", "label": "angry"}, "177": {"wav": "D:/voice-emo/dat/neutral\\1001_IWL_NEU_XX_stretch_augmented.wav", "label": "neutral"}, "178": {"wav": "D:/voice-emo/dat/neutral\\1001_ITS_NEU_XX.wav", "label": "neutral"}, "179": {"wav": "D:/voice-emo/dat/disgust\\1001_IWW_DIS_XX_noise_augmented.wav", "label": "disgust"}, "180": {"wav": "D:/voice-emo/dat/neutral\\1001_MTI_NEU_XX_stretch_augmented.wav", "label": "neutral"}, "181": {"wav": "D:/voice-emo/dat/sad\\1001_IWW_SAD_XX.wav", "label": "sad"}, "182": {"wav": "D:/voice-emo/dat/fear\\1001_TSI_FEA_XX_pitch_augmented.wav", "label": "fear"}, "183": {"wav": "D:/voice-emo/dat/surprise\\DC_su08_stretch_augmented.wav", "label": "surprise"}, "184": {"wav": "D:/voice-emo/dat/fear\\1001_IEO_FEA_HI.wav", "label": "fear"}, "185": {"wav": "D:/voice-emo/dat/happy\\1001_ITS_HAP_XX_pitch_augmented.wav", "label": "happy"}, "186": {"wav": "D:/voice-emo/dat/surprise\\DC_su05_pitch_augmented.wav", "label": "surprise"}, "187": {"wav": "D:/voice-emo/dat/fear\\1001_IWW_FEA_XX.wav", "label": "fear"}, "188": {"wav": "D:/voice-emo/dat/disgust\\1001_TSI_DIS_XX.wav", "label": "disgust"}, "189": {"wav": "D:/voice-emo/dat/neutral\\1001_TSI_NEU_XX.wav", "label": "neutral"}, "190": {"wav": "D:/voice-emo/dat/fear\\1001_IEO_FEA_LO_noise_augmented.wav", "label": "fear"}, "191": {"wav": "D:/voice-emo/dat/happy\\1001_TIE_HAP_XX_noise_augmented.wav", "label": "happy"}, "192": {"wav": "D:/voice-emo/dat/happy\\1001_DFA_HAP_XX.wav", "label": "happy"}, "193": {"wav": "D:/voice-emo/dat/sad\\1001_DFA_SAD_XX_noise_augmented.wav", "label": "sad"}, "194": {"wav": "D:/voice-emo/dat/fear\\1001_TAI_FEA_XX.wav", "label": "fear"}, "195": {"wav": "D:/voice-emo/dat/angry\\1001_IOM_ANG_XX_stretch_augmented.wav", "label": "angry"}, "196": {"wav": "D:/voice-emo/dat/disgust\\1001_ITS_DIS_XX_noise_augmented.wav", "label": "disgust"}, "197": {"wav": "D:/voice-emo/dat/disgust\\1001_DFA_DIS_XX_pitch_augmented.wav", "label": "disgust"}, "198": {"wav": "D:/voice-emo/dat/surprise\\DC_su09_noise_augmented.wav", "label": "surprise"}, "199": {"wav": "D:/voice-emo/dat/sad\\1001_IEO_SAD_HI.wav", "label": "sad"}, "200": {"wav": "D:/voice-emo/dat/disgust\\1001_WSI_DIS_XX_pitch_augmented.wav", "label": "disgust"}, "201": {"wav": "D:/voice-emo/dat/fear\\1001_WSI_FEA_XX_pitch_augmented.wav", "label": "fear"}, "202": {"wav": "D:/voice-emo/dat/disgust\\1001_IEO_DIS_HI_pitch_augmented.wav", "label": "disgust"}, "203": {"wav": "D:/voice-emo/dat/surprise\\DC_su02_pitch_augmented.wav", "label": "surprise"}, "204": {"wav": "D:/voice-emo/dat/fear\\1001_ITS_FEA_XX_pitch_augmented.wav", "label": "fear"}, "205": {"wav": "D:/voice-emo/dat/fear\\1001_TSI_FEA_XX_stretch_augmented.wav", "label": "fear"}, "206": {"wav": "D:/voice-emo/dat/happy\\1001_ITS_HAP_XX.wav", "label": "happy"}, "207": {"wav": "D:/voice-emo/dat/disgust\\1001_IWL_DIS_XX_stretch_augmented.wav", "label": "disgust"}, "208": {"wav": "D:/voice-emo/dat/fear\\1001_IWL_FEA_XX.wav", "label": "fear"}, "209": {"wav": "D:/voice-emo/dat/happy\\1001_WSI_HAP_XX.wav", "label": "happy"}, "210": {"wav": "D:/voice-emo/dat/angry\\1001_TAI_ANG_XX_pitch_augmented.wav", "label": "angry"}, "211": {"wav": "D:/voice-emo/dat/disgust\\1001_TSI_DIS_XX_noise_augmented.wav", "label": "disgust"}, "212": {"wav": "D:/voice-emo/dat/happy\\1001_IWL_HAP_XX_noise_augmented.wav", "label": "happy"}, "213": {"wav": "D:/voice-emo/dat/happy\\1001_TIE_HAP_XX_stretch_augmented.wav", "label": "happy"}, "214": {"wav": "D:/voice-emo/dat/surprise\\DC_su06_noise_augmented.wav", "label": "surprise"}, "215": {"wav": "D:/voice-emo/dat/fear\\1001_IEO_FEA_MD_pitch_augmented.wav", "label": "fear"}, "216": {"wav": "D:/voice-emo/dat/angry\\1001_IWW_ANG_XX_stretch_augmented.wav", "label": "angry"}, "217": {"wav": "D:/voice-emo/dat/disgust\\1001_DFA_DIS_XX.wav", "label": "disgust"}, "218": {"wav": "D:/voice-emo/dat/fear\\1001_TSI_FEA_XX_noise_augmented.wav", "label": "fear"}, "219": {"wav": "D:/voice-emo/dat/disgust\\1001_TAI_DIS_XX_pitch_augmented.wav", "label": "disgust"}, "220": {"wav": "D:/voice-emo/dat/angry\\1001_IWW_ANG_XX_noise_augmented.wav", "label": "angry"}, "221": {"wav": "D:/voice-emo/dat/sad\\1001_IOM_SAD_XX_stretch_augmented.wav", "label": "sad"}, "222": {"wav": "D:/voice-emo/dat/fear\\1001_DFA_FEA_XX_noise_augmented.wav", "label": "fear"}, "223": {"wav": "D:/voice-emo/dat/fear\\1001_IEO_FEA_LO.wav", "label": "fear"}, "224": {"wav": "D:/voice-emo/dat/surprise\\DC_su07.wav", "label": "surprise"}, "225": {"wav": "D:/voice-emo/dat/sad\\1001_DFA_SAD_XX.wav", "label": "sad"}, "226": {"wav": "D:/voice-emo/dat/fear\\1001_MTI_FEA_XX_pitch_augmented.wav", "label": "fear"}, "227": {"wav": "D:/voice-emo/dat/neutral\\1001_ITH_NEU_XX.wav", "label": "neutral"}, "228": {"wav": "D:/voice-emo/dat/angry\\1001_IEO_ANG_LO_pitch_augmented.wav", "label": "angry"}, "229": {"wav": "D:/voice-emo/dat/surprise\\DC_su10.wav", "label": "surprise"}, "230": {"wav": "D:/voice-emo/dat/disgust\\1001_DFA_DIS_XX_noise_augmented.wav", "label": "disgust"}, "231": {"wav": "D:/voice-emo/dat/disgust\\1001_IWL_DIS_XX_pitch_augmented.wav", "label": "disgust"}, "232": {"wav": "D:/voice-emo/dat/disgust\\1001_ITH_DIS_XX_pitch_augmented.wav", "label": "disgust"}, "233": {"wav": "D:/voice-emo/dat/fear\\1001_IWW_FEA_XX_stretch_augmented.wav", "label": "fear"}, "234": {"wav": "D:/voice-emo/dat/neutral\\1001_IWW_NEU_XX_noise_augmented.wav", "label": "neutral"}, "235": {"wav": "D:/voice-emo/dat/sad\\1001_ITS_SAD_XX_pitch_augmented.wav", "label": "sad"}, "236": {"wav": "D:/voice-emo/dat/angry\\1001_IWW_ANG_XX.wav", "label": "angry"}, "237": {"wav": "D:/voice-emo/dat/surprise\\DC_su01_stretch_augmented.wav", "label": "surprise"}, "238": {"wav": "D:/voice-emo/dat/fear\\1001_TSI_FEA_XX.wav", "label": "fear"}, "239": {"wav": "D:/voice-emo/dat/neutral\\1001_IEO_NEU_XX_pitch_augmented.wav", "label": "neutral"}, "240": {"wav": "D:/voice-emo/dat/sad\\1001_IWW_SAD_XX_stretch_augmented.wav", "label": "sad"}, "241": {"wav": "D:/voice-emo/dat/surprise\\DC_su01.wav", "label": "surprise"}, "242": {"wav": "D:/voice-emo/dat/angry\\1001_IEO_ANG_LO_stretch_augmented.wav", "label": "angry"}, "243": {"wav": "D:/voice-emo/dat/disgust\\1001_TIE_DIS_XX_stretch_augmented.wav", "label": "disgust"}, "244": {"wav": "D:/voice-emo/dat/sad\\1001_IOM_SAD_XX_noise_augmented.wav", "label": "sad"}, "245": {"wav": "D:/voice-emo/dat/fear\\1001_IOM_FEA_XX.wav", "label": "fear"}, "246": {"wav": "D:/voice-emo/dat/sad\\1001_TAI_SAD_XX_stretch_augmented.wav", "label": "sad"}, "247": {"wav": "D:/voice-emo/dat/disgust\\1001_TIE_DIS_XX_noise_augmented.wav", "label": "disgust"}, "248": {"wav": "D:/voice-emo/dat/disgust\\1001_WSI_DIS_XX_stretch_augmented.wav", "label": "disgust"}, "249": {"wav": "D:/voice-emo/dat/sad\\1001_IWL_SAD_XX_stretch_augmented.wav", "label": "sad"}, "250": {"wav": "D:/voice-emo/dat/happy\\1001_TSI_HAP_XX.wav", "label": "happy"}, "251": {"wav": "D:/voice-emo/dat/fear\\1001_ITH_FEA_XX_stretch_augmented.wav", "label": "fear"}, "252": {"wav": "D:/voice-emo/dat/fear\\1001_TIE_FEA_XX_pitch_augmented.wav", "label": "fear"}, "253": {"wav": "D:/voice-emo/dat/angry\\1001_WSI_ANG_XX_noise_augmented.wav", "label": "angry"}, "254": {"wav": "D:/voice-emo/dat/angry\\1001_TAI_ANG_XX_noise_augmented.wav", "label": "angry"}, "255": {"wav": "D:/voice-emo/dat/happy\\1001_WSI_HAP_XX_stretch_augmented.wav", "label": "happy"}, "256": {"wav": "D:/voice-emo/dat/neutral\\1001_TAI_NEU_XX_noise_augmented.wav", "label": "neutral"}, "257": {"wav": "D:/voice-emo/dat/surprise\\DC_su03_noise_augmented.wav", "label": "surprise"}, "258": {"wav": "D:/voice-emo/dat/happy\\1001_MTI_HAP_XX_stretch_augmented.wav", "label": "happy"}, "259": {"wav": "D:/voice-emo/dat/neutral\\1001_IOM_NEU_XX_pitch_augmented.wav", "label": "neutral"}, "260": {"wav": "D:/voice-emo/dat/happy\\1001_WSI_HAP_XX_pitch_augmented.wav", "label": "happy"}, "261": {"wav": "D:/voice-emo/dat/happy\\1001_ITH_HAP_XX_stretch_augmented.wav", "label": "happy"}, "262": {"wav": "D:/voice-emo/dat/sad\\1001_ITH_SAD_XX_noise_augmented.wav", "label": "sad"}, "263": {"wav": "D:/voice-emo/dat/sad\\1001_IEO_SAD_LO_stretch_augmented.wav", "label": "sad"}, "264": {"wav": "D:/voice-emo/dat/angry\\1001_TIE_ANG_XX_pitch_augmented.wav", "label": "angry"}, "265": {"wav": "D:/voice-emo/dat/happy\\1001_IEO_HAP_HI_stretch_augmented.wav", "label": "happy"}, "266": {"wav": "D:/voice-emo/dat/neutral\\1001_TSI_NEU_XX_noise_augmented.wav", "label": "neutral"}, "267": {"wav": "D:/voice-emo/dat/happy\\1001_IWW_HAP_XX_noise_augmented.wav", "label": "happy"}, "268": {"wav": "D:/voice-emo/dat/angry\\1001_IWL_ANG_XX.wav", "label": "angry"}, "269": {"wav": "D:/voice-emo/dat/surprise\\DC_su09_stretch_augmented.wav", "label": "surprise"}, "270": {"wav": "D:/voice-emo/dat/surprise\\DC_su10_pitch_augmented.wav", "label": "surprise"}, "271": {"wav": "D:/voice-emo/dat/neutral\\1001_WSI_NEU_XX_noise_augmented.wav", "label": "neutral"}, "272": {"wav": "D:/voice-emo/dat/surprise\\DC_su05_noise_augmented.wav", "label": "surprise"}, "273": {"wav": "D:/voice-emo/dat/angry\\1001_TAI_ANG_XX_stretch_augmented.wav", "label": "angry"}, "274": {"wav": "D:/voice-emo/dat/angry\\1001_TAI_ANG_XX.wav", "label": "angry"}, "275": {"wav": "D:/voice-emo/dat/happy\\1001_TAI_HAP_XX_pitch_augmented.wav", "label": "happy"}, "276": {"wav": "D:/voice-emo/dat/fear\\1001_ITS_FEA_XX.wav", "label": "fear"}, "277": {"wav": "D:/voice-emo/dat/happy\\1001_IEO_HAP_LO.wav", "label": "happy"}, "278": {"wav": "D:/voice-emo/dat/surprise\\DC_su02_noise_augmented.wav", "label": "surprise"}, "279": {"wav": "D:/voice-emo/dat/neutral\\1001_IWW_NEU_XX.wav", "label": "neutral"}, "280": {"wav": "D:/voice-emo/dat/neutral\\1001_DFA_NEU_XX_stretch_augmented.wav", "label": "neutral"}, "281": {"wav": "D:/voice-emo/dat/happy\\1001_DFA_HAP_XX_stretch_augmented.wav", "label": "happy"}, "282": {"wav": "D:/voice-emo/dat/angry\\1001_IOM_ANG_XX_noise_augmented.wav", "label": "angry"}, "283": {"wav": "D:/voice-emo/dat/fear\\1001_TIE_FEA_XX_noise_augmented.wav", "label": "fear"}, "284": {"wav": "D:/voice-emo/dat/neutral\\1001_WSI_NEU_XX_stretch_augmented.wav", "label": "neutral"}, "285": {"wav": "D:/voice-emo/dat/surprise\\DC_su07_stretch_augmented.wav", "label": "surprise"}, "286": {"wav": "D:/voice-emo/dat/fear\\1001_MTI_FEA_XX_noise_augmented.wav", "label": "fear"}, "287": {"wav": "D:/voice-emo/dat/fear\\1001_DFA_FEA_XX_pitch_augmented.wav", "label": "fear"}, "288": {"wav": "D:/voice-emo/dat/neutral\\1001_TIE_NEU_XX_stretch_augmented.wav", "label": "neutral"}, "289": {"wav": "D:/voice-emo/dat/surprise\\DC_su02.wav", "label": "surprise"}, "290": {"wav": "D:/voice-emo/dat/disgust\\1001_MTI_DIS_XX_pitch_augmented.wav", "label": "disgust"}, "291": {"wav": "D:/voice-emo/dat/disgust\\1001_IEO_DIS_MD_pitch_augmented.wav", "label": "disgust"}, "292": {"wav": "D:/voice-emo/dat/happy\\1001_IEO_HAP_HI_pitch_augmented.wav", "label": "happy"}, "293": {"wav": "D:/voice-emo/dat/sad\\1001_TIE_SAD_XX_stretch_augmented.wav", "label": "sad"}}
|
results/train_with_wav2vec2/1993/valid.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"0": {"wav": "D:/voice-emo/dat/happy\\1001_IWW_HAP_XX_pitch_augmented.wav", "label": "happy"}, "1": {"wav": "D:/voice-emo/dat/happy\\1001_ITH_HAP_XX_noise_augmented.wav", "label": "happy"}, "2": {"wav": "D:/voice-emo/dat/disgust\\1001_ITS_DIS_XX_pitch_augmented.wav", "label": "disgust"}, "3": {"wav": "D:/voice-emo/dat/happy\\1001_DFA_HAP_XX_noise_augmented.wav", "label": "happy"}, "4": {"wav": "D:/voice-emo/dat/disgust\\1001_IWL_DIS_XX.wav", "label": "disgust"}, "5": {"wav": "D:/voice-emo/dat/disgust\\1001_ITH_DIS_XX.wav", "label": "disgust"}, "6": {"wav": "D:/voice-emo/dat/neutral\\1001_DFA_NEU_XX_pitch_augmented.wav", "label": "neutral"}, "7": {"wav": "D:/voice-emo/dat/happy\\1001_IEO_HAP_MD.wav", "label": "happy"}, "8": {"wav": "D:/voice-emo/dat/angry\\1001_IEO_ANG_HI.wav", "label": "angry"}, "9": {"wav": "D:/voice-emo/dat/happy\\1001_TSI_HAP_XX_noise_augmented.wav", "label": "happy"}, "10": {"wav": "D:/voice-emo/dat/sad\\1001_DFA_SAD_XX_stretch_augmented.wav", "label": "sad"}, "11": {"wav": "D:/voice-emo/dat/neutral\\1001_IOM_NEU_XX_noise_augmented.wav", "label": "neutral"}, "12": {"wav": "D:/voice-emo/dat/sad\\1001_IWW_SAD_XX_noise_augmented.wav", "label": "sad"}, "13": {"wav": "D:/voice-emo/dat/angry\\1001_IEO_ANG_LO.wav", "label": "angry"}, "14": {"wav": "D:/voice-emo/dat/happy\\1001_IWL_HAP_XX_pitch_augmented.wav", "label": "happy"}, "15": {"wav": "D:/voice-emo/dat/sad\\1001_IEO_SAD_MD_pitch_augmented.wav", "label": "sad"}, "16": {"wav": "D:/voice-emo/dat/angry\\1001_TIE_ANG_XX_noise_augmented.wav", "label": "angry"}, "17": {"wav": "D:/voice-emo/dat/surprise\\DC_su03_pitch_augmented.wav", "label": "surprise"}, "18": {"wav": "D:/voice-emo/dat/neutral\\1001_ITS_NEU_XX_pitch_augmented.wav", "label": "neutral"}, "19": {"wav": "D:/voice-emo/dat/fear\\1001_IOM_FEA_XX_stretch_augmented.wav", "label": "fear"}, "20": {"wav": "D:/voice-emo/dat/happy\\1001_IEO_HAP_HI_noise_augmented.wav", "label": "happy"}, "21": {"wav": "D:/voice-emo/dat/neutral\\1001_TSI_NEU_XX_pitch_augmented.wav", "label": "neutral"}, "22": {"wav": "D:/voice-emo/dat/sad\\1001_IEO_SAD_MD.wav", "label": "sad"}, "23": {"wav": "D:/voice-emo/dat/fear\\1001_MTI_FEA_XX_stretch_augmented.wav", "label": "fear"}, "24": {"wav": "D:/voice-emo/dat/angry\\1001_IEO_ANG_MD.wav", "label": "angry"}, "25": {"wav": "D:/voice-emo/dat/surprise\\DC_su05.wav", "label": "surprise"}, "26": {"wav": "D:/voice-emo/dat/fear\\1001_IEO_FEA_MD_stretch_augmented.wav", "label": "fear"}, "27": {"wav": "D:/voice-emo/dat/happy\\1001_IEO_HAP_MD_pitch_augmented.wav", "label": "happy"}, "28": {"wav": "D:/voice-emo/dat/happy\\1001_TSI_HAP_XX_stretch_augmented.wav", "label": "happy"}, "29": {"wav": "D:/voice-emo/dat/angry\\1001_IWL_ANG_XX_pitch_augmented.wav", "label": "angry"}, "30": {"wav": "D:/voice-emo/dat/disgust\\1001_IEO_DIS_HI.wav", "label": "disgust"}, "31": {"wav": "D:/voice-emo/dat/happy\\1001_ITH_HAP_XX_pitch_augmented.wav", "label": "happy"}, "32": {"wav": "D:/voice-emo/dat/happy\\1001_IOM_HAP_XX.wav", "label": "happy"}, "33": {"wav": "D:/voice-emo/dat/angry\\1001_IEO_ANG_HI_stretch_augmented.wav", "label": "angry"}, "34": {"wav": "D:/voice-emo/dat/happy\\1001_IWL_HAP_XX_stretch_augmented.wav", "label": "happy"}, "35": {"wav": "D:/voice-emo/dat/neutral\\1001_MTI_NEU_XX_noise_augmented.wav", "label": "neutral"}}
|
train_with_wav2vec.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import logging
|
4 |
+
import speechbrain as sb
|
5 |
+
from hyperpyyaml import load_hyperpyyaml
|
6 |
+
import json
|
7 |
+
import random
|
8 |
+
import torch
|
9 |
+
from sklearn.preprocessing import LabelEncoder
|
10 |
+
|
11 |
+
# Check if GPU is available
|
12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
print(f"Using device: {device}")
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
SAMPLERATE = 16000
|
17 |
+
|
18 |
+
def prepare_data(data_original, save_json_train, save_json_valid, save_json_test, split_ratio=[80, 10, 10], seed=12):
|
19 |
+
# Setting seeds for reproducible code.
|
20 |
+
random.seed(seed)
|
21 |
+
|
22 |
+
# Check if data preparation has already been done (skip if files exist)
|
23 |
+
if skip(save_json_train, save_json_valid, save_json_test):
|
24 |
+
logger.info("Preparation completed in previous run, skipping.")
|
25 |
+
return
|
26 |
+
|
27 |
+
# Collect audio files and labels
|
28 |
+
wav_list = []
|
29 |
+
labels = os.listdir(data_original)
|
30 |
+
|
31 |
+
for label in labels:
|
32 |
+
label_dir = os.path.join(data_original, label)
|
33 |
+
if os.path.isdir(label_dir):
|
34 |
+
for audio_file in os.listdir(label_dir):
|
35 |
+
if audio_file.endswith('.wav'):
|
36 |
+
wav_file = os.path.join(label_dir, audio_file)
|
37 |
+
if os.path.isfile(wav_file):
|
38 |
+
wav_list.append((wav_file, label))
|
39 |
+
else:
|
40 |
+
logger.warning(f"Skipping invalid audio file: {wav_file}")
|
41 |
+
|
42 |
+
# Shuffle and split the data
|
43 |
+
random.shuffle(wav_list)
|
44 |
+
n_total = len(wav_list)
|
45 |
+
n_train = n_total * split_ratio[0] // 100
|
46 |
+
n_valid = n_total * split_ratio[1] // 100
|
47 |
+
|
48 |
+
train_set = wav_list[:n_train]
|
49 |
+
valid_set = wav_list[n_train:n_train + n_valid]
|
50 |
+
test_set = wav_list[n_train + n_valid:]
|
51 |
+
|
52 |
+
# Create JSON files for train, valid, and test sets
|
53 |
+
create_json(train_set, save_json_train)
|
54 |
+
create_json(valid_set, save_json_valid)
|
55 |
+
create_json(test_set, save_json_test)
|
56 |
+
|
57 |
+
logger.info(f"Created {save_json_train}, {save_json_valid}, and {save_json_test}")
|
58 |
+
|
59 |
+
|
60 |
+
def create_json(wav_list, json_file):
|
61 |
+
json_dict = {}
|
62 |
+
for wav_file, label in wav_list:
|
63 |
+
signal = sb.dataio.dataio.read_audio(wav_file)
|
64 |
+
duration = signal.shape[0] / SAMPLERATE
|
65 |
+
uttid = os.path.splitext(os.path.basename(wav_file))[0]
|
66 |
+
|
67 |
+
json_dict[uttid] = {
|
68 |
+
"wav": wav_file,
|
69 |
+
"length": duration,
|
70 |
+
"label": label,
|
71 |
+
}
|
72 |
+
|
73 |
+
with open(json_file, mode="w") as json_f:
|
74 |
+
json.dump(json_dict, json_f, indent=2)
|
75 |
+
|
76 |
+
logger.info(f"Created {json_file}")
|
77 |
+
|
78 |
+
|
79 |
+
def skip(*filenames):
|
80 |
+
for filename in filenames:
|
81 |
+
if not os.path.isfile(filename):
|
82 |
+
return False
|
83 |
+
return True
|
84 |
+
|
85 |
+
|
86 |
+
class EmoIdBrain(sb.Brain):
|
87 |
+
def compute_forward(self, batch, stage):
|
88 |
+
"""Computation pipeline based on an encoder + emotion classifier."""
|
89 |
+
batch = batch.to(self.device)
|
90 |
+
wavs, lens = batch.sig
|
91 |
+
|
92 |
+
outputs = self.modules.wav2vec2(wavs, lens)
|
93 |
+
|
94 |
+
# Apply pooling and MLP layers
|
95 |
+
outputs = self.hparams.avg_pool(outputs, lens)
|
96 |
+
outputs = outputs.view(outputs.shape[0], -1)
|
97 |
+
outputs = self.modules.output_mlp(outputs)
|
98 |
+
outputs = self.hparams.log_softmax(outputs)
|
99 |
+
|
100 |
+
return outputs
|
101 |
+
|
102 |
+
def compute_objectives(self, predictions, batch, stage):
|
103 |
+
emo_encoded_list = []
|
104 |
+
|
105 |
+
for sample in batch:
|
106 |
+
# Check if 'emo_encoded' exists in the sample
|
107 |
+
if 'emo_encoded' in sample:
|
108 |
+
emo_encoded_list.append(sample['emo_encoded'])
|
109 |
+
else:
|
110 |
+
# Log a warning and skip this sample if 'emo_encoded' is missing
|
111 |
+
logging.warning(f"'emo_encoded' key not found in sample: {sample}")
|
112 |
+
|
113 |
+
if not emo_encoded_list:
|
114 |
+
# If no valid 'emo_encoded' values were found in the batch, raise an error
|
115 |
+
raise ValueError("No valid 'emo_encoded' values found in the batch.")
|
116 |
+
|
117 |
+
# Convert emo_encoded_list to a torch tensor
|
118 |
+
emo_encoded = torch.tensor(emo_encoded_list, dtype=torch.long)
|
119 |
+
|
120 |
+
# Ensure emo_encoded is a tensor
|
121 |
+
if not isinstance(emo_encoded, torch.Tensor):
|
122 |
+
raise TypeError(f"Unsupported label type encountered: {type(emo_encoded)}")
|
123 |
+
|
124 |
+
# Perform any necessary operations with emo_encoded here
|
125 |
+
loss = self.hparams.compute_cost(predictions, emo_encoded)
|
126 |
+
|
127 |
+
if stage != sb.Stage.TRAIN:
|
128 |
+
self.error_metrics.append(batch.id, predictions, emo_encoded)
|
129 |
+
|
130 |
+
return loss
|
131 |
+
|
132 |
+
|
133 |
+
def on_stage_start(self, stage, epoch=None):
|
134 |
+
"""Gets called at the beginning of each epoch."""
|
135 |
+
self.loss_metric = sb.utils.metric_stats.MetricStats(metric=sb.nnet.losses.nll_loss)
|
136 |
+
if stage != sb.Stage.TRAIN:
|
137 |
+
self.error_metrics = self.hparams.error_stats()
|
138 |
+
|
139 |
+
def on_stage_end(self, stage, stage_loss, epoch=None):
|
140 |
+
"""Gets called at the end of an epoch."""
|
141 |
+
if stage == sb.Stage.TRAIN:
|
142 |
+
self.train_loss = stage_loss
|
143 |
+
else:
|
144 |
+
stats = {
|
145 |
+
"loss": stage_loss,
|
146 |
+
}
|
147 |
+
if self.error_metrics is not None and len(self.error_metrics.scores) > 0:
|
148 |
+
# Calculate error rate only if there are scores in the error_metrics
|
149 |
+
stats["error_rate"] = self.error_metrics.summarize("average")
|
150 |
+
else:
|
151 |
+
# Handle case where error_metrics are None or empty
|
152 |
+
stats["error_rate"] = float('nan') # Set error_rate to NaN if no scores available
|
153 |
+
|
154 |
+
if stage == sb.Stage.VALID:
|
155 |
+
old_lr, new_lr = self.hparams.lr_annealing(stats["error_rate"])
|
156 |
+
sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)
|
157 |
+
self.hparams.train_logger.log_stats(
|
158 |
+
{"Epoch": epoch, "lr": old_lr},
|
159 |
+
train_stats={"loss": self.train_loss},
|
160 |
+
valid_stats=stats,
|
161 |
+
)
|
162 |
+
self.checkpointer.save_and_keep_only(meta=stats, min_keys=["error_rate"])
|
163 |
+
elif stage == sb.Stage.TEST:
|
164 |
+
self.hparams.train_logger.log_stats(
|
165 |
+
{"Epoch loaded": self.hparams.epoch_counter.current},
|
166 |
+
test_stats=stats,
|
167 |
+
)
|
168 |
+
|
169 |
+
def init_optimizers(self):
|
170 |
+
"""Initializes the optimizer."""
|
171 |
+
self.optimizer = self.hparams.opt_class(self.hparams.model.parameters())
|
172 |
+
if self.checkpointer is not None:
|
173 |
+
self.checkpointer.add_recoverable("optimizer", self.optimizer)
|
174 |
+
self.optimizers_dict = {"model_optimizer": self.optimizer}
|
175 |
+
|
176 |
+
|
177 |
+
def dataio_prep(hparams):
|
178 |
+
"""Prepares the datasets to be used in the brain class."""
|
179 |
+
|
180 |
+
# Define the audio processing pipeline
|
181 |
+
@sb.utils.data_pipeline.takes("wav")
|
182 |
+
@sb.utils.data_pipeline.provides("sig")
|
183 |
+
def audio_pipeline(wav):
|
184 |
+
"""Load the signal from a WAV file."""
|
185 |
+
sig = sb.dataio.dataio.read_audio(wav)
|
186 |
+
return sig
|
187 |
+
|
188 |
+
# Initialize the label encoder
|
189 |
+
label_encoder = sb.dataio.encoder.CategoricalEncoder()
|
190 |
+
label_encoder.add_unk()
|
191 |
+
|
192 |
+
label_to_index = {
|
193 |
+
'angry': 0,
|
194 |
+
'happy': 1,
|
195 |
+
'neutral': 2,
|
196 |
+
'sad': 3,
|
197 |
+
'surprise': 4,
|
198 |
+
'disgust': 5,
|
199 |
+
'fear': 6
|
200 |
+
}
|
201 |
+
|
202 |
+
@sb.utils.data_pipeline.takes("label")
|
203 |
+
@sb.utils.data_pipeline.provides("label", "emo_encoded")
|
204 |
+
def label_pipeline(label):
|
205 |
+
"""Encode the emotion label."""
|
206 |
+
if label in label_to_index:
|
207 |
+
emo_encoded = label_to_index[label]
|
208 |
+
else:
|
209 |
+
raise ValueError(f"Unknown label encountered: {label}")
|
210 |
+
|
211 |
+
yield label, torch.tensor(emo_encoded, dtype=torch.long)
|
212 |
+
|
213 |
+
# Define datasets dictionary
|
214 |
+
datasets = {}
|
215 |
+
data_info = {
|
216 |
+
"train": hparams["train_annotation"],
|
217 |
+
"valid": hparams["valid_annotation"],
|
218 |
+
"test": hparams["test_annotation"],
|
219 |
+
}
|
220 |
+
|
221 |
+
# Load datasets and apply pipelines
|
222 |
+
for dataset_name, json_path in data_info.items():
|
223 |
+
datasets[dataset_name] = sb.dataio.dataset.DynamicItemDataset.from_json(
|
224 |
+
json_path=json_path,
|
225 |
+
replacements={"data_root": hparams["data_original"]},
|
226 |
+
dynamic_items=[audio_pipeline, label_pipeline],
|
227 |
+
output_keys=["id", "sig", "label", "emo_encoded"],
|
228 |
+
)
|
229 |
+
|
230 |
+
lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
|
231 |
+
label_encoder.load_or_create(
|
232 |
+
path=lab_enc_file,
|
233 |
+
from_didatasets=[datasets["train"]],
|
234 |
+
output_key="label",
|
235 |
+
)
|
236 |
+
|
237 |
+
return datasets
|
238 |
+
|
239 |
+
|
240 |
+
if __name__ == "__main__":
|
241 |
+
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
|
242 |
+
sb.utils.distributed.ddp_init_group(run_opts)
|
243 |
+
|
244 |
+
try:
|
245 |
+
with open(hparams_file) as fin:
|
246 |
+
hparams = load_hyperpyyaml(fin, overrides)
|
247 |
+
data_original = hparams.get("data_original")
|
248 |
+
if data_original is not None:
|
249 |
+
data_original = os.path.normpath(data_original)
|
250 |
+
if not os.path.exists(data_original):
|
251 |
+
raise ValueError(f"data_original path '{data_original}' does not exist.")
|
252 |
+
else:
|
253 |
+
raise ValueError("data_original path is not specified in the YAML configuration.")
|
254 |
+
|
255 |
+
except Exception as e:
|
256 |
+
print("Error occurred", e)
|
257 |
+
sys.exit(1)
|
258 |
+
|
259 |
+
sb.create_experiment_directory(
|
260 |
+
experiment_directory=hparams["output_folder"],
|
261 |
+
hyperparams_to_save=hparams_file,
|
262 |
+
overrides=overrides,
|
263 |
+
)
|
264 |
+
|
265 |
+
if not hparams["skip_prep"]:
|
266 |
+
prepare_kwargs = {
|
267 |
+
"data_original": hparams["data_original"],
|
268 |
+
"save_json_train": hparams["train_annotation"],
|
269 |
+
"save_json_valid": hparams["valid_annotation"],
|
270 |
+
"save_json_test": hparams["test_annotation"],
|
271 |
+
"split_ratio": hparams["split_ratio"],
|
272 |
+
"seed": hparams["seed"],
|
273 |
+
}
|
274 |
+
sb.utils.distributed.run_on_main(prepare_data, kwargs=prepare_kwargs)
|
275 |
+
|
276 |
+
datasets = dataio_prep(hparams)
|
277 |
+
|
278 |
+
hparams["wav2vec2"] = hparams["wav2vec2"].to(device=run_opts["device"])
|
279 |
+
if not hparams["freeze_wav2vec2"] and hparams["freeze_wav2vec2_conv"]:
|
280 |
+
hparams["wav2vec2"].model.feature_extractor._freeze_parameters()
|
281 |
+
|
282 |
+
emo_id_brain = EmoIdBrain(
|
283 |
+
modules=hparams["modules"],
|
284 |
+
opt_class=hparams["opt_class"],
|
285 |
+
hparams=hparams,
|
286 |
+
run_opts=run_opts,
|
287 |
+
checkpointer=hparams["checkpointer"],
|
288 |
+
)
|
289 |
+
|
290 |
+
emo_id_brain.fit(
|
291 |
+
epoch_counter=emo_id_brain.hparams.epoch_counter,
|
292 |
+
train_set=datasets["train"],
|
293 |
+
valid_set=datasets["valid"],
|
294 |
+
train_loader_kwargs=hparams["dataloader_options"],
|
295 |
+
valid_loader_kwargs=hparams["dataloader_options"],
|
296 |
+
)
|
297 |
+
|
298 |
+
test_stats = emo_id_brain.evaluate(
|
299 |
+
test_set=datasets["test"],
|
300 |
+
min_key="error_rate",
|
301 |
+
test_loader_kwargs=hparams["dataloader_options"],
|
302 |
+
)
|