import gradio as gr import torch import random from PIL import Image import os import argparse import shutil import gc import importlib import json from multiprocessing import cpu_count import cv2 import numpy as np from pathlib import Path from diffusers import ( StableDiffusionControlNetPipeline, StableDiffusionPipeline, ControlNetModel, AutoencoderKL, ) from src.controlnet_pipe import ControlNetPipe as StableDiffusionControlNetPipeline from src.lab import Lab from src.ui_shared import ( default_scheduler, scheduler_dict, model_ids, controlnet_ids, is_hfspace, ) CONTROLNET_REPO = "lint/anime_control" _xformers_available = importlib.util.find_spec("xformers") is not None device = "cuda" if torch.cuda.is_available() else "cpu" # device = 'cpu' dtype = torch.float16 if device == "cuda" else torch.float32 pipe = None loaded_model_id = "" loaded_controlnet_id = "" def load_pipe(model_id, controlnet_id, scheduler_name): global pipe, loaded_model_id, loaded_controlnet_id scheduler = scheduler_dict[scheduler_name] reload_pipe = False if pipe: new_weights = pipe.components else: new_weights = {} if model_id != loaded_model_id: new_pipe = StableDiffusionPipeline.from_pretrained( model_id, vae=AutoencoderKL.from_pretrained("lint/anime_vae", torch_dtype=dtype), safety_checker=None, feature_extractor=None, requires_safety_checker=False, use_safetensors=False, torch_dtype=dtype, ) loaded_model_id = model_id new_weights.update(new_pipe.components) new_weights["scheduler"] = scheduler.from_pretrained(model_id, subfolder="scheduler") reload_pipe = True if controlnet_id != loaded_controlnet_id: controlnet = ControlNetModel.from_pretrained( CONTROLNET_REPO, subfolder=controlnet_id, torch_dtype=dtype, ) loaded_controlnet_id = controlnet_id new_weights["controlnet"] = controlnet reload_pipe = True if reload_pipe: pipe = StableDiffusionControlNetPipeline( **new_weights, requires_safety_checker=False, ) if device == "cuda": for component in pipe.components.values(): if isinstance(component, torch.nn.Module): component.to("cuda", torch.float16) if _xformers_available: pipe.enable_xformers_memory_efficient_attention() pipe.enable_attention_slicing() pipe.enable_vae_tiling() return pipe # initialize with preloaded pipe if is_hfspace: pipe = load_pipe(model_ids[0], controlnet_ids[0], default_scheduler) def extract_canny(image): CANNY_THRESHOLD = (100, 200) image_array = np.asarray(image) canny_image = cv2.Canny(image_array, *CANNY_THRESHOLD) canny_image = canny_image[:, :, None] canny_image = np.concatenate([canny_image]*3, axis=2) return Image.fromarray(canny_image) @torch.no_grad() def generate( model_name, guidance_image, controlnet_name, scheduler_name, prompt, guidance, steps, n_images=1, width=512, height=512, seed=0, neg_prompt="", controlnet_prompt=None, controlnet_negative_prompt=None, controlnet_cond_scale=1.0, progress=gr.Progress(), ): if seed == -1: seed = random.randint(0, 2147483647) if guidance_image: guidance_image = extract_canny(guidance_image) else: guidance_image = torch.zeros(n_images, 3, height, width) generator = torch.Generator(device).manual_seed(seed) pipe = load_pipe( model_id=model_name, controlnet_id=controlnet_name, scheduler_name=scheduler_name, ) status_message = f"Prompt: '{prompt}' | Seed: {seed} | Guidance: {guidance} | Scheduler: {scheduler_name} | Steps: {steps}" # pass None so pipeline uses base prompt as controlnet_prompt if controlnet_prompt == "": controlnet_prompt = None # if controlnet_negative_prompt == "": controlnet_negative_prompt = None if controlnet_prompt: controlnet_prompt_embeds = pipe._encode_prompt( controlnet_prompt, device, n_images, do_classifier_free_guidance = guidance > 1.0, negative_prompt = controlnet_negative_prompt, prompt_embeds=None, negative_prompt_embeds=None, ) else: controlnet_prompt_embeds = None result = pipe( prompt, image=guidance_image, height=height, width=width, num_inference_steps=int(steps), guidance_scale=guidance, negative_prompt=neg_prompt, num_images_per_prompt=n_images, generator=generator, controlnet_conditioning_scale = float(controlnet_cond_scale), controlnet_prompt_embeds = controlnet_prompt_embeds, ) return result.images, status_message def run_training( model_name, controlnet_weights_path, train_data_dir, valid_data_dir, train_batch_size, train_whole_controlnet, gradient_accumulation_steps, num_train_epochs, train_learning_rate, output_dir, checkpointing_steps, image_logging_steps, save_whole_pipeline, progress=gr.Progress(), ): global pipe if device == "cpu": raise gr.Error("Training not supported on CPU") pathobj = Path(controlnet_weights_path) controlnet_path = str(Path().joinpath(*pathobj.parts[:-1])) subfolder = str(pathobj.parts[-1]) controlnet = ControlNetModel.from_pretrained( controlnet_path, subfolder=subfolder, low_cpu_mem_usage=False, device_map=None, ) pipe.components["controlnet"] = controlnet pipe = StableDiffusionControlNetPipeline( **pipe.components, requires_safety_checker=False, ) training_args = argparse.Namespace( # start training from preexisting models pretrained_model_name_or_path=None, controlnet_weights_path=None, # dataset args train_data_dir=train_data_dir, valid_data_dir=valid_data_dir, resolution=512, from_hf_hub = train_data_dir == "lint/anybooru", controlnet_hint_key="canny", # training args # options are ["zero convolutions", "input hint blocks"], trains whole controlnet by default training_stage="" if train_whole_controlnet else "zero convolutions", learning_rate=float(train_learning_rate), num_train_epochs=int(num_train_epochs), seed=3434554, max_grad_norm=1.0, gradient_accumulation_steps=int(gradient_accumulation_steps), # VRAM args batch_size=train_batch_size, mixed_precision="fp16", # set to "fp16" for mixed-precision training. gradient_checkpointing=True, # set this to True to lower the memory usage. use_8bit_adam=False, # use 8bit optimizer from bitsandbytes enable_xformers_memory_efficient_attention=True, allow_tf32=True, dataloader_num_workers=cpu_count(), # logging args output_dir=output_dir, report_to="tensorboard", image_logging_steps=image_logging_steps, # disabled when 0. costs additional VRAM to log images save_whole_pipeline=save_whole_pipeline, checkpointing_steps=checkpointing_steps, ) try: lab = Lab(training_args, pipe) lab.train(training_args.num_train_epochs, gr_progress=progress) except Exception as e: raise gr.Error(e) for component in pipe.components.values(): if isinstance(component, torch.nn.Module): component.to(device, dtype=dtype) gc.collect() torch.cuda.empty_cache() return f"Finished training! Check the {training_args.output_dir} directory for saved model weights"