Spaces:
Sleeping
Sleeping
File size: 1,900 Bytes
cac3ec7 bcc0935 cac3ec7 bcc0935 cac3ec7 bcc0935 cac3ec7 bcc0935 cac3ec7 bcc0935 cac3ec7 bcc0935 cac3ec7 bcc0935 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import json
import shlex
import subprocess
import gradio as gr
import numpy as np
import requests
import timm
import torch
import torch.nn.functional as F
from torchaudio.compliance import kaldi
TAG = "gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k"
MODEL = timm.create_model(f"hf_hub:{TAG}", pretrained=True).eval()
LABEL_URL = "https://huggingface.co/datasets/huggingface/label-files/raw/main/audioset-id2label.json"
AUDIOSET_LABELS = list(json.loads(requests.get(LABEL_URL).content).values())
SAMPLING_RATE = 16_000
MEAN = -4.2677393
STD = 4.5689974
def resample(x: np.ndarray, sr: int):
cmd = f"ffmpeg -ar {sr} -f s16le -i - -ar {SAMPLING_RATE} -f f32le -"
proc = subprocess.run(shlex.split(cmd), capture_output=True, input=x.tobytes())
return np.frombuffer(proc.stdout, dtype=np.float32)
def preprocess(x: torch.Tensor):
x = x - x.mean()
melspec = kaldi.fbank(x.unsqueeze(0), htk_compat=True, window_type="hanning", num_mel_bins=128)
if melspec.shape[0] < 1024:
melspec = F.pad(melspec, (0, 0, 0, 1024 - melspec.shape[0]))
else:
melspec = melspec[:1024]
melspec = (melspec - MEAN) / (STD * 2)
return melspec.view(1, 1, 1024, 128)
def predict(audio, start):
sr, x = audio
if x.shape[0] < start * sr:
raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({x.shape[0] / sr:.0f}s)")
x = resample(x[int(start * sr) :], sr)
x = torch.from_numpy(x)
with torch.inference_mode():
logits = MODEL(preprocess(x)).squeeze(0)
topk_probs, topk_classes = logits.sigmoid().topk(10)
return [[AUDIOSET_LABELS[cls], prob.item() * 100] for cls, prob in zip(topk_classes, topk_probs)]
gr.Interface(
fn=predict,
inputs=["audio", "number"],
outputs="dataframe",
examples=[["LS_female_1462-170138-0008.flac", 0], ["LS_male_3170-137482-0005.flac", 0]],
).launch()
|