|
""" |
|
Generate images from trained model |
|
""" |
|
|
|
import argparse |
|
import pickle |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
import matplotlib.pyplot as plt |
|
import yaml |
|
from flax import nnx |
|
from jax.experimental import ode |
|
|
|
from model import DiT, DiTConfig |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--config", type=str, default="config.yaml", help="Path to config file" |
|
) |
|
parser.add_argument( |
|
"--ckpt", type=str, default=None, help="Path to checkpoint file" |
|
) |
|
parser.add_argument("--seed", type=int, default=0, help="Random seed") |
|
return parser.parse_args() |
|
|
|
|
|
def load_config(config_path): |
|
with open(config_path) as f: |
|
config = yaml.safe_load(f) |
|
return config |
|
|
|
|
|
@jax.jit |
|
def sample_images(graphdef, state, rng): |
|
flow = nnx.merge(graphdef, state) |
|
|
|
def flow_fn(y, t): |
|
o = flow(y, t[None]) |
|
return o |
|
|
|
x = jax.random.normal(rng, shape=(16, 64, 64, 3), dtype=jnp.float32) |
|
o = ode.odeint(flow_fn, x, jnp.linspace(0, 1, 1000)) |
|
o = jnp.clip(o[-1], 0, 1) |
|
return o |
|
|
|
|
|
def plot_new_images(graphdef, state, seed): |
|
images = sample_images(graphdef, state, nnx.Rngs(seed)()) |
|
|
|
plt.figure(figsize=(2, 2)) |
|
for i in range(16): |
|
plt.subplot(4, 4, i + 1) |
|
plt.imshow(images[i]) |
|
plt.axis("off") |
|
plt.subplots_adjust(left=0, bottom=0, top=1, right=1, wspace=0, hspace=0) |
|
plt.savefig(f"samples.png") |
|
plt.close() |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
config = load_config(args.config) |
|
|
|
dit_config = DiTConfig( |
|
input_dim=config["model"]["input_dim"], |
|
hidden_dim=config["model"]["hidden_dim"], |
|
num_blocks=config["model"]["num_blocks"], |
|
num_heads=config["model"]["num_heads"], |
|
patch_size=config["model"]["patch_size"], |
|
patch_stride=config["model"]["patch_stride"], |
|
time_freq_dim=config["model"]["time_freq_dim"], |
|
time_max_period=config["model"]["time_max_period"], |
|
mlp_ratio=config["model"]["mlp_ratio"], |
|
use_bias=config["model"]["use_bias"], |
|
padding=config["model"]["padding"], |
|
pos_embed_cls_token=config["model"]["pos_embed_cls_token"], |
|
pos_embed_extra_tokens=config["model"]["pos_embed_extra_tokens"], |
|
) |
|
|
|
abstract_flow = nnx.eval_shape(lambda: DiT(dit_config, rngs=nnx.Rngs(0))) |
|
graphdef, _ = nnx.split(abstract_flow) |
|
with open(args.ckpt, "rb") as f: |
|
state = pickle.load(f, fix_imports=True) |
|
if "time_embedding" not in state: |
|
state = state[0] |
|
plot_new_images(graphdef, state, args.seed) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|