|
import os |
|
import sys |
|
import logging |
|
import speechbrain as sb |
|
from hyperpyyaml import load_hyperpyyaml |
|
import json |
|
import random |
|
import torch |
|
from sklearn.preprocessing import LabelEncoder |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
logger = logging.getLogger(__name__) |
|
SAMPLERATE = 16000 |
|
|
|
def prepare_data(data_original, save_json_train, save_json_valid, save_json_test, split_ratio=[80, 10, 10], seed=12): |
|
|
|
random.seed(seed) |
|
|
|
|
|
if skip(save_json_train, save_json_valid, save_json_test): |
|
logger.info("Preparation completed in previous run, skipping.") |
|
return |
|
|
|
|
|
wav_list = [] |
|
labels = os.listdir(data_original) |
|
|
|
for label in labels: |
|
label_dir = os.path.join(data_original, label) |
|
if os.path.isdir(label_dir): |
|
for audio_file in os.listdir(label_dir): |
|
if audio_file.endswith('.wav'): |
|
wav_file = os.path.join(label_dir, audio_file) |
|
if os.path.isfile(wav_file): |
|
wav_list.append((wav_file, label)) |
|
else: |
|
logger.warning(f"Skipping invalid audio file: {wav_file}") |
|
|
|
|
|
random.shuffle(wav_list) |
|
n_total = len(wav_list) |
|
n_train = n_total * split_ratio[0] // 100 |
|
n_valid = n_total * split_ratio[1] // 100 |
|
|
|
train_set = wav_list[:n_train] |
|
valid_set = wav_list[n_train:n_train + n_valid] |
|
test_set = wav_list[n_train + n_valid:] |
|
|
|
|
|
create_json(train_set, save_json_train) |
|
create_json(valid_set, save_json_valid) |
|
create_json(test_set, save_json_test) |
|
|
|
logger.info(f"Created {save_json_train}, {save_json_valid}, and {save_json_test}") |
|
|
|
|
|
def create_json(wav_list, json_file): |
|
json_dict = {} |
|
for wav_file, label in wav_list: |
|
signal = sb.dataio.dataio.read_audio(wav_file) |
|
duration = signal.shape[0] / SAMPLERATE |
|
uttid = os.path.splitext(os.path.basename(wav_file))[0] |
|
|
|
json_dict[uttid] = { |
|
"wav": wav_file, |
|
"length": duration, |
|
"label": label, |
|
} |
|
|
|
with open(json_file, mode="w") as json_f: |
|
json.dump(json_dict, json_f, indent=2) |
|
|
|
logger.info(f"Created {json_file}") |
|
|
|
|
|
def skip(*filenames): |
|
for filename in filenames: |
|
if not os.path.isfile(filename): |
|
return False |
|
return True |
|
|
|
|
|
class EmoIdBrain(sb.Brain): |
|
def compute_forward(self, batch, stage): |
|
"""Computation pipeline based on an encoder + emotion classifier.""" |
|
batch = batch.to(self.device) |
|
wavs, lens = batch.sig |
|
|
|
outputs = self.modules.wav2vec2(wavs, lens) |
|
|
|
|
|
outputs = self.hparams.avg_pool(outputs, lens) |
|
outputs = outputs.view(outputs.shape[0], -1) |
|
outputs = self.modules.output_mlp(outputs) |
|
outputs = self.hparams.log_softmax(outputs) |
|
|
|
return outputs |
|
|
|
def compute_objectives(self, predictions, batch, stage): |
|
emo_encoded_list = [] |
|
|
|
for sample in batch: |
|
|
|
if 'emo_encoded' in sample: |
|
emo_encoded_list.append(sample['emo_encoded']) |
|
else: |
|
|
|
logging.warning(f"'emo_encoded' key not found in sample: {sample}") |
|
|
|
if not emo_encoded_list: |
|
|
|
raise ValueError("No valid 'emo_encoded' values found in the batch.") |
|
|
|
|
|
emo_encoded = torch.tensor(emo_encoded_list, dtype=torch.long) |
|
|
|
|
|
if not isinstance(emo_encoded, torch.Tensor): |
|
raise TypeError(f"Unsupported label type encountered: {type(emo_encoded)}") |
|
|
|
|
|
loss = self.hparams.compute_cost(predictions, emo_encoded) |
|
|
|
if stage != sb.Stage.TRAIN: |
|
self.error_metrics.append(batch.id, predictions, emo_encoded) |
|
|
|
return loss |
|
|
|
|
|
def on_stage_start(self, stage, epoch=None): |
|
"""Gets called at the beginning of each epoch.""" |
|
self.loss_metric = sb.utils.metric_stats.MetricStats(metric=sb.nnet.losses.nll_loss) |
|
if stage != sb.Stage.TRAIN: |
|
self.error_metrics = self.hparams.error_stats() |
|
|
|
def on_stage_end(self, stage, stage_loss, epoch=None): |
|
"""Gets called at the end of an epoch.""" |
|
if stage == sb.Stage.TRAIN: |
|
self.train_loss = stage_loss |
|
else: |
|
stats = { |
|
"loss": stage_loss, |
|
} |
|
if self.error_metrics is not None and len(self.error_metrics.scores) > 0: |
|
|
|
stats["error_rate"] = self.error_metrics.summarize("average") |
|
else: |
|
|
|
stats["error_rate"] = float('nan') |
|
|
|
if stage == sb.Stage.VALID: |
|
old_lr, new_lr = self.hparams.lr_annealing(stats["error_rate"]) |
|
sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr) |
|
self.hparams.train_logger.log_stats( |
|
{"Epoch": epoch, "lr": old_lr}, |
|
train_stats={"loss": self.train_loss}, |
|
valid_stats=stats, |
|
) |
|
self.checkpointer.save_and_keep_only(meta=stats, min_keys=["error_rate"]) |
|
elif stage == sb.Stage.TEST: |
|
self.hparams.train_logger.log_stats( |
|
{"Epoch loaded": self.hparams.epoch_counter.current}, |
|
test_stats=stats, |
|
) |
|
|
|
def init_optimizers(self): |
|
"""Initializes the optimizer.""" |
|
self.optimizer = self.hparams.opt_class(self.hparams.model.parameters()) |
|
if self.checkpointer is not None: |
|
self.checkpointer.add_recoverable("optimizer", self.optimizer) |
|
self.optimizers_dict = {"model_optimizer": self.optimizer} |
|
|
|
|
|
def dataio_prep(hparams): |
|
"""Prepares the datasets to be used in the brain class.""" |
|
|
|
|
|
@sb.utils.data_pipeline.takes("wav") |
|
@sb.utils.data_pipeline.provides("sig") |
|
def audio_pipeline(wav): |
|
"""Load the signal from a WAV file.""" |
|
sig = sb.dataio.dataio.read_audio(wav) |
|
return sig |
|
|
|
|
|
label_encoder = sb.dataio.encoder.CategoricalEncoder() |
|
label_encoder.add_unk() |
|
|
|
label_to_index = { |
|
'angry': 0, |
|
'happy': 1, |
|
'neutral': 2, |
|
'sad': 3, |
|
'surprise': 4, |
|
'disgust': 5, |
|
'fear': 6 |
|
} |
|
|
|
@sb.utils.data_pipeline.takes("label") |
|
@sb.utils.data_pipeline.provides("label", "emo_encoded") |
|
def label_pipeline(label): |
|
"""Encode the emotion label.""" |
|
if label in label_to_index: |
|
emo_encoded = label_to_index[label] |
|
else: |
|
raise ValueError(f"Unknown label encountered: {label}") |
|
|
|
yield label, torch.tensor(emo_encoded, dtype=torch.long) |
|
|
|
|
|
datasets = {} |
|
data_info = { |
|
"train": hparams["train_annotation"], |
|
"valid": hparams["valid_annotation"], |
|
"test": hparams["test_annotation"], |
|
} |
|
|
|
|
|
for dataset_name, json_path in data_info.items(): |
|
datasets[dataset_name] = sb.dataio.dataset.DynamicItemDataset.from_json( |
|
json_path=json_path, |
|
replacements={"data_root": hparams["data_original"]}, |
|
dynamic_items=[audio_pipeline, label_pipeline], |
|
output_keys=["id", "sig", "label", "emo_encoded"], |
|
) |
|
|
|
lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt") |
|
label_encoder.load_or_create( |
|
path=lab_enc_file, |
|
from_didatasets=[datasets["train"]], |
|
output_key="label", |
|
) |
|
|
|
return datasets |
|
|
|
|
|
if __name__ == "__main__": |
|
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) |
|
sb.utils.distributed.ddp_init_group(run_opts) |
|
|
|
try: |
|
with open(hparams_file) as fin: |
|
hparams = load_hyperpyyaml(fin, overrides) |
|
data_original = hparams.get("data_original") |
|
if data_original is not None: |
|
data_original = os.path.normpath(data_original) |
|
if not os.path.exists(data_original): |
|
raise ValueError(f"data_original path '{data_original}' does not exist.") |
|
else: |
|
raise ValueError("data_original path is not specified in the YAML configuration.") |
|
|
|
except Exception as e: |
|
print("Error occurred", e) |
|
sys.exit(1) |
|
|
|
sb.create_experiment_directory( |
|
experiment_directory=hparams["output_folder"], |
|
hyperparams_to_save=hparams_file, |
|
overrides=overrides, |
|
) |
|
|
|
if not hparams["skip_prep"]: |
|
prepare_kwargs = { |
|
"data_original": hparams["data_original"], |
|
"save_json_train": hparams["train_annotation"], |
|
"save_json_valid": hparams["valid_annotation"], |
|
"save_json_test": hparams["test_annotation"], |
|
"split_ratio": hparams["split_ratio"], |
|
"seed": hparams["seed"], |
|
} |
|
sb.utils.distributed.run_on_main(prepare_data, kwargs=prepare_kwargs) |
|
|
|
datasets = dataio_prep(hparams) |
|
|
|
hparams["wav2vec2"] = hparams["wav2vec2"].to(device=run_opts["device"]) |
|
if not hparams["freeze_wav2vec2"] and hparams["freeze_wav2vec2_conv"]: |
|
hparams["wav2vec2"].model.feature_extractor._freeze_parameters() |
|
|
|
emo_id_brain = EmoIdBrain( |
|
modules=hparams["modules"], |
|
opt_class=hparams["opt_class"], |
|
hparams=hparams, |
|
run_opts=run_opts, |
|
checkpointer=hparams["checkpointer"], |
|
) |
|
|
|
emo_id_brain.fit( |
|
epoch_counter=emo_id_brain.hparams.epoch_counter, |
|
train_set=datasets["train"], |
|
valid_set=datasets["valid"], |
|
train_loader_kwargs=hparams["dataloader_options"], |
|
valid_loader_kwargs=hparams["dataloader_options"], |
|
) |
|
|
|
test_stats = emo_id_brain.evaluate( |
|
test_set=datasets["test"], |
|
min_key="error_rate", |
|
test_loader_kwargs=hparams["dataloader_options"], |
|
) |
|
|