Spaces:
Runtime error
Runtime error
# modified starting from HuggingFace diffusers train_dreambooth.py example | |
# https://github.com/huggingface/diffusers/blob/024c4376fb19caa85275c038f071b6e1446a5cad/examples/dreambooth/train_dreambooth.py | |
import os | |
from pathlib import Path | |
import torch | |
import torch.nn.functional as F | |
import torch.utils.checkpoint | |
from accelerate import Accelerator | |
from accelerate.logging import get_logger | |
from accelerate.utils import ProjectConfiguration, set_seed | |
from PIL import Image | |
from tqdm.auto import tqdm | |
from diffusers import AutoencoderKL, StableDiffusionPipeline | |
from torchvision.utils import make_grid | |
import numpy as np | |
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( | |
download_from_original_stable_diffusion_ckpt, | |
) | |
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel | |
from diffusers.schedulers import UniPCMultistepScheduler | |
from .data import PNGDataModule | |
logger = get_logger(__name__) | |
class Lab(Accelerator): | |
def __init__(self, args, control_pipe=None): | |
self.cond_key = "prompts" | |
self.target_key = "images" | |
self.args = args | |
self.output_dir = Path(args.output_dir) | |
logging_dir = str(self.output_dir / "logs") | |
accelerator_project_config = ProjectConfiguration( | |
logging_dir=logging_dir, | |
) | |
super().__init__( | |
mixed_precision=args.mixed_precision, | |
log_with=args.report_to, | |
project_config=accelerator_project_config, | |
) | |
if self.mixed_precision == "fp16": | |
self.weight_dtype = torch.float16 | |
elif self.mixed_precision == "bf16": | |
self.weight_dtype = torch.bfloat16 | |
else: | |
self.weight_dtype = torch.float32 | |
if args.seed is not None: | |
set_seed(args.seed) | |
if control_pipe is None: | |
control_pipe = self.load_pipe( | |
args.pretrained_model_name_or_path, args.controlnet_weights_path | |
) | |
self.control_pipe = control_pipe | |
vae = control_pipe.vae | |
unet = control_pipe.unet | |
text_encoder = control_pipe.text_encoder | |
tokenizer = control_pipe.tokenizer | |
controlnet = ( | |
control_pipe.controlnet if hasattr(control_pipe, "controlnet") else None | |
) | |
self.noise_scheduler = UniPCMultistepScheduler.from_config(control_pipe.scheduler.config) | |
vae.requires_grad_(False) | |
text_encoder.requires_grad_(False) | |
if controlnet: | |
unet.requires_grad_(False) | |
if args.training_stage == "zero convolutions": | |
controlnet.requires_grad_(False) | |
controlnet.controlnet_down_blocks.requires_grad_(True) | |
controlnet.controlnet_mid_block.requires_grad_(True) | |
# optimize only the zero convolution weights | |
params_to_optimize = list( | |
controlnet.controlnet_down_blocks.parameters() | |
) + list(controlnet.controlnet_mid_block.parameters()) | |
elif args.training_stage == "input hint blocks": | |
controlnet.requires_grad_(False) | |
controlnet.controlnet_cond_embedding.requires_grad_(True) | |
params_to_optimize = list( | |
controlnet.controlnet_cond_embedding.parameters() | |
) | |
else: | |
controlnet.requires_grad_(True) | |
params_to_optimize = list(controlnet.parameters()) | |
else: | |
unet.requires_grad_(True) | |
params_to_optimize = list(unet.parameters()) | |
self.params_to_optimize = params_to_optimize | |
args.learning_rate = ( | |
args.learning_rate | |
* args.gradient_accumulation_steps | |
* args.batch_size | |
* self.num_processes | |
) | |
if args.use_8bit_adam: | |
import bitsandbytes as bnb | |
optimizer_class = bnb.optim.AdamW8bit | |
else: | |
optimizer_class = torch.optim.AdamW | |
self.optimizer = self.prepare( | |
optimizer_class( | |
params_to_optimize, | |
lr=args.learning_rate, | |
) | |
) | |
if args.enable_xformers_memory_efficient_attention: | |
unet.enable_xformers_memory_efficient_attention() | |
if controlnet: | |
controlnet.enable_xformers_memory_efficient_attention() | |
if args.gradient_checkpointing: | |
unet.enable_gradient_checkpointing() | |
if controlnet: | |
controlnet.enable_gradient_checkpointing() | |
torch.backends.cuda.matmul.allow_tf32 = True | |
datamodule = PNGDataModule( | |
tokenizer=tokenizer, | |
from_hf_hub=args.from_hf_hub, | |
resolution=[args.resolution, args.resolution], | |
target_key=self.target_key, | |
cond_key=self.cond_key, | |
persistent_workers=True, | |
num_workers=args.dataloader_num_workers, | |
batch_size=args.batch_size, | |
controlnet_hint_key=None if controlnet is None else args.controlnet_hint_key, | |
) | |
self.train_dataloader = self.prepare( | |
datamodule.get_dataloader(args.train_data_dir, shuffle=True) | |
) | |
if args.valid_data_dir: | |
self.valid_dataloader = self.prepare( | |
datamodule.get_dataloader(args.valid_data_dir) | |
) | |
self.vae = vae.to(self.device, dtype=self.weight_dtype) | |
self.text_encoder = text_encoder.to(self.device, dtype=self.weight_dtype) | |
if controlnet: | |
controlnet = self.prepare(controlnet) | |
self.controlnet = controlnet.to(self.device, dtype=torch.float32) | |
self.unet = unet.to(self.device, dtype=self.weight_dtype) | |
else: | |
unet = self.prepare(unet) | |
self.unet = unet.to(self.device, dtype=torch.float32) | |
self.controlnet = None | |
def load_pipe(self, sd_model_path, controlnet_path=None): | |
if self.args.vae_path: | |
vae = AutoencoderKL.from_pretrained( | |
self.args.vae_path, torch_dtype=self.weight_dtype | |
) | |
if os.path.isfile(sd_model_path): | |
file_ext = sd_model_path.rsplit(".", 1)[-1] | |
from_safetensors = file_ext == "safetensors" | |
pipe = download_from_original_stable_diffusion_ckpt( | |
sd_model_path, | |
from_safetensors=from_safetensors, | |
device="cpu", | |
load_safety_checker=False, | |
) | |
pipe.safety_checker = None | |
pipe.feature_extractor = None | |
if self.args.vae_path: | |
pipe.vae = vae | |
else: | |
if self.args.vae_path: | |
kw_args = dict(vae=vae) | |
else: | |
kw_args = dict() | |
pipe = StableDiffusionPipeline.from_pretrained( | |
sd_model_path, | |
safety_checker=None, | |
feature_extractor=None, | |
requires_safety_checker=False, | |
torch_dtype=self.weight_dtype, | |
**kw_args | |
) | |
if not controlnet_path: | |
return pipe | |
pathobj = Path(controlnet_path) | |
if pathobj.is_file(): | |
controlnet = ControlNetModel.from_config( | |
ControlNetModel.load_config("configs/controlnet_config.json") | |
) | |
controlnet.load_weights_from_sd_ckpt(controlnet_path) | |
else: | |
controlnet_path = str(Path().joinpath(*pathobj.parts[:-1])) | |
subfolder = str(pathobj.parts[-1]) | |
controlnet = ControlNetModel.from_pretrained( | |
controlnet_path, | |
subfolder=subfolder, | |
low_cpu_mem_usage=False, | |
device_map=None, | |
) | |
return StableDiffusionControlNetPipeline( | |
**pipe.components, | |
controlnet=controlnet, | |
requires_safety_checker=False, | |
) | |
def compute_loss(self, batch): | |
images = batch[self.target_key].to(dtype=self.weight_dtype) | |
latents = self.vae.encode(images).latent_dist.sample() | |
latents = latents * self.vae.config.scaling_factor | |
# Sample noise that we'll add to the latents | |
noise = torch.randn_like(latents) | |
# Sample a random timestep for each image | |
timesteps = torch.randint( | |
0, | |
self.noise_scheduler.config.num_train_timesteps, | |
(latents.shape[0],), | |
device=latents.device, | |
) | |
timesteps = timesteps.long() | |
# Add noise to the latents according to the noise magnitude at each timestep | |
# (this is the forward diffusion process) | |
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) | |
# Get the text embedding for conditioning | |
encoder_hidden_states = self.text_encoder(batch[self.cond_key])[0] | |
if self.controlnet: | |
if self.args.controlnet_hint_key in batch: | |
controlnet_hint = batch[self.args.controlnet_hint_key].to( | |
dtype=self.weight_dtype | |
) | |
else: | |
controlnet_hint = torch.zeros(images.shape).to(images) | |
down_block_res_samples, mid_block_res_sample = self.controlnet( | |
noisy_latents, | |
timesteps, | |
encoder_hidden_states=encoder_hidden_states, | |
controlnet_cond=controlnet_hint, | |
return_dict=False, | |
) | |
else: | |
down_block_res_samples, mid_block_res_sample = None, None | |
noise_pred = self.unet( | |
noisy_latents, | |
timesteps, | |
encoder_hidden_states=encoder_hidden_states, | |
down_block_additional_residuals=down_block_res_samples, | |
mid_block_additional_residual=mid_block_res_sample, | |
).sample | |
# Get the target for loss depending on the prediction type | |
if self.noise_scheduler.config.prediction_type == "epsilon": | |
target = noise | |
elif self.noise_scheduler.config.prediction_type == "v_prediction": | |
target = self.noise_scheduler.get_velocity(latents, noise, timesteps) | |
else: | |
raise ValueError( | |
f"Unknown prediction type {self.noise_scheduler.config.prediction_type}" | |
) | |
loss = F.mse_loss(noise_pred, target, reduction="mean") | |
return loss, encoder_hidden_states | |
def decode_latents(self, latents): | |
latents = 1 / self.vae.config.scaling_factor * latents | |
output_latents = self.vae.decode(latents).sample | |
output_latents = (output_latents / 2 + 0.5).clamp(0, 1) | |
return output_latents | |
def log_images(self, batch, encoder_hidden_states, cond_scales=[0.0, 0.5, 1.0]): | |
input_tensors = batch[self.target_key].to(self.weight_dtype) | |
input_tensors = (input_tensors / 2 + 0.5).clamp(0, 1) | |
tensors_to_log = [input_tensors.cpu()] | |
[height, width] = input_tensors.shape[-2:] | |
if self.controlnet: | |
if self.args.controlnet_hint_key in batch: | |
controlnet_hint = batch[self.args.controlnet_hint_key].to( | |
self.weight_dtype | |
) | |
else: | |
controlnet_hint = None | |
for cond_scale in cond_scales: | |
latents = self.control_pipe( | |
image=controlnet_hint, | |
prompt_embeds=encoder_hidden_states, | |
controlnet_conditioning_scale=cond_scale, | |
height=height, | |
width=width, | |
output_type="latent", | |
num_inference_steps=25, | |
)[0] | |
tensors_to_log.append(self.decode_latents(latents).detach().cpu()) | |
if controlnet_hint is not None: | |
tensors_to_log.append(controlnet_hint.detach().cpu()) | |
else: | |
latents = self.control_pipe( | |
prompt_embeds=encoder_hidden_states, | |
height=height, | |
width=width, | |
output_type="latent", | |
num_inference_steps=25, | |
)[0] | |
tensors_to_log.append(self.decode_latents(latents).detach().cpu()) | |
image_tensors = torch.cat(tensors_to_log) | |
grid = make_grid(image_tensors, normalize=False, nrow=input_tensors.shape[0]) | |
grid = grid.permute(1, 2, 0).squeeze(-1) * 255 | |
grid = grid.numpy().astype(np.uint8) | |
image_grid = Image.fromarray(grid) | |
image_grid.save(Path(self.trackers[0].logging_dir) / f"{self.global_step}.png") | |
def save_weights(self, to_safetensors=True): | |
save_dir = self.output_dir / f"checkpoint-{self.global_step}" | |
os.makedirs(save_dir, exist_ok=True) | |
if self.args.save_whole_pipeline: | |
self.control_pipe.save_pretrained( | |
str(save_dir), safe_serialization=to_safetensors | |
) | |
elif self.controlnet: | |
self.controlnet.save_pretrained( | |
str(save_dir / "controlnet"), safe_serialization=to_safetensors | |
) | |
else: | |
self.unet.save_pretrained( | |
str(save_dir / "unet"), safe_serialization=to_safetensors | |
) | |
def train(self, num_train_epochs=1000, gr_progress = None): | |
args = self.args | |
if args.num_train_epochs: | |
num_train_epochs = args.num_train_epochs | |
max_train_steps = ( | |
num_train_epochs | |
* len(self.train_dataloader) | |
// args.gradient_accumulation_steps | |
) | |
if self.is_main_process: | |
self.init_trackers("tb_logs", config=vars(args)) | |
self.global_step = 0 | |
# Only show the progress bar once on each machine. | |
progress_bar = tqdm( | |
range(max_train_steps), | |
disable=not self.is_local_main_process, | |
) | |
progress_bar.set_description("Steps") | |
try: | |
for epoch in range(num_train_epochs): | |
# run training loop | |
if gr_progress is not None: | |
gr_progress(0, desc=f"Starting Epoch {epoch}") | |
if self.controlnet: | |
self.controlnet.train() | |
else: | |
self.unet.train() | |
for i, batch in enumerate(self.train_dataloader): | |
loss, encoder_hidden_states = self.compute_loss(batch) | |
loss /= args.gradient_accumulation_steps | |
self.backward(loss) | |
if self.global_step % args.gradient_accumulation_steps == 0: | |
if self.sync_gradients: | |
self.clip_grad_norm_( | |
self.params_to_optimize, args.max_grad_norm | |
) | |
self.optimizer.step() | |
self.optimizer.zero_grad() | |
# Checks if the accelerator has performed an optimization step behind the scenes | |
if self.sync_gradients: | |
progress_bar.update(1) | |
if gr_progress is not None: | |
gr_progress(float(i/len(self.train_dataloader))) | |
self.global_step += 1 | |
if self.is_main_process: | |
if self.global_step % args.checkpointing_steps == 0: | |
self.save_weights() | |
if args.image_logging_steps and ( | |
self.global_step % args.image_logging_steps == 0 | |
or self.global_step == 1 | |
): | |
self.log_images(batch, encoder_hidden_states) | |
logs = {"training_loss": loss.detach().item()} | |
self.log(logs, step=self.global_step) | |
progress_bar.set_postfix(**logs) | |
if self.global_step >= max_train_steps: | |
break | |
self.wait_for_everyone() | |
# run validation loop | |
if args.valid_data_dir: | |
total_valid_loss = 0 | |
if self.controlnet: | |
self.controlnet.eval() | |
else: | |
self.unet.eval() | |
for batch in self.valid_dataloader: | |
with torch.no_grad(): | |
loss, encoder_hidden_states = self.compute_loss(batch) | |
loss = loss.detach().item() | |
total_valid_loss += loss | |
logs = {"validation_loss": loss} | |
progress_bar.set_postfix(**logs) | |
self.log( | |
{ | |
"validation_loss": total_valid_loss | |
/ len(self.valid_dataloader) | |
}, | |
step=self.global_step, | |
) | |
self.wait_for_everyone() | |
except KeyboardInterrupt: | |
print("Keyboard interrupt detected, attempting to save trained weights") | |
# except Exception as e: | |
# print(f"Encountered error {e}, attempting to save trained weights") | |
self.save_weights() | |
self.end_training() | |