File size: 2,242 Bytes
0c9bb32
 
 
 
 
 
 
da24dea
0c9bb32
 
 
 
 
 
 
 
 
 
 
e8774dc
 
0c9bb32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8774dc
 
0c9bb32
e8774dc
0c9bb32
e8774dc
0c9bb32
 
 
e8774dc
da24dea
0c9bb32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da24dea
0c9bb32
 
 
 
 
 
 
 
d4e9c9c
0c9bb32
 
 
 
da24dea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
import jax
import jax.numpy as jnp
from jax.experimental import ode
import yaml
from flax import nnx
import pickle
import spaces


def load_model(config_path, ckpt_path):
    # Load config
    with open(config_path) as f:
        config = yaml.safe_load(f)

    # Load model and state
    with open(ckpt_path, "rb") as f:
        leaves = pickle.load(f)

    leaves = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), leaves)

    from model import DiT, DiTConfig

    dit_config = DiTConfig(**config["model"])
    model = nnx.eval_shape(lambda: DiT(dit_config, rngs=nnx.Rngs(0)))
    graphdef, state = nnx.split(model)
    _, treedef = jax.tree_util.tree_flatten(state)
    state = jax.tree_util.tree_unflatten(treedef, leaves)
    return graphdef, state


@jax.jit
def sample_images(graphdef, state, x0, t):
    flow = nnx.merge(graphdef, state)

    def flow_fn(y, t):
        y = y.astype(jnp.bfloat16)
        t = t.astype(jnp.bfloat16)
        o = flow(y, t[None])
        return o.astype(jnp.float32)

    o = ode.odeint(flow_fn, x0, t, rtol=1e-4)
    o = jnp.clip(o[-1], 0, 1)
    return o


@spaces.GPU
def generate_grid(seed, noise_level):
    # Load model (doing this inside function to avoid global variables)
    graphdef, state = load_model("config.yaml", "ckpt_1000k.pkl")

    t = jnp.linspace(0, 1, 2)
    x0 = jax.random.truncated_normal(
        nnx.Rngs(seed)(),
        -noise_level,
        noise_level,
        shape=(16, 64, 64, 3),
        dtype=jnp.float32,
    )

    # Generate images
    images = sample_images(graphdef, state, x0, t)

    # Convert to grid of 4x4
    rows = []
    for i in range(4):
        row = jnp.concatenate(images[i * 4 : (i + 1) * 4], axis=1)
        rows.append(row)
    grid = jnp.concatenate(rows, axis=0)

    return jax.device_get(grid)


# Create Gradio interface
demo = gr.Interface(
    fn=generate_grid,
    inputs=[
        gr.Number(label="Random Seed", value=0, precision=0),
        gr.Slider(minimum=0, maximum=10, value=3.0, label="Noise Scale"),
    ],
    outputs=gr.Image(label="Generated Images"),
    title="Anime Flow",
    description="Generate a 4x4 grid of anime faces using Anime Flow",
)

if __name__ == "__main__":
    demo.launch()