AnimeFlow / sample.py
ntt123's picture
Upload folder using huggingface_hub
0c9bb32 verified
"""
Generate images from trained model
"""
import argparse
import pickle
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import yaml
from flax import nnx
from jax.experimental import ode
from model import DiT, DiTConfig
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", type=str, default="config.yaml", help="Path to config file"
)
parser.add_argument(
"--ckpt", type=str, default=None, help="Path to checkpoint file"
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
return parser.parse_args()
def load_config(config_path):
with open(config_path) as f:
config = yaml.safe_load(f)
return config
@jax.jit
def sample_images(graphdef, state, rng):
flow = nnx.merge(graphdef, state)
def flow_fn(y, t):
o = flow(y, t[None])
return o
x = jax.random.normal(rng, shape=(16, 64, 64, 3), dtype=jnp.float32)
o = ode.odeint(flow_fn, x, jnp.linspace(0, 1, 1000))
o = jnp.clip(o[-1], 0, 1)
return o
def plot_new_images(graphdef, state, seed):
images = sample_images(graphdef, state, nnx.Rngs(seed)())
plt.figure(figsize=(2, 2))
for i in range(16):
plt.subplot(4, 4, i + 1)
plt.imshow(images[i])
plt.axis("off")
plt.subplots_adjust(left=0, bottom=0, top=1, right=1, wspace=0, hspace=0)
plt.savefig(f"samples.png")
plt.close()
def main():
args = parse_args()
config = load_config(args.config)
dit_config = DiTConfig(
input_dim=config["model"]["input_dim"],
hidden_dim=config["model"]["hidden_dim"],
num_blocks=config["model"]["num_blocks"],
num_heads=config["model"]["num_heads"],
patch_size=config["model"]["patch_size"],
patch_stride=config["model"]["patch_stride"],
time_freq_dim=config["model"]["time_freq_dim"],
time_max_period=config["model"]["time_max_period"],
mlp_ratio=config["model"]["mlp_ratio"],
use_bias=config["model"]["use_bias"],
padding=config["model"]["padding"],
pos_embed_cls_token=config["model"]["pos_embed_cls_token"],
pos_embed_extra_tokens=config["model"]["pos_embed_extra_tokens"],
)
abstract_flow = nnx.eval_shape(lambda: DiT(dit_config, rngs=nnx.Rngs(0)))
graphdef, _ = nnx.split(abstract_flow)
with open(args.ckpt, "rb") as f:
state = pickle.load(f, fix_imports=True)
if "time_embedding" not in state:
state = state[0]
plot_new_images(graphdef, state, args.seed)
if __name__ == "__main__":
main()