File size: 3,832 Bytes
f3f28d3
 
 
 
 
 
aca4b77
f3f28d3
 
aca4b77
f3f28d3
 
 
 
 
 
 
 
0faef57
7e86f55
aca4b77
ccb7c0b
 
aca4b77
f3f28d3
 
 
 
 
 
 
aca4b77
f3f28d3
 
012fbfa
 
9a7b1fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
012fbfa
f3f28d3
012fbfa
f3f28d3
 
9a7b1fe
012fbfa
 
f3f28d3
 
bceb583
9a7b1fe
 
b3760ba
 
 
9a7b1fe
 
 
 
 
 
bceb583
 
9a7b1fe
 
 
 
 
 
 
 
f3f28d3
b3760ba
f3f28d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

import os
import json
import numpy as np
import torch
import soundfile as sf
import gradio as gr
from diffusers import DDPMScheduler
from pico_model import PicoDiffusion, build_pretrained_models
from audioldm.variational_autoencoder.autoencoder import AutoencoderKL

class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

class InferRunner:
    def __init__(self, device):
        vae_config = json.load(open("ckpts/ldm/vae_config.json"))
        self.vae = AutoencoderKL(**vae_config).to(device)
        vae_weights = torch.load("ckpts/ldm/pytorch_model_vae.bin", map_location=device)
        self.vae.load_state_dict(vae_weights)

        train_args = dotdict(json.loads(open("ckpts/pico_model/summary.jsonl").readlines()[0]))
        self.pico_model = PicoDiffusion(
            scheduler_name=train_args.scheduler_name, 
            unet_model_config_path=train_args.unet_model_config, 
            snr_gamma=train_args.snr_gamma,
            freeze_text_encoder_ckpt="ckpts/laion_clap/630k-audioset-best.pt",
            diffusion_pt="ckpts/pico_model/diffusion.pt",
        ).eval().to(device)
        self.scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")

device = "cuda" if torch.cuda.is_available() else "cpu"
runner = InferRunner(device)
event_list = [
            "burping_belching",             # 0
            "car_horn_honking",             #
            "cat_meowing",                  #    
            "cow_mooing",                   #
            "dog_barking",                  #  
            "door_knocking",                #
            "door_slamming",                #
            "explosion",                    #  
            "gunshot",                      # 8
            "sheep_goat_bleating",          #
            "sneeze",                       #
            "spraying",                     # 
            "thump_thud",                   #   
            "train_horn",                   #
            "tapping_clicking_clanking",    #
            "woman_laughing",               #         
            "duck_quacking",                # 16   
            "whistling",                    #    
        ]
def infer(caption, num_steps=200, guidance_scale=3.0, audio_len=16000*10):
    with torch.no_grad():
        latents = runner.pico_model.demo_inference(caption, runner.scheduler, num_steps=num_steps, guidance_scale=guidance_scale, num_samples_per_prompt=1, disable_progress=True)
        mel = runner.vae.decode_first_stage(latents)
        wave = runner.vae.decode_to_waveform(mel)[0][:audio_len]
    outpath = f"output.wav"
    sf.write(outpath, wave, samplerate=16000, subtype='PCM_16')
    return outpath


description_text = f"18 events: {', '.join(event_list)}"
prompt = gr.Textbox(label="Prompt: Input your caption formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'.",
    value="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",)
outaudio = gr.Audio()
num_steps = gr.Slider(label="num_steps", minimum=1, maximum=300, value=200, step=1)
guidance_scale = gr.Slider(label="guidance_scale", minimum=0.1, maximum=8.0, value=3.0, step=0.1)    


gr_interface = gr.Interface(
    fn=infer,
    inputs=[prompt, num_steps, guidance_scale], 
    outputs=[outaudio],
    title="PicoAudio",
    description=description_text,
    allow_flagging=False,
    examples=[
        ["spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031."],
        ["dog_barking at 0.562-2.562_4.25-6.25."],
        ["cow_mooing at 0.958-3.582_5.272-7.896."],
    ],
    cache_examples="lazy", # Turn on to cache.
)
    
gr_interface.queue(10).launch()