How to train a 16ch VAE decoder
Prerequisites
Read the MNIST example first. Get familiar with pytorch.
Training loop
Train your first module on color images.
for epoch in range(10): # Loop over the dataset multiple times
for i, data in enumerate(dataloader, 0):
# Get the inputs; data is a list of [inputs, labels].
inputs, labels = data
# Zero the parameter gradients.
optimizer.zero_grad()
# Forward + backward + optimize.
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
Compare images
And measuring similarity.
Stable Diffusion calculates the loss:
- by MSE or
- by MSE + LPIPS * 0.1
import numpy as np
from PIL import Image
import torch
from torchmetrics.functional.image import learned_perceptual_image_patch_similarity, spectral_angle_mapper
pil_image = Image.open('vehicle.png').convert('RGB')
# The original image ("target"), and the predicted image ("output").
# In this example, it's a perfect match.
target = torch.from_numpy(np.array(pil_image, dtype=float) / 255.0)
output = target.clamp(-1.0, 1.0).float()
loss_sam = spectral_angle_mapper(output, target)
loss_lpips = learned_perceptual_image_patch_similarity(output, target)
loss_mse = torch.nn.functional.mse_loss(output, target, reduction='mean')
loss = loss_mse + loss_lpips * 0.1
Parquet
Load a downloaded dataset or write your own dataset loader.
import numpy as np
from torch.utils.data import Dataset
from torchvision.transforms import Compose
from torchvision.transforms.functional import pil_to_tensor
class ImageDataset(Dataset):
def __init__(self, transform: Compose):
super().__init__()
self.transform = transform
# Initialize image sources.
...
def __getitem__(self, idx):
# Load your images from urls, folders or downloaded datasets.
image = ...
image = image.convert('RGB')
# All images must have the same size.
assert image.width == image.height
return {'image': self.transform(image),
'pil_image': image,
'image_array': np.array(image),
'image_tensor': pil_to_tensor(image)}
Dataloader
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
# Call by transform(pil_image).
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# Option A.
dataset = load_dataset(...)
dataset.set_transform(transform)
dataloader = DataLoader(dataset, batch_size=2)
# Option B.
dataloader = DataLoader(ImageDataset(transform), batch_size=2)
# Option C.
dataloader = DataLoader(ImageFolder('path/to/image/folder', transform=transform), batch_size=2)
Rescaling the latent space
In the latent space of a KL-regularized model, the model tends to allocate a lot of semantic detail early on. This means that the model is capturing too much information too quickly.
By rescaling the latent space values by their component-wise standard deviation, the authors are effectively reducing the SNR.
import diffusers
from PIL import Image
import torch
import torchvision.transforms
from torchvision.transforms.functional import to_pil_image
def encode_image(image):
if isinstance(image, Image.Image):
# Transform, normalize the image.
image = transform(image)
image_tensor = image * 2.0 - 1.0
image_tensor = image_tensor.to(cuda_device)
latent = vae.encode(image_tensor)
# SD1.5, SDXL
# vae.config.scaling_factor * latent.latent_dist.sample(), image
# SD3
return vae.config.scaling_factor * (latent.latent_dist.sample() - vae.config.shift_factor), image
def decode_image(latent: torch.Tensor):
# SD1.5, SDXL
# latent = 1 / vae.config.scaling_factor * latent
# SD3
latent = 1 / vae.config.scaling_factor * latent + vae.config.shift_factor
image_tensor = vae.decode(latent)
return diffusers.utils.pt_to_pil(image_tensor)
Regression model
In the previous training exercises, you focused on classification tasks, where the goal was to predict a categorical label or class from a set of predefined categories. Unlike classification, regression models aim to establish a relationship between the inputs and outputs, makes predictions on a continuous scale.
Copy the sd3_impls python file to your folder. You will have two models, a pretrained stable diffusion model, and your VAE decoder.
from diffusers import AutoencoderKL
from safetensors.torch import load_model, save_model
from sd3_impls import VAEDecoder
import torch
from torch import nn
class SDVAE(nn.Module):
def __init__(self):
# Choose between the (newer) torch.bfloat16 and (older) torch.float32.
self.decoder = VAEDecoder(dtype=torch.float32)
def decode(self, latent):
return self.decoder(latent)
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
vae = AutoencoderKL.from_pretrained(
'path/to/pretrained/model',
subfolder='vae',
revision=None,
variant=None
).to(cuda_device)
vae.requires_grad_(False) # Otherwise OOM.
# Option A.
# Train a new model.
model = SDVAE().to(cuda_device)
# Option B.
# Continue training from your previous checkpoint.
model = SDVAE().to(cuda_device)
load_model(model, 'vae.safetensors')
dataloader = ...
optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)
for epoch in range(10):
for batch in dataloader:
# Read the 'image' column from the dataset,
# preferably it will already be a normalized tensor.
norm_img, *_ = batch.values()
optimizer.zero_grad()
# Exercise to the reader:
# Encode all images before the training loop,
# store them in the latents[] list.
...
# Forward.
output = model.decode(model_input)
...
# Compare tensors, the output of the training model will be a bit blurry.
# Exercise to the reader:
# use the MSE + LPIPS * 0.1 formula to improve image fidelity.
loss = torch.nn.functional.mse_loss(output[:, :3, :, :].float(),
norm_img[:, :3, :, :].float(),
reduction='mean')
loss.backward()
optimizer.step()
save_model(model, './epoch_{:03d}.safetensors'.format(epoch))
save_model(model, 'vae.safetensors')
Note to model creators
Open-source training code and tagged datasets are essential for the community. Why don't you share your know-how starting from now?