kjysmu commited on
Commit
2822008
·
verified ·
1 Parent(s): e163262

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +611 -0
app.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import json
4
+ import torch
5
+ import torchaudio
6
+ import numpy as np
7
+ import logging
8
+ import warnings
9
+ import subprocess
10
+ import math
11
+ import random
12
+ import time
13
+ from pathlib import Path
14
+ from tqdm import tqdm
15
+ from PIL import Image
16
+ from huggingface_hub import snapshot_download
17
+ from omegaconf import DictConfig
18
+ import hydra
19
+ from hydra.utils import to_absolute_path
20
+ from transformers import Wav2Vec2FeatureExtractor, AutoModel
21
+ import mir_eval
22
+ import pretty_midi as pm
23
+ import gradio as gr
24
+ from gradio import Markdown
25
+ from music21 import converter
26
+ import torchaudio.transforms as T
27
+
28
+ # Custom utility imports
29
+ from utils import logger
30
+ from utils.btc_model import BTC_model
31
+ from utils.transformer_modules import *
32
+ from utils.transformer_modules import _gen_timing_signal, _gen_bias_mask
33
+ from utils.hparams import HParams
34
+ from utils.mir_eval_modules import (
35
+ audio_file_to_features, idx2chord, idx2voca_chord,
36
+ get_audio_paths, get_lab_paths
37
+ )
38
+ from utils.mert import FeatureExtractorMERT
39
+ from model.linear_mt_attn_ck import FeedforwardModelMTAttnCK
40
+
41
+ # Suppress unnecessary warnings and logs
42
+ warnings.filterwarnings("ignore")
43
+ logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
44
+
45
+ # from gradio import Markdown
46
+
47
+ PITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
48
+
49
+ pitch_num_dic = {
50
+ 'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5,
51
+ 'F#': 6, 'G': 7, 'G#': 8, 'A': 9, 'A#': 10, 'B': 11
52
+ }
53
+
54
+ minor_major_dic = {
55
+ 'D-':'C#', 'E-':'D#', 'G-':'F#', 'A-':'G#', 'B-':'A#'
56
+ }
57
+ minor_major_dic2 = {
58
+ 'Db':'C#', 'Eb':'D#', 'Gb':'F#', 'Ab':'G#', 'Bb':'A#'
59
+ }
60
+
61
+ shift_major_dic = {
62
+ 'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5,
63
+ 'F#': 6, 'G': 7, 'G#': 8, 'A': 9, 'A#': 10, 'B': 11
64
+ }
65
+
66
+ shift_minor_dic = {
67
+ 'A': 0, 'A#': 1, 'B': 2, 'C': 3, 'C#': 4, 'D': 5,
68
+ 'D#': 6, 'E': 7, 'F': 8, 'F#': 9, 'G': 10, 'G#': 11,
69
+ }
70
+
71
+ flat_to_sharp_mapping = {
72
+ "Cb": "B",
73
+ "Db": "C#",
74
+ "Eb": "D#",
75
+ "Fb": "E",
76
+ "Gb": "F#",
77
+ "Ab": "G#",
78
+ "Bb": "A#"
79
+ }
80
+
81
+ segment_duration = 30
82
+ resample_rate = 24000
83
+ is_split = True
84
+
85
+ def normalize_chord(file_path, key, key_type='major'):
86
+ with open(file_path, 'r') as f:
87
+ lines = f.readlines()
88
+
89
+ if key == "None":
90
+ new_key = "C major"
91
+ shift = 0
92
+ else:
93
+ #print ("asdas",key)
94
+ if len(key) == 1:
95
+ key = key[0].upper()
96
+ else:
97
+ key = key[0].upper() + key[1:]
98
+
99
+ if key in minor_major_dic2:
100
+ key = minor_major_dic2[key]
101
+
102
+ shift = 0
103
+
104
+ if key_type == "major":
105
+ new_key = "C major"
106
+
107
+ shift = shift_major_dic[key]
108
+ else:
109
+ new_key = "A minor"
110
+ shift = shift_minor_dic[key]
111
+
112
+ converted_lines = []
113
+ for line in lines:
114
+ if line.strip(): # Skip empty lines
115
+ parts = line.split()
116
+ start_time = parts[0]
117
+ end_time = parts[1]
118
+ chord = parts[2] # The chord is in the 3rd column
119
+ if chord == "N":
120
+ newchordnorm = "N"
121
+ elif chord == "X":
122
+ newchordnorm = "X"
123
+ elif ":" in chord:
124
+ pitch = chord.split(":")[0]
125
+ attr = chord.split(":")[1]
126
+ pnum = pitch_num_dic [pitch]
127
+ new_idx = (pnum - shift)%12
128
+ newchord = PITCH_CLASS[new_idx]
129
+ newchordnorm = newchord + ":" + attr
130
+ else:
131
+ pitch = chord
132
+ pnum = pitch_num_dic [pitch]
133
+ new_idx = (pnum - shift)%12
134
+ newchord = PITCH_CLASS[new_idx]
135
+ newchordnorm = newchord
136
+
137
+ converted_lines.append(f"{start_time} {end_time} {newchordnorm}\n")
138
+
139
+ return converted_lines
140
+
141
+ def sanitize_key_signature(key):
142
+ return key.replace('-', 'b')
143
+
144
+ def resample_waveform(waveform, original_sample_rate, target_sample_rate):
145
+ if original_sample_rate != target_sample_rate:
146
+ resampler = T.Resample(original_sample_rate, target_sample_rate)
147
+ return resampler(waveform), target_sample_rate
148
+ return waveform, original_sample_rate
149
+
150
+ def split_audio(waveform, sample_rate):
151
+ segment_samples = segment_duration * sample_rate
152
+ total_samples = waveform.size(0)
153
+
154
+ segments = []
155
+ for start in range(0, total_samples, segment_samples):
156
+ end = start + segment_samples
157
+ if end <= total_samples:
158
+ segment = waveform[start:end]
159
+ segments.append(segment)
160
+
161
+ # In case audio length is shorter than segment length.
162
+ if len(segments) == 0:
163
+ segment = waveform
164
+ segments.append(segment)
165
+
166
+ return segments
167
+
168
+
169
+
170
+ class Music2emo:
171
+ def __init__(
172
+ self,
173
+ name="amaai-lab/music2emo",
174
+ device="cuda:0",
175
+ cache_dir=None,
176
+ local_files_only=False,
177
+ ):
178
+
179
+ # use_cuda = torch.cuda.is_available()
180
+ # self.device = torch.device("cuda" if use_cuda else "cpu")
181
+ model_weights = "saved_models/J_all.ckpt"
182
+ self.device = device
183
+
184
+ self.feature_extractor = FeatureExtractorMERT(model_name='m-a-p/MERT-v1-95M', device=self.device, sr=resample_rate)
185
+ self.model_weights = model_weights
186
+
187
+ self.music2emo_model = FeedforwardModelMTAttnCK(
188
+ input_size= 768 * 2,
189
+ output_size_classification=56,
190
+ output_size_regression=2
191
+ )
192
+
193
+ checkpoint = torch.load(self.model_weights, map_location=self.device, weights_only=False)
194
+ state_dict = checkpoint["state_dict"]
195
+
196
+ # Adjust the keys in the state_dict
197
+ state_dict = {key.replace("model.", ""): value for key, value in state_dict.items()}
198
+
199
+ # Filter state_dict to match model's keys
200
+ model_keys = set(self.music2emo_model.state_dict().keys())
201
+ filtered_state_dict = {key: value for key, value in state_dict.items() if key in model_keys}
202
+
203
+ # Load the filtered state_dict and set the model to evaluation mode
204
+ self.music2emo_model.load_state_dict(filtered_state_dict)
205
+
206
+ self.music2emo_model.to(self.device)
207
+ self.music2emo_model.eval()
208
+
209
+ def predict(self, audio, threshold = 0.5):
210
+
211
+ feature_dir = Path("./temp_out")
212
+ output_dir = Path("./output")
213
+
214
+ if feature_dir.exists():
215
+ shutil.rmtree(str(feature_dir))
216
+ if output_dir.exists():
217
+ shutil.rmtree(str(output_dir))
218
+
219
+ feature_dir.mkdir(parents=True)
220
+ output_dir.mkdir(parents=True)
221
+
222
+ warnings.filterwarnings('ignore')
223
+ logger.logging_verbosity(1)
224
+
225
+ mert_dir = feature_dir / "mert"
226
+ mert_dir.mkdir(parents=True)
227
+
228
+ waveform, sample_rate = torchaudio.load(audio)
229
+ if waveform.shape[0] > 1:
230
+ waveform = waveform.mean(dim=0).unsqueeze(0)
231
+ waveform = waveform.squeeze()
232
+ waveform, sample_rate = resample_waveform(waveform, sample_rate, resample_rate)
233
+
234
+ if is_split:
235
+ segments = split_audio(waveform, sample_rate)
236
+ for i, segment in enumerate(segments):
237
+ segment_save_path = os.path.join(mert_dir, f"segment_{i}.npy")
238
+ self.feature_extractor.extract_features_from_segment(segment, sample_rate, segment_save_path)
239
+ else:
240
+ segment_save_path = os.path.join(mert_dir, f"segment_0.npy")
241
+ self.feature_extractor.extract_features_from_segment(waveform, sample_rate, segment_save_path)
242
+
243
+ embeddings = []
244
+ layers_to_extract = [5,6]
245
+ segment_embeddings = []
246
+ for filename in sorted(os.listdir(mert_dir)): # Sort files to ensure sequential order
247
+ file_path = os.path.join(mert_dir, filename)
248
+ if os.path.isfile(file_path) and filename.endswith('.npy'):
249
+ segment = np.load(file_path)
250
+ concatenated_features = np.concatenate(
251
+ [segment[:, layer_idx, :] for layer_idx in layers_to_extract], axis=1
252
+ )
253
+ concatenated_features = np.squeeze(concatenated_features) # Shape: 768 * 2 = 1536
254
+ segment_embeddings.append(concatenated_features)
255
+
256
+ segment_embeddings = np.array(segment_embeddings)
257
+ if len(segment_embeddings) > 0:
258
+ final_embedding_mert = np.mean(segment_embeddings, axis=0)
259
+ else:
260
+ final_embedding_mert = np.zeros((1536,))
261
+
262
+ final_embedding_mert = torch.from_numpy(final_embedding_mert)
263
+ final_embedding_mert.to(self.device)
264
+
265
+ # --- Chord feature extract ---
266
+ config = HParams.load("./inference/data/run_config.yaml")
267
+ config.feature['large_voca'] = True
268
+ config.model['num_chords'] = 170
269
+ model_file = './inference/data/btc_model_large_voca.pt'
270
+ idx_to_chord = idx2voca_chord()
271
+ model = BTC_model(config=config.model).to(self.device)
272
+
273
+ if os.path.isfile(model_file):
274
+ checkpoint = torch.load(model_file)
275
+ mean = checkpoint['mean']
276
+ std = checkpoint['std']
277
+ model.load_state_dict(checkpoint['model'])
278
+
279
+ audio_path = audio
280
+ audio_id = audio_path.split("/")[-1][:-4]
281
+ try:
282
+ feature, feature_per_second, song_length_second = audio_file_to_features(audio_path, config)
283
+ except:
284
+ logger.info("audio file failed to load : %s" % audio_path)
285
+ assert(False)
286
+
287
+ logger.info("audio file loaded and feature computation success : %s" % audio_path)
288
+
289
+ feature = feature.T
290
+ feature = (feature - mean) / std
291
+ time_unit = feature_per_second
292
+ n_timestep = config.model['timestep']
293
+
294
+ num_pad = n_timestep - (feature.shape[0] % n_timestep)
295
+ feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
296
+ num_instance = feature.shape[0] // n_timestep
297
+
298
+ start_time = 0.0
299
+ lines = []
300
+ with torch.no_grad():
301
+ model.eval()
302
+ feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(self.device)
303
+ for t in range(num_instance):
304
+ self_attn_output, _ = model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :])
305
+ prediction, _ = model.output_layer(self_attn_output)
306
+ prediction = prediction.squeeze()
307
+ for i in range(n_timestep):
308
+ if t == 0 and i == 0:
309
+ prev_chord = prediction[i].item()
310
+ continue
311
+ if prediction[i].item() != prev_chord:
312
+ lines.append(
313
+ '%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord]))
314
+ start_time = time_unit * (n_timestep * t + i)
315
+ prev_chord = prediction[i].item()
316
+ if t == num_instance - 1 and i + num_pad == n_timestep:
317
+ if start_time != time_unit * (n_timestep * t + i):
318
+ lines.append('%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord]))
319
+ break
320
+
321
+ save_path = os.path.join(feature_dir, os.path.split(audio_path)[-1].replace('.mp3', '').replace('.wav', '') + '.lab')
322
+ with open(save_path, 'w') as f:
323
+ for line in lines:
324
+ f.write(line)
325
+
326
+ # logger.info("label file saved : %s" % save_path)
327
+
328
+ # lab file to midi file
329
+ starts, ends, pitchs = list(), list(), list()
330
+
331
+ intervals, chords = mir_eval.io.load_labeled_intervals(save_path)
332
+ for p in range(12):
333
+ for i, (interval, chord) in enumerate(zip(intervals, chords)):
334
+ root_num, relative_bitmap, _ = mir_eval.chord.encode(chord)
335
+ tmp_label = mir_eval.chord.rotate_bitmap_to_root(relative_bitmap, root_num)[p]
336
+ if i == 0:
337
+ start_time = interval[0]
338
+ label = tmp_label
339
+ continue
340
+ if tmp_label != label:
341
+ if label == 1.0:
342
+ starts.append(start_time), ends.append(interval[0]), pitchs.append(p + 48)
343
+ start_time = interval[0]
344
+ label = tmp_label
345
+ if i == (len(intervals) - 1):
346
+ if label == 1.0:
347
+ starts.append(start_time), ends.append(interval[1]), pitchs.append(p + 48)
348
+
349
+ midi = pm.PrettyMIDI()
350
+ instrument = pm.Instrument(program=0)
351
+
352
+ for start, end, pitch in zip(starts, ends, pitchs):
353
+ pm_note = pm.Note(velocity=120, pitch=pitch, start=start, end=end)
354
+ instrument.notes.append(pm_note)
355
+
356
+ midi.instruments.append(instrument)
357
+ midi.write(save_path.replace('.lab', '.midi'))
358
+
359
+ tonic_signatures = ["A", "A#", "B", "C", "C#", "D", "D#", "E", "F", "F#", "G", "G#"]
360
+ mode_signatures = ["major", "minor"] # Major and minor modes
361
+
362
+ tonic_to_idx = {tonic: idx for idx, tonic in enumerate(tonic_signatures)}
363
+ mode_to_idx = {mode: idx for idx, mode in enumerate(mode_signatures)}
364
+ idx_to_tonic = {idx: tonic for tonic, idx in tonic_to_idx.items()}
365
+ idx_to_mode = {idx: mode for mode, idx in mode_to_idx.items()}
366
+
367
+ with open('inference/data/chord.json', 'r') as f:
368
+ chord_to_idx = json.load(f)
369
+ with open('inference/data/chord_inv.json', 'r') as f:
370
+ idx_to_chord = json.load(f)
371
+ idx_to_chord = {int(k): v for k, v in idx_to_chord.items()} # Ensure keys are ints
372
+ with open('inference/data/chord_root.json') as json_file:
373
+ chordRootDic = json.load(json_file)
374
+ with open('inference/data/chord_attr.json') as json_file:
375
+ chordAttrDic = json.load(json_file)
376
+
377
+ try:
378
+ midi_file = converter.parse(save_path.replace('.lab', '.midi'))
379
+ key_signature = str(midi_file.analyze('key'))
380
+ except Exception as e:
381
+ key_signature = "None"
382
+
383
+ key_parts = key_signature.split()
384
+ key_signature = sanitize_key_signature(key_parts[0]) # Sanitize key signature
385
+ key_type = key_parts[1] if len(key_parts) > 1 else 'major'
386
+
387
+ # --- Key feature (Tonic and Mode separation) ---
388
+ if key_signature == "None":
389
+ mode = "major"
390
+ else:
391
+ mode = key_signature.split()[-1]
392
+
393
+ encoded_mode = mode_to_idx.get(mode, 0)
394
+ mode_tensor = torch.tensor([encoded_mode], dtype=torch.long).to(self.device)
395
+
396
+ converted_lines = normalize_chord(save_path, key_signature, key_type)
397
+
398
+ lab_norm_path = save_path[:-4] + "_norm.lab"
399
+
400
+ # Write the converted lines to the new file
401
+ with open(lab_norm_path, 'w') as f:
402
+ f.writelines(converted_lines)
403
+
404
+ chords = []
405
+
406
+ if not os.path.exists(lab_norm_path):
407
+ chords.append((float(0), float(0), "N"))
408
+ else:
409
+ with open(lab_norm_path, 'r') as file:
410
+ for line in file:
411
+ start, end, chord = line.strip().split()
412
+ chords.append((float(start), float(end), chord))
413
+
414
+ encoded = []
415
+ encoded_root= []
416
+ encoded_attr=[]
417
+ durations = []
418
+
419
+ for start, end, chord in chords:
420
+ chord_arr = chord.split(":")
421
+ if len(chord_arr) == 1:
422
+ chordRootID = chordRootDic[chord_arr[0]]
423
+ if chord_arr[0] == "N" or chord_arr[0] == "X":
424
+ chordAttrID = 0
425
+ else:
426
+ chordAttrID = 1
427
+ elif len(chord_arr) == 2:
428
+ chordRootID = chordRootDic[chord_arr[0]]
429
+ chordAttrID = chordAttrDic[chord_arr[1]]
430
+ encoded_root.append(chordRootID)
431
+ encoded_attr.append(chordAttrID)
432
+
433
+ if chord in chord_to_idx:
434
+ encoded.append(chord_to_idx[chord])
435
+ else:
436
+ print(f"Warning: Chord {chord} not found in chord.json. Skipping.")
437
+
438
+ durations.append(end - start) # Compute duration
439
+
440
+ encoded_chords = np.array(encoded)
441
+ encoded_chords_root = np.array(encoded_root)
442
+ encoded_chords_attr = np.array(encoded_attr)
443
+
444
+ # Maximum sequence length for chords
445
+ max_sequence_length = 100 # Define this globally or as a parameter
446
+
447
+ # Truncate or pad chord sequences
448
+ if len(encoded_chords) > max_sequence_length:
449
+ # Truncate to max length
450
+ encoded_chords = encoded_chords[:max_sequence_length]
451
+ encoded_chords_root = encoded_chords_root[:max_sequence_length]
452
+ encoded_chords_attr = encoded_chords_attr[:max_sequence_length]
453
+
454
+ else:
455
+ # Pad with zeros (padding value for chords)
456
+ padding = [0] * (max_sequence_length - len(encoded_chords))
457
+ encoded_chords = np.concatenate([encoded_chords, padding])
458
+ encoded_chords_root = np.concatenate([encoded_chords_root, padding])
459
+ encoded_chords_attr = np.concatenate([encoded_chords_attr, padding])
460
+
461
+ # Convert to tensor
462
+ chords_tensor = torch.tensor(encoded_chords, dtype=torch.long).to(self.device)
463
+ chords_root_tensor = torch.tensor(encoded_chords_root, dtype=torch.long).to(self.device)
464
+ chords_attr_tensor = torch.tensor(encoded_chords_attr, dtype=torch.long).to(self.device)
465
+
466
+ model_input_dic = {
467
+ "x_mert": final_embedding_mert.unsqueeze(0),
468
+ "x_chord": chords_tensor.unsqueeze(0),
469
+ "x_chord_root": chords_root_tensor.unsqueeze(0),
470
+ "x_chord_attr": chords_attr_tensor.unsqueeze(0),
471
+ "x_key": mode_tensor.unsqueeze(0)
472
+ }
473
+
474
+ model_input_dic = {k: v.to(self.device) for k, v in model_input_dic.items()}
475
+ classification_output, regression_output = self.music2emo_model(model_input_dic)
476
+ probs = torch.sigmoid(classification_output)
477
+
478
+ tag_list = np.load ( "./inference/data/tag_list.npy")
479
+ tag_list = tag_list[127:]
480
+ mood_list = [t.replace("mood/theme---", "") for t in tag_list]
481
+ threshold = threshold
482
+ predicted_moods = [mood_list[i] for i, p in enumerate(probs.squeeze().tolist()) if p > threshold]
483
+ valence, arousal = regression_output.squeeze().tolist()
484
+
485
+ model_output_dic = {
486
+ "valence": valence,
487
+ "arousal": arousal,
488
+ "predicted_moods": predicted_moods
489
+ }
490
+
491
+ return model_output_dic
492
+
493
+ # Initialize Mustango
494
+ if torch.cuda.is_available():
495
+ music2emo = Music2emo()
496
+ else:
497
+ music2emo = Music2emo(device="cpu")
498
+
499
+
500
+ def format_prediction(model_output_dic):
501
+ """Format the model output in a more readable and attractive format"""
502
+ valence = model_output_dic["valence"]
503
+ arousal = model_output_dic["arousal"]
504
+ moods = model_output_dic["predicted_moods"]
505
+
506
+ # Create a formatted string with emojis and proper formatting
507
+ output_text = """
508
+ 🎵 **Music Emotion Recognition Results** 🎵
509
+ --------------------------------------------------
510
+ 🎭 **Predicted Mood Tags:** {}
511
+ 💖 **Valence:** {:.2f} (Scale: 1-9)
512
+ ⚡ **Arousal:** {:.2f} (Scale: 1-9)
513
+ --------------------------------------------------
514
+ """.format(
515
+ ', '.join(moods) if moods else 'None',
516
+ valence,
517
+ arousal
518
+ )
519
+
520
+ return output_text
521
+
522
+ title = "Music2Emo: Towards Unified Music Emotion Recognition across Dimensional and Categorical Models"
523
+ description_text = """
524
+ <p>
525
+ Upload an audio file to analyze its emotional characteristics using Music2Emo.
526
+ The model will predict:
527
+ • Mood tags describing the emotional content
528
+ • Valence score (1-9 scale, representing emotional positivity)
529
+ • Arousal score (1-9 scale, representing emotional intensity)
530
+ </p>
531
+ """
532
+
533
+ css = """
534
+ #output-text {
535
+ font-family: monospace;
536
+ white-space: pre-wrap;
537
+ font-size: 16px;
538
+ background-color: #333333;
539
+ padding: 20px;
540
+ border-radius: 10px;
541
+ margin: 10px 0;
542
+ }
543
+ .gradio-container {
544
+ font-family: 'Inter', -apple-system, system-ui, sans-serif;
545
+ }
546
+ .gr-button {
547
+ color: white;
548
+ background: #1565c0;
549
+ border-radius: 100vh;
550
+ }
551
+ """
552
+
553
+
554
+
555
+
556
+ # Initialize Music2Emo
557
+ if torch.cuda.is_available():
558
+ music2emo = Music2emo()
559
+ else:
560
+ music2emo = Music2emo(device="cpu")
561
+
562
+ with gr.Blocks(css=css) as demo:
563
+ gr.HTML(f"<h1><center>{title}</center></h1>")
564
+ gr.Markdown(description_text)
565
+
566
+ with gr.Row():
567
+ with gr.Column(scale=1):
568
+ input_audio = gr.Audio(
569
+ label="Upload Audio File",
570
+ type="filepath" # Removed 'source' parameter
571
+ )
572
+ threshold = gr.Slider(
573
+ minimum=0.0,
574
+ maximum=1.0,
575
+ value=0.5,
576
+ step=0.01,
577
+ label="Mood Detection Threshold",
578
+ info="Adjust threshold for mood detection (0.0 to 1.0)"
579
+ )
580
+ predict_btn = gr.Button("🎭 Analyze Emotions", variant="primary")
581
+
582
+ with gr.Column(scale=1):
583
+ output_text = gr.Markdown(
584
+ label="Analysis Results",
585
+ elem_id="output-text"
586
+ )
587
+
588
+ # Add example usage
589
+ gr.Examples(
590
+ examples=["inference/input/test.mp3"],
591
+ inputs=input_audio,
592
+ outputs=output_text,
593
+ fn=lambda x: format_prediction(music2emo.predict(x, 0.5)),
594
+ cache_examples=True
595
+ )
596
+
597
+ predict_btn.click(
598
+ fn=lambda audio, thresh: format_prediction(music2emo.predict(audio, thresh)),
599
+ inputs=[input_audio, threshold],
600
+ outputs=output_text
601
+ )
602
+
603
+ gr.Markdown("""
604
+ ### 📝 Notes:
605
+ - Supported audio formats: MP3, WAV
606
+ - For best results, use high-quality audio files
607
+ - Processing may take a few moments depending on file size
608
+ """)
609
+
610
+ # Launch the demo
611
+ demo.queue().launch()