|
import gradio as gr |
|
import torch |
|
from transformers import pipeline |
|
|
|
username = "ardneebwar" |
|
model_id = f"{username}/distilhubert-finetuned-gtzan" |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
pipe = pipeline("audio-classification", model=model_id, device=device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_audio(filepath): |
|
import time |
|
start_time = time.time() |
|
|
|
|
|
preds = pipe(filepath) |
|
|
|
outputs = {} |
|
for p in preds: |
|
outputs[p["label"]] = p["score"] |
|
|
|
end_time = time.time() |
|
prediction_time = end_time - start_time |
|
|
|
return outputs, prediction_time |
|
|
|
|
|
title = "π΅ Music Genre Classifier" |
|
description = """ |
|
Music Genre Classifier model (Fine-tuned "ntu-spml/distilhubert") Dataset: [GTZAN](https://huggingface.co/datasets/marsyas/gtzan) |
|
""" |
|
|
|
filenames = ['rock-it-21275.mp3'] |
|
filenames = [f"./{f}" for f in filenames] |
|
|
|
demo = gr.Interface( |
|
fn=classify_audio, |
|
inputs=gr.Audio(type="filepath"), |
|
outputs=[gr.Label(), gr.Number(label="Prediction time (s)")], |
|
title=title, |
|
description=description, |
|
examples=[(f,) for f in filenames], |
|
) |
|
|
|
|
|
demo.launch() |
|
|