import os import shutil import json import torch import torchaudio import numpy as np import logging import warnings import subprocess import math import random import time from pathlib import Path from tqdm import tqdm from PIL import Image from huggingface_hub import snapshot_download from omegaconf import DictConfig import hydra from hydra.utils import to_absolute_path from transformers import Wav2Vec2FeatureExtractor, AutoModel import mir_eval import pretty_midi as pm import gradio as gr from gradio import Markdown from music21 import converter import torchaudio.transforms as T import matplotlib.pyplot as plt # カスタムユーティリティのインポート from utils import logger from utils.btc_model import BTC_model from utils.transformer_modules import * from utils.transformer_modules import _gen_timing_signal, _gen_bias_mask from utils.hparams import HParams from utils.mir_eval_modules import ( audio_file_to_features, idx2chord, idx2voca_chord, get_audio_paths, get_lab_paths ) from utils.mert import FeatureExtractorMERT from model.linear_mt_attn_ck import FeedforwardModelMTAttnCK # 不要な警告・ログを抑制 warnings.filterwarnings("ignore") logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) PITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'] tonic_signatures = ["A", "A#", "B", "C", "C#", "D", "D#", "E", "F", "F#", "G", "G#"] mode_signatures = ["major", "minor"] pitch_num_dic = { 'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5, 'F#': 6, 'G': 7, 'G#': 8, 'A': 9, 'A#': 10, 'B': 11 } minor_major_dic = { 'D-':'C#', 'E-':'D#', 'G-':'F#', 'A-':'G#', 'B-':'A#' } minor_major_dic2 = { 'Db':'C#', 'Eb':'D#', 'Gb':'F#', 'Ab':'G#', 'Bb':'A#' } shift_major_dic = { 'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5, 'F#': 6, 'G': 7, 'G#': 8, 'A': 9, 'A#': 10, 'B': 11 } shift_minor_dic = { 'A': 0, 'A#': 1, 'B': 2, 'C': 3, 'C#': 4, 'D': 5, 'D#': 6, 'E': 7, 'F': 8, 'F#': 9, 'G': 10, 'G#': 11, } flat_to_sharp_mapping = { "Cb": "B", "Db": "C#", "Eb": "D#", "Fb": "E", "Gb": "F#", "Ab": "G#", "Bb": "A#" } segment_duration = 30 resample_rate = 24000 is_split = True def normalize_chord(file_path, key, key_type='major'): with open(file_path, 'r') as f: lines = f.readlines() if key == "None": new_key = "C major" shift = 0 else: if len(key) == 1: key = key[0].upper() else: key = key[0].upper() + key[1:] if key in minor_major_dic2: key = minor_major_dic2[key] shift = 0 if key_type == "major": new_key = "C major" shift = shift_major_dic[key] else: new_key = "A minor" shift = shift_minor_dic[key] converted_lines = [] for line in lines: if line.strip(): parts = line.split() start_time = parts[0] end_time = parts[1] chord = parts[2] if chord == "N" or chord == "X": newchordnorm = chord elif ":" in chord: pitch = chord.split(":")[0] attr = chord.split(":")[1] pnum = pitch_num_dic[pitch] new_idx = (pnum - shift) % 12 newchord = PITCH_CLASS[new_idx] newchordnorm = newchord + ":" + attr else: pitch = chord pnum = pitch_num_dic[pitch] new_idx = (pnum - shift) % 12 newchord = PITCH_CLASS[new_idx] newchordnorm = newchord converted_lines.append(f"{start_time} {end_time} {newchordnorm}\n") return converted_lines def sanitize_key_signature(key): return key.replace('-', 'b') def resample_waveform(waveform, original_sample_rate, target_sample_rate): if original_sample_rate != target_sample_rate: resampler = T.Resample(original_sample_rate, target_sample_rate) return resampler(waveform), target_sample_rate return waveform, original_sample_rate def split_audio(waveform, sample_rate): segment_samples = segment_duration * sample_rate total_samples = waveform.size(0) segments = [] for start in range(0, total_samples, segment_samples): end = start + segment_samples if end <= total_samples: segments.append(waveform[start:end]) if len(segments) == 0: segments.append(waveform) return segments def safe_remove_dir(directory): directory = Path(directory) if directory.exists(): try: shutil.rmtree(directory) except Exception as e: print(f"ディレクトリ {directory} の削除中にエラーが発生しました: {e}") # 追加:YouTube URL から音声をダウンロードする関数 def download_audio_from_youtube(url, output_dir="inference/input"): import yt_dlp os.makedirs(output_dir, exist_ok=True) ydl_opts = { 'format': 'bestaudio/best', 'outtmpl': os.path.join(output_dir, 'tmp.%(ext)s'), 'postprocessors': [{ 'key': 'FFmpegExtractAudio', 'preferredcodec': 'mp3', 'preferredquality': '192', }], 'noplaylist': True, 'quiet': True, } with yt_dlp.YoutubeDL(ydl_opts) as ydl: info = ydl.extract_info(url, download=True) title = info.get('title', '不明なタイトル') output_file = os.path.join(output_dir, 'tmp.mp3') return output_file, title # Music2emo クラス(既存コード) class Music2emo: def __init__(self, name="amaai-lab/music2emo", device="cuda:0", cache_dir=None, local_files_only=False): model_weights = "saved_models/J_all.ckpt" self.device = device self.feature_extractor = FeatureExtractorMERT(model_name='m-a-p/MERT-v1-95M', device=self.device, sr=resample_rate) self.model_weights = model_weights self.music2emo_model = FeedforwardModelMTAttnCK( input_size=768 * 2, output_size_classification=56, output_size_regression=2 ) checkpoint = torch.load(self.model_weights, map_location=self.device, weights_only=False) state_dict = {key.replace("model.", ""): value for key, value in checkpoint["state_dict"].items()} model_keys = set(self.music2emo_model.state_dict().keys()) filtered_state_dict = {key: value for key, value in state_dict.items() if key in model_keys} self.music2emo_model.load_state_dict(filtered_state_dict) self.music2emo_model.to(self.device) self.music2emo_model.eval() self.config = HParams.load("./inference/data/run_config.yaml") self.config.feature['large_voca'] = True self.config.model['num_chords'] = 170 model_file = './inference/data/btc_model_large_voca.pt' self.idx_to_voca = idx2voca_chord() self.btc_model = BTC_model(config=self.config.model).to(self.device) if os.path.isfile(model_file): checkpoint = torch.load(model_file, map_location=self.device) self.mean = checkpoint['mean'] self.std = checkpoint['std'] self.btc_model.load_state_dict(checkpoint['model']) self.tonic_to_idx = {tonic: idx for idx, tonic in enumerate(tonic_signatures)} self.mode_to_idx = {mode: idx for idx, mode in enumerate(mode_signatures)} self.idx_to_tonic = {idx: tonic for tonic, idx in self.tonic_to_idx.items()} self.idx_to_mode = {idx: mode for mode, idx in self.mode_to_idx.items()} with open('inference/data/chord.json', 'r') as f: self.chord_to_idx = json.load(f) with open('inference/data/chord_inv.json', 'r') as f: self.idx_to_chord = {int(k): v for k, v in json.load(f).items()} with open('inference/data/chord_root.json') as json_file: self.chordRootDic = json.load(json_file) with open('inference/data/chord_attr.json') as json_file: self.chordAttrDic = json.load(json_file) def predict(self, audio, threshold=0.5): feature_dir = Path("./inference/temp_out") output_dir = Path("./inference/output") safe_remove_dir(feature_dir) safe_remove_dir(output_dir) feature_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True) warnings.filterwarnings('ignore') logger.logging_verbosity(1) mert_dir = feature_dir / "mert" mert_dir.mkdir(parents=True, exist_ok=True) waveform, sample_rate = torchaudio.load(audio) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0).unsqueeze(0) waveform = waveform.squeeze() waveform, sample_rate = resample_waveform(waveform, sample_rate, resample_rate) if is_split: segments = split_audio(waveform, sample_rate) for i, segment in enumerate(segments): segment_save_path = os.path.join(mert_dir, f"segment_{i}.npy") self.feature_extractor.extract_features_from_segment(segment, sample_rate, segment_save_path) else: segment_save_path = os.path.join(mert_dir, f"segment_0.npy") self.feature_extractor.extract_features_from_segment(waveform, sample_rate, segment_save_path) segment_embeddings = [] layers_to_extract = [5,6] for filename in sorted(os.listdir(mert_dir)): file_path = os.path.join(mert_dir, filename) if os.path.isfile(file_path) and filename.endswith('.npy'): segment = np.load(file_path) concatenated_features = np.concatenate( [segment[:, layer_idx, :] for layer_idx in layers_to_extract], axis=1 ) concatenated_features = np.squeeze(concatenated_features) segment_embeddings.append(concatenated_features) segment_embeddings = np.array(segment_embeddings) if len(segment_embeddings) > 0: final_embedding_mert = np.mean(segment_embeddings, axis=0) else: final_embedding_mert = np.zeros((1536,)) final_embedding_mert = torch.from_numpy(final_embedding_mert).to(self.device) audio_path = audio audio_id = os.path.split(audio_path)[-1][:-4] try: feature, feature_per_second, song_length_second = audio_file_to_features(audio_path, self.config) except: logger.info("音声ファイルの読み込みに失敗しました : %s" % audio_path) assert(False) logger.info("音声ファイルの読み込みと特徴量計算に成功しました : %s" % audio_path) feature = feature.T feature = (feature - self.mean) / self.std time_unit = feature_per_second n_timestep = self.config.model['timestep'] num_pad = n_timestep - (feature.shape[0] % n_timestep) feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0) num_instance = feature.shape[0] // n_timestep start_time = 0.0 lines = [] with torch.no_grad(): self.btc_model.eval() feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(self.device) for t in range(num_instance): self_attn_output, _ = self.btc_model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :]) prediction, _ = self.btc_model.output_layer(self_attn_output) prediction = prediction.squeeze() for i in range(n_timestep): if t == 0 and i == 0: prev_chord = prediction[i].item() continue if prediction[i].item() != prev_chord: lines.append('%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), self.idx_to_voca[prev_chord])) start_time = time_unit * (n_timestep * t + i) prev_chord = prediction[i].item() if t == num_instance - 1 and i + num_pad == n_timestep: if start_time != time_unit * (n_timestep * t + i): lines.append('%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), self.idx_to_voca[prev_chord])) break save_path = os.path.join(feature_dir, os.path.split(audio_path)[-1].replace('.mp3', '').replace('.wav', '') + '.lab') with open(save_path, 'w') as f: for line in lines: f.write(line) try: midi_file = converter.parse(save_path.replace('.lab', '.midi')) key_signature = str(midi_file.analyze('key')) except Exception as e: key_signature = "None" key_parts = key_signature.split() key_signature = sanitize_key_signature(key_parts[0]) key_type = key_parts[1] if len(key_parts) > 1 else 'major' converted_lines = normalize_chord(save_path, key_signature, key_type) lab_norm_path = save_path[:-4] + "_norm.lab" with open(lab_norm_path, 'w') as f: f.writelines(converted_lines) chords = [] if not os.path.exists(lab_norm_path): chords.append((float(0), float(0), "N")) else: with open(lab_norm_path, 'r') as file: for line in file: start, end, chord = line.strip().split() chords.append((float(start), float(end), chord)) encoded = [] encoded_root = [] encoded_attr = [] durations = [] for start, end, chord in chords: chord_arr = chord.split(":") if len(chord_arr) == 1: chordRootID = self.chordRootDic[chord_arr[0]] chordAttrID = 0 if chord_arr[0] in ["N", "X"] else 1 elif len(chord_arr) == 2: chordRootID = self.chordRootDic[chord_arr[0]] chordAttrID = self.chordAttrDic[chord_arr[1]] encoded_root.append(chordRootID) encoded_attr.append(chordAttrID) if chord in self.chord_to_idx: encoded.append(self.chord_to_idx[chord]) else: print(f"警告: {chord} は chord.json に見つかりませんでした。スキップします。") durations.append(end - start) encoded_chords = np.array(encoded) encoded_chords_root = np.array(encoded_root) encoded_chords_attr = np.array(encoded_attr) max_sequence_length = 100 if len(encoded_chords) > max_sequence_length: encoded_chords = encoded_chords[:max_sequence_length] encoded_chords_root = encoded_chords_root[:max_sequence_length] encoded_chords_attr = encoded_chords_attr[:max_sequence_length] else: padding = [0] * (max_sequence_length - len(encoded_chords)) encoded_chords = np.concatenate([encoded_chords, padding]) encoded_chords_root = np.concatenate([encoded_chords_root, padding]) encoded_chords_attr = np.concatenate([encoded_chords_attr, padding]) chords_tensor = torch.tensor(encoded_chords, dtype=torch.long).to(self.device) chords_root_tensor = torch.tensor(encoded_chords_root, dtype=torch.long).to(self.device) chords_attr_tensor = torch.tensor(encoded_chords_attr, dtype=torch.long).to(self.device) model_input_dic = { "x_mert": final_embedding_mert.unsqueeze(0), "x_chord": chords_tensor.unsqueeze(0), "x_chord_root": chords_root_tensor.unsqueeze(0), "x_chord_attr": chords_attr_tensor.unsqueeze(0), "x_key": torch.tensor([self.mode_to_idx.get(key_type, 0)], dtype=torch.long).unsqueeze(0).to(self.device) } model_input_dic = {k: v.to(self.device) for k, v in model_input_dic.items()} classification_output, regression_output = self.music2emo_model(model_input_dic) tag_list = np.load("./inference/data/tag_list.npy") tag_list = tag_list[127:] mood_list = [t.replace("mood/theme---", "") for t in tag_list] probs = torch.sigmoid(classification_output).squeeze().tolist() predicted_moods_with_scores = [ {"mood": mood_list[i], "score": round(p, 4)} for i, p in enumerate(probs) if p > threshold ] predicted_moods_with_scores_all = [ {"mood": mood_list[i], "score": round(p, 4)} for i, p in enumerate(probs) ] predicted_moods_with_scores.sort(key=lambda x: x["score"], reverse=True) valence, arousal = regression_output.squeeze().tolist() model_output_dic = { "valence": valence, "arousal": arousal, "predicted_moods": predicted_moods_with_scores, "predicted_moods_all": predicted_moods_with_scores_all } return model_output_dic # Music2Emo モデルの初期化 if torch.cuda.is_available(): music2emo = Music2emo() else: music2emo = Music2emo(device="cpu") # 入力(音声ファイルまたはYouTube URL)を処理する関数 def process_input(audio, youtube_url, threshold): if youtube_url and youtube_url.strip().startswith("http"): # YouTube URL が入力されている場合、音声をダウンロード audio_file, video_title = download_audio_from_youtube(youtube_url) output_dic = music2emo.predict(audio_file, threshold) output_text, va_chart, mood_chart = format_prediction(output_dic) output_text += f"\n動画タイトル: {video_title}" return output_text, va_chart, mood_chart elif audio: output_dic = music2emo.predict(audio, threshold) return format_prediction(output_dic) else: return "音声ファイルまたは YouTube URL を入力してください。", None, None # 解析結果のフォーマット関数 def format_prediction(model_output_dic): valence = model_output_dic["valence"] arousal = model_output_dic["arousal"] predicted_moods_with_scores = model_output_dic["predicted_moods"] predicted_moods_with_scores_all = model_output_dic["predicted_moods_all"] va_chart = plot_valence_arousal(valence, arousal) mood_chart = plot_mood_probabilities(predicted_moods_with_scores_all) if predicted_moods_with_scores: moods_text = ", ".join([f"{m['mood']} ({m['score']:.2f})" for m in predicted_moods_with_scores]) else: moods_text = "顕著なムードは検出されませんでした。" output_text = f"""🎭 ムードタグ: {moods_text} 💖 バレンス: {valence:.2f} (1〜9 スケール) ⚡ アラウザル: {arousal:.2f} (1〜9 スケール)""" return output_text, va_chart, mood_chart def plot_mood_probabilities(predicted_moods_with_scores): if not predicted_moods_with_scores: return None moods = [m["mood"] for m in predicted_moods_with_scores] probs = [m["score"] for m in predicted_moods_with_scores] sorted_indices = np.argsort(probs)[::-1] sorted_probs = [probs[i] for i in sorted_indices] sorted_moods = [moods[i] for i in sorted_indices] fig, ax = plt.subplots(figsize=(8, 4)) ax.barh(sorted_moods[:10], sorted_probs[:10], color="#4CAF50") ax.set_xlabel("確率") ax.set_title("上位10のムードタグ") ax.invert_yaxis() return fig def plot_valence_arousal(valence, arousal): fig, ax = plt.subplots(figsize=(4, 4)) ax.scatter(valence, arousal, color="red", s=100) ax.set_xlim(1, 9) ax.set_ylim(1, 9) ax.axhline(y=5, color='gray', linestyle='--', linewidth=1) ax.axvline(x=5, color='gray', linestyle='--', linewidth=1) ax.set_xlabel("バレンス (ポジティブ度)") ax.set_ylabel("アラウザル (活発度)") ax.set_title("バレンス・アラウザル プロット") ax.grid(True, linestyle="--", alpha=0.6) return fig # Gradio UI の設定 title = "🎵 Music2Emo:統一型音楽感情認識システム" description_text = """

音声ファイルまたは YouTube の URL を入力すると、Music2Emo が楽曲の感情的特徴を解析します。

このデモでは、1) ムードタグ、2) バレンス(1〜9 スケール)、3) アラウザル(1〜9 スケール)を予測します。

詳細は 論文 をご参照ください。

""" css = """ .gradio-container { font-family: 'Inter', -apple-system, system-ui, sans-serif; } .gr-button { color: white; background: #4CAF50; border-radius: 8px; padding: 10px; } .gr-box { padding-top: 25px !important; } """ with gr.Blocks(css=css) as demo: gr.HTML(f"

{title}

") gr.Markdown(description_text) gr.Markdown(""" ### 📝 注意事項: - **対応音声フォーマット:** MP3, WAV - **YouTube URL も入力可能です(任意) - **推奨:** 高品質な音声ファイル """) with gr.Row(): with gr.Column(scale=1): input_audio = gr.Audio(label="音声ファイルをアップロード", type="filepath") youtube_url = gr.Textbox(label="YouTube URL (任意)", placeholder="例: https://youtu.be/XXXXXXX") threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="ムード検出のしきい値", info="しきい値を調整してください") predict_btn = gr.Button("🎭 感情解析を実行", variant="primary") with gr.Column(scale=1): output_text = gr.Textbox(label="解析結果", lines=4, interactive=False) with gr.Row(equal_height=True): mood_chart = gr.Plot(label="ムード確率", scale=2, elem_classes=["gr-box"]) va_chart = gr.Plot(label="バレンス・アラウザル", scale=1, elem_classes=["gr-box"]) predict_btn.click( fn=process_input, inputs=[input_audio, youtube_url, threshold], outputs=[output_text, va_chart, mood_chart] ) demo.queue().launch()