import json import shlex import subprocess import gradio as gr import numpy as np import requests import timm import torch import torch.nn.functional as F from torchaudio.compliance import kaldi TAG = "gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k" MODEL = timm.create_model(f"hf_hub:{TAG}", pretrained=True).eval() LABEL_URL = "https://huggingface.co/datasets/huggingface/label-files/raw/main/audioset-id2label.json" AUDIOSET_LABELS = list(json.loads(requests.get(LABEL_URL).content).values()) SAMPLING_RATE = 16_000 MEAN = -4.2677393 STD = 4.5689974 def resample(x: np.ndarray, sr: int): cmd = f"ffmpeg -ar {sr} -f s16le -i - -ar {SAMPLING_RATE} -f f32le -" proc = subprocess.run(shlex.split(cmd), capture_output=True, input=x.tobytes()) return np.frombuffer(proc.stdout, dtype=np.float32) def preprocess(x: torch.Tensor): x = x - x.mean() melspec = kaldi.fbank(x.unsqueeze(0), htk_compat=True, window_type="hanning", num_mel_bins=128) if melspec.shape[0] < 1024: melspec = F.pad(melspec, (0, 0, 0, 1024 - melspec.shape[0])) else: melspec = melspec[:1024] melspec = (melspec - MEAN) / (STD * 2) return melspec.view(1, 1, 1024, 128) def predict(audio, start): sr, x = audio if x.shape[0] < start * sr: raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({x.shape[0] / sr:.0f}s)") x = resample(x[int(start * sr) :], sr) x = torch.from_numpy(x) with torch.inference_mode(): logits = MODEL(preprocess(x)).squeeze(0) topk_probs, topk_classes = logits.sigmoid().topk(10) return [[AUDIOSET_LABELS[cls], prob.item() * 100] for cls, prob in zip(topk_classes, topk_probs)] gr.Interface( fn=predict, inputs=["audio", "number"], outputs="dataframe", examples=[["LS_female_1462-170138-0008.flac", 0], ["LS_male_3170-137482-0005.flac", 0]], ).launch()