""" |
A simple implementation of conditional flow matching for generating anime faces. |
""" |
import argparse |
import pickle |
import random |
import time |
from pathlib import Path |
import jax |
import jax.numpy as jnp |
import kagglehub |
import matplotlib.pyplot as plt |
import numpy as np |
import optax |
import ot |
import yaml |
from flax import nnx |
from jax.experimental import ode |
from PIL import Image |
from tqdm.cli import tqdm |
from model import DiT, DiTConfig |
def parse_args(): |
parser = argparse.ArgumentParser() |
parser.add_argument( |
"--config", type=str, default="config.yaml", help="Path to config file" |
) |
return parser.parse_args() |
def load_config(config_path): |
with open(config_path) as f: |
config = yaml.safe_load(f) |
return config |
def gen_data_batches(data, batch_size): |
N = data.shape[0] |
while True: |
random_indices = np.random.choice(N, size=batch_size, replace=False) |
batch = data[random_indices] |
batch = batch.astype(np.float32) / 256 |
yield batch |
def loss_fn(flow, batch): |
xt, t, vt = batch |
velocity = flow(xt, t) |
loss = jnp.mean(jnp.square(velocity - vt)) |
return loss |
def train_step(flow, optimizer, rngs, batch): |
x0, x1 = batch |
noise = jax.random.uniform(rngs(), shape=x1.shape, minval=0, maxval=1 / 256) |
x1 = x1 + noise |
t = jax.random.uniform(rngs(), (x1.shape[0],), minval=0, maxval=1) |
xt = x0 + (x1 - x0) * t[:, None, None, None] |
vt = x1 - x0 |
batch = (xt, t, vt) |
loss, grads = nnx.value_and_grad(loss_fn)(flow, batch) |
optimizer.update(grads) |
return loss |
@jax.jit |
def train_step_raw(graphdef, state, batch): |
flow, optimizer, rngs = nnx.merge(graphdef, state) |
loss = train_step(flow, optimizer, rngs, batch) |
_, state = nnx.split((flow, optimizer, rngs)) |
return state, loss |
@jax.jit |
def sample_images(graphdef, state): |
flow, _, _ = nnx.merge(graphdef, state) |
def flow_fn(y, t): |
o = flow(y, t[None]) |
return o |
x = jax.random.normal(nnx.Rngs(0)(), shape=(16, 64, 64, 3), dtype=jnp.float32) |
o = ode.odeint(flow_fn, x, jnp.linspace(0, 1, 1000)) |
o = jnp.clip(o[-1], 0, 1) |
return o |
def generate_ot_pairs(x1): |
n = x1.shape[0] |
x0 = np.random.randn(*x1.shape) |
d1 = x1.reshape(n, -1) |
d0 = x0.reshape(n, -1) |
M = ot.dist(d0, d1) |
a, b = np.ones((n,)), np.ones((n,)) |
G0 = ot.emd(a, b, M) |
d1 = np.matmul(G0, d1) |
x1 = d1.reshape(*x1.shape) |
return x0, x1 |
def plot_new_images(step: int, graphdef, state): |
images = sample_images(graphdef, state) |
plt.figure(figsize=(2, 2)) |
for i in range(16): |
plt.subplot(4, 4, i + 1) |
plt.imshow(images[i]) |
plt.axis("off") |
plt.subplots_adjust(left=0, bottom=0, top=1, right=1, wspace=0, hspace=0) |
plt.savefig(f"images_{step:06d}.png") |
plt.close() |
args = parse_args() |
config = load_config(args.config) |
path = kagglehub.dataset_download("thimac/anime-face-64") |
data_path = Path(path) / "64x64" |
print("Path to dataset files:", data_path) |
data_dir = data_path |
image_files = sorted(data_dir.glob("*.jpg")) |
random.Random(config["data"]["random_seed"]).shuffle(image_files) |
N = len(image_files) |
dataset = np.empty((N, 64, 64, 3), dtype=np.uint8) |
for i, file_path in enumerate(tqdm(image_files)): |
dataset[i] = Image.open(file_path) |
L = int(N * config["data"]["train_split"]) |
train_data = dataset[:L] |
test_data = dataset[L:] |
plt.figure(figsize=(2, 2)) |
for i in range(16): |
plt.subplot(4, 4, i + 1) |
plt.imshow(train_data[i]) |
plt.axis("off") |
plt.subplots_adjust(left=0, bottom=0, top=1, right=1, wspace=0, hspace=0) |
plt.savefig("train_data_samples.png") |
plt.close() |
scheduler = optax.cosine_onecycle_schedule( |
transition_steps=config["training"]["num_steps"], |
peak_value=config["training"]["learning_rate"], |
pct_start=config["training"]["warmup_pct"], |
) |
gradient_transform = optax.chain( |
optax.clip_by_global_norm(config["training"]["grad_clip_norm"]), |
optax.scale_by_adam(), |
optax.scale_by_schedule(scheduler), |
optax.add_decayed_weights(config["training"]["weight_decay"]), |
optax.scale(-1.0), |
) |
dit_config = DiTConfig( |
input_dim=config["model"]["input_dim"], |
hidden_dim=config["model"]["hidden_dim"], |
num_blocks=config["model"]["num_blocks"], |
num_heads=config["model"]["num_heads"], |
patch_size=config["model"]["patch_size"], |
patch_stride=config["model"]["patch_stride"], |
time_freq_dim=config["model"]["time_freq_dim"], |
time_max_period=config["model"]["time_max_period"], |
mlp_ratio=config["model"]["mlp_ratio"], |
use_bias=config["model"]["use_bias"], |
padding=config["model"]["padding"], |
pos_embed_cls_token=config["model"]["pos_embed_cls_token"], |
pos_embed_extra_tokens=config["model"]["pos_embed_extra_tokens"], |
) |
flow = DiT(dit_config, rngs=nnx.Rngs(0)) |
optimizer = nnx.Optimizer(flow, gradient_transform) |
rngs = nnx.Rngs(0) |
graphdef, state = nnx.split((flow, optimizer, rngs)) |
train_data_iter = gen_data_batches(train_data, config["training"]["batch_size"]) |
start = time.perf_counter() |
losses = [] |
ckpt_path = config["checkpointing"].get("resume_from_checkpoint") |
if ckpt_path: |
del state |
with open(ckpt_path, "rb") as f: |
state = pickle.load(f) |
print(f"Resuming from checkpoint {ckpt_path}") |
step_str = Path(ckpt_path).stem.split("_")[-1] |
start_step = int(step_str) + 1 |
else: |
start_step = 1 |
for step, batch in enumerate(train_data_iter, start=start_step): |
x0, x1 = generate_ot_pairs(batch) |
state, loss = train_step_raw(graphdef, state, (x0, x1)) |
if step % 100 == 0: |
losses.append(loss.item()) |
if step % config["checkpointing"]["log_every"] == 0: |
end = time.perf_counter() |
duration = end - start |
loss = sum(losses) / len(losses) |
start = time.perf_counter() |
losses = [] |
print(f"step {step:06d} loss {loss:.3f} duration {duration:.3f}s", flush=True) |
if step % config["checkpointing"]["plot_every"] == 0: |
plot_new_images(step, graphdef, state) |
if step % config["checkpointing"]["save_every"] == 0: |
with open(f"state_{step:06d}.ckpt", "wb") as f: |
pickle.dump(state, f) |