AnimeFlow / app.py
ntt123's picture
Update app.py
e8774dc verified
import gradio as gr
import jax
import jax.numpy as jnp
from jax.experimental import ode
import yaml
from flax import nnx
import pickle
import spaces
def load_model(config_path, ckpt_path):
# Load config
with open(config_path) as f:
config = yaml.safe_load(f)
# Load model and state
with open(ckpt_path, "rb") as f:
leaves = pickle.load(f)
leaves = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), leaves)
from model import DiT, DiTConfig
dit_config = DiTConfig(**config["model"])
model = nnx.eval_shape(lambda: DiT(dit_config, rngs=nnx.Rngs(0)))
graphdef, state = nnx.split(model)
_, treedef = jax.tree_util.tree_flatten(state)
state = jax.tree_util.tree_unflatten(treedef, leaves)
return graphdef, state
@jax.jit
def sample_images(graphdef, state, x0, t):
flow = nnx.merge(graphdef, state)
def flow_fn(y, t):
y = y.astype(jnp.bfloat16)
t = t.astype(jnp.bfloat16)
o = flow(y, t[None])
return o.astype(jnp.float32)
o = ode.odeint(flow_fn, x0, t, rtol=1e-4)
o = jnp.clip(o[-1], 0, 1)
return o
@spaces.GPU
def generate_grid(seed, noise_level):
# Load model (doing this inside function to avoid global variables)
graphdef, state = load_model("config.yaml", "ckpt_1000k.pkl")
t = jnp.linspace(0, 1, 2)
x0 = jax.random.truncated_normal(
nnx.Rngs(seed)(),
-noise_level,
noise_level,
shape=(16, 64, 64, 3),
dtype=jnp.float32,
)
# Generate images
images = sample_images(graphdef, state, x0, t)
# Convert to grid of 4x4
rows = []
for i in range(4):
row = jnp.concatenate(images[i * 4 : (i + 1) * 4], axis=1)
rows.append(row)
grid = jnp.concatenate(rows, axis=0)
return jax.device_get(grid)
# Create Gradio interface
demo = gr.Interface(
fn=generate_grid,
inputs=[
gr.Number(label="Random Seed", value=0, precision=0),
gr.Slider(minimum=0, maximum=10, value=3.0, label="Noise Scale"),
],
outputs=gr.Image(label="Generated Images"),
title="Anime Flow",
description="Generate a 4x4 grid of anime faces using Anime Flow",
)
if __name__ == "__main__":
demo.launch()