Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import random | |
from PIL import Image | |
import os | |
import argparse | |
import shutil | |
import gc | |
import importlib | |
import json | |
from diffusers import ( | |
StableDiffusionPipeline, | |
StableDiffusionImg2ImgPipeline, | |
) | |
from .inpaint_pipeline import SDInpaintPipeline as StableDiffusionInpaintPipelineLegacy | |
from .textual_inversion import main as run_textual_inversion | |
from .shared import default_scheduler, scheduler_dict, model_ids | |
_xformers_available = importlib.util.find_spec("xformers") is not None | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# device = 'cpu' | |
dtype = torch.float16 if device == "cuda" else torch.float32 | |
low_vram_mode = False | |
tab_to_pipeline = { | |
1: StableDiffusionPipeline, | |
2: StableDiffusionImg2ImgPipeline, | |
3: StableDiffusionInpaintPipelineLegacy, | |
} | |
def load_pipe(model_id, scheduler_name, tab_index=1, pipe_kwargs="{}"): | |
global pipe, loaded_model_id | |
scheduler = scheduler_dict[scheduler_name] | |
pipe_class = tab_to_pipeline[tab_index] | |
# load new weights from disk only when changing model_id | |
if model_id != loaded_model_id: | |
pipe = pipe_class.from_pretrained( | |
model_id, | |
torch_dtype=dtype, | |
safety_checker=None, | |
requires_safety_checker=False, | |
scheduler=scheduler.from_pretrained(model_id, subfolder="scheduler"), | |
**json.loads(pipe_kwargs), | |
) | |
loaded_model_id = model_id | |
# if same model_id, instantiate new pipeline with same underlying pytorch objects to avoid reloading weights from disk | |
elif pipe_class != pipe.__class__ or not isinstance(pipe.scheduler, scheduler): | |
pipe.components["scheduler"] = scheduler.from_pretrained( | |
model_id, subfolder="scheduler" | |
) | |
pipe = pipe_class(**pipe.components) | |
if device == "cuda": | |
pipe = pipe.to(device) | |
if _xformers_available: | |
pipe.enable_xformers_memory_efficient_attention() | |
print("using xformers") | |
if low_vram_mode: | |
pipe.enable_attention_slicing() | |
print("using attention slicing to lower VRAM") | |
return pipe | |
pipe = None | |
loaded_model_id = "" | |
pipe = load_pipe(model_ids[0], default_scheduler) | |
def pad_image(image): | |
w, h = image.size | |
if w == h: | |
return image | |
elif w > h: | |
new_image = Image.new(image.mode, (w, w), (0, 0, 0)) | |
new_image.paste(image, (0, (w - h) // 2)) | |
return new_image | |
else: | |
new_image = Image.new(image.mode, (h, h), (0, 0, 0)) | |
new_image.paste(image, ((h - w) // 2, 0)) | |
return new_image | |
def generate( | |
model_name, | |
scheduler_name, | |
prompt, | |
guidance, | |
steps, | |
n_images=1, | |
width=512, | |
height=512, | |
seed=0, | |
image=None, | |
strength=0.5, | |
inpaint_image=None, | |
inpaint_strength=0.5, | |
inpaint_radio="", | |
neg_prompt="", | |
tab_index=1, | |
pipe_kwargs="{}", | |
progress=gr.Progress(track_tqdm=True), | |
): | |
if seed == -1: | |
seed = random.randint(0, 2147483647) | |
generator = torch.Generator(device).manual_seed(seed) | |
pipe = load_pipe( | |
model_id=model_name, | |
scheduler_name=scheduler_name, | |
tab_index=tab_index, | |
pipe_kwargs=pipe_kwargs, | |
) | |
status_message = f"Prompt: '{prompt}' | Seed: {seed} | Guidance: {guidance} | Scheduler: {scheduler_name} | Steps: {steps}" | |
if tab_index == 1: | |
status_message = "Text to Image " + status_message | |
result = pipe( | |
prompt, | |
negative_prompt=neg_prompt, | |
num_images_per_prompt=n_images, | |
num_inference_steps=int(steps), | |
guidance_scale=guidance, | |
width=width, | |
height=height, | |
generator=generator, | |
) | |
elif tab_index == 2: | |
status_message = "Image to Image " + status_message | |
print(image.size) | |
image = image.resize((width, height)) | |
print(image.size) | |
result = pipe( | |
prompt, | |
negative_prompt=neg_prompt, | |
num_images_per_prompt=n_images, | |
image=image, | |
num_inference_steps=int(steps), | |
strength=strength, | |
guidance_scale=guidance, | |
generator=generator, | |
) | |
elif tab_index == 3: | |
status_message = "Inpainting " + status_message | |
init_image = inpaint_image["image"].resize((width, height)) | |
mask = inpaint_image["mask"].resize((width, height)) | |
result = pipe( | |
prompt, | |
negative_prompt=neg_prompt, | |
num_images_per_prompt=n_images, | |
image=init_image, | |
mask_image=mask, | |
num_inference_steps=int(steps), | |
strength=inpaint_strength, | |
preserve_unmasked_image=( | |
inpaint_radio == "preserve non-masked portions of image" | |
), | |
guidance_scale=guidance, | |
generator=generator, | |
) | |
else: | |
return None, f"Unhandled tab index: {tab_index}" | |
return result.images, status_message | |
# based on lvkaokao/textual-inversion-training | |
def train_textual_inversion( | |
model_name, | |
scheduler_name, | |
type_of_thing, | |
files, | |
concept_word, | |
init_word, | |
text_train_steps, | |
text_train_bsz, | |
text_learning_rate, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
if device == "cpu": | |
raise gr.Error("Textual inversion training not supported on CPU") | |
pipe = load_pipe( | |
model_id=model_name, | |
scheduler_name=scheduler_name, | |
tab_index=1, | |
) | |
pipe.disable_xformers_memory_efficient_attention() # xformers handled by textual inversion script | |
concept_dir = "concept_images" | |
output_dir = "output_model" | |
training_resolution = 512 | |
if os.path.exists(output_dir): | |
shutil.rmtree("output_model") | |
if os.path.exists(concept_dir): | |
shutil.rmtree("concept_images") | |
os.makedirs(concept_dir, exist_ok=True) | |
os.makedirs(output_dir, exist_ok=True) | |
gc.collect() | |
torch.cuda.empty_cache() | |
if concept_word == "" or concept_word == None: | |
raise gr.Error("You forgot to define your concept prompt") | |
for j, file_temp in enumerate(files): | |
file = Image.open(file_temp.name) | |
image = pad_image(file) | |
image = image.resize((training_resolution, training_resolution)) | |
extension = file_temp.name.split(".")[1] | |
image = image.convert("RGB") | |
image.save(f"{concept_dir}/{j+1}.{extension}", quality=100) | |
args_general = argparse.Namespace( | |
train_data_dir=concept_dir, | |
learnable_property=type_of_thing, | |
placeholder_token=concept_word, | |
initializer_token=init_word, | |
resolution=training_resolution, | |
train_batch_size=text_train_bsz, | |
gradient_accumulation_steps=1, | |
gradient_checkpointing=True, | |
mixed_precision="fp16", | |
use_bf16=False, | |
max_train_steps=int(text_train_steps), | |
learning_rate=text_learning_rate, | |
scale_lr=True, | |
lr_scheduler="constant", | |
lr_warmup_steps=0, | |
output_dir=output_dir, | |
) | |
try: | |
final_result = run_textual_inversion(pipe, args_general) | |
except Exception as e: | |
raise gr.Error(e) | |
pipe.text_encoder = pipe.text_encoder.eval().to(device, dtype=dtype) | |
pipe.unet = pipe.unet.eval().to(device, dtype=dtype) | |
gc.collect() | |
torch.cuda.empty_cache() | |
return ( | |
f"Finished training! Check the {output_dir} directory for saved model weights" | |
) | |