|
""" |
|
A simple implementation of conditional flow matching for generating anime faces. |
|
""" |
|
|
|
import argparse |
|
import pickle |
|
import random |
|
import time |
|
from pathlib import Path |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
import kagglehub |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import optax |
|
import ot |
|
import yaml |
|
from flax import nnx |
|
from jax.experimental import ode |
|
from PIL import Image |
|
from tqdm.cli import tqdm |
|
|
|
from model import DiT, DiTConfig |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--config", type=str, default="config.yaml", help="Path to config file" |
|
) |
|
return parser.parse_args() |
|
|
|
|
|
def load_config(config_path): |
|
with open(config_path) as f: |
|
config = yaml.safe_load(f) |
|
return config |
|
|
|
|
|
def gen_data_batches(data, batch_size): |
|
N = data.shape[0] |
|
while True: |
|
random_indices = np.random.choice(N, size=batch_size, replace=False) |
|
batch = data[random_indices] |
|
batch = batch.astype(np.float32) / 256 |
|
yield batch |
|
|
|
|
|
def loss_fn(flow, batch): |
|
xt, t, vt = batch |
|
velocity = flow(xt, t) |
|
loss = jnp.mean(jnp.square(velocity - vt)) |
|
return loss |
|
|
|
|
|
def train_step(flow, optimizer, rngs, batch): |
|
x0, x1 = batch |
|
noise = jax.random.uniform(rngs(), shape=x1.shape, minval=0, maxval=1 / 256) |
|
x1 = x1 + noise |
|
|
|
t = jax.random.uniform(rngs(), (x1.shape[0],), minval=0, maxval=1) |
|
|
|
xt = x0 + (x1 - x0) * t[:, None, None, None] |
|
vt = x1 - x0 |
|
batch = (xt, t, vt) |
|
loss, grads = nnx.value_and_grad(loss_fn)(flow, batch) |
|
optimizer.update(grads) |
|
return loss |
|
|
|
|
|
@jax.jit |
|
def train_step_raw(graphdef, state, batch): |
|
flow, optimizer, rngs = nnx.merge(graphdef, state) |
|
loss = train_step(flow, optimizer, rngs, batch) |
|
_, state = nnx.split((flow, optimizer, rngs)) |
|
return state, loss |
|
|
|
|
|
@jax.jit |
|
def sample_images(graphdef, state): |
|
flow, _, _ = nnx.merge(graphdef, state) |
|
|
|
def flow_fn(y, t): |
|
o = flow(y, t[None]) |
|
return o |
|
|
|
x = jax.random.normal(nnx.Rngs(0)(), shape=(16, 64, 64, 3), dtype=jnp.float32) |
|
o = ode.odeint(flow_fn, x, jnp.linspace(0, 1, 1000)) |
|
o = jnp.clip(o[-1], 0, 1) |
|
return o |
|
|
|
|
|
def generate_ot_pairs(x1): |
|
n = x1.shape[0] |
|
x0 = np.random.randn(*x1.shape) |
|
d1 = x1.reshape(n, -1) |
|
d0 = x0.reshape(n, -1) |
|
|
|
M = ot.dist(d0, d1) |
|
a, b = np.ones((n,)), np.ones((n,)) |
|
G0 = ot.emd(a, b, M) |
|
d1 = np.matmul(G0, d1) |
|
x1 = d1.reshape(*x1.shape) |
|
return x0, x1 |
|
|
|
|
|
def plot_new_images(step: int, graphdef, state): |
|
images = sample_images(graphdef, state) |
|
|
|
plt.figure(figsize=(2, 2)) |
|
for i in range(16): |
|
plt.subplot(4, 4, i + 1) |
|
plt.imshow(images[i]) |
|
plt.axis("off") |
|
plt.subplots_adjust(left=0, bottom=0, top=1, right=1, wspace=0, hspace=0) |
|
plt.savefig(f"images_{step:06d}.png") |
|
plt.close() |
|
|
|
|
|
args = parse_args() |
|
config = load_config(args.config) |
|
|
|
|
|
path = kagglehub.dataset_download("thimac/anime-face-64") |
|
data_path = Path(path) / "64x64" |
|
print("Path to dataset files:", data_path) |
|
|
|
data_dir = data_path |
|
image_files = sorted(data_dir.glob("*.jpg")) |
|
random.Random(config["data"]["random_seed"]).shuffle(image_files) |
|
N = len(image_files) |
|
dataset = np.empty((N, 64, 64, 3), dtype=np.uint8) |
|
for i, file_path in enumerate(tqdm(image_files)): |
|
dataset[i] = Image.open(file_path) |
|
|
|
L = int(N * config["data"]["train_split"]) |
|
train_data = dataset[:L] |
|
test_data = dataset[L:] |
|
|
|
plt.figure(figsize=(2, 2)) |
|
for i in range(16): |
|
plt.subplot(4, 4, i + 1) |
|
plt.imshow(train_data[i]) |
|
plt.axis("off") |
|
plt.subplots_adjust(left=0, bottom=0, top=1, right=1, wspace=0, hspace=0) |
|
plt.savefig("train_data_samples.png") |
|
plt.close() |
|
|
|
scheduler = optax.cosine_onecycle_schedule( |
|
transition_steps=config["training"]["num_steps"], |
|
peak_value=config["training"]["learning_rate"], |
|
pct_start=config["training"]["warmup_pct"], |
|
) |
|
|
|
gradient_transform = optax.chain( |
|
optax.clip_by_global_norm(config["training"]["grad_clip_norm"]), |
|
optax.scale_by_adam(), |
|
optax.scale_by_schedule(scheduler), |
|
optax.add_decayed_weights(config["training"]["weight_decay"]), |
|
optax.scale(-1.0), |
|
) |
|
|
|
dit_config = DiTConfig( |
|
input_dim=config["model"]["input_dim"], |
|
hidden_dim=config["model"]["hidden_dim"], |
|
num_blocks=config["model"]["num_blocks"], |
|
num_heads=config["model"]["num_heads"], |
|
patch_size=config["model"]["patch_size"], |
|
patch_stride=config["model"]["patch_stride"], |
|
time_freq_dim=config["model"]["time_freq_dim"], |
|
time_max_period=config["model"]["time_max_period"], |
|
mlp_ratio=config["model"]["mlp_ratio"], |
|
use_bias=config["model"]["use_bias"], |
|
padding=config["model"]["padding"], |
|
pos_embed_cls_token=config["model"]["pos_embed_cls_token"], |
|
pos_embed_extra_tokens=config["model"]["pos_embed_extra_tokens"], |
|
) |
|
|
|
flow = DiT(dit_config, rngs=nnx.Rngs(0)) |
|
optimizer = nnx.Optimizer(flow, gradient_transform) |
|
|
|
rngs = nnx.Rngs(0) |
|
graphdef, state = nnx.split((flow, optimizer, rngs)) |
|
train_data_iter = gen_data_batches(train_data, config["training"]["batch_size"]) |
|
|
|
start = time.perf_counter() |
|
losses = [] |
|
ckpt_path = config["checkpointing"].get("resume_from_checkpoint") |
|
if ckpt_path: |
|
del state |
|
with open(ckpt_path, "rb") as f: |
|
state = pickle.load(f) |
|
print(f"Resuming from checkpoint {ckpt_path}") |
|
step_str = Path(ckpt_path).stem.split("_")[-1] |
|
start_step = int(step_str) + 1 |
|
else: |
|
start_step = 1 |
|
|
|
for step, batch in enumerate(train_data_iter, start=start_step): |
|
x0, x1 = generate_ot_pairs(batch) |
|
state, loss = train_step_raw(graphdef, state, (x0, x1)) |
|
|
|
if step % 100 == 0: |
|
losses.append(loss.item()) |
|
|
|
if step % config["checkpointing"]["log_every"] == 0: |
|
end = time.perf_counter() |
|
duration = end - start |
|
loss = sum(losses) / len(losses) |
|
start = time.perf_counter() |
|
losses = [] |
|
print(f"step {step:06d} loss {loss:.3f} duration {duration:.3f}s", flush=True) |
|
|
|
if step % config["checkpointing"]["plot_every"] == 0: |
|
plot_new_images(step, graphdef, state) |
|
|
|
if step % config["checkpointing"]["save_every"] == 0: |
|
|
|
with open(f"state_{step:06d}.ckpt", "wb") as f: |
|
pickle.dump(state, f) |
|
|