|
import gradio as gr |
|
import jax |
|
import jax.numpy as jnp |
|
from jax.experimental import ode |
|
import yaml |
|
from flax import nnx |
|
import pickle |
|
import spaces |
|
|
|
|
|
def load_model(config_path, ckpt_path): |
|
|
|
with open(config_path) as f: |
|
config = yaml.safe_load(f) |
|
|
|
|
|
with open(ckpt_path, "rb") as f: |
|
leaves = pickle.load(f) |
|
|
|
leaves = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), leaves) |
|
|
|
from model import DiT, DiTConfig |
|
|
|
dit_config = DiTConfig(**config["model"]) |
|
model = nnx.eval_shape(lambda: DiT(dit_config, rngs=nnx.Rngs(0))) |
|
graphdef, state = nnx.split(model) |
|
_, treedef = jax.tree_util.tree_flatten(state) |
|
state = jax.tree_util.tree_unflatten(treedef, leaves) |
|
return graphdef, state |
|
|
|
|
|
@jax.jit |
|
def sample_images(graphdef, state, x0, t): |
|
flow = nnx.merge(graphdef, state) |
|
|
|
def flow_fn(y, t): |
|
y = y.astype(jnp.bfloat16) |
|
t = t.astype(jnp.bfloat16) |
|
o = flow(y, t[None]) |
|
return o.astype(jnp.float32) |
|
|
|
o = ode.odeint(flow_fn, x0, t, rtol=1e-4) |
|
o = jnp.clip(o[-1], 0, 1) |
|
return o |
|
|
|
|
|
@spaces.GPU |
|
def generate_grid(seed, noise_level): |
|
|
|
graphdef, state = load_model("config.yaml", "ckpt_1000k.pkl") |
|
|
|
t = jnp.linspace(0, 1, 2) |
|
x0 = jax.random.truncated_normal( |
|
nnx.Rngs(seed)(), |
|
-noise_level, |
|
noise_level, |
|
shape=(16, 64, 64, 3), |
|
dtype=jnp.float32, |
|
) |
|
|
|
|
|
images = sample_images(graphdef, state, x0, t) |
|
|
|
|
|
rows = [] |
|
for i in range(4): |
|
row = jnp.concatenate(images[i * 4 : (i + 1) * 4], axis=1) |
|
rows.append(row) |
|
grid = jnp.concatenate(rows, axis=0) |
|
|
|
return jax.device_get(grid) |
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=generate_grid, |
|
inputs=[ |
|
gr.Number(label="Random Seed", value=0, precision=0), |
|
gr.Slider(minimum=0, maximum=10, value=3.0, label="Noise Scale"), |
|
], |
|
outputs=gr.Image(label="Generated Images"), |
|
title="Anime Flow", |
|
description="Generate a 4x4 grid of anime faces using Anime Flow", |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|