ZeyuXie commited on
Commit
aca4b77
·
verified ·
1 Parent(s): 5f1d3d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -2
app.py CHANGED
@@ -4,8 +4,10 @@ import json
4
  import numpy as np
5
  import torch
6
  import soundfile as sf
 
7
  from diffusers import DDPMScheduler
8
  from pico_model import PicoDiffusion, build_pretrained_models
 
9
 
10
  class dotdict(dict):
11
  """dot.notation access to dictionary attributes"""
@@ -15,7 +17,11 @@ class dotdict(dict):
15
 
16
  class InferRunner:
17
  def __init__(self):
18
- self.vae, _ = build_pretrained_models("audioldm-s-full")
 
 
 
 
19
  train_args = dotdict(json.loads(open("ckpts/pico_model/summary.jsonl").readlines()[0]))
20
  self.pico_model = PicoDiffusion(
21
  scheduler_name=train_args.scheduler_name,
@@ -23,7 +29,7 @@ class InferRunner:
23
  snr_gamma=train_args.snr_gamma,
24
  freeze_text_encoder_ckpt="ckpts/laion_clap/630k-audioset-best.pt",
25
  diffusion_pt="ckpts/pico_model/diffusion.pt",
26
- ).cuda().eval()
27
  self.scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")
28
 
29
  def infer(caption, runner):
@@ -34,6 +40,12 @@ def infer(caption, runner):
34
  sf.write(f"synthesized/{caption}.wav", wave, samplerate=16000, subtype='PCM_16')
35
 
36
  infer_runner = InferRunner()
 
 
 
 
 
 
37
 
38
  with gr.Blocks() as demo:
39
  with gr.Row():
 
4
  import numpy as np
5
  import torch
6
  import soundfile as sf
7
+ import gradio as gr
8
  from diffusers import DDPMScheduler
9
  from pico_model import PicoDiffusion, build_pretrained_models
10
+ from audioldm.variational_autoencoder.autoencoder import AutoencoderKL
11
 
12
  class dotdict(dict):
13
  """dot.notation access to dictionary attributes"""
 
17
 
18
  class InferRunner:
19
  def __init__(self):
20
+ vae_config = json.load(open("ckpts/ldm/vae_config.json".format(path)))
21
+ self.vae = AutoencoderKL(**vae_config).to(device)
22
+ vae_weights = torch.load("ckpts/ldm/pytorch_model_vae.bin".format(path), map_location=device)
23
+ self.vae.load_state_dict(vae_weights)
24
+
25
  train_args = dotdict(json.loads(open("ckpts/pico_model/summary.jsonl").readlines()[0]))
26
  self.pico_model = PicoDiffusion(
27
  scheduler_name=train_args.scheduler_name,
 
29
  snr_gamma=train_args.snr_gamma,
30
  freeze_text_encoder_ckpt="ckpts/laion_clap/630k-audioset-best.pt",
31
  diffusion_pt="ckpts/pico_model/diffusion.pt",
32
+ ).eval().to(device)
33
  self.scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")
34
 
35
  def infer(caption, runner):
 
40
  sf.write(f"synthesized/{caption}.wav", wave, samplerate=16000, subtype='PCM_16')
41
 
42
  infer_runner = InferRunner()
43
+ if torch.cuda.is_available():
44
+ device = "cuda"
45
+ device_selection = "cuda:0"
46
+ else:
47
+ device = "cpu"
48
+ device_selection = "cpu"
49
 
50
  with gr.Blocks() as demo:
51
  with gr.Row():