# modified starting from HuggingFace diffusers train_dreambooth.py example # https://github.com/huggingface/diffusers/blob/024c4376fb19caa85275c038f071b6e1446a5cad/examples/dreambooth/train_dreambooth.py import os from pathlib import Path import torch import torch.nn.functional as F import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from PIL import Image from tqdm.auto import tqdm from diffusers import AutoencoderKL, StableDiffusionPipeline from torchvision.utils import make_grid import numpy as np from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( download_from_original_stable_diffusion_ckpt, ) from diffusers import StableDiffusionControlNetPipeline, ControlNetModel from diffusers.schedulers import UniPCMultistepScheduler from .data import PNGDataModule logger = get_logger(__name__) class Lab(Accelerator): def __init__(self, args, control_pipe=None): self.cond_key = "prompts" self.target_key = "images" self.args = args self.output_dir = Path(args.output_dir) logging_dir = str(self.output_dir / "logs") accelerator_project_config = ProjectConfiguration( logging_dir=logging_dir, ) super().__init__( mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, ) if self.mixed_precision == "fp16": self.weight_dtype = torch.float16 elif self.mixed_precision == "bf16": self.weight_dtype = torch.bfloat16 else: self.weight_dtype = torch.float32 if args.seed is not None: set_seed(args.seed) if control_pipe is None: control_pipe = self.load_pipe( args.pretrained_model_name_or_path, args.controlnet_weights_path ) self.control_pipe = control_pipe vae = control_pipe.vae unet = control_pipe.unet text_encoder = control_pipe.text_encoder tokenizer = control_pipe.tokenizer controlnet = ( control_pipe.controlnet if hasattr(control_pipe, "controlnet") else None ) self.noise_scheduler = UniPCMultistepScheduler.from_config(control_pipe.scheduler.config) vae.requires_grad_(False) text_encoder.requires_grad_(False) if controlnet: unet.requires_grad_(False) if args.training_stage == "zero convolutions": controlnet.requires_grad_(False) controlnet.controlnet_down_blocks.requires_grad_(True) controlnet.controlnet_mid_block.requires_grad_(True) # optimize only the zero convolution weights params_to_optimize = list( controlnet.controlnet_down_blocks.parameters() ) + list(controlnet.controlnet_mid_block.parameters()) elif args.training_stage == "input hint blocks": controlnet.requires_grad_(False) controlnet.controlnet_cond_embedding.requires_grad_(True) params_to_optimize = list( controlnet.controlnet_cond_embedding.parameters() ) else: controlnet.requires_grad_(True) params_to_optimize = list(controlnet.parameters()) else: unet.requires_grad_(True) params_to_optimize = list(unet.parameters()) self.params_to_optimize = params_to_optimize args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.batch_size * self.num_processes ) if args.use_8bit_adam: import bitsandbytes as bnb optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW self.optimizer = self.prepare( optimizer_class( params_to_optimize, lr=args.learning_rate, ) ) if args.enable_xformers_memory_efficient_attention: unet.enable_xformers_memory_efficient_attention() if controlnet: controlnet.enable_xformers_memory_efficient_attention() if args.gradient_checkpointing: unet.enable_gradient_checkpointing() if controlnet: controlnet.enable_gradient_checkpointing() torch.backends.cuda.matmul.allow_tf32 = True datamodule = PNGDataModule( tokenizer=tokenizer, from_hf_hub=args.from_hf_hub, resolution=[args.resolution, args.resolution], target_key=self.target_key, cond_key=self.cond_key, persistent_workers=True, num_workers=args.dataloader_num_workers, batch_size=args.batch_size, controlnet_hint_key=None if controlnet is None else args.controlnet_hint_key, ) self.train_dataloader = self.prepare( datamodule.get_dataloader(args.train_data_dir, shuffle=True) ) if args.valid_data_dir: self.valid_dataloader = self.prepare( datamodule.get_dataloader(args.valid_data_dir) ) self.vae = vae.to(self.device, dtype=self.weight_dtype) self.text_encoder = text_encoder.to(self.device, dtype=self.weight_dtype) if controlnet: controlnet = self.prepare(controlnet) self.controlnet = controlnet.to(self.device, dtype=torch.float32) self.unet = unet.to(self.device, dtype=self.weight_dtype) else: unet = self.prepare(unet) self.unet = unet.to(self.device, dtype=torch.float32) self.controlnet = None def load_pipe(self, sd_model_path, controlnet_path=None): if self.args.vae_path: vae = AutoencoderKL.from_pretrained( self.args.vae_path, torch_dtype=self.weight_dtype ) if os.path.isfile(sd_model_path): file_ext = sd_model_path.rsplit(".", 1)[-1] from_safetensors = file_ext == "safetensors" pipe = download_from_original_stable_diffusion_ckpt( sd_model_path, from_safetensors=from_safetensors, device="cpu", load_safety_checker=False, ) pipe.safety_checker = None pipe.feature_extractor = None if self.args.vae_path: pipe.vae = vae else: if self.args.vae_path: kw_args = dict(vae=vae) else: kw_args = dict() pipe = StableDiffusionPipeline.from_pretrained( sd_model_path, safety_checker=None, feature_extractor=None, requires_safety_checker=False, torch_dtype=self.weight_dtype, **kw_args ) if not controlnet_path: return pipe pathobj = Path(controlnet_path) if pathobj.is_file(): controlnet = ControlNetModel.from_config( ControlNetModel.load_config("configs/controlnet_config.json") ) controlnet.load_weights_from_sd_ckpt(controlnet_path) else: controlnet_path = str(Path().joinpath(*pathobj.parts[:-1])) subfolder = str(pathobj.parts[-1]) controlnet = ControlNetModel.from_pretrained( controlnet_path, subfolder=subfolder, low_cpu_mem_usage=False, device_map=None, ) return StableDiffusionControlNetPipeline( **pipe.components, controlnet=controlnet, requires_safety_checker=False, ) @torch.autocast("cuda") def compute_loss(self, batch): images = batch[self.target_key].to(dtype=self.weight_dtype) latents = self.vae.encode(images).latent_dist.sample() latents = latents * self.vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn_like(latents) # Sample a random timestep for each image timesteps = torch.randint( 0, self.noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device, ) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning encoder_hidden_states = self.text_encoder(batch[self.cond_key])[0] if self.controlnet: if self.args.controlnet_hint_key in batch: controlnet_hint = batch[self.args.controlnet_hint_key].to( dtype=self.weight_dtype ) else: controlnet_hint = torch.zeros(images.shape).to(images) down_block_res_samples, mid_block_res_sample = self.controlnet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, controlnet_cond=controlnet_hint, return_dict=False, ) else: down_block_res_samples, mid_block_res_sample = None, None noise_pred = self.unet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, ).sample # Get the target for loss depending on the prediction type if self.noise_scheduler.config.prediction_type == "epsilon": target = noise elif self.noise_scheduler.config.prediction_type == "v_prediction": target = self.noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError( f"Unknown prediction type {self.noise_scheduler.config.prediction_type}" ) loss = F.mse_loss(noise_pred, target, reduction="mean") return loss, encoder_hidden_states def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents output_latents = self.vae.decode(latents).sample output_latents = (output_latents / 2 + 0.5).clamp(0, 1) return output_latents @torch.no_grad() @torch.autocast("cuda") def log_images(self, batch, encoder_hidden_states, cond_scales=[0.0, 0.5, 1.0]): input_tensors = batch[self.target_key].to(self.weight_dtype) input_tensors = (input_tensors / 2 + 0.5).clamp(0, 1) tensors_to_log = [input_tensors.cpu()] [height, width] = input_tensors.shape[-2:] if self.controlnet: if self.args.controlnet_hint_key in batch: controlnet_hint = batch[self.args.controlnet_hint_key].to( self.weight_dtype ) else: controlnet_hint = None for cond_scale in cond_scales: latents = self.control_pipe( image=controlnet_hint, prompt_embeds=encoder_hidden_states, controlnet_conditioning_scale=cond_scale, height=height, width=width, output_type="latent", num_inference_steps=25, )[0] tensors_to_log.append(self.decode_latents(latents).detach().cpu()) if controlnet_hint is not None: tensors_to_log.append(controlnet_hint.detach().cpu()) else: latents = self.control_pipe( prompt_embeds=encoder_hidden_states, height=height, width=width, output_type="latent", num_inference_steps=25, )[0] tensors_to_log.append(self.decode_latents(latents).detach().cpu()) image_tensors = torch.cat(tensors_to_log) grid = make_grid(image_tensors, normalize=False, nrow=input_tensors.shape[0]) grid = grid.permute(1, 2, 0).squeeze(-1) * 255 grid = grid.numpy().astype(np.uint8) image_grid = Image.fromarray(grid) image_grid.save(Path(self.trackers[0].logging_dir) / f"{self.global_step}.png") def save_weights(self, to_safetensors=True): save_dir = self.output_dir / f"checkpoint-{self.global_step}" os.makedirs(save_dir, exist_ok=True) if self.args.save_whole_pipeline: self.control_pipe.save_pretrained( str(save_dir), safe_serialization=to_safetensors ) elif self.controlnet: self.controlnet.save_pretrained( str(save_dir / "controlnet"), safe_serialization=to_safetensors ) else: self.unet.save_pretrained( str(save_dir / "unet"), safe_serialization=to_safetensors ) def train(self, num_train_epochs=1000, gr_progress = None): args = self.args if args.num_train_epochs: num_train_epochs = args.num_train_epochs max_train_steps = ( num_train_epochs * len(self.train_dataloader) // args.gradient_accumulation_steps ) if self.is_main_process: self.init_trackers("tb_logs", config=vars(args)) self.global_step = 0 # Only show the progress bar once on each machine. progress_bar = tqdm( range(max_train_steps), disable=not self.is_local_main_process, ) progress_bar.set_description("Steps") try: for epoch in range(num_train_epochs): # run training loop if gr_progress is not None: gr_progress(0, desc=f"Starting Epoch {epoch}") if self.controlnet: self.controlnet.train() else: self.unet.train() for i, batch in enumerate(self.train_dataloader): loss, encoder_hidden_states = self.compute_loss(batch) loss /= args.gradient_accumulation_steps self.backward(loss) if self.global_step % args.gradient_accumulation_steps == 0: if self.sync_gradients: self.clip_grad_norm_( self.params_to_optimize, args.max_grad_norm ) self.optimizer.step() self.optimizer.zero_grad() # Checks if the accelerator has performed an optimization step behind the scenes if self.sync_gradients: progress_bar.update(1) if gr_progress is not None: gr_progress(float(i/len(self.train_dataloader))) self.global_step += 1 if self.is_main_process: if self.global_step % args.checkpointing_steps == 0: self.save_weights() if args.image_logging_steps and ( self.global_step % args.image_logging_steps == 0 or self.global_step == 1 ): self.log_images(batch, encoder_hidden_states) logs = {"training_loss": loss.detach().item()} self.log(logs, step=self.global_step) progress_bar.set_postfix(**logs) if self.global_step >= max_train_steps: break self.wait_for_everyone() # run validation loop if args.valid_data_dir: total_valid_loss = 0 if self.controlnet: self.controlnet.eval() else: self.unet.eval() for batch in self.valid_dataloader: with torch.no_grad(): loss, encoder_hidden_states = self.compute_loss(batch) loss = loss.detach().item() total_valid_loss += loss logs = {"validation_loss": loss} progress_bar.set_postfix(**logs) self.log( { "validation_loss": total_valid_loss / len(self.valid_dataloader) }, step=self.global_step, ) self.wait_for_everyone() except KeyboardInterrupt: print("Keyboard interrupt detected, attempting to save trained weights") # except Exception as e: # print(f"Encountered error {e}, attempting to save trained weights") self.save_weights() self.end_training()