UTMOSv2 / app.py
kAIto47802
Fix and add quick option
8537948
raw
history blame
2.97 kB
import importlib
from types import SimpleNamespace
import gradio as gr
import pandas as pd
import spaces
import torch
from utmosv2.utils import get_dataset, get_model
description = (
"# πŸš€ UTMOSv2 demo\n\n"
"[![GitHub](https://img.shields.io/badge/-GitHub-181717.svg?logo=github&style=flat)](https://github.com/sarulab-speech/UTMOSv2)\n\n"
"This is a demonstration of MOS prediction using UTMOSv2. "
"This demonstration only accepts `.wav` format. Best at 16 kHz sampling rate."
)
device = torch.device("cuda")
config = importlib.import_module("utmosv2.config.fusion_stage3")
cfg = SimpleNamespace(**{attr: getattr(config, attr) for attr in config.__dict__ if not attr.startswith("__")})
cfg.reproduce = False
cfg.config = "fusion_stage3"
cfg.print_config = False
cfg.data_config = None
cfg.phase = "inference"
cfg.num_workers = 1
@spaces.GPU
@torch.inference_mode()
def predict_mos(audio_path: str, domain: str, quick: bool) -> float:
data = pd.DataFrame({"file_path": [audio_path]})
data["dataset"] = domain
data["mos"] = 0
preds = 0.0
for fold in range(5):
cfg.now_fold = fold
cfg.weight = f"models/fusion_stage3/fold{fold}_s42_best_model.pth"
model = get_model(cfg, device).eval()
for _ in range(5):
test_dataset = get_dataset(cfg, data, "test")
p = model(*[torch.tensor(t).unsqueeze(0).to(device) for t in test_dataset[0][:-1]])
preds += p.cpu().numpy()[0][0]
if quick:
return preds
preds /= 25.0
return preds
with gr.Blocks() as demo:
gr.Markdown(description)
with gr.Row():
with gr.Column():
audio = gr.Audio(type="filepath", label="Audio")
domain = gr.Dropdown(
[
"sarulab",
"bvcc",
"somos",
"blizzard2008",
"blizzard2009",
"blizzard2010-EH1",
"blizzard2010-EH2",
"blizzard2010-ES1",
"blizzard2010-ES3",
"blizzard2011",
],
label="Data-domain ID for the MOS prediction",
value="sarulab",
)
quick = gr.Checkbox(
label="Quick prediction",
value=True,
info=(
"UTMOSv2 makes predictions repeatedly for five randomly selected frames "
"of the input speech waveform for all five folds. "
"To make quick predictions by reducing this to a single repetition, "
"check this checkbox:",
),
)
submit = gr.Button(value="Submit")
with gr.Column():
output = gr.Textbox(label="Predicted MOS", type="text")
submit.click(fn=predict_mos, inputs=[audio, domain, quick], outputs=[output])
demo.queue().launch()