ayousanz commited on
Commit
babf34b
·
verified ·
1 Parent(s): 537c090

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -327
app.py CHANGED
@@ -24,8 +24,9 @@ import gradio as gr
24
  from gradio import Markdown
25
  from music21 import converter
26
  import torchaudio.transforms as T
 
27
 
28
- # Custom utility imports
29
  from utils import logger
30
  from utils.btc_model import BTC_model
31
  from utils.transformer_modules import *
@@ -38,20 +39,13 @@ from utils.mir_eval_modules import (
38
  from utils.mert import FeatureExtractorMERT
39
  from model.linear_mt_attn_ck import FeedforwardModelMTAttnCK
40
 
41
- import matplotlib.pyplot as plt
42
-
43
-
44
- # Suppress unnecessary warnings and logs
45
  warnings.filterwarnings("ignore")
46
  logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
47
 
48
- # from gradio import Markdown
49
-
50
  PITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
51
-
52
  tonic_signatures = ["A", "A#", "B", "C", "C#", "D", "D#", "E", "F", "F#", "G", "G#"]
53
- mode_signatures = ["major", "minor"] # Major and minor modes
54
-
55
 
56
  pitch_num_dic = {
57
  'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5,
@@ -97,20 +91,15 @@ def normalize_chord(file_path, key, key_type='major'):
97
  new_key = "C major"
98
  shift = 0
99
  else:
100
- #print ("asdas",key)
101
  if len(key) == 1:
102
  key = key[0].upper()
103
  else:
104
  key = key[0].upper() + key[1:]
105
-
106
  if key in minor_major_dic2:
107
  key = minor_major_dic2[key]
108
-
109
  shift = 0
110
-
111
  if key_type == "major":
112
  new_key = "C major"
113
-
114
  shift = shift_major_dic[key]
115
  else:
116
  new_key = "A minor"
@@ -118,31 +107,27 @@ def normalize_chord(file_path, key, key_type='major'):
118
 
119
  converted_lines = []
120
  for line in lines:
121
- if line.strip(): # Skip empty lines
122
  parts = line.split()
123
  start_time = parts[0]
124
  end_time = parts[1]
125
- chord = parts[2] # The chord is in the 3rd column
126
- if chord == "N":
127
- newchordnorm = "N"
128
- elif chord == "X":
129
- newchordnorm = "X"
130
  elif ":" in chord:
131
  pitch = chord.split(":")[0]
132
  attr = chord.split(":")[1]
133
- pnum = pitch_num_dic [pitch]
134
- new_idx = (pnum - shift)%12
135
  newchord = PITCH_CLASS[new_idx]
136
  newchordnorm = newchord + ":" + attr
137
  else:
138
  pitch = chord
139
- pnum = pitch_num_dic [pitch]
140
- new_idx = (pnum - shift)%12
141
  newchord = PITCH_CLASS[new_idx]
142
  newchordnorm = newchord
143
-
144
  converted_lines.append(f"{start_time} {end_time} {newchordnorm}\n")
145
-
146
  return converted_lines
147
 
148
  def sanitize_key_signature(key):
@@ -157,146 +142,108 @@ def resample_waveform(waveform, original_sample_rate, target_sample_rate):
157
  def split_audio(waveform, sample_rate):
158
  segment_samples = segment_duration * sample_rate
159
  total_samples = waveform.size(0)
160
-
161
  segments = []
162
  for start in range(0, total_samples, segment_samples):
163
  end = start + segment_samples
164
  if end <= total_samples:
165
- segment = waveform[start:end]
166
- segments.append(segment)
167
-
168
- # In case audio length is shorter than segment length.
169
  if len(segments) == 0:
170
- segment = waveform
171
- segments.append(segment)
172
-
173
  return segments
174
 
175
-
176
  def safe_remove_dir(directory):
177
- """
178
- Safely removes a directory only if it exists and is empty.
179
- """
180
  directory = Path(directory)
181
  if directory.exists():
182
  try:
183
  shutil.rmtree(directory)
184
- except FileNotFoundError:
185
- print(f"Warning: Some files in {directory} were already deleted.")
186
- except PermissionError:
187
- print(f"Warning: Permission issue encountered while deleting {directory}.")
188
  except Exception as e:
189
- print(f"Unexpected error while deleting {directory}: {e}")
190
-
191
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  class Music2emo:
193
- def __init__(
194
- self,
195
- name="amaai-lab/music2emo",
196
- device="cuda:0",
197
- cache_dir=None,
198
- local_files_only=False,
199
- ):
200
-
201
- # use_cuda = torch.cuda.is_available()
202
- # self.device = torch.device("cuda" if use_cuda else "cpu")
203
  model_weights = "saved_models/J_all.ckpt"
204
  self.device = device
205
-
206
  self.feature_extractor = FeatureExtractorMERT(model_name='m-a-p/MERT-v1-95M', device=self.device, sr=resample_rate)
207
  self.model_weights = model_weights
208
-
209
  self.music2emo_model = FeedforwardModelMTAttnCK(
210
- input_size= 768 * 2,
211
  output_size_classification=56,
212
  output_size_regression=2
213
  )
214
-
215
  checkpoint = torch.load(self.model_weights, map_location=self.device, weights_only=False)
216
- state_dict = checkpoint["state_dict"]
217
-
218
- # Adjust the keys in the state_dict
219
- state_dict = {key.replace("model.", ""): value for key, value in state_dict.items()}
220
-
221
- # Filter state_dict to match model's keys
222
  model_keys = set(self.music2emo_model.state_dict().keys())
223
  filtered_state_dict = {key: value for key, value in state_dict.items() if key in model_keys}
224
-
225
- # Load the filtered state_dict and set the model to evaluation mode
226
  self.music2emo_model.load_state_dict(filtered_state_dict)
227
-
228
  self.music2emo_model.to(self.device)
229
  self.music2emo_model.eval()
230
-
231
  self.config = HParams.load("./inference/data/run_config.yaml")
232
  self.config.feature['large_voca'] = True
233
  self.config.model['num_chords'] = 170
234
  model_file = './inference/data/btc_model_large_voca.pt'
235
  self.idx_to_voca = idx2voca_chord()
236
  self.btc_model = BTC_model(config=self.config.model).to(self.device)
237
-
238
  if os.path.isfile(model_file):
239
  checkpoint = torch.load(model_file, map_location=self.device)
240
  self.mean = checkpoint['mean']
241
  self.std = checkpoint['std']
242
  self.btc_model.load_state_dict(checkpoint['model'])
243
-
244
-
245
  self.tonic_to_idx = {tonic: idx for idx, tonic in enumerate(tonic_signatures)}
246
  self.mode_to_idx = {mode: idx for idx, mode in enumerate(mode_signatures)}
247
  self.idx_to_tonic = {idx: tonic for tonic, idx in self.tonic_to_idx.items()}
248
  self.idx_to_mode = {idx: mode for mode, idx in self.mode_to_idx.items()}
249
-
250
  with open('inference/data/chord.json', 'r') as f:
251
  self.chord_to_idx = json.load(f)
252
  with open('inference/data/chord_inv.json', 'r') as f:
253
- self.idx_to_chord = json.load(f)
254
- self.idx_to_chord = {int(k): v for k, v in self.idx_to_chord.items()} # Ensure keys are ints
255
  with open('inference/data/chord_root.json') as json_file:
256
  self.chordRootDic = json.load(json_file)
257
  with open('inference/data/chord_attr.json') as json_file:
258
  self.chordAttrDic = json.load(json_file)
259
 
260
-
261
-
262
- def predict(self, audio, threshold = 0.5):
263
-
264
  feature_dir = Path("./inference/temp_out")
265
  output_dir = Path("./inference/output")
266
-
267
- # if feature_dir.exists():
268
- # shutil.rmtree(str(feature_dir))
269
- # if output_dir.exists():
270
- # shutil.rmtree(str(output_dir))
271
-
272
- # feature_dir.mkdir(parents=True)
273
- # output_dir.mkdir(parents=True)
274
-
275
- # warnings.filterwarnings('ignore')
276
- # logger.logging_verbosity(1)
277
-
278
- # mert_dir = feature_dir / "mert"
279
- # mert_dir.mkdir(parents=True)
280
-
281
  safe_remove_dir(feature_dir)
282
  safe_remove_dir(output_dir)
283
-
284
  feature_dir.mkdir(parents=True, exist_ok=True)
285
  output_dir.mkdir(parents=True, exist_ok=True)
286
-
287
  warnings.filterwarnings('ignore')
288
  logger.logging_verbosity(1)
289
-
290
  mert_dir = feature_dir / "mert"
291
  mert_dir.mkdir(parents=True, exist_ok=True)
292
-
293
  waveform, sample_rate = torchaudio.load(audio)
294
  if waveform.shape[0] > 1:
295
  waveform = waveform.mean(dim=0).unsqueeze(0)
296
  waveform = waveform.squeeze()
297
  waveform, sample_rate = resample_waveform(waveform, sample_rate, resample_rate)
298
-
299
- if is_split:
300
  segments = split_audio(waveform, sample_rate)
301
  for i, segment in enumerate(segments):
302
  segment_save_path = os.path.join(mert_dir, f"segment_{i}.npy")
@@ -304,50 +251,38 @@ class Music2emo:
304
  else:
305
  segment_save_path = os.path.join(mert_dir, f"segment_0.npy")
306
  self.feature_extractor.extract_features_from_segment(waveform, sample_rate, segment_save_path)
307
-
308
- embeddings = []
309
- layers_to_extract = [5,6]
310
  segment_embeddings = []
311
- for filename in sorted(os.listdir(mert_dir)): # Sort files to ensure sequential order
 
312
  file_path = os.path.join(mert_dir, filename)
313
  if os.path.isfile(file_path) and filename.endswith('.npy'):
314
  segment = np.load(file_path)
315
  concatenated_features = np.concatenate(
316
  [segment[:, layer_idx, :] for layer_idx in layers_to_extract], axis=1
317
  )
318
- concatenated_features = np.squeeze(concatenated_features) # Shape: 768 * 2 = 1536
319
  segment_embeddings.append(concatenated_features)
320
-
321
  segment_embeddings = np.array(segment_embeddings)
322
  if len(segment_embeddings) > 0:
323
  final_embedding_mert = np.mean(segment_embeddings, axis=0)
324
  else:
325
  final_embedding_mert = np.zeros((1536,))
326
-
327
- final_embedding_mert = torch.from_numpy(final_embedding_mert)
328
- final_embedding_mert.to(self.device)
329
-
330
- # --- Chord feature extract ---
331
-
332
  audio_path = audio
333
- audio_id = audio_path.split("/")[-1][:-4]
334
  try:
335
  feature, feature_per_second, song_length_second = audio_file_to_features(audio_path, self.config)
336
  except:
337
- logger.info("audio file failed to load : %s" % audio_path)
338
  assert(False)
339
-
340
- logger.info("audio file loaded and feature computation success : %s" % audio_path)
341
-
342
  feature = feature.T
343
  feature = (feature - self.mean) / self.std
344
  time_unit = feature_per_second
345
  n_timestep = self.config.model['timestep']
346
-
347
  num_pad = n_timestep - (feature.shape[0] % n_timestep)
348
  feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
349
  num_instance = feature.shape[0] // n_timestep
350
-
351
  start_time = 0.0
352
  lines = []
353
  with torch.no_grad():
@@ -362,85 +297,30 @@ class Music2emo:
362
  prev_chord = prediction[i].item()
363
  continue
364
  if prediction[i].item() != prev_chord:
365
- lines.append(
366
- '%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), self.idx_to_voca[prev_chord]))
367
  start_time = time_unit * (n_timestep * t + i)
368
  prev_chord = prediction[i].item()
369
  if t == num_instance - 1 and i + num_pad == n_timestep:
370
  if start_time != time_unit * (n_timestep * t + i):
371
  lines.append('%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), self.idx_to_voca[prev_chord]))
372
  break
373
-
374
  save_path = os.path.join(feature_dir, os.path.split(audio_path)[-1].replace('.mp3', '').replace('.wav', '') + '.lab')
375
  with open(save_path, 'w') as f:
376
  for line in lines:
377
  f.write(line)
378
-
379
- # logger.info("label file saved : %s" % save_path)
380
-
381
- # lab file to midi file
382
- starts, ends, pitchs = list(), list(), list()
383
-
384
- intervals, chords = mir_eval.io.load_labeled_intervals(save_path)
385
- for p in range(12):
386
- for i, (interval, chord) in enumerate(zip(intervals, chords)):
387
- root_num, relative_bitmap, _ = mir_eval.chord.encode(chord)
388
- tmp_label = mir_eval.chord.rotate_bitmap_to_root(relative_bitmap, root_num)[p]
389
- if i == 0:
390
- start_time = interval[0]
391
- label = tmp_label
392
- continue
393
- if tmp_label != label:
394
- if label == 1.0:
395
- starts.append(start_time), ends.append(interval[0]), pitchs.append(p + 48)
396
- start_time = interval[0]
397
- label = tmp_label
398
- if i == (len(intervals) - 1):
399
- if label == 1.0:
400
- starts.append(start_time), ends.append(interval[1]), pitchs.append(p + 48)
401
-
402
- midi = pm.PrettyMIDI()
403
- instrument = pm.Instrument(program=0)
404
-
405
- for start, end, pitch in zip(starts, ends, pitchs):
406
- pm_note = pm.Note(velocity=120, pitch=pitch, start=start, end=end)
407
- instrument.notes.append(pm_note)
408
-
409
- midi.instruments.append(instrument)
410
- midi.write(save_path.replace('.lab', '.midi'))
411
-
412
-
413
-
414
-
415
  try:
416
  midi_file = converter.parse(save_path.replace('.lab', '.midi'))
417
  key_signature = str(midi_file.analyze('key'))
418
  except Exception as e:
419
  key_signature = "None"
420
-
421
  key_parts = key_signature.split()
422
- key_signature = sanitize_key_signature(key_parts[0]) # Sanitize key signature
423
  key_type = key_parts[1] if len(key_parts) > 1 else 'major'
424
-
425
- # --- Key feature (Tonic and Mode separation) ---
426
- if key_signature == "None":
427
- mode = "major"
428
- else:
429
- mode = key_signature.split()[-1]
430
-
431
- encoded_mode = self.mode_to_idx.get(mode, 0)
432
- mode_tensor = torch.tensor([encoded_mode], dtype=torch.long).to(self.device)
433
-
434
  converted_lines = normalize_chord(save_path, key_signature, key_type)
435
-
436
  lab_norm_path = save_path[:-4] + "_norm.lab"
437
-
438
- # Write the converted lines to the new file
439
  with open(lab_norm_path, 'w') as f:
440
  f.writelines(converted_lines)
441
-
442
  chords = []
443
-
444
  if not os.path.exists(lab_norm_path):
445
  chords.append((float(0), float(0), "N"))
446
  else:
@@ -448,202 +328,148 @@ class Music2emo:
448
  for line in file:
449
  start, end, chord = line.strip().split()
450
  chords.append((float(start), float(end), chord))
451
-
452
  encoded = []
453
- encoded_root= []
454
- encoded_attr=[]
455
  durations = []
456
-
457
  for start, end, chord in chords:
458
  chord_arr = chord.split(":")
459
  if len(chord_arr) == 1:
460
  chordRootID = self.chordRootDic[chord_arr[0]]
461
- if chord_arr[0] == "N" or chord_arr[0] == "X":
462
- chordAttrID = 0
463
- else:
464
- chordAttrID = 1
465
  elif len(chord_arr) == 2:
466
  chordRootID = self.chordRootDic[chord_arr[0]]
467
  chordAttrID = self.chordAttrDic[chord_arr[1]]
468
  encoded_root.append(chordRootID)
469
  encoded_attr.append(chordAttrID)
470
-
471
  if chord in self.chord_to_idx:
472
  encoded.append(self.chord_to_idx[chord])
473
  else:
474
- print(f"Warning: Chord {chord} not found in chord.json. Skipping.")
475
-
476
- durations.append(end - start) # Compute duration
477
-
478
  encoded_chords = np.array(encoded)
479
  encoded_chords_root = np.array(encoded_root)
480
  encoded_chords_attr = np.array(encoded_attr)
481
-
482
- # Maximum sequence length for chords
483
- max_sequence_length = 100 # Define this globally or as a parameter
484
-
485
- # Truncate or pad chord sequences
486
  if len(encoded_chords) > max_sequence_length:
487
- # Truncate to max length
488
  encoded_chords = encoded_chords[:max_sequence_length]
489
  encoded_chords_root = encoded_chords_root[:max_sequence_length]
490
  encoded_chords_attr = encoded_chords_attr[:max_sequence_length]
491
-
492
  else:
493
- # Pad with zeros (padding value for chords)
494
  padding = [0] * (max_sequence_length - len(encoded_chords))
495
  encoded_chords = np.concatenate([encoded_chords, padding])
496
  encoded_chords_root = np.concatenate([encoded_chords_root, padding])
497
  encoded_chords_attr = np.concatenate([encoded_chords_attr, padding])
498
-
499
- # Convert to tensor
500
  chords_tensor = torch.tensor(encoded_chords, dtype=torch.long).to(self.device)
501
  chords_root_tensor = torch.tensor(encoded_chords_root, dtype=torch.long).to(self.device)
502
  chords_attr_tensor = torch.tensor(encoded_chords_attr, dtype=torch.long).to(self.device)
503
-
504
  model_input_dic = {
505
  "x_mert": final_embedding_mert.unsqueeze(0),
506
  "x_chord": chords_tensor.unsqueeze(0),
507
  "x_chord_root": chords_root_tensor.unsqueeze(0),
508
  "x_chord_attr": chords_attr_tensor.unsqueeze(0),
509
- "x_key": mode_tensor.unsqueeze(0)
510
  }
511
-
512
  model_input_dic = {k: v.to(self.device) for k, v in model_input_dic.items()}
513
  classification_output, regression_output = self.music2emo_model(model_input_dic)
514
- # probs = torch.sigmoid(classification_output)
515
-
516
- tag_list = np.load ( "./inference/data/tag_list.npy")
517
  tag_list = tag_list[127:]
518
  mood_list = [t.replace("mood/theme---", "") for t in tag_list]
519
- threshold = threshold
520
-
521
- # Get probabilities
522
  probs = torch.sigmoid(classification_output).squeeze().tolist()
523
-
524
- # Include both mood names and scores
525
  predicted_moods_with_scores = [
526
- {"mood": mood_list[i], "score": round(p, 4)} # Rounded for better readability
527
  for i, p in enumerate(probs) if p > threshold
528
  ]
529
-
530
- # Include both mood names and scores
531
  predicted_moods_with_scores_all = [
532
- {"mood": mood_list[i], "score": round(p, 4)} # Rounded for better readability
533
  for i, p in enumerate(probs)
534
  ]
535
-
536
-
537
- # Sort by highest probability
538
  predicted_moods_with_scores.sort(key=lambda x: x["score"], reverse=True)
539
-
540
  valence, arousal = regression_output.squeeze().tolist()
541
-
542
  model_output_dic = {
543
  "valence": valence,
544
  "arousal": arousal,
545
  "predicted_moods": predicted_moods_with_scores,
546
  "predicted_moods_all": predicted_moods_with_scores_all
547
  }
548
-
549
- # predicted_moods = [mood_list[i] for i, p in enumerate(probs.squeeze().tolist()) if p > threshold]
550
- # valence, arousal = regression_output.squeeze().tolist()
551
- # model_output_dic = {
552
- # "valence": valence,
553
- # "arousal": arousal,
554
- # "predicted_moods": predicted_moods
555
- # }
556
-
557
  return model_output_dic
558
 
559
- # Music2Emo Model Initialization
560
  if torch.cuda.is_available():
561
  music2emo = Music2emo()
562
  else:
563
  music2emo = Music2emo(device="cpu")
564
 
565
- # Plot Functions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
  def plot_mood_probabilities(predicted_moods_with_scores):
567
- """Plot mood probabilities as a horizontal bar chart."""
568
  if not predicted_moods_with_scores:
569
  return None
570
-
571
- # Extract mood names and their scores
572
  moods = [m["mood"] for m in predicted_moods_with_scores]
573
  probs = [m["score"] for m in predicted_moods_with_scores]
574
-
575
- # Sort moods by probability
576
  sorted_indices = np.argsort(probs)[::-1]
577
  sorted_probs = [probs[i] for i in sorted_indices]
578
  sorted_moods = [moods[i] for i in sorted_indices]
579
-
580
- # Create bar chart
581
  fig, ax = plt.subplots(figsize=(8, 4))
582
  ax.barh(sorted_moods[:10], sorted_probs[:10], color="#4CAF50")
583
- ax.set_xlabel("Probability")
584
- ax.set_title("Top 10 Predicted Mood Tags")
585
  ax.invert_yaxis()
586
-
587
  return fig
588
 
589
  def plot_valence_arousal(valence, arousal):
590
- """Plot valence-arousal on a 2D circumplex model."""
591
  fig, ax = plt.subplots(figsize=(4, 4))
592
  ax.scatter(valence, arousal, color="red", s=100)
593
  ax.set_xlim(1, 9)
594
  ax.set_ylim(1, 9)
595
-
596
- # Add midpoint lines
597
- ax.axhline(y=5, color='gray', linestyle='--', linewidth=1) # Horizontal middle line
598
- ax.axvline(x=5, color='gray', linestyle='--', linewidth=1) # Vertical middle line
599
-
600
- # Labels & Grid
601
- ax.set_xlabel("Valence (Positivity)")
602
- ax.set_ylabel("Arousal (Intensity)")
603
- ax.set_title("Valence-Arousal Plot")
604
- ax.legend()
605
  ax.grid(True, linestyle="--", alpha=0.6)
606
-
607
  return fig
608
 
609
-
610
- # Prediction Formatting
611
- def format_prediction(model_output_dic):
612
- """Format the model output in a structured format"""
613
- valence = model_output_dic["valence"]
614
- arousal = model_output_dic["arousal"]
615
- predicted_moods_with_scores = model_output_dic["predicted_moods"]
616
- predicted_moods_with_scores_all = model_output_dic["predicted_moods_all"]
617
-
618
- # Generate charts
619
- va_chart = plot_valence_arousal(valence, arousal)
620
- mood_chart = plot_mood_probabilities(predicted_moods_with_scores_all)
621
-
622
- # Format mood output with scores
623
- if predicted_moods_with_scores:
624
- moods_text = ", ".join(
625
- [f"{m['mood']} ({m['score']:.2f})" for m in predicted_moods_with_scores]
626
- )
627
- else:
628
- moods_text = "No significant moods detected."
629
-
630
- # Create formatted output
631
- output_text = f"""🎭 Predicted Mood Tags: {moods_text}
632
-
633
- 💖 Valence: {valence:.2f} (Scale: 1-9)
634
- ⚡ Arousal: {arousal:.2f} (Scale: 1-9)"""
635
-
636
- return output_text, va_chart, mood_chart
637
-
638
- # Gradio UI Elements
639
- title="🎵 Music2Emo: Toward Unified Music Emotion Recognition"
640
  description_text = """
641
- <p> Upload an audio file to analyze its emotional characteristics using Music2Emo. The model will predict: 1) Mood tags describing the emotional content, 2) Valence score (1-9 scale, representing emotional positivity), and 3) Arousal score (1-9 scale, representing emotional intensity)
642
- <br/><br/> This is the demo for Music2Emo for music emotion recognition: <a href="https://arxiv.org/abs/2502.03979">Read our paper.</a>
 
 
643
  </p>
644
  """
645
-
646
- # Custom CSS Styling
647
  css = """
648
  .gradio-container {
649
  font-family: 'Inter', -apple-system, system-ui, sans-serif;
@@ -654,7 +480,6 @@ css = """
654
  border-radius: 8px;
655
  padding: 10px;
656
  }
657
- /* Add padding to the top of the two plot boxes */
658
  .gr-box {
659
  padding-top: 25px !important;
660
  }
@@ -663,52 +488,27 @@ css = """
663
  with gr.Blocks(css=css) as demo:
664
  gr.HTML(f"<h1 style='text-align: center;'>{title}</h1>")
665
  gr.Markdown(description_text)
666
-
667
- # Notes Section
668
  gr.Markdown("""
669
- ### 📝 Notes:
670
- - **Supported audio formats:** MP3, WAV
671
- - **Recommended:** High-quality audio files
 
672
  """)
673
-
674
  with gr.Row():
675
- # Left Panel (Input)
676
  with gr.Column(scale=1):
677
- input_audio = gr.Audio(
678
- label="Upload Audio File",
679
- type="filepath"
680
- )
681
- threshold = gr.Slider(
682
- minimum=0.0,
683
- maximum=1.0,
684
- value=0.5,
685
- step=0.01,
686
- label="Mood Detection Threshold",
687
- info="Adjust threshold for mood detection"
688
- )
689
- predict_btn = gr.Button("🎭 Analyze Emotions", variant="primary")
690
-
691
- # Right Panel (Output)
692
  with gr.Column(scale=1):
693
- output_text = gr.Textbox(
694
- label="Analysis Results",
695
- lines=4,
696
- interactive=False # Prevent user input
697
- )
698
-
699
- # Ensure both plots have padding on top
700
  with gr.Row(equal_height=True):
701
- mood_chart = gr.Plot(label="Mood Probabilities", scale=2, elem_classes=["gr-box"])
702
- va_chart = gr.Plot(label="Valence-Arousal Space", scale=1, elem_classes=["gr-box"])
703
-
704
  predict_btn.click(
705
- fn=lambda audio, thresh: format_prediction(music2emo.predict(audio, thresh)),
706
- inputs=[input_audio, threshold],
707
  outputs=[output_text, va_chart, mood_chart]
708
  )
709
 
710
- # Launch the App
711
  demo.queue().launch()
712
-
713
-
714
-
 
24
  from gradio import Markdown
25
  from music21 import converter
26
  import torchaudio.transforms as T
27
+ import matplotlib.pyplot as plt
28
 
29
+ # カスタムユーティリティのインポート
30
  from utils import logger
31
  from utils.btc_model import BTC_model
32
  from utils.transformer_modules import *
 
39
  from utils.mert import FeatureExtractorMERT
40
  from model.linear_mt_attn_ck import FeedforwardModelMTAttnCK
41
 
42
+ # 不要な警告・ログを抑制
 
 
 
43
  warnings.filterwarnings("ignore")
44
  logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
45
 
 
 
46
  PITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
 
47
  tonic_signatures = ["A", "A#", "B", "C", "C#", "D", "D#", "E", "F", "F#", "G", "G#"]
48
+ mode_signatures = ["major", "minor"]
 
49
 
50
  pitch_num_dic = {
51
  'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5,
 
91
  new_key = "C major"
92
  shift = 0
93
  else:
 
94
  if len(key) == 1:
95
  key = key[0].upper()
96
  else:
97
  key = key[0].upper() + key[1:]
 
98
  if key in minor_major_dic2:
99
  key = minor_major_dic2[key]
 
100
  shift = 0
 
101
  if key_type == "major":
102
  new_key = "C major"
 
103
  shift = shift_major_dic[key]
104
  else:
105
  new_key = "A minor"
 
107
 
108
  converted_lines = []
109
  for line in lines:
110
+ if line.strip():
111
  parts = line.split()
112
  start_time = parts[0]
113
  end_time = parts[1]
114
+ chord = parts[2]
115
+ if chord == "N" or chord == "X":
116
+ newchordnorm = chord
 
 
117
  elif ":" in chord:
118
  pitch = chord.split(":")[0]
119
  attr = chord.split(":")[1]
120
+ pnum = pitch_num_dic[pitch]
121
+ new_idx = (pnum - shift) % 12
122
  newchord = PITCH_CLASS[new_idx]
123
  newchordnorm = newchord + ":" + attr
124
  else:
125
  pitch = chord
126
+ pnum = pitch_num_dic[pitch]
127
+ new_idx = (pnum - shift) % 12
128
  newchord = PITCH_CLASS[new_idx]
129
  newchordnorm = newchord
 
130
  converted_lines.append(f"{start_time} {end_time} {newchordnorm}\n")
 
131
  return converted_lines
132
 
133
  def sanitize_key_signature(key):
 
142
  def split_audio(waveform, sample_rate):
143
  segment_samples = segment_duration * sample_rate
144
  total_samples = waveform.size(0)
 
145
  segments = []
146
  for start in range(0, total_samples, segment_samples):
147
  end = start + segment_samples
148
  if end <= total_samples:
149
+ segments.append(waveform[start:end])
 
 
 
150
  if len(segments) == 0:
151
+ segments.append(waveform)
 
 
152
  return segments
153
 
 
154
  def safe_remove_dir(directory):
 
 
 
155
  directory = Path(directory)
156
  if directory.exists():
157
  try:
158
  shutil.rmtree(directory)
 
 
 
 
159
  except Exception as e:
160
+ print(f"ディレクトリ {directory} の削除中にエラーが発生しました: {e}")
161
+
162
+ # 追加:YouTube URL から音声をダウンロードする関数
163
+ def download_audio_from_youtube(url, output_dir="inference/input"):
164
+ import yt_dlp
165
+ os.makedirs(output_dir, exist_ok=True)
166
+ ydl_opts = {
167
+ 'format': 'bestaudio/best',
168
+ 'outtmpl': os.path.join(output_dir, 'tmp.%(ext)s'),
169
+ 'postprocessors': [{
170
+ 'key': 'FFmpegExtractAudio',
171
+ 'preferredcodec': 'mp3',
172
+ 'preferredquality': '192',
173
+ }],
174
+ 'noplaylist': True,
175
+ 'quiet': True,
176
+ }
177
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
178
+ info = ydl.extract_info(url, download=True)
179
+ title = info.get('title', '不明なタイトル')
180
+ output_file = os.path.join(output_dir, 'tmp.mp3')
181
+ return output_file, title
182
+
183
+ # Music2emo クラス(既存コード)
184
  class Music2emo:
185
+ def __init__(self,
186
+ name="amaai-lab/music2emo",
187
+ device="cuda:0",
188
+ cache_dir=None,
189
+ local_files_only=False):
 
 
 
 
 
190
  model_weights = "saved_models/J_all.ckpt"
191
  self.device = device
 
192
  self.feature_extractor = FeatureExtractorMERT(model_name='m-a-p/MERT-v1-95M', device=self.device, sr=resample_rate)
193
  self.model_weights = model_weights
 
194
  self.music2emo_model = FeedforwardModelMTAttnCK(
195
+ input_size=768 * 2,
196
  output_size_classification=56,
197
  output_size_regression=2
198
  )
 
199
  checkpoint = torch.load(self.model_weights, map_location=self.device, weights_only=False)
200
+ state_dict = {key.replace("model.", ""): value for key, value in checkpoint["state_dict"].items()}
 
 
 
 
 
201
  model_keys = set(self.music2emo_model.state_dict().keys())
202
  filtered_state_dict = {key: value for key, value in state_dict.items() if key in model_keys}
 
 
203
  self.music2emo_model.load_state_dict(filtered_state_dict)
 
204
  self.music2emo_model.to(self.device)
205
  self.music2emo_model.eval()
 
206
  self.config = HParams.load("./inference/data/run_config.yaml")
207
  self.config.feature['large_voca'] = True
208
  self.config.model['num_chords'] = 170
209
  model_file = './inference/data/btc_model_large_voca.pt'
210
  self.idx_to_voca = idx2voca_chord()
211
  self.btc_model = BTC_model(config=self.config.model).to(self.device)
 
212
  if os.path.isfile(model_file):
213
  checkpoint = torch.load(model_file, map_location=self.device)
214
  self.mean = checkpoint['mean']
215
  self.std = checkpoint['std']
216
  self.btc_model.load_state_dict(checkpoint['model'])
 
 
217
  self.tonic_to_idx = {tonic: idx for idx, tonic in enumerate(tonic_signatures)}
218
  self.mode_to_idx = {mode: idx for idx, mode in enumerate(mode_signatures)}
219
  self.idx_to_tonic = {idx: tonic for tonic, idx in self.tonic_to_idx.items()}
220
  self.idx_to_mode = {idx: mode for mode, idx in self.mode_to_idx.items()}
 
221
  with open('inference/data/chord.json', 'r') as f:
222
  self.chord_to_idx = json.load(f)
223
  with open('inference/data/chord_inv.json', 'r') as f:
224
+ self.idx_to_chord = {int(k): v for k, v in json.load(f).items()}
 
225
  with open('inference/data/chord_root.json') as json_file:
226
  self.chordRootDic = json.load(json_file)
227
  with open('inference/data/chord_attr.json') as json_file:
228
  self.chordAttrDic = json.load(json_file)
229
 
230
+ def predict(self, audio, threshold=0.5):
 
 
 
231
  feature_dir = Path("./inference/temp_out")
232
  output_dir = Path("./inference/output")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  safe_remove_dir(feature_dir)
234
  safe_remove_dir(output_dir)
 
235
  feature_dir.mkdir(parents=True, exist_ok=True)
236
  output_dir.mkdir(parents=True, exist_ok=True)
 
237
  warnings.filterwarnings('ignore')
238
  logger.logging_verbosity(1)
 
239
  mert_dir = feature_dir / "mert"
240
  mert_dir.mkdir(parents=True, exist_ok=True)
 
241
  waveform, sample_rate = torchaudio.load(audio)
242
  if waveform.shape[0] > 1:
243
  waveform = waveform.mean(dim=0).unsqueeze(0)
244
  waveform = waveform.squeeze()
245
  waveform, sample_rate = resample_waveform(waveform, sample_rate, resample_rate)
246
+ if is_split:
 
247
  segments = split_audio(waveform, sample_rate)
248
  for i, segment in enumerate(segments):
249
  segment_save_path = os.path.join(mert_dir, f"segment_{i}.npy")
 
251
  else:
252
  segment_save_path = os.path.join(mert_dir, f"segment_0.npy")
253
  self.feature_extractor.extract_features_from_segment(waveform, sample_rate, segment_save_path)
 
 
 
254
  segment_embeddings = []
255
+ layers_to_extract = [5,6]
256
+ for filename in sorted(os.listdir(mert_dir)):
257
  file_path = os.path.join(mert_dir, filename)
258
  if os.path.isfile(file_path) and filename.endswith('.npy'):
259
  segment = np.load(file_path)
260
  concatenated_features = np.concatenate(
261
  [segment[:, layer_idx, :] for layer_idx in layers_to_extract], axis=1
262
  )
263
+ concatenated_features = np.squeeze(concatenated_features)
264
  segment_embeddings.append(concatenated_features)
 
265
  segment_embeddings = np.array(segment_embeddings)
266
  if len(segment_embeddings) > 0:
267
  final_embedding_mert = np.mean(segment_embeddings, axis=0)
268
  else:
269
  final_embedding_mert = np.zeros((1536,))
270
+ final_embedding_mert = torch.from_numpy(final_embedding_mert).to(self.device)
 
 
 
 
 
271
  audio_path = audio
272
+ audio_id = os.path.split(audio_path)[-1][:-4]
273
  try:
274
  feature, feature_per_second, song_length_second = audio_file_to_features(audio_path, self.config)
275
  except:
276
+ logger.info("音声ファイルの読み込みに失敗しました : %s" % audio_path)
277
  assert(False)
278
+ logger.info("音声ファイルの読み込みと特徴量計算に成功しました : %s" % audio_path)
 
 
279
  feature = feature.T
280
  feature = (feature - self.mean) / self.std
281
  time_unit = feature_per_second
282
  n_timestep = self.config.model['timestep']
 
283
  num_pad = n_timestep - (feature.shape[0] % n_timestep)
284
  feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
285
  num_instance = feature.shape[0] // n_timestep
 
286
  start_time = 0.0
287
  lines = []
288
  with torch.no_grad():
 
297
  prev_chord = prediction[i].item()
298
  continue
299
  if prediction[i].item() != prev_chord:
300
+ lines.append('%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), self.idx_to_voca[prev_chord]))
 
301
  start_time = time_unit * (n_timestep * t + i)
302
  prev_chord = prediction[i].item()
303
  if t == num_instance - 1 and i + num_pad == n_timestep:
304
  if start_time != time_unit * (n_timestep * t + i):
305
  lines.append('%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), self.idx_to_voca[prev_chord]))
306
  break
 
307
  save_path = os.path.join(feature_dir, os.path.split(audio_path)[-1].replace('.mp3', '').replace('.wav', '') + '.lab')
308
  with open(save_path, 'w') as f:
309
  for line in lines:
310
  f.write(line)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  try:
312
  midi_file = converter.parse(save_path.replace('.lab', '.midi'))
313
  key_signature = str(midi_file.analyze('key'))
314
  except Exception as e:
315
  key_signature = "None"
 
316
  key_parts = key_signature.split()
317
+ key_signature = sanitize_key_signature(key_parts[0])
318
  key_type = key_parts[1] if len(key_parts) > 1 else 'major'
 
 
 
 
 
 
 
 
 
 
319
  converted_lines = normalize_chord(save_path, key_signature, key_type)
 
320
  lab_norm_path = save_path[:-4] + "_norm.lab"
 
 
321
  with open(lab_norm_path, 'w') as f:
322
  f.writelines(converted_lines)
 
323
  chords = []
 
324
  if not os.path.exists(lab_norm_path):
325
  chords.append((float(0), float(0), "N"))
326
  else:
 
328
  for line in file:
329
  start, end, chord = line.strip().split()
330
  chords.append((float(start), float(end), chord))
 
331
  encoded = []
332
+ encoded_root = []
333
+ encoded_attr = []
334
  durations = []
 
335
  for start, end, chord in chords:
336
  chord_arr = chord.split(":")
337
  if len(chord_arr) == 1:
338
  chordRootID = self.chordRootDic[chord_arr[0]]
339
+ chordAttrID = 0 if chord_arr[0] in ["N", "X"] else 1
 
 
 
340
  elif len(chord_arr) == 2:
341
  chordRootID = self.chordRootDic[chord_arr[0]]
342
  chordAttrID = self.chordAttrDic[chord_arr[1]]
343
  encoded_root.append(chordRootID)
344
  encoded_attr.append(chordAttrID)
 
345
  if chord in self.chord_to_idx:
346
  encoded.append(self.chord_to_idx[chord])
347
  else:
348
+ print(f"警告: {chord} chord.json に見つかりませんでした。スキップします。")
349
+ durations.append(end - start)
 
 
350
  encoded_chords = np.array(encoded)
351
  encoded_chords_root = np.array(encoded_root)
352
  encoded_chords_attr = np.array(encoded_attr)
353
+ max_sequence_length = 100
 
 
 
 
354
  if len(encoded_chords) > max_sequence_length:
 
355
  encoded_chords = encoded_chords[:max_sequence_length]
356
  encoded_chords_root = encoded_chords_root[:max_sequence_length]
357
  encoded_chords_attr = encoded_chords_attr[:max_sequence_length]
 
358
  else:
 
359
  padding = [0] * (max_sequence_length - len(encoded_chords))
360
  encoded_chords = np.concatenate([encoded_chords, padding])
361
  encoded_chords_root = np.concatenate([encoded_chords_root, padding])
362
  encoded_chords_attr = np.concatenate([encoded_chords_attr, padding])
 
 
363
  chords_tensor = torch.tensor(encoded_chords, dtype=torch.long).to(self.device)
364
  chords_root_tensor = torch.tensor(encoded_chords_root, dtype=torch.long).to(self.device)
365
  chords_attr_tensor = torch.tensor(encoded_chords_attr, dtype=torch.long).to(self.device)
 
366
  model_input_dic = {
367
  "x_mert": final_embedding_mert.unsqueeze(0),
368
  "x_chord": chords_tensor.unsqueeze(0),
369
  "x_chord_root": chords_root_tensor.unsqueeze(0),
370
  "x_chord_attr": chords_attr_tensor.unsqueeze(0),
371
+ "x_key": torch.tensor([self.mode_to_idx.get(key_type, 0)], dtype=torch.long).unsqueeze(0).to(self.device)
372
  }
 
373
  model_input_dic = {k: v.to(self.device) for k, v in model_input_dic.items()}
374
  classification_output, regression_output = self.music2emo_model(model_input_dic)
375
+ tag_list = np.load("./inference/data/tag_list.npy")
 
 
376
  tag_list = tag_list[127:]
377
  mood_list = [t.replace("mood/theme---", "") for t in tag_list]
 
 
 
378
  probs = torch.sigmoid(classification_output).squeeze().tolist()
 
 
379
  predicted_moods_with_scores = [
380
+ {"mood": mood_list[i], "score": round(p, 4)}
381
  for i, p in enumerate(probs) if p > threshold
382
  ]
 
 
383
  predicted_moods_with_scores_all = [
384
+ {"mood": mood_list[i], "score": round(p, 4)}
385
  for i, p in enumerate(probs)
386
  ]
 
 
 
387
  predicted_moods_with_scores.sort(key=lambda x: x["score"], reverse=True)
 
388
  valence, arousal = regression_output.squeeze().tolist()
 
389
  model_output_dic = {
390
  "valence": valence,
391
  "arousal": arousal,
392
  "predicted_moods": predicted_moods_with_scores,
393
  "predicted_moods_all": predicted_moods_with_scores_all
394
  }
 
 
 
 
 
 
 
 
 
395
  return model_output_dic
396
 
397
+ # Music2Emo モデルの初期化
398
  if torch.cuda.is_available():
399
  music2emo = Music2emo()
400
  else:
401
  music2emo = Music2emo(device="cpu")
402
 
403
+ # 入力(音声ファイルまたはYouTube URL)を処理する関数
404
+ def process_input(audio, youtube_url, threshold):
405
+ if youtube_url and youtube_url.strip().startswith("http"):
406
+ # YouTube URL が入力されている場合、音声をダウンロード
407
+ audio_file, video_title = download_audio_from_youtube(youtube_url)
408
+ output_dic = music2emo.predict(audio_file, threshold)
409
+ output_text, va_chart, mood_chart = format_prediction(output_dic)
410
+ output_text += f"\n動画タイトル: {video_title}"
411
+ return output_text, va_chart, mood_chart
412
+ elif audio:
413
+ output_dic = music2emo.predict(audio, threshold)
414
+ return format_prediction(output_dic)
415
+ else:
416
+ return "音声ファイルまたは YouTube URL を入力してください。", None, None
417
+
418
+ # 解析結果のフォーマット関数
419
+ def format_prediction(model_output_dic):
420
+ valence = model_output_dic["valence"]
421
+ arousal = model_output_dic["arousal"]
422
+ predicted_moods_with_scores = model_output_dic["predicted_moods"]
423
+ predicted_moods_with_scores_all = model_output_dic["predicted_moods_all"]
424
+ va_chart = plot_valence_arousal(valence, arousal)
425
+ mood_chart = plot_mood_probabilities(predicted_moods_with_scores_all)
426
+ if predicted_moods_with_scores:
427
+ moods_text = ", ".join([f"{m['mood']} ({m['score']:.2f})" for m in predicted_moods_with_scores])
428
+ else:
429
+ moods_text = "顕著なムードは検出されませんでした。"
430
+ output_text = f"""🎭 ムードタグ: {moods_text}
431
+
432
+ 💖 バレンス: {valence:.2f} (1〜9 スケール)
433
+ ⚡ アラウザル: {arousal:.2f} (1〜9 スケール)"""
434
+ return output_text, va_chart, mood_chart
435
+
436
  def plot_mood_probabilities(predicted_moods_with_scores):
 
437
  if not predicted_moods_with_scores:
438
  return None
 
 
439
  moods = [m["mood"] for m in predicted_moods_with_scores]
440
  probs = [m["score"] for m in predicted_moods_with_scores]
 
 
441
  sorted_indices = np.argsort(probs)[::-1]
442
  sorted_probs = [probs[i] for i in sorted_indices]
443
  sorted_moods = [moods[i] for i in sorted_indices]
 
 
444
  fig, ax = plt.subplots(figsize=(8, 4))
445
  ax.barh(sorted_moods[:10], sorted_probs[:10], color="#4CAF50")
446
+ ax.set_xlabel("確率")
447
+ ax.set_title("上位10のムードタグ")
448
  ax.invert_yaxis()
 
449
  return fig
450
 
451
  def plot_valence_arousal(valence, arousal):
 
452
  fig, ax = plt.subplots(figsize=(4, 4))
453
  ax.scatter(valence, arousal, color="red", s=100)
454
  ax.set_xlim(1, 9)
455
  ax.set_ylim(1, 9)
456
+ ax.axhline(y=5, color='gray', linestyle='--', linewidth=1)
457
+ ax.axvline(x=5, color='gray', linestyle='--', linewidth=1)
458
+ ax.set_xlabel("バレンス (ポジティブ度)")
459
+ ax.set_ylabel("アラウザル (活発度)")
460
+ ax.set_title("バレンス・アラウザル プロット")
 
 
 
 
 
461
  ax.grid(True, linestyle="--", alpha=0.6)
 
462
  return fig
463
 
464
+ # Gradio UI の設定
465
+ title = "🎵 Music2Emo:統一型音楽感情認識システム"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  description_text = """
467
+ <p>
468
+ 音声ファイルまたは YouTube URL を入力すると、Music2Emo が楽曲の感情的特徴を解析します。<br/><br/>
469
+ このデモでは、1) ムードタグ、2) バレンス(1〜9 スケール)、3) アラウザル(1〜9 スケール)を予測します。<br/><br/>
470
+ 詳細は <a href="https://arxiv.org/abs/2502.03979" target="_blank">論文</a> をご参照ください。
471
  </p>
472
  """
 
 
473
  css = """
474
  .gradio-container {
475
  font-family: 'Inter', -apple-system, system-ui, sans-serif;
 
480
  border-radius: 8px;
481
  padding: 10px;
482
  }
 
483
  .gr-box {
484
  padding-top: 25px !important;
485
  }
 
488
  with gr.Blocks(css=css) as demo:
489
  gr.HTML(f"<h1 style='text-align: center;'>{title}</h1>")
490
  gr.Markdown(description_text)
 
 
491
  gr.Markdown("""
492
+ ### 📝 注意事項:
493
+ - **対応音声フォーマット:** MP3, WAV
494
+ - **YouTube URL も入力可能です(任意)
495
+ - **推奨:** 高品質な音声ファイル
496
  """)
 
497
  with gr.Row():
 
498
  with gr.Column(scale=1):
499
+ input_audio = gr.Audio(label="音声ファイルをアップロード", type="filepath")
500
+ youtube_url = gr.Textbox(label="YouTube URL (任意)", placeholder="例: https://youtu.be/XXXXXXX")
501
+ threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="ムード検出のしきい値", info="しきい値を調整してください")
502
+ predict_btn = gr.Button("🎭 感情解析を実行", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
503
  with gr.Column(scale=1):
504
+ output_text = gr.Textbox(label="解析結果", lines=4, interactive=False)
 
 
 
 
 
 
505
  with gr.Row(equal_height=True):
506
+ mood_chart = gr.Plot(label="ムード確率", scale=2, elem_classes=["gr-box"])
507
+ va_chart = gr.Plot(label="バレンス・アラウザル", scale=1, elem_classes=["gr-box"])
 
508
  predict_btn.click(
509
+ fn=process_input,
510
+ inputs=[input_audio, youtube_url, threshold],
511
  outputs=[output_text, va_chart, mood_chart]
512
  )
513
 
 
514
  demo.queue().launch()