Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,832 Bytes
f3f28d3 aca4b77 f3f28d3 aca4b77 f3f28d3 0faef57 7e86f55 aca4b77 ccb7c0b aca4b77 f3f28d3 aca4b77 f3f28d3 012fbfa 9a7b1fe 012fbfa f3f28d3 012fbfa f3f28d3 9a7b1fe 012fbfa f3f28d3 bceb583 9a7b1fe b3760ba 9a7b1fe bceb583 9a7b1fe f3f28d3 b3760ba f3f28d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import os
import json
import numpy as np
import torch
import soundfile as sf
import gradio as gr
from diffusers import DDPMScheduler
from pico_model import PicoDiffusion, build_pretrained_models
from audioldm.variational_autoencoder.autoencoder import AutoencoderKL
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
class InferRunner:
def __init__(self, device):
vae_config = json.load(open("ckpts/ldm/vae_config.json"))
self.vae = AutoencoderKL(**vae_config).to(device)
vae_weights = torch.load("ckpts/ldm/pytorch_model_vae.bin", map_location=device)
self.vae.load_state_dict(vae_weights)
train_args = dotdict(json.loads(open("ckpts/pico_model/summary.jsonl").readlines()[0]))
self.pico_model = PicoDiffusion(
scheduler_name=train_args.scheduler_name,
unet_model_config_path=train_args.unet_model_config,
snr_gamma=train_args.snr_gamma,
freeze_text_encoder_ckpt="ckpts/laion_clap/630k-audioset-best.pt",
diffusion_pt="ckpts/pico_model/diffusion.pt",
).eval().to(device)
self.scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")
device = "cuda" if torch.cuda.is_available() else "cpu"
runner = InferRunner(device)
event_list = [
"burping_belching", # 0
"car_horn_honking", #
"cat_meowing", #
"cow_mooing", #
"dog_barking", #
"door_knocking", #
"door_slamming", #
"explosion", #
"gunshot", # 8
"sheep_goat_bleating", #
"sneeze", #
"spraying", #
"thump_thud", #
"train_horn", #
"tapping_clicking_clanking", #
"woman_laughing", #
"duck_quacking", # 16
"whistling", #
]
def infer(caption, num_steps=200, guidance_scale=3.0, audio_len=16000*10):
with torch.no_grad():
latents = runner.pico_model.demo_inference(caption, runner.scheduler, num_steps=num_steps, guidance_scale=guidance_scale, num_samples_per_prompt=1, disable_progress=True)
mel = runner.vae.decode_first_stage(latents)
wave = runner.vae.decode_to_waveform(mel)[0][:audio_len]
outpath = f"output.wav"
sf.write(outpath, wave, samplerate=16000, subtype='PCM_16')
return outpath
description_text = f"18 events: {', '.join(event_list)}"
prompt = gr.Textbox(label="Prompt: Input your caption formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'.",
value="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",)
outaudio = gr.Audio()
num_steps = gr.Slider(label="num_steps", minimum=1, maximum=300, value=200, step=1)
guidance_scale = gr.Slider(label="guidance_scale", minimum=0.1, maximum=8.0, value=3.0, step=0.1)
gr_interface = gr.Interface(
fn=infer,
inputs=[prompt, num_steps, guidance_scale],
outputs=[outaudio],
title="PicoAudio",
description=description_text,
allow_flagging=False,
examples=[
["spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031."],
["dog_barking at 0.562-2.562_4.25-6.25."],
["cow_mooing at 0.958-3.582_5.272-7.896."],
],
cache_examples="lazy", # Turn on to cache.
)
gr_interface.queue(10).launch()
|