kjysmu commited on
Commit
e163262
·
verified ·
1 Parent(s): fa556e0

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -810
app.py DELETED
@@ -1,810 +0,0 @@
1
- import os
2
- import shutil
3
- import json
4
- import torch
5
- import torchaudio
6
- import numpy as np
7
- import logging
8
- import warnings
9
- import subprocess
10
- import math
11
- import random
12
- import time
13
- from pathlib import Path
14
- from tqdm import tqdm
15
- from PIL import Image
16
- from huggingface_hub import snapshot_download
17
- from omegaconf import DictConfig
18
- import hydra
19
- from hydra.utils import to_absolute_path
20
- from transformers import Wav2Vec2FeatureExtractor, AutoModel
21
- import mir_eval
22
- import pretty_midi as pm
23
- import gradio as gr
24
- from gradio import Markdown
25
- from music21 import converter
26
- import torchaudio.transforms as T
27
-
28
- # Custom utility imports
29
- from utils import logger
30
- from utils.btc_model import BTC_model
31
- from utils.transformer_modules import *
32
- from utils.transformer_modules import _gen_timing_signal, _gen_bias_mask
33
- from utils.hparams import HParams
34
- from utils.mir_eval_modules import (
35
- audio_file_to_features, idx2chord, idx2voca_chord,
36
- get_audio_paths, get_lab_paths
37
- )
38
- from utils.mert import FeatureExtractorMERT
39
- from model.linear_mt_attn_ck import FeedforwardModelMTAttnCK
40
-
41
- # Suppress unnecessary warnings and logs
42
- warnings.filterwarnings("ignore")
43
- logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
44
-
45
- # from gradio import Markdown
46
-
47
- PITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
48
-
49
- pitch_num_dic = {
50
- 'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5,
51
- 'F#': 6, 'G': 7, 'G#': 8, 'A': 9, 'A#': 10, 'B': 11
52
- }
53
-
54
- minor_major_dic = {
55
- 'D-':'C#', 'E-':'D#', 'G-':'F#', 'A-':'G#', 'B-':'A#'
56
- }
57
- minor_major_dic2 = {
58
- 'Db':'C#', 'Eb':'D#', 'Gb':'F#', 'Ab':'G#', 'Bb':'A#'
59
- }
60
-
61
- shift_major_dic = {
62
- 'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5,
63
- 'F#': 6, 'G': 7, 'G#': 8, 'A': 9, 'A#': 10, 'B': 11
64
- }
65
-
66
- shift_minor_dic = {
67
- 'A': 0, 'A#': 1, 'B': 2, 'C': 3, 'C#': 4, 'D': 5,
68
- 'D#': 6, 'E': 7, 'F': 8, 'F#': 9, 'G': 10, 'G#': 11,
69
- }
70
-
71
- flat_to_sharp_mapping = {
72
- "Cb": "B",
73
- "Db": "C#",
74
- "Eb": "D#",
75
- "Fb": "E",
76
- "Gb": "F#",
77
- "Ab": "G#",
78
- "Bb": "A#"
79
- }
80
-
81
- segment_duration = 30
82
- resample_rate = 24000
83
- is_split = True
84
-
85
- def normalize_chord(file_path, key, key_type='major'):
86
- with open(file_path, 'r') as f:
87
- lines = f.readlines()
88
-
89
- if key == "None":
90
- new_key = "C major"
91
- shift = 0
92
- else:
93
- #print ("asdas",key)
94
- if len(key) == 1:
95
- key = key[0].upper()
96
- else:
97
- key = key[0].upper() + key[1:]
98
-
99
- if key in minor_major_dic2:
100
- key = minor_major_dic2[key]
101
-
102
- shift = 0
103
-
104
- if key_type == "major":
105
- new_key = "C major"
106
-
107
- shift = shift_major_dic[key]
108
- else:
109
- new_key = "A minor"
110
- shift = shift_minor_dic[key]
111
-
112
- converted_lines = []
113
- for line in lines:
114
- if line.strip(): # Skip empty lines
115
- parts = line.split()
116
- start_time = parts[0]
117
- end_time = parts[1]
118
- chord = parts[2] # The chord is in the 3rd column
119
- if chord == "N":
120
- newchordnorm = "N"
121
- elif chord == "X":
122
- newchordnorm = "X"
123
- elif ":" in chord:
124
- pitch = chord.split(":")[0]
125
- attr = chord.split(":")[1]
126
- pnum = pitch_num_dic [pitch]
127
- new_idx = (pnum - shift)%12
128
- newchord = PITCH_CLASS[new_idx]
129
- newchordnorm = newchord + ":" + attr
130
- else:
131
- pitch = chord
132
- pnum = pitch_num_dic [pitch]
133
- new_idx = (pnum - shift)%12
134
- newchord = PITCH_CLASS[new_idx]
135
- newchordnorm = newchord
136
-
137
- converted_lines.append(f"{start_time} {end_time} {newchordnorm}\n")
138
-
139
- return converted_lines
140
-
141
- def sanitize_key_signature(key):
142
- return key.replace('-', 'b')
143
-
144
- def resample_waveform(waveform, original_sample_rate, target_sample_rate):
145
- if original_sample_rate != target_sample_rate:
146
- resampler = T.Resample(original_sample_rate, target_sample_rate)
147
- return resampler(waveform), target_sample_rate
148
- return waveform, original_sample_rate
149
-
150
- def split_audio(waveform, sample_rate):
151
- segment_samples = segment_duration * sample_rate
152
- total_samples = waveform.size(0)
153
-
154
- segments = []
155
- for start in range(0, total_samples, segment_samples):
156
- end = start + segment_samples
157
- if end <= total_samples:
158
- segment = waveform[start:end]
159
- segments.append(segment)
160
-
161
- # In case audio length is shorter than segment length.
162
- if len(segments) == 0:
163
- segment = waveform
164
- segments.append(segment)
165
-
166
- return segments
167
-
168
-
169
-
170
- class Music2emo:
171
- def __init__(
172
- self,
173
- name="amaai-lab/music2emo",
174
- device="cuda:0",
175
- cache_dir=None,
176
- local_files_only=False,
177
- ):
178
-
179
- # use_cuda = torch.cuda.is_available()
180
- # self.device = torch.device("cuda" if use_cuda else "cpu")
181
- model_weights = "saved_models/J_all.ckpt"
182
- self.device = device
183
-
184
- self.feature_extractor = FeatureExtractorMERT(model_name='m-a-p/MERT-v1-95M', device=self.device, sr=resample_rate)
185
- self.model_weights = model_weights
186
-
187
- self.music2emo_model = FeedforwardModelMTAttnCK(
188
- input_size= 768 * 2,
189
- output_size_classification=56,
190
- output_size_regression=2
191
- )
192
-
193
- checkpoint = torch.load(self.model_weights, map_location=self.device, weights_only=False)
194
- state_dict = checkpoint["state_dict"]
195
-
196
- # Adjust the keys in the state_dict
197
- state_dict = {key.replace("model.", ""): value for key, value in state_dict.items()}
198
-
199
- # Filter state_dict to match model's keys
200
- model_keys = set(self.music2emo_model.state_dict().keys())
201
- filtered_state_dict = {key: value for key, value in state_dict.items() if key in model_keys}
202
-
203
- # Load the filtered state_dict and set the model to evaluation mode
204
- self.music2emo_model.load_state_dict(filtered_state_dict)
205
-
206
- self.music2emo_model.to(self.device)
207
- self.music2emo_model.eval()
208
-
209
- def predict(self, audio, threshold = 0.5):
210
-
211
- feature_dir = Path("./temp_out")
212
- output_dir = Path("./output")
213
-
214
- if feature_dir.exists():
215
- shutil.rmtree(str(feature_dir))
216
- if output_dir.exists():
217
- shutil.rmtree(str(output_dir))
218
-
219
- feature_dir.mkdir(parents=True)
220
- output_dir.mkdir(parents=True)
221
-
222
- warnings.filterwarnings('ignore')
223
- logger.logging_verbosity(1)
224
-
225
- mert_dir = feature_dir / "mert"
226
- mert_dir.mkdir(parents=True)
227
-
228
- waveform, sample_rate = torchaudio.load(audio)
229
- if waveform.shape[0] > 1:
230
- waveform = waveform.mean(dim=0).unsqueeze(0)
231
- waveform = waveform.squeeze()
232
- waveform, sample_rate = resample_waveform(waveform, sample_rate, resample_rate)
233
-
234
- if is_split:
235
- segments = split_audio(waveform, sample_rate)
236
- for i, segment in enumerate(segments):
237
- segment_save_path = os.path.join(mert_dir, f"segment_{i}.npy")
238
- self.feature_extractor.extract_features_from_segment(segment, sample_rate, segment_save_path)
239
- else:
240
- segment_save_path = os.path.join(mert_dir, f"segment_0.npy")
241
- self.feature_extractor.extract_features_from_segment(waveform, sample_rate, segment_save_path)
242
-
243
- embeddings = []
244
- layers_to_extract = [5,6]
245
- segment_embeddings = []
246
- for filename in sorted(os.listdir(mert_dir)): # Sort files to ensure sequential order
247
- file_path = os.path.join(mert_dir, filename)
248
- if os.path.isfile(file_path) and filename.endswith('.npy'):
249
- segment = np.load(file_path)
250
- concatenated_features = np.concatenate(
251
- [segment[:, layer_idx, :] for layer_idx in layers_to_extract], axis=1
252
- )
253
- concatenated_features = np.squeeze(concatenated_features) # Shape: 768 * 2 = 1536
254
- segment_embeddings.append(concatenated_features)
255
-
256
- segment_embeddings = np.array(segment_embeddings)
257
- if len(segment_embeddings) > 0:
258
- final_embedding_mert = np.mean(segment_embeddings, axis=0)
259
- else:
260
- final_embedding_mert = np.zeros((1536,))
261
-
262
- final_embedding_mert = torch.from_numpy(final_embedding_mert)
263
- final_embedding_mert.to(self.device)
264
-
265
- # --- Chord feature extract ---
266
- config = HParams.load("./inference/data/run_config.yaml")
267
- config.feature['large_voca'] = True
268
- config.model['num_chords'] = 170
269
- model_file = './inference/data/btc_model_large_voca.pt'
270
- idx_to_chord = idx2voca_chord()
271
- model = BTC_model(config=config.model).to(self.device)
272
-
273
- if os.path.isfile(model_file):
274
- checkpoint = torch.load(model_file)
275
- mean = checkpoint['mean']
276
- std = checkpoint['std']
277
- model.load_state_dict(checkpoint['model'])
278
-
279
- audio_path = audio
280
- audio_id = audio_path.split("/")[-1][:-4]
281
- try:
282
- feature, feature_per_second, song_length_second = audio_file_to_features(audio_path, config)
283
- except:
284
- logger.info("audio file failed to load : %s" % audio_path)
285
- assert(False)
286
-
287
- logger.info("audio file loaded and feature computation success : %s" % audio_path)
288
-
289
- feature = feature.T
290
- feature = (feature - mean) / std
291
- time_unit = feature_per_second
292
- n_timestep = config.model['timestep']
293
-
294
- num_pad = n_timestep - (feature.shape[0] % n_timestep)
295
- feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
296
- num_instance = feature.shape[0] // n_timestep
297
-
298
- start_time = 0.0
299
- lines = []
300
- with torch.no_grad():
301
- model.eval()
302
- feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(self.device)
303
- for t in range(num_instance):
304
- self_attn_output, _ = model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :])
305
- prediction, _ = model.output_layer(self_attn_output)
306
- prediction = prediction.squeeze()
307
- for i in range(n_timestep):
308
- if t == 0 and i == 0:
309
- prev_chord = prediction[i].item()
310
- continue
311
- if prediction[i].item() != prev_chord:
312
- lines.append(
313
- '%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord]))
314
- start_time = time_unit * (n_timestep * t + i)
315
- prev_chord = prediction[i].item()
316
- if t == num_instance - 1 and i + num_pad == n_timestep:
317
- if start_time != time_unit * (n_timestep * t + i):
318
- lines.append('%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord]))
319
- break
320
-
321
- save_path = os.path.join(feature_dir, os.path.split(audio_path)[-1].replace('.mp3', '').replace('.wav', '') + '.lab')
322
- with open(save_path, 'w') as f:
323
- for line in lines:
324
- f.write(line)
325
-
326
- # logger.info("label file saved : %s" % save_path)
327
-
328
- # lab file to midi file
329
- starts, ends, pitchs = list(), list(), list()
330
-
331
- intervals, chords = mir_eval.io.load_labeled_intervals(save_path)
332
- for p in range(12):
333
- for i, (interval, chord) in enumerate(zip(intervals, chords)):
334
- root_num, relative_bitmap, _ = mir_eval.chord.encode(chord)
335
- tmp_label = mir_eval.chord.rotate_bitmap_to_root(relative_bitmap, root_num)[p]
336
- if i == 0:
337
- start_time = interval[0]
338
- label = tmp_label
339
- continue
340
- if tmp_label != label:
341
- if label == 1.0:
342
- starts.append(start_time), ends.append(interval[0]), pitchs.append(p + 48)
343
- start_time = interval[0]
344
- label = tmp_label
345
- if i == (len(intervals) - 1):
346
- if label == 1.0:
347
- starts.append(start_time), ends.append(interval[1]), pitchs.append(p + 48)
348
-
349
- midi = pm.PrettyMIDI()
350
- instrument = pm.Instrument(program=0)
351
-
352
- for start, end, pitch in zip(starts, ends, pitchs):
353
- pm_note = pm.Note(velocity=120, pitch=pitch, start=start, end=end)
354
- instrument.notes.append(pm_note)
355
-
356
- midi.instruments.append(instrument)
357
- midi.write(save_path.replace('.lab', '.midi'))
358
-
359
- tonic_signatures = ["A", "A#", "B", "C", "C#", "D", "D#", "E", "F", "F#", "G", "G#"]
360
- mode_signatures = ["major", "minor"] # Major and minor modes
361
-
362
- tonic_to_idx = {tonic: idx for idx, tonic in enumerate(tonic_signatures)}
363
- mode_to_idx = {mode: idx for idx, mode in enumerate(mode_signatures)}
364
- idx_to_tonic = {idx: tonic for tonic, idx in tonic_to_idx.items()}
365
- idx_to_mode = {idx: mode for mode, idx in mode_to_idx.items()}
366
-
367
- with open('inference/data/chord.json', 'r') as f:
368
- chord_to_idx = json.load(f)
369
- with open('inference/data/chord_inv.json', 'r') as f:
370
- idx_to_chord = json.load(f)
371
- idx_to_chord = {int(k): v for k, v in idx_to_chord.items()} # Ensure keys are ints
372
- with open('inference/data/chord_root.json') as json_file:
373
- chordRootDic = json.load(json_file)
374
- with open('inference/data/chord_attr.json') as json_file:
375
- chordAttrDic = json.load(json_file)
376
-
377
- try:
378
- midi_file = converter.parse(save_path.replace('.lab', '.midi'))
379
- key_signature = str(midi_file.analyze('key'))
380
- except Exception as e:
381
- key_signature = "None"
382
-
383
- key_parts = key_signature.split()
384
- key_signature = sanitize_key_signature(key_parts[0]) # Sanitize key signature
385
- key_type = key_parts[1] if len(key_parts) > 1 else 'major'
386
-
387
- # --- Key feature (Tonic and Mode separation) ---
388
- if key_signature == "None":
389
- mode = "major"
390
- else:
391
- mode = key_signature.split()[-1]
392
-
393
- encoded_mode = mode_to_idx.get(mode, 0)
394
- mode_tensor = torch.tensor([encoded_mode], dtype=torch.long).to(self.device)
395
-
396
- converted_lines = normalize_chord(save_path, key_signature, key_type)
397
-
398
- lab_norm_path = save_path[:-4] + "_norm.lab"
399
-
400
- # Write the converted lines to the new file
401
- with open(lab_norm_path, 'w') as f:
402
- f.writelines(converted_lines)
403
-
404
- chords = []
405
-
406
- if not os.path.exists(lab_norm_path):
407
- chords.append((float(0), float(0), "N"))
408
- else:
409
- with open(lab_norm_path, 'r') as file:
410
- for line in file:
411
- start, end, chord = line.strip().split()
412
- chords.append((float(start), float(end), chord))
413
-
414
- encoded = []
415
- encoded_root= []
416
- encoded_attr=[]
417
- durations = []
418
-
419
- for start, end, chord in chords:
420
- chord_arr = chord.split(":")
421
- if len(chord_arr) == 1:
422
- chordRootID = chordRootDic[chord_arr[0]]
423
- if chord_arr[0] == "N" or chord_arr[0] == "X":
424
- chordAttrID = 0
425
- else:
426
- chordAttrID = 1
427
- elif len(chord_arr) == 2:
428
- chordRootID = chordRootDic[chord_arr[0]]
429
- chordAttrID = chordAttrDic[chord_arr[1]]
430
- encoded_root.append(chordRootID)
431
- encoded_attr.append(chordAttrID)
432
-
433
- if chord in chord_to_idx:
434
- encoded.append(chord_to_idx[chord])
435
- else:
436
- print(f"Warning: Chord {chord} not found in chord.json. Skipping.")
437
-
438
- durations.append(end - start) # Compute duration
439
-
440
- encoded_chords = np.array(encoded)
441
- encoded_chords_root = np.array(encoded_root)
442
- encoded_chords_attr = np.array(encoded_attr)
443
-
444
- # Maximum sequence length for chords
445
- max_sequence_length = 100 # Define this globally or as a parameter
446
-
447
- # Truncate or pad chord sequences
448
- if len(encoded_chords) > max_sequence_length:
449
- # Truncate to max length
450
- encoded_chords = encoded_chords[:max_sequence_length]
451
- encoded_chords_root = encoded_chords_root[:max_sequence_length]
452
- encoded_chords_attr = encoded_chords_attr[:max_sequence_length]
453
-
454
- else:
455
- # Pad with zeros (padding value for chords)
456
- padding = [0] * (max_sequence_length - len(encoded_chords))
457
- encoded_chords = np.concatenate([encoded_chords, padding])
458
- encoded_chords_root = np.concatenate([encoded_chords_root, padding])
459
- encoded_chords_attr = np.concatenate([encoded_chords_attr, padding])
460
-
461
- # Convert to tensor
462
- chords_tensor = torch.tensor(encoded_chords, dtype=torch.long).to(self.device)
463
- chords_root_tensor = torch.tensor(encoded_chords_root, dtype=torch.long).to(self.device)
464
- chords_attr_tensor = torch.tensor(encoded_chords_attr, dtype=torch.long).to(self.device)
465
-
466
- model_input_dic = {
467
- "x_mert": final_embedding_mert.unsqueeze(0),
468
- "x_chord": chords_tensor.unsqueeze(0),
469
- "x_chord_root": chords_root_tensor.unsqueeze(0),
470
- "x_chord_attr": chords_attr_tensor.unsqueeze(0),
471
- "x_key": mode_tensor.unsqueeze(0)
472
- }
473
-
474
- model_input_dic = {k: v.to(self.device) for k, v in model_input_dic.items()}
475
- classification_output, regression_output = self.music2emo_model(model_input_dic)
476
- probs = torch.sigmoid(classification_output)
477
-
478
- tag_list = np.load ( "./inference/data/tag_list.npy")
479
- tag_list = tag_list[127:]
480
- mood_list = [t.replace("mood/theme---", "") for t in tag_list]
481
- threshold = threshold
482
- predicted_moods = [mood_list[i] for i, p in enumerate(probs.squeeze().tolist()) if p > threshold]
483
- valence, arousal = regression_output.squeeze().tolist()
484
-
485
- model_output_dic = {
486
- "valence": valence,
487
- "arousal": arousal,
488
- "predicted_moods": predicted_moods
489
- }
490
-
491
- return model_output_dic
492
-
493
- # Initialize Mustango
494
- if torch.cuda.is_available():
495
- music2emo = Music2emo()
496
- else:
497
- music2emo = Music2emo(device="cpu")
498
-
499
-
500
- def format_prediction(model_output_dic):
501
- """Format the model output in a more readable and attractive format"""
502
- valence = model_output_dic["valence"]
503
- arousal = model_output_dic["arousal"]
504
- moods = model_output_dic["predicted_moods"]
505
-
506
- # Create a formatted string with emojis and proper formatting
507
- output_text = """
508
- 🎵 **Music Emotion Recognition Results** 🎵
509
- --------------------------------------------------
510
- 🎭 **Predicted Mood Tags:** {}
511
- 💖 **Valence:** {:.2f} (Scale: 1-9)
512
- ⚡ **Arousal:** {:.2f} (Scale: 1-9)
513
- --------------------------------------------------
514
- """.format(
515
- ', '.join(moods) if moods else 'None',
516
- valence,
517
- arousal
518
- )
519
-
520
- return output_text
521
-
522
- title = "Music2Emo: Towards Unified Music Emotion Recognition across Dimensional and Categorical Models"
523
- description_text = """
524
- <p>
525
- Upload an audio file to analyze its emotional characteristics using Music2Emo.
526
- The model will predict:
527
- • Mood tags describing the emotional content
528
- • Valence score (1-9 scale, representing emotional positivity)
529
- • Arousal score (1-9 scale, representing emotional intensity)
530
- </p>
531
- """
532
-
533
- css = """
534
- #output-text {
535
- font-family: monospace;
536
- white-space: pre-wrap;
537
- font-size: 16px;
538
- background-color: #333333;
539
- padding: 20px;
540
- border-radius: 10px;
541
- margin: 10px 0;
542
- }
543
- .gradio-container {
544
- font-family: 'Inter', -apple-system, system-ui, sans-serif;
545
- }
546
- .gr-button {
547
- color: white;
548
- background: #1565c0;
549
- border-radius: 100vh;
550
- }
551
- """
552
-
553
-
554
-
555
-
556
- # Initialize Music2Emo
557
- if torch.cuda.is_available():
558
- music2emo = Music2emo()
559
- else:
560
- music2emo = Music2emo(device="cpu")
561
-
562
- with gr.Blocks(css=css) as demo:
563
- gr.HTML(f"<h1><center>{title}</center></h1>")
564
- gr.Markdown(description_text)
565
-
566
- with gr.Row():
567
- with gr.Column(scale=1):
568
- input_audio = gr.Audio(
569
- label="Upload Audio File",
570
- type="filepath" # Removed 'source' parameter
571
- )
572
- threshold = gr.Slider(
573
- minimum=0.0,
574
- maximum=1.0,
575
- value=0.5,
576
- step=0.01,
577
- label="Mood Detection Threshold",
578
- info="Adjust threshold for mood detection (0.0 to 1.0)"
579
- )
580
- predict_btn = gr.Button("🎭 Analyze Emotions", variant="primary")
581
-
582
- with gr.Column(scale=1):
583
- output_text = gr.Markdown(
584
- label="Analysis Results",
585
- elem_id="output-text"
586
- )
587
-
588
- # Add example usage
589
- gr.Examples(
590
- examples=["inference/input/test.mp3"],
591
- inputs=input_audio,
592
- outputs=output_text,
593
- fn=lambda x: format_prediction(music2emo.predict(x, 0.5)),
594
- cache_examples=True
595
- )
596
-
597
- predict_btn.click(
598
- fn=lambda audio, thresh: format_prediction(music2emo.predict(audio, thresh)),
599
- inputs=[input_audio, threshold],
600
- outputs=output_text
601
- )
602
-
603
- gr.Markdown("""
604
- ### 📝 Notes:
605
- - Supported audio formats: MP3, WAV
606
- - For best results, use high-quality audio files
607
- - Processing may take a few moments depending on file size
608
- """)
609
-
610
- # Launch the demo
611
- demo.queue().launch()
612
-
613
- # with gr.Blocks(css=css) as demo:
614
- # gr.HTML(f"<h1><center>{title}</center></h1>")
615
- # gr.Markdown(description_text)
616
-
617
- # with gr.Row():
618
- # with gr.Column(scale=1):
619
- # input_audio = gr.Audio(
620
- # label="Upload Audio File",
621
- # type="filepath",
622
- # source="upload"
623
- # )
624
- # threshold = gr.Slider(
625
- # minimum=0.0,
626
- # maximum=1.0,
627
- # value=0.5,
628
- # step=0.01,
629
- # label="Mood Detection Threshold",
630
- # info="Adjust threshold for mood detection (0.0 to 1.0)"
631
- # )
632
- # predict_btn = gr.Button("🎭 Analyze Emotions", variant="primary")
633
-
634
- # with gr.Column(scale=1):
635
- # output_text = gr.Markdown(
636
- # label="Analysis Results",
637
- # elem_id="output-text"
638
- # )
639
-
640
- # # Add example usage
641
- # gr.Examples(
642
- # examples=["inference/input/test.mp3"],
643
- # inputs=input_audio,
644
- # outputs=output_text,
645
- # fn=lambda x: format_prediction(music2emo.predict(x, 0.5)),
646
- # cache_examples=True
647
- # )
648
-
649
- # predict_btn.click(
650
- # fn=lambda audio, thresh: format_prediction(music2emo.predict(audio, thresh)),
651
- # inputs=[input_audio, threshold],
652
- # outputs=output_text
653
- # )
654
-
655
- # gr.Markdown("""
656
- # ### 📝 Notes:
657
- # - Supported audio formats: MP3, WAV
658
- # - For best results, use high-quality audio files
659
- # - Processing may take a few moments depending on file size
660
- # """)
661
-
662
- # # Launch the demo
663
- # demo.queue().launch()
664
-
665
-
666
- # def gradio_predict(input_audio, threshold):
667
- # model_output_dic = music2emo.predict(input_audio, threshold)
668
- # return model_output_dic
669
-
670
-
671
- # def format_prediction(model_output_dic):
672
- # """Format the model output for display"""
673
- # valence = model_output_dic["valence"]
674
- # arousal = model_output_dic["arousal"]
675
- # moods = model_output_dic["predicted_moods"]
676
-
677
- # # Format the output as a dictionary for the JSON component
678
- # formatted_output = {
679
- # "Dimensional Scores": {
680
- # "Valence": f"{valence:.3f}",
681
- # "Arousal": f"{arousal:.3f}"
682
- # },
683
- # "Predicted Moods": moods
684
- # }
685
-
686
- # return formatted_output
687
-
688
- # title = "Music2Emo: Towards Unified Music Emotion Recognition across Dimensional and Categorical Models"
689
- # description_text = """
690
- # <p>
691
- # Predict emotion using Music2Emo by providing an input audio.
692
- # <br/><br/> This is the demo for Music2Emo: Towards Unified Music Emotion Recognition across Dimensional and Categorical Models
693
- # <a href="https://arxiv.org/abs/2502.03979">Read our paper.</a>
694
- # </p>
695
- # """
696
-
697
- # css = '''
698
- # #duplicate-button {
699
- # margin: auto;
700
- # color: white;
701
- # background: #1565c0;
702
- # border-radius: 100vh;
703
- # }
704
- # '''
705
-
706
- # # Initialize Music2Emo
707
- # if torch.cuda.is_available():
708
- # music2emo = Music2emo()
709
- # else:
710
- # music2emo = Music2emo(device="cpu")
711
-
712
-
713
-
714
- # with gr.Blocks(css=css) as demo:
715
- # title = gr.HTML(f"<h1><center>{title}</center></h1>")
716
- # gr.Markdown(
717
- # """
718
- # This is the demo for Music2Emo: Towards Unified Music Emotion Recognition across Dimensional and Categorical Models.
719
- # [Read our paper](https://arxiv.org/abs/2502.03979).
720
- # """
721
- # )
722
-
723
- # with gr.Row():
724
- # with gr.Column():
725
- # with gr.Column(visible=True) as rowA:
726
- # with gr.Row():
727
- # input_audio = gr.Audio(
728
- # label="Input Audio",
729
- # type="filepath",
730
- # source="upload"
731
- # )
732
- # with gr.Row():
733
- # threshold = gr.Slider(
734
- # minimum=0.0,
735
- # maximum=1.0,
736
- # value=0.5,
737
- # step=0.01,
738
- # label="Mood Detection Threshold",
739
- # info="Adjust threshold for mood detection (0.0 to 1.0)"
740
- # )
741
- # with gr.Row():
742
- # btn = gr.Button("Predict", variant="primary")
743
-
744
- # with gr.Column():
745
- # with gr.Row():
746
- # output_emo = gr.JSON(
747
- # label="Prediction Results",
748
- # info="Displays valence, arousal scores and predicted moods"
749
- # )
750
-
751
- # btn.click(
752
- # fn=lambda audio, thresh: format_prediction(music2emo.predict(audio, thresh)),
753
- # inputs=[input_audio, threshold],
754
- # outputs=[output_emo],
755
- # )
756
-
757
- # # Launch the demo
758
- # demo.queue().launch()
759
-
760
- # title="Music2Emo: Towards Unified Music Emotion Recognition across Dimensional and Categorical Models"
761
- # description_text = """
762
- # <p>
763
- # Predict emotion using Music2Emo by providing an input audio.
764
- # <br/><br/> This is the demo for Music2Emo: Towards Unified Music Emotion Recognition across Dimensional and Categorical Models
765
- # <a href="https://arxiv.org/abs/2502.03979">Read our paper.</a>
766
- # <p/>
767
- # """
768
-
769
-
770
- # css = '''
771
- # #duplicate-button {
772
- # margin: auto;
773
- # color: white;
774
- # background: #1565c0;
775
- # border-radius: 100vh;
776
- # }
777
- # '''
778
- # # with gr.Blocks() as demo:
779
- # with gr.Blocks(css=css) as demo:
780
- # title=gr.HTML(f"<h1><center>{title}</center></h1>")
781
- # gr.Markdown(
782
- # """
783
- # This is the demo for Music2Emo: Towards Unified Music Emotion Recognition across Dimensional and Categorical Models.
784
- # [Read our paper](https://arxiv.org/abs/2502.03979).
785
- # """
786
- # )
787
- # with gr.Row():
788
- # with gr.Column():
789
- # # with gr.Row(visible=True) as mainA:
790
- # # with gr.Column(visible=True) as colA:
791
- # with gr.Column(visible=True) as rowA:
792
- # with gr.Row():
793
- # input_audio = ???
794
- # with gr.Row():
795
- # with gr.Row():
796
- # threshold = ???
797
- # with gr.Row():
798
- # btn = gr.Button("Predict")
799
-
800
- # with gr.Column():
801
- # with gr.Row():
802
- # output_emo = gr.Label ???
803
-
804
- # btn.click(
805
- # fn=gradio_predict,
806
- # inputs=[input_audio,threshold],
807
- # outputs=[output_emo],
808
- # )
809
-
810
- # demo.queue().launch()