import gradio as gr import jax import jax.numpy as jnp from jax.experimental import ode import yaml from flax import nnx import pickle import spaces def load_model(config_path, ckpt_path): # Load config with open(config_path) as f: config = yaml.safe_load(f) # Load model and state with open(ckpt_path, "rb") as f: leaves = pickle.load(f) leaves = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), leaves) from model import DiT, DiTConfig dit_config = DiTConfig(**config["model"]) model = nnx.eval_shape(lambda: DiT(dit_config, rngs=nnx.Rngs(0))) graphdef, state = nnx.split(model) _, treedef = jax.tree_util.tree_flatten(state) state = jax.tree_util.tree_unflatten(treedef, leaves) return graphdef, state @jax.jit def sample_images(graphdef, state, x0, t): flow = nnx.merge(graphdef, state) def flow_fn(y, t): y = y.astype(jnp.bfloat16) t = t.astype(jnp.bfloat16) o = flow(y, t[None]) return o.astype(jnp.float32) o = ode.odeint(flow_fn, x0, t, rtol=1e-4) o = jnp.clip(o[-1], 0, 1) return o @spaces.GPU def generate_grid(seed, noise_level): # Load model (doing this inside function to avoid global variables) graphdef, state = load_model("config.yaml", "ckpt_1000k.pkl") t = jnp.linspace(0, 1, 2) x0 = jax.random.truncated_normal( nnx.Rngs(seed)(), -noise_level, noise_level, shape=(16, 64, 64, 3), dtype=jnp.float32, ) # Generate images images = sample_images(graphdef, state, x0, t) # Convert to grid of 4x4 rows = [] for i in range(4): row = jnp.concatenate(images[i * 4 : (i + 1) * 4], axis=1) rows.append(row) grid = jnp.concatenate(rows, axis=0) return jax.device_get(grid) # Create Gradio interface demo = gr.Interface( fn=generate_grid, inputs=[ gr.Number(label="Random Seed", value=0, precision=0), gr.Slider(minimum=0, maximum=10, value=3.0, label="Noise Scale"), ], outputs=gr.Image(label="Generated Images"), title="Anime Flow", description="Generate a 4x4 grid of anime faces using Anime Flow", ) if __name__ == "__main__": demo.launch()