sdfinetuned / utils /functions.py
Jackflack09's picture
Duplicate from lint/sdpipe_webui
001c876
import gradio as gr
import torch
import random
from PIL import Image
import os
import argparse
import shutil
import gc
import importlib
import json
from diffusers import (
StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
)
from .inpaint_pipeline import SDInpaintPipeline as StableDiffusionInpaintPipelineLegacy
from .textual_inversion import main as run_textual_inversion
from .shared import default_scheduler, scheduler_dict, model_ids
_xformers_available = importlib.util.find_spec("xformers") is not None
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = 'cpu'
dtype = torch.float16 if device == "cuda" else torch.float32
low_vram_mode = False
tab_to_pipeline = {
1: StableDiffusionPipeline,
2: StableDiffusionImg2ImgPipeline,
3: StableDiffusionInpaintPipelineLegacy,
}
def load_pipe(model_id, scheduler_name, tab_index=1, pipe_kwargs="{}"):
global pipe, loaded_model_id
scheduler = scheduler_dict[scheduler_name]
pipe_class = tab_to_pipeline[tab_index]
# load new weights from disk only when changing model_id
if model_id != loaded_model_id:
pipe = pipe_class.from_pretrained(
model_id,
torch_dtype=dtype,
safety_checker=None,
requires_safety_checker=False,
scheduler=scheduler.from_pretrained(model_id, subfolder="scheduler"),
**json.loads(pipe_kwargs),
)
loaded_model_id = model_id
# if same model_id, instantiate new pipeline with same underlying pytorch objects to avoid reloading weights from disk
elif pipe_class != pipe.__class__ or not isinstance(pipe.scheduler, scheduler):
pipe.components["scheduler"] = scheduler.from_pretrained(
model_id, subfolder="scheduler"
)
pipe = pipe_class(**pipe.components)
if device == "cuda":
pipe = pipe.to(device)
if _xformers_available:
pipe.enable_xformers_memory_efficient_attention()
print("using xformers")
if low_vram_mode:
pipe.enable_attention_slicing()
print("using attention slicing to lower VRAM")
return pipe
pipe = None
loaded_model_id = ""
pipe = load_pipe(model_ids[0], default_scheduler)
def pad_image(image):
w, h = image.size
if w == h:
return image
elif w > h:
new_image = Image.new(image.mode, (w, w), (0, 0, 0))
new_image.paste(image, (0, (w - h) // 2))
return new_image
else:
new_image = Image.new(image.mode, (h, h), (0, 0, 0))
new_image.paste(image, ((h - w) // 2, 0))
return new_image
@torch.no_grad()
def generate(
model_name,
scheduler_name,
prompt,
guidance,
steps,
n_images=1,
width=512,
height=512,
seed=0,
image=None,
strength=0.5,
inpaint_image=None,
inpaint_strength=0.5,
inpaint_radio="",
neg_prompt="",
tab_index=1,
pipe_kwargs="{}",
progress=gr.Progress(track_tqdm=True),
):
if seed == -1:
seed = random.randint(0, 2147483647)
generator = torch.Generator(device).manual_seed(seed)
pipe = load_pipe(
model_id=model_name,
scheduler_name=scheduler_name,
tab_index=tab_index,
pipe_kwargs=pipe_kwargs,
)
status_message = f"Prompt: '{prompt}' | Seed: {seed} | Guidance: {guidance} | Scheduler: {scheduler_name} | Steps: {steps}"
if tab_index == 1:
status_message = "Text to Image " + status_message
result = pipe(
prompt,
negative_prompt=neg_prompt,
num_images_per_prompt=n_images,
num_inference_steps=int(steps),
guidance_scale=guidance,
width=width,
height=height,
generator=generator,
)
elif tab_index == 2:
status_message = "Image to Image " + status_message
print(image.size)
image = image.resize((width, height))
print(image.size)
result = pipe(
prompt,
negative_prompt=neg_prompt,
num_images_per_prompt=n_images,
image=image,
num_inference_steps=int(steps),
strength=strength,
guidance_scale=guidance,
generator=generator,
)
elif tab_index == 3:
status_message = "Inpainting " + status_message
init_image = inpaint_image["image"].resize((width, height))
mask = inpaint_image["mask"].resize((width, height))
result = pipe(
prompt,
negative_prompt=neg_prompt,
num_images_per_prompt=n_images,
image=init_image,
mask_image=mask,
num_inference_steps=int(steps),
strength=inpaint_strength,
preserve_unmasked_image=(
inpaint_radio == "preserve non-masked portions of image"
),
guidance_scale=guidance,
generator=generator,
)
else:
return None, f"Unhandled tab index: {tab_index}"
return result.images, status_message
# based on lvkaokao/textual-inversion-training
def train_textual_inversion(
model_name,
scheduler_name,
type_of_thing,
files,
concept_word,
init_word,
text_train_steps,
text_train_bsz,
text_learning_rate,
progress=gr.Progress(track_tqdm=True),
):
if device == "cpu":
raise gr.Error("Textual inversion training not supported on CPU")
pipe = load_pipe(
model_id=model_name,
scheduler_name=scheduler_name,
tab_index=1,
)
pipe.disable_xformers_memory_efficient_attention() # xformers handled by textual inversion script
concept_dir = "concept_images"
output_dir = "output_model"
training_resolution = 512
if os.path.exists(output_dir):
shutil.rmtree("output_model")
if os.path.exists(concept_dir):
shutil.rmtree("concept_images")
os.makedirs(concept_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)
gc.collect()
torch.cuda.empty_cache()
if concept_word == "" or concept_word == None:
raise gr.Error("You forgot to define your concept prompt")
for j, file_temp in enumerate(files):
file = Image.open(file_temp.name)
image = pad_image(file)
image = image.resize((training_resolution, training_resolution))
extension = file_temp.name.split(".")[1]
image = image.convert("RGB")
image.save(f"{concept_dir}/{j+1}.{extension}", quality=100)
args_general = argparse.Namespace(
train_data_dir=concept_dir,
learnable_property=type_of_thing,
placeholder_token=concept_word,
initializer_token=init_word,
resolution=training_resolution,
train_batch_size=text_train_bsz,
gradient_accumulation_steps=1,
gradient_checkpointing=True,
mixed_precision="fp16",
use_bf16=False,
max_train_steps=int(text_train_steps),
learning_rate=text_learning_rate,
scale_lr=True,
lr_scheduler="constant",
lr_warmup_steps=0,
output_dir=output_dir,
)
try:
final_result = run_textual_inversion(pipe, args_general)
except Exception as e:
raise gr.Error(e)
pipe.text_encoder = pipe.text_encoder.eval().to(device, dtype=dtype)
pipe.unet = pipe.unet.eval().to(device, dtype=dtype)
gc.collect()
torch.cuda.empty_cache()
return (
f"Finished training! Check the {output_dir} directory for saved model weights"
)