gaunernst's picture
fix preprocessing. add examples
bcc0935
raw
history blame
1.9 kB
import json
import shlex
import subprocess
import gradio as gr
import numpy as np
import requests
import timm
import torch
import torch.nn.functional as F
from torchaudio.compliance import kaldi
TAG = "gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k"
MODEL = timm.create_model(f"hf_hub:{TAG}", pretrained=True).eval()
LABEL_URL = "https://huggingface.co/datasets/huggingface/label-files/raw/main/audioset-id2label.json"
AUDIOSET_LABELS = list(json.loads(requests.get(LABEL_URL).content).values())
SAMPLING_RATE = 16_000
MEAN = -4.2677393
STD = 4.5689974
def resample(x: np.ndarray, sr: int):
cmd = f"ffmpeg -ar {sr} -f s16le -i - -ar {SAMPLING_RATE} -f f32le -"
proc = subprocess.run(shlex.split(cmd), capture_output=True, input=x.tobytes())
return np.frombuffer(proc.stdout, dtype=np.float32)
def preprocess(x: torch.Tensor):
x = x - x.mean()
melspec = kaldi.fbank(x.unsqueeze(0), htk_compat=True, window_type="hanning", num_mel_bins=128)
if melspec.shape[0] < 1024:
melspec = F.pad(melspec, (0, 0, 0, 1024 - melspec.shape[0]))
else:
melspec = melspec[:1024]
melspec = (melspec - MEAN) / (STD * 2)
return melspec.view(1, 1, 1024, 128)
def predict(audio, start):
sr, x = audio
if x.shape[0] < start * sr:
raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({x.shape[0] / sr:.0f}s)")
x = resample(x[int(start * sr) :], sr)
x = torch.from_numpy(x)
with torch.inference_mode():
logits = MODEL(preprocess(x)).squeeze(0)
topk_probs, topk_classes = logits.sigmoid().topk(10)
return [[AUDIOSET_LABELS[cls], prob.item() * 100] for cls, prob in zip(topk_classes, topk_probs)]
gr.Interface(
fn=predict,
inputs=["audio", "number"],
outputs="dataframe",
examples=[["LS_female_1462-170138-0008.flac", 0], ["LS_male_3170-137482-0005.flac", 0]],
).launch()