1lint
fix and revise app
0d0a1c2
raw
history blame
17.5 kB
# modified starting from HuggingFace diffusers train_dreambooth.py example
# https://github.com/huggingface/diffusers/blob/024c4376fb19caa85275c038f071b6e1446a5cad/examples/dreambooth/train_dreambooth.py
import os
from pathlib import Path
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from PIL import Image
from tqdm.auto import tqdm
from diffusers import AutoencoderKL, StableDiffusionPipeline
from torchvision.utils import make_grid
import numpy as np
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
)
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from diffusers.schedulers import UniPCMultistepScheduler
from .data import PNGDataModule
logger = get_logger(__name__)
class Lab(Accelerator):
def __init__(self, args, control_pipe=None):
self.cond_key = "prompts"
self.target_key = "images"
self.args = args
self.output_dir = Path(args.output_dir)
logging_dir = str(self.output_dir / "logs")
accelerator_project_config = ProjectConfiguration(
logging_dir=logging_dir,
)
super().__init__(
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
if self.mixed_precision == "fp16":
self.weight_dtype = torch.float16
elif self.mixed_precision == "bf16":
self.weight_dtype = torch.bfloat16
else:
self.weight_dtype = torch.float32
if args.seed is not None:
set_seed(args.seed)
if control_pipe is None:
control_pipe = self.load_pipe(
args.pretrained_model_name_or_path, args.controlnet_weights_path
)
self.control_pipe = control_pipe
vae = control_pipe.vae
unet = control_pipe.unet
text_encoder = control_pipe.text_encoder
tokenizer = control_pipe.tokenizer
controlnet = (
control_pipe.controlnet if hasattr(control_pipe, "controlnet") else None
)
self.noise_scheduler = UniPCMultistepScheduler.from_config(control_pipe.scheduler.config)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
if controlnet:
unet.requires_grad_(False)
if args.training_stage == "zero convolutions":
controlnet.requires_grad_(False)
controlnet.controlnet_down_blocks.requires_grad_(True)
controlnet.controlnet_mid_block.requires_grad_(True)
# optimize only the zero convolution weights
params_to_optimize = list(
controlnet.controlnet_down_blocks.parameters()
) + list(controlnet.controlnet_mid_block.parameters())
elif args.training_stage == "input hint blocks":
controlnet.requires_grad_(False)
controlnet.controlnet_cond_embedding.requires_grad_(True)
params_to_optimize = list(
controlnet.controlnet_cond_embedding.parameters()
)
else:
controlnet.requires_grad_(True)
params_to_optimize = list(controlnet.parameters())
else:
unet.requires_grad_(True)
params_to_optimize = list(unet.parameters())
self.params_to_optimize = params_to_optimize
args.learning_rate = (
args.learning_rate
* args.gradient_accumulation_steps
* args.batch_size
* self.num_processes
)
if args.use_8bit_adam:
import bitsandbytes as bnb
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
self.optimizer = self.prepare(
optimizer_class(
params_to_optimize,
lr=args.learning_rate,
)
)
if args.enable_xformers_memory_efficient_attention:
unet.enable_xformers_memory_efficient_attention()
if controlnet:
controlnet.enable_xformers_memory_efficient_attention()
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if controlnet:
controlnet.enable_gradient_checkpointing()
torch.backends.cuda.matmul.allow_tf32 = True
datamodule = PNGDataModule(
tokenizer=tokenizer,
from_hf_hub=args.from_hf_hub,
resolution=[args.resolution, args.resolution],
target_key=self.target_key,
cond_key=self.cond_key,
persistent_workers=True,
num_workers=args.dataloader_num_workers,
batch_size=args.batch_size,
controlnet_hint_key=None if controlnet is None else args.controlnet_hint_key,
)
self.train_dataloader = self.prepare(
datamodule.get_dataloader(args.train_data_dir, shuffle=True)
)
if args.valid_data_dir:
self.valid_dataloader = self.prepare(
datamodule.get_dataloader(args.valid_data_dir)
)
self.vae = vae.to(self.device, dtype=self.weight_dtype)
self.text_encoder = text_encoder.to(self.device, dtype=self.weight_dtype)
if controlnet:
controlnet = self.prepare(controlnet)
self.controlnet = controlnet.to(self.device, dtype=torch.float32)
self.unet = unet.to(self.device, dtype=self.weight_dtype)
else:
unet = self.prepare(unet)
self.unet = unet.to(self.device, dtype=torch.float32)
self.controlnet = None
def load_pipe(self, sd_model_path, controlnet_path=None):
if self.args.vae_path:
vae = AutoencoderKL.from_pretrained(
self.args.vae_path, torch_dtype=self.weight_dtype
)
if os.path.isfile(sd_model_path):
file_ext = sd_model_path.rsplit(".", 1)[-1]
from_safetensors = file_ext == "safetensors"
pipe = download_from_original_stable_diffusion_ckpt(
sd_model_path,
from_safetensors=from_safetensors,
device="cpu",
load_safety_checker=False,
)
pipe.safety_checker = None
pipe.feature_extractor = None
if self.args.vae_path:
pipe.vae = vae
else:
if self.args.vae_path:
kw_args = dict(vae=vae)
else:
kw_args = dict()
pipe = StableDiffusionPipeline.from_pretrained(
sd_model_path,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
torch_dtype=self.weight_dtype,
**kw_args
)
if not controlnet_path:
return pipe
pathobj = Path(controlnet_path)
if pathobj.is_file():
controlnet = ControlNetModel.from_config(
ControlNetModel.load_config("configs/controlnet_config.json")
)
controlnet.load_weights_from_sd_ckpt(controlnet_path)
else:
controlnet_path = str(Path().joinpath(*pathobj.parts[:-1]))
subfolder = str(pathobj.parts[-1])
controlnet = ControlNetModel.from_pretrained(
controlnet_path,
subfolder=subfolder,
low_cpu_mem_usage=False,
device_map=None,
)
return StableDiffusionControlNetPipeline(
**pipe.components,
controlnet=controlnet,
requires_safety_checker=False,
)
@torch.autocast("cuda")
def compute_loss(self, batch):
images = batch[self.target_key].to(dtype=self.weight_dtype)
latents = self.vae.encode(images).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
# Sample a random timestep for each image
timesteps = torch.randint(
0,
self.noise_scheduler.config.num_train_timesteps,
(latents.shape[0],),
device=latents.device,
)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = self.text_encoder(batch[self.cond_key])[0]
if self.controlnet:
if self.args.controlnet_hint_key in batch:
controlnet_hint = batch[self.args.controlnet_hint_key].to(
dtype=self.weight_dtype
)
else:
controlnet_hint = torch.zeros(images.shape).to(images)
down_block_res_samples, mid_block_res_sample = self.controlnet(
noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=controlnet_hint,
return_dict=False,
)
else:
down_block_res_samples, mid_block_res_sample = None, None
noise_pred = self.unet(
noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
).sample
# Get the target for loss depending on the prediction type
if self.noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif self.noise_scheduler.config.prediction_type == "v_prediction":
target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(
f"Unknown prediction type {self.noise_scheduler.config.prediction_type}"
)
loss = F.mse_loss(noise_pred, target, reduction="mean")
return loss, encoder_hidden_states
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
output_latents = self.vae.decode(latents).sample
output_latents = (output_latents / 2 + 0.5).clamp(0, 1)
return output_latents
@torch.no_grad()
@torch.autocast("cuda")
def log_images(self, batch, encoder_hidden_states, cond_scales=[0.0, 0.5, 1.0]):
input_tensors = batch[self.target_key].to(self.weight_dtype)
input_tensors = (input_tensors / 2 + 0.5).clamp(0, 1)
tensors_to_log = [input_tensors.cpu()]
[height, width] = input_tensors.shape[-2:]
if self.controlnet:
if self.args.controlnet_hint_key in batch:
controlnet_hint = batch[self.args.controlnet_hint_key].to(
self.weight_dtype
)
else:
controlnet_hint = None
for cond_scale in cond_scales:
latents = self.control_pipe(
image=controlnet_hint,
prompt_embeds=encoder_hidden_states,
controlnet_conditioning_scale=cond_scale,
height=height,
width=width,
output_type="latent",
num_inference_steps=25,
)[0]
tensors_to_log.append(self.decode_latents(latents).detach().cpu())
if controlnet_hint is not None:
tensors_to_log.append(controlnet_hint.detach().cpu())
else:
latents = self.control_pipe(
prompt_embeds=encoder_hidden_states,
height=height,
width=width,
output_type="latent",
num_inference_steps=25,
)[0]
tensors_to_log.append(self.decode_latents(latents).detach().cpu())
image_tensors = torch.cat(tensors_to_log)
grid = make_grid(image_tensors, normalize=False, nrow=input_tensors.shape[0])
grid = grid.permute(1, 2, 0).squeeze(-1) * 255
grid = grid.numpy().astype(np.uint8)
image_grid = Image.fromarray(grid)
image_grid.save(Path(self.trackers[0].logging_dir) / f"{self.global_step}.png")
def save_weights(self, to_safetensors=True):
save_dir = self.output_dir / f"checkpoint-{self.global_step}"
os.makedirs(save_dir, exist_ok=True)
if self.args.save_whole_pipeline:
self.control_pipe.save_pretrained(
str(save_dir), safe_serialization=to_safetensors
)
elif self.controlnet:
self.controlnet.save_pretrained(
str(save_dir / "controlnet"), safe_serialization=to_safetensors
)
else:
self.unet.save_pretrained(
str(save_dir / "unet"), safe_serialization=to_safetensors
)
def train(self, num_train_epochs=1000, gr_progress = None):
args = self.args
if args.num_train_epochs:
num_train_epochs = args.num_train_epochs
max_train_steps = (
num_train_epochs
* len(self.train_dataloader)
// args.gradient_accumulation_steps
)
if self.is_main_process:
self.init_trackers("tb_logs", config=vars(args))
self.global_step = 0
# Only show the progress bar once on each machine.
progress_bar = tqdm(
range(max_train_steps),
disable=not self.is_local_main_process,
)
progress_bar.set_description("Steps")
try:
for epoch in range(num_train_epochs):
# run training loop
if gr_progress is not None:
gr_progress(0, desc=f"Starting Epoch {epoch}")
if self.controlnet:
self.controlnet.train()
else:
self.unet.train()
for i, batch in enumerate(self.train_dataloader):
loss, encoder_hidden_states = self.compute_loss(batch)
loss /= args.gradient_accumulation_steps
self.backward(loss)
if self.global_step % args.gradient_accumulation_steps == 0:
if self.sync_gradients:
self.clip_grad_norm_(
self.params_to_optimize, args.max_grad_norm
)
self.optimizer.step()
self.optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if self.sync_gradients:
progress_bar.update(1)
if gr_progress is not None:
gr_progress(float(i/len(self.train_dataloader)))
self.global_step += 1
if self.is_main_process:
if self.global_step % args.checkpointing_steps == 0:
self.save_weights()
if args.image_logging_steps and (
self.global_step % args.image_logging_steps == 0
or self.global_step == 1
):
self.log_images(batch, encoder_hidden_states)
logs = {"training_loss": loss.detach().item()}
self.log(logs, step=self.global_step)
progress_bar.set_postfix(**logs)
if self.global_step >= max_train_steps:
break
self.wait_for_everyone()
# run validation loop
if args.valid_data_dir:
total_valid_loss = 0
if self.controlnet:
self.controlnet.eval()
else:
self.unet.eval()
for batch in self.valid_dataloader:
with torch.no_grad():
loss, encoder_hidden_states = self.compute_loss(batch)
loss = loss.detach().item()
total_valid_loss += loss
logs = {"validation_loss": loss}
progress_bar.set_postfix(**logs)
self.log(
{
"validation_loss": total_valid_loss
/ len(self.valid_dataloader)
},
step=self.global_step,
)
self.wait_for_everyone()
except KeyboardInterrupt:
print("Keyboard interrupt detected, attempting to save trained weights")
# except Exception as e:
# print(f"Encountered error {e}, attempting to save trained weights")
self.save_weights()
self.end_training()