Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,7 @@ from jax.experimental import ode
|
|
5 |
import yaml
|
6 |
from flax import nnx
|
7 |
import pickle
|
|
|
8 |
|
9 |
|
10 |
def load_model(config_path, ckpt_path):
|
@@ -38,7 +39,7 @@ def sample_images(graphdef, state, x0, t):
|
|
38 |
o = jnp.clip(o[-1], 0, 1)
|
39 |
return o
|
40 |
|
41 |
-
|
42 |
def generate_grid(seed, noise_level):
|
43 |
# Load model (doing this inside function to avoid global variables)
|
44 |
graphdef, state = load_model("config.yaml", "ckpt_1000k.pkl")
|
@@ -64,6 +65,9 @@ def generate_grid(seed, noise_level):
|
|
64 |
|
65 |
return jax.device_get(grid)
|
66 |
|
|
|
|
|
|
|
67 |
# Create Gradio interface
|
68 |
demo = gr.Interface(
|
69 |
fn=generate_grid,
|
@@ -77,4 +81,4 @@ demo = gr.Interface(
|
|
77 |
)
|
78 |
|
79 |
if __name__ == "__main__":
|
80 |
-
demo.launch(
|
|
|
5 |
import yaml
|
6 |
from flax import nnx
|
7 |
import pickle
|
8 |
+
import spaces
|
9 |
|
10 |
|
11 |
def load_model(config_path, ckpt_path):
|
|
|
39 |
o = jnp.clip(o[-1], 0, 1)
|
40 |
return o
|
41 |
|
42 |
+
@spaces.GPU
|
43 |
def generate_grid(seed, noise_level):
|
44 |
# Load model (doing this inside function to avoid global variables)
|
45 |
graphdef, state = load_model("config.yaml", "ckpt_1000k.pkl")
|
|
|
65 |
|
66 |
return jax.device_get(grid)
|
67 |
|
68 |
+
|
69 |
+
generate_grid(0, 3)
|
70 |
+
|
71 |
# Create Gradio interface
|
72 |
demo = gr.Interface(
|
73 |
fn=generate_grid,
|
|
|
81 |
)
|
82 |
|
83 |
if __name__ == "__main__":
|
84 |
+
demo.launch()
|