File size: 2,242 Bytes
0c9bb32 da24dea 0c9bb32 e8774dc 0c9bb32 e8774dc 0c9bb32 e8774dc 0c9bb32 e8774dc 0c9bb32 e8774dc da24dea 0c9bb32 da24dea 0c9bb32 d4e9c9c 0c9bb32 da24dea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import gradio as gr
import jax
import jax.numpy as jnp
from jax.experimental import ode
import yaml
from flax import nnx
import pickle
import spaces
def load_model(config_path, ckpt_path):
# Load config
with open(config_path) as f:
config = yaml.safe_load(f)
# Load model and state
with open(ckpt_path, "rb") as f:
leaves = pickle.load(f)
leaves = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), leaves)
from model import DiT, DiTConfig
dit_config = DiTConfig(**config["model"])
model = nnx.eval_shape(lambda: DiT(dit_config, rngs=nnx.Rngs(0)))
graphdef, state = nnx.split(model)
_, treedef = jax.tree_util.tree_flatten(state)
state = jax.tree_util.tree_unflatten(treedef, leaves)
return graphdef, state
@jax.jit
def sample_images(graphdef, state, x0, t):
flow = nnx.merge(graphdef, state)
def flow_fn(y, t):
y = y.astype(jnp.bfloat16)
t = t.astype(jnp.bfloat16)
o = flow(y, t[None])
return o.astype(jnp.float32)
o = ode.odeint(flow_fn, x0, t, rtol=1e-4)
o = jnp.clip(o[-1], 0, 1)
return o
@spaces.GPU
def generate_grid(seed, noise_level):
# Load model (doing this inside function to avoid global variables)
graphdef, state = load_model("config.yaml", "ckpt_1000k.pkl")
t = jnp.linspace(0, 1, 2)
x0 = jax.random.truncated_normal(
nnx.Rngs(seed)(),
-noise_level,
noise_level,
shape=(16, 64, 64, 3),
dtype=jnp.float32,
)
# Generate images
images = sample_images(graphdef, state, x0, t)
# Convert to grid of 4x4
rows = []
for i in range(4):
row = jnp.concatenate(images[i * 4 : (i + 1) * 4], axis=1)
rows.append(row)
grid = jnp.concatenate(rows, axis=0)
return jax.device_get(grid)
# Create Gradio interface
demo = gr.Interface(
fn=generate_grid,
inputs=[
gr.Number(label="Random Seed", value=0, precision=0),
gr.Slider(minimum=0, maximum=10, value=3.0, label="Noise Scale"),
],
outputs=gr.Image(label="Generated Images"),
title="Anime Flow",
description="Generate a 4x4 grid of anime faces using Anime Flow",
)
if __name__ == "__main__":
demo.launch()
|