Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import random | |
from PIL import Image | |
import os | |
import argparse | |
import shutil | |
import gc | |
import importlib | |
import json | |
from multiprocessing import cpu_count | |
import cv2 | |
import numpy as np | |
from pathlib import Path | |
from diffusers import ( | |
StableDiffusionControlNetPipeline, | |
StableDiffusionPipeline, | |
ControlNetModel, | |
AutoencoderKL, | |
) | |
from src.controlnet_pipe import ControlNetPipe as StableDiffusionControlNetPipeline | |
from src.lab import Lab | |
from src.ui_shared import ( | |
default_scheduler, | |
scheduler_dict, | |
model_ids, | |
controlnet_ids, | |
is_hfspace, | |
) | |
CONTROLNET_REPO = "lint/anime_control" | |
_xformers_available = importlib.util.find_spec("xformers") is not None | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# device = 'cpu' | |
dtype = torch.float16 if device == "cuda" else torch.float32 | |
pipe = None | |
loaded_model_id = "" | |
loaded_controlnet_id = "" | |
def load_pipe(model_id, controlnet_id, scheduler_name): | |
global pipe, loaded_model_id, loaded_controlnet_id | |
scheduler = scheduler_dict[scheduler_name] | |
reload_pipe = False | |
if pipe: | |
new_weights = pipe.components | |
else: | |
new_weights = {} | |
if model_id != loaded_model_id: | |
new_pipe = StableDiffusionPipeline.from_pretrained( | |
model_id, | |
vae=AutoencoderKL.from_pretrained("lint/anime_vae", torch_dtype=dtype), | |
safety_checker=None, | |
feature_extractor=None, | |
requires_safety_checker=False, | |
use_safetensors=False, | |
torch_dtype=dtype, | |
) | |
loaded_model_id = model_id | |
new_weights.update(new_pipe.components) | |
new_weights["scheduler"] = scheduler.from_pretrained(model_id, subfolder="scheduler") | |
reload_pipe = True | |
if controlnet_id != loaded_controlnet_id: | |
controlnet = ControlNetModel.from_pretrained( | |
CONTROLNET_REPO, | |
subfolder=controlnet_id, | |
torch_dtype=dtype, | |
) | |
loaded_controlnet_id = controlnet_id | |
new_weights["controlnet"] = controlnet | |
reload_pipe = True | |
if reload_pipe: | |
pipe = StableDiffusionControlNetPipeline( | |
**new_weights, | |
requires_safety_checker=False, | |
) | |
if device == "cuda": | |
for component in pipe.components.values(): | |
if isinstance(component, torch.nn.Module): | |
component.to("cuda", torch.float16) | |
if _xformers_available: | |
pipe.enable_xformers_memory_efficient_attention() | |
pipe.enable_attention_slicing() | |
pipe.enable_vae_tiling() | |
return pipe | |
# initialize with preloaded pipe | |
if is_hfspace: | |
pipe = load_pipe(model_ids[0], controlnet_ids[0], default_scheduler) | |
def extract_canny(image): | |
CANNY_THRESHOLD = (100, 200) | |
image_array = np.asarray(image) | |
canny_image = cv2.Canny(image_array, *CANNY_THRESHOLD) | |
canny_image = canny_image[:, :, None] | |
canny_image = np.concatenate([canny_image]*3, axis=2) | |
return Image.fromarray(canny_image) | |
def generate( | |
model_name, | |
guidance_image, | |
controlnet_name, | |
scheduler_name, | |
prompt, | |
guidance, | |
steps, | |
n_images=1, | |
width=512, | |
height=512, | |
seed=0, | |
neg_prompt="", | |
controlnet_prompt=None, | |
controlnet_negative_prompt=None, | |
controlnet_cond_scale=1.0, | |
progress=gr.Progress(), | |
): | |
if seed == -1: | |
seed = random.randint(0, 2147483647) | |
if guidance_image: | |
guidance_image = extract_canny(guidance_image) | |
else: | |
guidance_image = torch.zeros(1, 3, height, width) | |
generator = torch.Generator(device).manual_seed(seed) | |
pipe = load_pipe( | |
model_id=model_name, | |
controlnet_id=controlnet_name, | |
scheduler_name=scheduler_name, | |
) | |
status_message = f"Prompt: '{prompt}' | Seed: {seed} | Guidance: {guidance} | Scheduler: {scheduler_name} | Steps: {steps}" | |
# pass None so pipeline uses base prompt as controlnet_prompt | |
if controlnet_prompt == "": | |
controlnet_prompt = None # | |
if controlnet_negative_prompt == "": | |
controlnet_negative_prompt = None | |
if controlnet_prompt: | |
controlnet_prompt_embeds = pipe._encode_prompt( | |
controlnet_prompt, | |
device, | |
n_images, | |
do_classifier_free_guidance = guidance > 1.0, | |
negative_prompt = controlnet_negative_prompt, | |
prompt_embeds=None, | |
negative_prompt_embeds=None, | |
) | |
else: | |
controlnet_prompt_embeds = None | |
result = pipe( | |
prompt, | |
image=guidance_image, | |
height=height, | |
width=width, | |
num_inference_steps=int(steps), | |
guidance_scale=guidance, | |
negative_prompt=neg_prompt, | |
num_images_per_prompt=n_images, | |
generator=generator, | |
controlnet_conditioning_scale = float(controlnet_cond_scale), | |
controlnet_prompt_embeds = controlnet_prompt_embeds, | |
) | |
return result.images, status_message | |
def run_training( | |
model_name, | |
controlnet_weights_path, | |
train_data_dir, | |
valid_data_dir, | |
train_batch_size, | |
train_whole_controlnet, | |
gradient_accumulation_steps, | |
num_train_epochs, | |
train_learning_rate, | |
output_dir, | |
checkpointing_steps, | |
image_logging_steps, | |
save_whole_pipeline, | |
progress=gr.Progress(), | |
): | |
global pipe | |
if device == "cpu": | |
raise gr.Error("Training not supported on CPU") | |
pathobj = Path(controlnet_weights_path) | |
controlnet_path = str(Path().joinpath(*pathobj.parts[:-1])) | |
subfolder = str(pathobj.parts[-1]) | |
controlnet = ControlNetModel.from_pretrained( | |
controlnet_path, | |
subfolder=subfolder, | |
low_cpu_mem_usage=False, | |
device_map=None, | |
) | |
pipe.components["controlnet"] = controlnet | |
pipe = StableDiffusionControlNetPipeline( | |
**pipe.components, | |
requires_safety_checker=False, | |
) | |
training_args = argparse.Namespace( | |
# start training from preexisting models | |
pretrained_model_name_or_path=None, | |
controlnet_weights_path=None, | |
# dataset args | |
train_data_dir=train_data_dir, | |
valid_data_dir=valid_data_dir, | |
resolution=512, | |
from_hf_hub = train_data_dir == "lint/anybooru", | |
controlnet_hint_key="canny", | |
# training args | |
# options are ["zero convolutions", "input hint blocks"], trains whole controlnet by default | |
training_stage="" if train_whole_controlnet else "zero convolutions", | |
learning_rate=float(train_learning_rate), | |
num_train_epochs=int(num_train_epochs), | |
seed=3434554, | |
max_grad_norm=1.0, | |
gradient_accumulation_steps=int(gradient_accumulation_steps), | |
# VRAM args | |
batch_size=train_batch_size, | |
mixed_precision="fp16", # set to "fp16" for mixed-precision training. | |
gradient_checkpointing=True, # set this to True to lower the memory usage. | |
use_8bit_adam=False, # use 8bit optimizer from bitsandbytes | |
enable_xformers_memory_efficient_attention=True, | |
allow_tf32=True, | |
dataloader_num_workers=cpu_count(), | |
# logging args | |
output_dir=output_dir, | |
report_to="tensorboard", | |
image_logging_steps=image_logging_steps, # disabled when 0. costs additional VRAM to log images | |
save_whole_pipeline=save_whole_pipeline, | |
checkpointing_steps=checkpointing_steps, | |
) | |
try: | |
lab = Lab(training_args, pipe) | |
lab.train(training_args.num_train_epochs, gr_progress=progress) | |
except Exception as e: | |
raise gr.Error(e) | |
for component in pipe.components.values(): | |
if isinstance(component, torch.nn.Module): | |
component.to(device, dtype=dtype) | |
gc.collect() | |
torch.cuda.empty_cache() | |
return f"Finished training! Check the {training_args.output_dir} directory for saved model weights" | |