ntt123 commited on
Commit
da24dea
·
verified ·
1 Parent(s): 0c9bb32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -5,6 +5,7 @@ from jax.experimental import ode
5
  import yaml
6
  from flax import nnx
7
  import pickle
 
8
 
9
 
10
  def load_model(config_path, ckpt_path):
@@ -38,7 +39,7 @@ def sample_images(graphdef, state, x0, t):
38
  o = jnp.clip(o[-1], 0, 1)
39
  return o
40
 
41
-
42
  def generate_grid(seed, noise_level):
43
  # Load model (doing this inside function to avoid global variables)
44
  graphdef, state = load_model("config.yaml", "ckpt_1000k.pkl")
@@ -64,6 +65,9 @@ def generate_grid(seed, noise_level):
64
 
65
  return jax.device_get(grid)
66
 
 
 
 
67
  # Create Gradio interface
68
  demo = gr.Interface(
69
  fn=generate_grid,
@@ -77,4 +81,4 @@ demo = gr.Interface(
77
  )
78
 
79
  if __name__ == "__main__":
80
- demo.launch(share=True)
 
5
  import yaml
6
  from flax import nnx
7
  import pickle
8
+ import spaces
9
 
10
 
11
  def load_model(config_path, ckpt_path):
 
39
  o = jnp.clip(o[-1], 0, 1)
40
  return o
41
 
42
+ @spaces.GPU
43
  def generate_grid(seed, noise_level):
44
  # Load model (doing this inside function to avoid global variables)
45
  graphdef, state = load_model("config.yaml", "ckpt_1000k.pkl")
 
65
 
66
  return jax.device_get(grid)
67
 
68
+
69
+ generate_grid(0, 3)
70
+
71
  # Create Gradio interface
72
  demo = gr.Interface(
73
  fn=generate_grid,
 
81
  )
82
 
83
  if __name__ == "__main__":
84
+ demo.launch()