anime_controlnet / src /ui_functions.py
1lint
fix and revise app
0d0a1c2
raw
history blame
7.98 kB
import gradio as gr
import torch
import random
from PIL import Image
import os
import argparse
import shutil
import gc
import importlib
import json
from multiprocessing import cpu_count
import cv2
import numpy as np
from pathlib import Path
from diffusers import (
StableDiffusionControlNetPipeline,
StableDiffusionPipeline,
ControlNetModel,
AutoencoderKL,
)
from src.controlnet_pipe import ControlNetPipe as StableDiffusionControlNetPipeline
from src.lab import Lab
from src.ui_shared import (
default_scheduler,
scheduler_dict,
model_ids,
controlnet_ids,
is_hfspace,
)
CONTROLNET_REPO = "lint/anime_control"
_xformers_available = importlib.util.find_spec("xformers") is not None
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = 'cpu'
dtype = torch.float16 if device == "cuda" else torch.float32
pipe = None
loaded_model_id = ""
loaded_controlnet_id = ""
def load_pipe(model_id, controlnet_id, scheduler_name):
global pipe, loaded_model_id, loaded_controlnet_id
scheduler = scheduler_dict[scheduler_name]
reload_pipe = False
if pipe:
new_weights = pipe.components
else:
new_weights = {}
if model_id != loaded_model_id:
new_pipe = StableDiffusionPipeline.from_pretrained(
model_id,
vae=AutoencoderKL.from_pretrained("lint/anime_vae", torch_dtype=dtype),
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
use_safetensors=False,
torch_dtype=dtype,
)
loaded_model_id = model_id
new_weights.update(new_pipe.components)
new_weights["scheduler"] = scheduler.from_pretrained(model_id, subfolder="scheduler")
reload_pipe = True
if controlnet_id != loaded_controlnet_id:
controlnet = ControlNetModel.from_pretrained(
CONTROLNET_REPO,
subfolder=controlnet_id,
torch_dtype=dtype,
)
loaded_controlnet_id = controlnet_id
new_weights["controlnet"] = controlnet
reload_pipe = True
if reload_pipe:
pipe = StableDiffusionControlNetPipeline(
**new_weights,
requires_safety_checker=False,
)
if device == "cuda":
for component in pipe.components.values():
if isinstance(component, torch.nn.Module):
component.to("cuda", torch.float16)
if _xformers_available:
pipe.enable_xformers_memory_efficient_attention()
pipe.enable_attention_slicing()
pipe.enable_vae_tiling()
return pipe
# initialize with preloaded pipe
if is_hfspace:
pipe = load_pipe(model_ids[0], controlnet_ids[0], default_scheduler)
def extract_canny(image):
CANNY_THRESHOLD = (100, 200)
image_array = np.asarray(image)
canny_image = cv2.Canny(image_array, *CANNY_THRESHOLD)
canny_image = canny_image[:, :, None]
canny_image = np.concatenate([canny_image]*3, axis=2)
return Image.fromarray(canny_image)
@torch.no_grad()
def generate(
model_name,
guidance_image,
controlnet_name,
scheduler_name,
prompt,
guidance,
steps,
n_images=1,
width=512,
height=512,
seed=0,
neg_prompt="",
controlnet_prompt=None,
controlnet_negative_prompt=None,
controlnet_cond_scale=1.0,
progress=gr.Progress(),
):
if seed == -1:
seed = random.randint(0, 2147483647)
if guidance_image:
guidance_image = extract_canny(guidance_image)
else:
guidance_image = torch.zeros(n_images, 3, height, width)
generator = torch.Generator(device).manual_seed(seed)
pipe = load_pipe(
model_id=model_name,
controlnet_id=controlnet_name,
scheduler_name=scheduler_name,
)
status_message = f"Prompt: '{prompt}' | Seed: {seed} | Guidance: {guidance} | Scheduler: {scheduler_name} | Steps: {steps}"
# pass None so pipeline uses base prompt as controlnet_prompt
if controlnet_prompt == "":
controlnet_prompt = None #
if controlnet_negative_prompt == "":
controlnet_negative_prompt = None
if controlnet_prompt:
controlnet_prompt_embeds = pipe._encode_prompt(
controlnet_prompt,
device,
n_images,
do_classifier_free_guidance = guidance > 1.0,
negative_prompt = controlnet_negative_prompt,
prompt_embeds=None,
negative_prompt_embeds=None,
)
else:
controlnet_prompt_embeds = None
result = pipe(
prompt,
image=guidance_image,
height=height,
width=width,
num_inference_steps=int(steps),
guidance_scale=guidance,
negative_prompt=neg_prompt,
num_images_per_prompt=n_images,
generator=generator,
controlnet_conditioning_scale = float(controlnet_cond_scale),
controlnet_prompt_embeds = controlnet_prompt_embeds,
)
return result.images, status_message
def run_training(
model_name,
controlnet_weights_path,
train_data_dir,
valid_data_dir,
train_batch_size,
train_whole_controlnet,
gradient_accumulation_steps,
num_train_epochs,
train_learning_rate,
output_dir,
checkpointing_steps,
image_logging_steps,
save_whole_pipeline,
progress=gr.Progress(),
):
global pipe
if device == "cpu":
raise gr.Error("Training not supported on CPU")
pathobj = Path(controlnet_weights_path)
controlnet_path = str(Path().joinpath(*pathobj.parts[:-1]))
subfolder = str(pathobj.parts[-1])
controlnet = ControlNetModel.from_pretrained(
controlnet_path,
subfolder=subfolder,
low_cpu_mem_usage=False,
device_map=None,
)
pipe.components["controlnet"] = controlnet
pipe = StableDiffusionControlNetPipeline(
**pipe.components,
requires_safety_checker=False,
)
training_args = argparse.Namespace(
# start training from preexisting models
pretrained_model_name_or_path=None,
controlnet_weights_path=None,
# dataset args
train_data_dir=train_data_dir,
valid_data_dir=valid_data_dir,
resolution=512,
from_hf_hub = train_data_dir == "lint/anybooru",
controlnet_hint_key="canny",
# training args
# options are ["zero convolutions", "input hint blocks"], trains whole controlnet by default
training_stage="" if train_whole_controlnet else "zero convolutions",
learning_rate=float(train_learning_rate),
num_train_epochs=int(num_train_epochs),
seed=3434554,
max_grad_norm=1.0,
gradient_accumulation_steps=int(gradient_accumulation_steps),
# VRAM args
batch_size=train_batch_size,
mixed_precision="fp16", # set to "fp16" for mixed-precision training.
gradient_checkpointing=True, # set this to True to lower the memory usage.
use_8bit_adam=False, # use 8bit optimizer from bitsandbytes
enable_xformers_memory_efficient_attention=True,
allow_tf32=True,
dataloader_num_workers=cpu_count(),
# logging args
output_dir=output_dir,
report_to="tensorboard",
image_logging_steps=image_logging_steps, # disabled when 0. costs additional VRAM to log images
save_whole_pipeline=save_whole_pipeline,
checkpointing_steps=checkpointing_steps,
)
try:
lab = Lab(training_args, pipe)
lab.train(training_args.num_train_epochs, gr_progress=progress)
except Exception as e:
raise gr.Error(e)
for component in pipe.components.values():
if isinstance(component, torch.nn.Module):
component.to(device, dtype=dtype)
gc.collect()
torch.cuda.empty_cache()
return f"Finished training! Check the {training_args.output_dir} directory for saved model weights"