import gradio as gr # from transformers import Wav2Vec2FeatureExtractor from transformers import AutoModel import torch from torch import nn import torchaudio import torchaudio.transforms as T import logging import json import importlib modeling_MERT = importlib.import_module("MERT-v0-public.modeling_MERT") from Prediction_Head.MTGGenre_head import MLPProberBase # input cr: https://huggingface.co/spaces/thealphhamerc/audio-to-text/blob/main/app.py logger = logging.getLogger("whisper-jax-app") logger.setLevel(logging.INFO) ch = logging.StreamHandler() ch.setLevel(logging.INFO) formatter = logging.Formatter( "%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S") ch.setFormatter(formatter) logger.addHandler(ch) inputs = [ gr.components.Audio(type="filepath", label="Add music audio file"), gr.inputs.Audio(source="microphone", type="filepath"), ] live_inputs = [ gr.Audio(source="microphone",streaming=True, type="filepath"), ] # outputs = [gr.components.Textbox()] # outputs = [gr.components.Textbox(), transcription_df] title = "Predict the top 5 possible genres and tags of Music" description = "An example of using map/MERT-95M-public model as backbone to conduct music genre/tagging predcition." article = "" audio_examples = [ # ["input/example-1.wav"], # ["input/example-2.wav"], ] # Load the model and the corresponding preprocessor config # model = AutoModel.from_pretrained("m-a-p/MERT-v0-public", trust_remote_code=True) # processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v0-public",trust_remote_code=True) model = modeling_MERT.MERTModel.from_pretrained("./MERT-v0-public") processor = Wav2Vec2FeatureExtractor.from_pretrained("./MERT-v0-public") MERT_LAYER_IDX = 7 MTGGenre_classifier = MLPProberBase() MTGGenre_classifier.load_state_dict(torch.load('Prediction_Head/best_MTGGenre.ckpt')['state_dict']) with open('Prediction_Head/MTGGenre_id2class.json', 'r') as f: id2cls=json.load(f) device = 'cuda' if torch.cuda.is_available() else 'cpu' model.to(device) MTGGenre_classifier.to(device) def convert_audio(inputs, microphone): if (microphone is not None): inputs = microphone waveform, sample_rate = torchaudio.load(inputs) resample_rate = processor.sampling_rate # make sure the sample_rate aligned if resample_rate != sample_rate: print(f'setting rate from {sample_rate} to {resample_rate}') resampler = T.Resample(sample_rate, resample_rate) waveform = resampler(waveform) waveform = waveform.view(-1,) # make it (n_sample, ) model_inputs = processor(waveform, sampling_rate=resample_rate, return_tensors="pt") model_inputs.to(device) with torch.no_grad(): model_outputs = model(**model_inputs, output_hidden_states=True) # take a look at the output shape, there are 13 layers of representation # each layer performs differently in different downstream tasks, you should choose empirically all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze() print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim] logits = MTGGenre_classifier(torch.mean(all_layer_hidden_states[MERT_LAYER_IDX], dim=0)) # [1, 87] print(logits.shape) sorted_idx = torch.argsort(logits, dim = -1, descending=True) output_texts = "\n".join([id2cls[str(idx.item())].replace('genre---', '') for idx in sorted_idx[:5]]) # logger.warning(all_layer_hidden_states.shape) # return f"device {device}, sample reprensentation: {str(all_layer_hidden_states[12, 0, :10])}" return f"device: {device}\n" + output_texts def live_convert_audio(microphone): if (microphone is not None): inputs = microphone waveform, sample_rate = torchaudio.load(inputs) resample_rate = processor.sampling_rate # make sure the sample_rate aligned if resample_rate != sample_rate: print(f'setting rate from {sample_rate} to {resample_rate}') resampler = T.Resample(sample_rate, resample_rate) waveform = resampler(waveform) waveform = waveform.view(-1,) # make it (n_sample, ) model_inputs = processor(waveform, sampling_rate=resample_rate, return_tensors="pt") model_inputs.to(device) with torch.no_grad(): model_outputs = model(**model_inputs, output_hidden_states=True) # take a look at the output shape, there are 13 layers of representation # each layer performs differently in different downstream tasks, you should choose empirically all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze() print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim] logits = MTGGenre_classifier(torch.mean(all_layer_hidden_states[MERT_LAYER_IDX], dim=0)) # [1, 87] print(logits.shape) sorted_idx = torch.argsort(logits, dim = -1, descending=True) output_texts = "\n".join([id2cls[str(idx.item())].replace('genre---', '') for idx in sorted_idx[:5]]) # logger.warning(all_layer_hidden_states.shape) # return f"device {device}, sample reprensentation: {str(all_layer_hidden_states[12, 0, :10])}" return f"device: {device}\n" + output_texts audio_chunked = gr.Interface( fn=convert_audio, inputs=inputs, outputs=[gr.components.Textbox()], allow_flagging="never", title=title, description=description, article=article, examples=audio_examples, ) live_audio_chunked = gr.Interface( fn=live_convert_audio, inputs=live_inputs, outputs=[gr.components.Textbox()], allow_flagging="never", title=title, description=description, article=article, # examples=audio_examples, live=True, ) demo = gr.Blocks() with demo: gr.TabbedInterface( [ audio_chunked, live_audio_chunked, ], [ "Audio File or Recording", "Live Streaming Music" ] ) demo.queue(concurrency_count=1, max_size=5) demo.launch(show_api=False)