|
import spaces
|
|
import os
|
|
from stablepy import Model_Diffusers
|
|
from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
|
|
from stablepy.diffusers_vanilla.constants import FLUX_CN_UNION_MODES
|
|
import torch
|
|
import re
|
|
from huggingface_hub import HfApi
|
|
from stablepy import (
|
|
CONTROLNET_MODEL_IDS,
|
|
VALID_TASKS,
|
|
T2I_PREPROCESSOR_NAME,
|
|
FLASH_LORA,
|
|
SCHEDULER_CONFIG_MAP,
|
|
scheduler_names,
|
|
IP_ADAPTER_MODELS,
|
|
IP_ADAPTERS_SD,
|
|
IP_ADAPTERS_SDXL,
|
|
REPO_IMAGE_ENCODER,
|
|
ALL_PROMPT_WEIGHT_OPTIONS,
|
|
SD15_TASKS,
|
|
SDXL_TASKS,
|
|
)
|
|
import time
|
|
|
|
import gradio as gr
|
|
import logging
|
|
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
|
import diffusers
|
|
diffusers.utils.logging.set_verbosity(40)
|
|
import warnings
|
|
warnings.filterwarnings(action="ignore", category=FutureWarning, module="diffusers")
|
|
warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
|
|
warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
|
|
from stablepy import logger
|
|
logger.setLevel(logging.CRITICAL)
|
|
|
|
from env import (
|
|
HF_TOKEN, hf_read_token,
|
|
CIVITAI_API_KEY, HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
|
|
HF_LORA_ESSENTIAL_PRIVATE_REPO, HF_VAE_PRIVATE_REPO,
|
|
HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO,
|
|
directory_models, directory_loras, directory_vaes, directory_embeds,
|
|
directory_embeds_sdxl, directory_embeds_positive_sdxl,
|
|
load_diffusers_format_model, download_model_list, download_lora_list,
|
|
download_vae_list, download_embeds)
|
|
|
|
PREPROCESSOR_CONTROLNET = {
|
|
"openpose": [
|
|
"Openpose",
|
|
"None",
|
|
],
|
|
"scribble": [
|
|
"HED",
|
|
"Pidinet",
|
|
"None",
|
|
],
|
|
"softedge": [
|
|
"Pidinet",
|
|
"HED",
|
|
"HED safe",
|
|
"Pidinet safe",
|
|
"None",
|
|
],
|
|
"segmentation": [
|
|
"UPerNet",
|
|
"None",
|
|
],
|
|
"depth": [
|
|
"DPT",
|
|
"Midas",
|
|
"None",
|
|
],
|
|
"normalbae": [
|
|
"NormalBae",
|
|
"None",
|
|
],
|
|
"lineart": [
|
|
"Lineart",
|
|
"Lineart coarse",
|
|
"Lineart (anime)",
|
|
"None",
|
|
"None (anime)",
|
|
],
|
|
"lineart_anime": [
|
|
"Lineart",
|
|
"Lineart coarse",
|
|
"Lineart (anime)",
|
|
"None",
|
|
"None (anime)",
|
|
],
|
|
"shuffle": [
|
|
"ContentShuffle",
|
|
"None",
|
|
],
|
|
"canny": [
|
|
"Canny",
|
|
"None",
|
|
],
|
|
"mlsd": [
|
|
"MLSD",
|
|
"None",
|
|
],
|
|
"ip2p": [
|
|
"ip2p"
|
|
],
|
|
"recolor": [
|
|
"Recolor luminance",
|
|
"Recolor intensity",
|
|
"None",
|
|
],
|
|
"tile": [
|
|
"Mild Blur",
|
|
"Moderate Blur",
|
|
"Heavy Blur",
|
|
"None",
|
|
],
|
|
}
|
|
|
|
TASK_STABLEPY = {
|
|
'txt2img': 'txt2img',
|
|
'img2img': 'img2img',
|
|
'inpaint': 'inpaint',
|
|
|
|
|
|
|
|
|
|
|
|
'openpose ControlNet': 'openpose',
|
|
'canny ControlNet': 'canny',
|
|
'mlsd ControlNet': 'mlsd',
|
|
'scribble ControlNet': 'scribble',
|
|
'softedge ControlNet': 'softedge',
|
|
'segmentation ControlNet': 'segmentation',
|
|
'depth ControlNet': 'depth',
|
|
'normalbae ControlNet': 'normalbae',
|
|
'lineart ControlNet': 'lineart',
|
|
'lineart_anime ControlNet': 'lineart_anime',
|
|
'shuffle ControlNet': 'shuffle',
|
|
'ip2p ControlNet': 'ip2p',
|
|
'optical pattern ControlNet': 'pattern',
|
|
'recolor ControlNet': 'recolor',
|
|
'tile ControlNet': 'tile',
|
|
}
|
|
|
|
TASK_MODEL_LIST = list(TASK_STABLEPY.keys())
|
|
|
|
UPSCALER_DICT_GUI = {
|
|
None: None,
|
|
"Lanczos": "Lanczos",
|
|
"Nearest": "Nearest",
|
|
'Latent': 'Latent',
|
|
'Latent (antialiased)': 'Latent (antialiased)',
|
|
'Latent (bicubic)': 'Latent (bicubic)',
|
|
'Latent (bicubic antialiased)': 'Latent (bicubic antialiased)',
|
|
'Latent (nearest)': 'Latent (nearest)',
|
|
'Latent (nearest-exact)': 'Latent (nearest-exact)',
|
|
"RealESRGAN_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
|
"RealESRNet_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
|
|
"RealESRGAN_x4plus_anime_6B": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
|
"RealESRGAN_x2plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
|
"realesr-animevideov3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
|
"realesr-general-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
|
"realesr-general-wdn-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
|
|
"4x-UltraSharp": "https://huggingface.co/Shandypur/ESRGAN-4x-UltraSharp/resolve/main/4x-UltraSharp.pth",
|
|
"4x_foolhardy_Remacri": "https://huggingface.co/FacehugmanIII/4x_foolhardy_Remacri/resolve/main/4x_foolhardy_Remacri.pth",
|
|
"Remacri4xExtraSmoother": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/Remacri%204x%20ExtraSmoother.pth",
|
|
"AnimeSharp4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/AnimeSharp%204x.pth",
|
|
"lollypop": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/lollypop.pth",
|
|
"RealisticRescaler4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/RealisticRescaler%204x.pth",
|
|
"NickelbackFS4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/NickelbackFS%204x.pth"
|
|
}
|
|
|
|
UPSCALER_KEYS = list(UPSCALER_DICT_GUI.keys())
|
|
|
|
def download_things(directory, url, hf_token="", civitai_api_key=""):
|
|
url = url.strip()
|
|
|
|
if "drive.google.com" in url:
|
|
original_dir = os.getcwd()
|
|
os.chdir(directory)
|
|
os.system(f"gdown --fuzzy {url}")
|
|
os.chdir(original_dir)
|
|
elif "huggingface.co" in url:
|
|
url = url.replace("?download=true", "")
|
|
|
|
if "/blob/" in url:
|
|
url = url.replace("/blob/", "/resolve/")
|
|
user_header = f'"Authorization: Bearer {hf_token}"'
|
|
if hf_token:
|
|
os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
|
|
else:
|
|
os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
|
|
elif "civitai.com" in url:
|
|
if "?" in url:
|
|
url = url.split("?")[0]
|
|
if civitai_api_key:
|
|
url = url + f"?token={civitai_api_key}"
|
|
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
|
|
else:
|
|
print("\033[91mYou need an API key to download Civitai models.\033[0m")
|
|
else:
|
|
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
|
|
|
|
def get_model_list(directory_path):
|
|
model_list = []
|
|
valid_extensions = {'.ckpt', '.pt', '.pth', '.safetensors', '.bin'}
|
|
|
|
for filename in os.listdir(directory_path):
|
|
if os.path.splitext(filename)[1] in valid_extensions:
|
|
|
|
file_path = os.path.join(directory_path, filename)
|
|
|
|
model_list.append(file_path)
|
|
print('\033[34mFILE: ' + file_path + '\033[0m')
|
|
return model_list
|
|
|
|
|
|
from modutils import (to_list, list_uniq, list_sub, get_model_id_list, get_tupled_embed_list,
|
|
get_tupled_model_list, get_lora_model_list, download_private_repo)
|
|
|
|
|
|
download_model = ", ".join(download_model_list)
|
|
|
|
download_vae = ", ".join(download_vae_list)
|
|
|
|
download_lora = ", ".join(download_lora_list)
|
|
|
|
|
|
download_private_repo(HF_VAE_PRIVATE_REPO, directory_vaes, False)
|
|
|
|
load_diffusers_format_model = list_uniq(load_diffusers_format_model + get_model_id_list())
|
|
|
|
|
|
|
|
for url in [url.strip() for url in download_model.split(',')]:
|
|
if not os.path.exists(f"./models/{url.split('/')[-1]}"):
|
|
download_things(directory_models, url, HF_TOKEN, CIVITAI_API_KEY)
|
|
for url in [url.strip() for url in download_vae.split(',')]:
|
|
if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
|
|
download_things(directory_vaes, url, HF_TOKEN, CIVITAI_API_KEY)
|
|
for url in [url.strip() for url in download_lora.split(',')]:
|
|
if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
|
|
download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
|
|
|
|
|
|
for url_embed in download_embeds:
|
|
if not os.path.exists(f"./embedings/{url_embed.split('/')[-1]}"):
|
|
download_things(directory_embeds, url_embed, HF_TOKEN, CIVITAI_API_KEY)
|
|
|
|
|
|
embed_list = get_model_list(directory_embeds)
|
|
model_list = get_model_list(directory_models)
|
|
model_list = load_diffusers_format_model + model_list
|
|
|
|
lora_model_list = get_lora_model_list()
|
|
vae_model_list = get_model_list(directory_vaes)
|
|
vae_model_list.insert(0, "None")
|
|
|
|
|
|
|
|
embed_sdxl_list = get_model_list(directory_embeds_sdxl) + get_model_list(directory_embeds_positive_sdxl)
|
|
|
|
def get_embed_list(pipeline_name):
|
|
return get_tupled_embed_list(embed_sdxl_list if pipeline_name == "StableDiffusionXLPipeline" else embed_list)
|
|
|
|
|
|
print('\033[33m🏁 Download and listing of valid models completed.\033[0m')
|
|
|
|
msg_inc_vae = (
|
|
"Use the right VAE for your model to maintain image quality. The wrong"
|
|
" VAE can lead to poor results, like blurriness in the generated images."
|
|
)
|
|
|
|
SDXL_TASK = [k for k, v in TASK_STABLEPY.items() if v in SDXL_TASKS]
|
|
SD_TASK = [k for k, v in TASK_STABLEPY.items() if v in SD15_TASKS]
|
|
FLUX_TASK = list(TASK_STABLEPY.keys())[:3] + [k for k, v in TASK_STABLEPY.items() if v in FLUX_CN_UNION_MODES.keys()]
|
|
|
|
MODEL_TYPE_TASK = {
|
|
"SD 1.5": SD_TASK,
|
|
"SDXL": SDXL_TASK,
|
|
"FLUX": FLUX_TASK,
|
|
}
|
|
|
|
MODEL_TYPE_CLASS = {
|
|
"diffusers:StableDiffusionPipeline": "SD 1.5",
|
|
"diffusers:StableDiffusionXLPipeline": "SDXL",
|
|
"diffusers:FluxPipeline": "FLUX",
|
|
}
|
|
|
|
POST_PROCESSING_SAMPLER = ["Use same sampler"] + scheduler_names[:-2]
|
|
|
|
def extract_parameters(input_string):
|
|
parameters = {}
|
|
input_string = input_string.replace("\n", "")
|
|
|
|
if "Negative prompt:" not in input_string:
|
|
print("Negative prompt not detected")
|
|
parameters["prompt"] = input_string
|
|
return parameters
|
|
|
|
parm = input_string.split("Negative prompt:")
|
|
parameters["prompt"] = parm[0]
|
|
if "Steps:" not in parm[1]:
|
|
print("Steps not detected")
|
|
parameters["neg_prompt"] = parm[1]
|
|
return parameters
|
|
parm = parm[1].split("Steps:")
|
|
parameters["neg_prompt"] = parm[0]
|
|
input_string = "Steps:" + parm[1]
|
|
|
|
|
|
steps_match = re.search(r'Steps: (\d+)', input_string)
|
|
if steps_match:
|
|
parameters['Steps'] = int(steps_match.group(1))
|
|
|
|
|
|
size_match = re.search(r'Size: (\d+x\d+)', input_string)
|
|
if size_match:
|
|
parameters['Size'] = size_match.group(1)
|
|
width, height = map(int, parameters['Size'].split('x'))
|
|
parameters['width'] = width
|
|
parameters['height'] = height
|
|
|
|
|
|
other_parameters = re.findall(r'(\w+): (.*?)(?=, \w+|$)', input_string)
|
|
for param in other_parameters:
|
|
parameters[param[0]] = param[1].strip('"')
|
|
|
|
return parameters
|
|
|
|
def get_model_type(repo_id: str):
|
|
api = HfApi(token=os.environ.get("HF_TOKEN"))
|
|
default = "SD 1.5"
|
|
try:
|
|
model = api.model_info(repo_id=repo_id, timeout=5.0)
|
|
tags = model.tags
|
|
for tag in tags:
|
|
if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
|
|
except Exception:
|
|
return default
|
|
return default
|
|
|
|
|
|
class GuiSD:
|
|
def __init__(self):
|
|
self.model = None
|
|
|
|
print("Loading model...")
|
|
self.model = Model_Diffusers(
|
|
base_model_id="Lykon/dreamshaper-8",
|
|
task_name="txt2img",
|
|
vae_model=None,
|
|
type_model_precision=torch.float16,
|
|
retain_task_model_in_cache=False,
|
|
device="cpu",
|
|
)
|
|
self.model.device = torch.device("cpu")
|
|
|
|
def infer_short(self, model, pipe_params, progress=gr.Progress(track_tqdm=True)):
|
|
progress(0, desc="Start inference...")
|
|
images, seed, image_list, metadata = model(**pipe_params)
|
|
progress(1, desc="Inference completed.")
|
|
if not isinstance(images, list): images = [images]
|
|
images = save_images(images, metadata)
|
|
img = []
|
|
for image in images:
|
|
img.append((image, None))
|
|
return img
|
|
|
|
def load_new_model(self, model_name, vae_model, task, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
|
|
|
vae_model = vae_model if vae_model != "None" else None
|
|
model_type = get_model_type(model_name)
|
|
|
|
if vae_model:
|
|
vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
|
|
if model_type != vae_type:
|
|
gr.Warning(msg_inc_vae)
|
|
|
|
self.model.device = torch.device("cpu")
|
|
dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
|
|
|
|
self.model.load_pipe(
|
|
model_name,
|
|
task_name=TASK_STABLEPY[task],
|
|
vae_model=vae_model if vae_model != "None" else None,
|
|
type_model_precision=dtype_model,
|
|
retain_task_model_in_cache=False,
|
|
)
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
def generate_pipeline(
|
|
self,
|
|
prompt,
|
|
neg_prompt,
|
|
num_images,
|
|
steps,
|
|
cfg,
|
|
clip_skip,
|
|
seed,
|
|
lora1,
|
|
lora_scale1,
|
|
lora2,
|
|
lora_scale2,
|
|
lora3,
|
|
lora_scale3,
|
|
lora4,
|
|
lora_scale4,
|
|
lora5,
|
|
lora_scale5,
|
|
sampler,
|
|
img_height,
|
|
img_width,
|
|
model_name,
|
|
vae_model,
|
|
task,
|
|
image_control,
|
|
preprocessor_name,
|
|
preprocess_resolution,
|
|
image_resolution,
|
|
style_prompt,
|
|
style_json_file,
|
|
image_mask,
|
|
strength,
|
|
low_threshold,
|
|
high_threshold,
|
|
value_threshold,
|
|
distance_threshold,
|
|
controlnet_output_scaling_in_unet,
|
|
controlnet_start_threshold,
|
|
controlnet_stop_threshold,
|
|
textual_inversion,
|
|
syntax_weights,
|
|
upscaler_model_path,
|
|
upscaler_increases_size,
|
|
esrgan_tile,
|
|
esrgan_tile_overlap,
|
|
hires_steps,
|
|
hires_denoising_strength,
|
|
hires_sampler,
|
|
hires_prompt,
|
|
hires_negative_prompt,
|
|
hires_before_adetailer,
|
|
hires_after_adetailer,
|
|
loop_generation,
|
|
leave_progress_bar,
|
|
disable_progress_bar,
|
|
image_previews,
|
|
display_images,
|
|
save_generated_images,
|
|
image_storage_location,
|
|
retain_compel_previous_load,
|
|
retain_detailfix_model_previous_load,
|
|
retain_hires_model_previous_load,
|
|
t2i_adapter_preprocessor,
|
|
t2i_adapter_conditioning_scale,
|
|
t2i_adapter_conditioning_factor,
|
|
xformers_memory_efficient_attention,
|
|
freeu,
|
|
generator_in_cpu,
|
|
adetailer_inpaint_only,
|
|
adetailer_verbose,
|
|
adetailer_sampler,
|
|
adetailer_active_a,
|
|
prompt_ad_a,
|
|
negative_prompt_ad_a,
|
|
strength_ad_a,
|
|
face_detector_ad_a,
|
|
person_detector_ad_a,
|
|
hand_detector_ad_a,
|
|
mask_dilation_a,
|
|
mask_blur_a,
|
|
mask_padding_a,
|
|
adetailer_active_b,
|
|
prompt_ad_b,
|
|
negative_prompt_ad_b,
|
|
strength_ad_b,
|
|
face_detector_ad_b,
|
|
person_detector_ad_b,
|
|
hand_detector_ad_b,
|
|
mask_dilation_b,
|
|
mask_blur_b,
|
|
mask_padding_b,
|
|
retain_task_cache_gui,
|
|
image_ip1,
|
|
mask_ip1,
|
|
model_ip1,
|
|
mode_ip1,
|
|
scale_ip1,
|
|
image_ip2,
|
|
mask_ip2,
|
|
model_ip2,
|
|
mode_ip2,
|
|
scale_ip2,
|
|
pag_scale,
|
|
progress=gr.Progress(track_tqdm=True),
|
|
):
|
|
progress(0, desc="Preparing inference...")
|
|
|
|
vae_model = vae_model if vae_model != "None" else None
|
|
loras_list = [lora1, lora2, lora3, lora4, lora5]
|
|
vae_msg = f"VAE: {vae_model}" if vae_model else ""
|
|
msg_lora = ""
|
|
|
|
print("Config model:", model_name, vae_model, loras_list)
|
|
|
|
|
|
prompt, neg_prompt = insert_model_recom_prompt(prompt, neg_prompt, model_name)
|
|
global lora_model_list
|
|
lora_model_list = get_lora_model_list()
|
|
|
|
|
|
task = TASK_STABLEPY[task]
|
|
|
|
params_ip_img = []
|
|
params_ip_msk = []
|
|
params_ip_model = []
|
|
params_ip_mode = []
|
|
params_ip_scale = []
|
|
|
|
all_adapters = [
|
|
(image_ip1, mask_ip1, model_ip1, mode_ip1, scale_ip1),
|
|
(image_ip2, mask_ip2, model_ip2, mode_ip2, scale_ip2),
|
|
]
|
|
|
|
for imgip, mskip, modelip, modeip, scaleip in all_adapters:
|
|
if imgip:
|
|
params_ip_img.append(imgip)
|
|
if mskip:
|
|
params_ip_msk.append(mskip)
|
|
params_ip_model.append(modelip)
|
|
params_ip_mode.append(modeip)
|
|
params_ip_scale.append(scaleip)
|
|
|
|
if task != "txt2img" and not image_control:
|
|
raise ValueError("No control image found: To use this function, you have to upload an image in 'Image ControlNet/Inpaint/Img2img'")
|
|
|
|
if task == "inpaint" and not image_mask:
|
|
raise ValueError("No mask image found: Specify one in 'Image Mask'")
|
|
|
|
if upscaler_model_path in UPSCALER_KEYS[:9]:
|
|
upscaler_model = upscaler_model_path
|
|
else:
|
|
directory_upscalers = 'upscalers'
|
|
os.makedirs(directory_upscalers, exist_ok=True)
|
|
|
|
url_upscaler = UPSCALER_DICT_GUI[upscaler_model_path]
|
|
|
|
if not os.path.exists(f"./upscalers/{url_upscaler.split('/')[-1]}"):
|
|
download_things(directory_upscalers, url_upscaler, HF_TOKEN)
|
|
|
|
upscaler_model = f"./upscalers/{url_upscaler.split('/')[-1]}"
|
|
|
|
logging.getLogger("ultralytics").setLevel(logging.INFO if adetailer_verbose else logging.ERROR)
|
|
|
|
adetailer_params_A = {
|
|
"face_detector_ad": face_detector_ad_a,
|
|
"person_detector_ad": person_detector_ad_a,
|
|
"hand_detector_ad": hand_detector_ad_a,
|
|
"prompt": prompt_ad_a,
|
|
"negative_prompt": negative_prompt_ad_a,
|
|
"strength": strength_ad_a,
|
|
|
|
"mask_dilation": mask_dilation_a,
|
|
"mask_blur": mask_blur_a,
|
|
"mask_padding": mask_padding_a,
|
|
"inpaint_only": adetailer_inpaint_only,
|
|
"sampler": adetailer_sampler,
|
|
}
|
|
|
|
adetailer_params_B = {
|
|
"face_detector_ad": face_detector_ad_b,
|
|
"person_detector_ad": person_detector_ad_b,
|
|
"hand_detector_ad": hand_detector_ad_b,
|
|
"prompt": prompt_ad_b,
|
|
"negative_prompt": negative_prompt_ad_b,
|
|
"strength": strength_ad_b,
|
|
|
|
"mask_dilation": mask_dilation_b,
|
|
"mask_blur": mask_blur_b,
|
|
"mask_padding": mask_padding_b,
|
|
}
|
|
pipe_params = {
|
|
"prompt": prompt,
|
|
"negative_prompt": neg_prompt,
|
|
"img_height": img_height,
|
|
"img_width": img_width,
|
|
"num_images": num_images,
|
|
"num_steps": steps,
|
|
"guidance_scale": cfg,
|
|
"clip_skip": clip_skip,
|
|
"pag_scale": float(pag_scale),
|
|
"seed": seed,
|
|
"image": image_control,
|
|
"preprocessor_name": preprocessor_name,
|
|
"preprocess_resolution": preprocess_resolution,
|
|
"image_resolution": image_resolution,
|
|
"style_prompt": style_prompt if style_prompt else "",
|
|
"style_json_file": "",
|
|
"image_mask": image_mask,
|
|
"strength": strength,
|
|
"low_threshold": low_threshold,
|
|
"high_threshold": high_threshold,
|
|
"value_threshold": value_threshold,
|
|
"distance_threshold": distance_threshold,
|
|
"lora_A": lora1 if lora1 != "None" else None,
|
|
"lora_scale_A": lora_scale1,
|
|
"lora_B": lora2 if lora2 != "None" else None,
|
|
"lora_scale_B": lora_scale2,
|
|
"lora_C": lora3 if lora3 != "None" else None,
|
|
"lora_scale_C": lora_scale3,
|
|
"lora_D": lora4 if lora4 != "None" else None,
|
|
"lora_scale_D": lora_scale4,
|
|
"lora_E": lora5 if lora5 != "None" else None,
|
|
"lora_scale_E": lora_scale5,
|
|
|
|
"textual_inversion": get_embed_list(self.model.class_name) if textual_inversion else [],
|
|
|
|
"syntax_weights": syntax_weights,
|
|
"sampler": sampler,
|
|
"xformers_memory_efficient_attention": xformers_memory_efficient_attention,
|
|
"gui_active": True,
|
|
"loop_generation": loop_generation,
|
|
"controlnet_conditioning_scale": float(controlnet_output_scaling_in_unet),
|
|
"control_guidance_start": float(controlnet_start_threshold),
|
|
"control_guidance_end": float(controlnet_stop_threshold),
|
|
"generator_in_cpu": generator_in_cpu,
|
|
"FreeU": freeu,
|
|
"adetailer_A": adetailer_active_a,
|
|
"adetailer_A_params": adetailer_params_A,
|
|
"adetailer_B": adetailer_active_b,
|
|
"adetailer_B_params": adetailer_params_B,
|
|
"leave_progress_bar": leave_progress_bar,
|
|
"disable_progress_bar": disable_progress_bar,
|
|
"image_previews": image_previews,
|
|
"display_images": display_images,
|
|
"save_generated_images": save_generated_images,
|
|
"image_storage_location": image_storage_location,
|
|
"retain_compel_previous_load": retain_compel_previous_load,
|
|
"retain_detailfix_model_previous_load": retain_detailfix_model_previous_load,
|
|
"retain_hires_model_previous_load": retain_hires_model_previous_load,
|
|
"t2i_adapter_preprocessor": t2i_adapter_preprocessor,
|
|
"t2i_adapter_conditioning_scale": float(t2i_adapter_conditioning_scale),
|
|
"t2i_adapter_conditioning_factor": float(t2i_adapter_conditioning_factor),
|
|
"upscaler_model_path": upscaler_model,
|
|
"upscaler_increases_size": upscaler_increases_size,
|
|
"esrgan_tile": esrgan_tile,
|
|
"esrgan_tile_overlap": esrgan_tile_overlap,
|
|
"hires_steps": hires_steps,
|
|
"hires_denoising_strength": hires_denoising_strength,
|
|
"hires_prompt": hires_prompt,
|
|
"hires_negative_prompt": hires_negative_prompt,
|
|
"hires_sampler": hires_sampler,
|
|
"hires_before_adetailer": hires_before_adetailer,
|
|
"hires_after_adetailer": hires_after_adetailer,
|
|
"ip_adapter_image": params_ip_img,
|
|
"ip_adapter_mask": params_ip_msk,
|
|
"ip_adapter_model": params_ip_model,
|
|
"ip_adapter_mode": params_ip_mode,
|
|
"ip_adapter_scale": params_ip_scale,
|
|
}
|
|
|
|
self.model.device = torch.device("cuda:0")
|
|
if hasattr(self.model.pipe, "transformer") and loras_list != ["None"] * 5:
|
|
self.model.pipe.transformer.to(self.model.device)
|
|
print("transformer to cuda")
|
|
|
|
progress(1, desc="Inference preparation completed. Starting inference...")
|
|
|
|
info_state = ""
|
|
return self.infer_short(self.model, pipe_params, progress), info_state
|
|
|
|
|
|
def dynamic_gpu_duration(func, duration, *args):
|
|
|
|
@spaces.GPU(duration=duration)
|
|
def wrapped_func():
|
|
yield from func(*args)
|
|
|
|
return wrapped_func()
|
|
|
|
|
|
@spaces.GPU
|
|
def dummy_gpu():
|
|
return None
|
|
|
|
|
|
def sd_gen_generate_pipeline(*args):
|
|
|
|
gpu_duration_arg = int(args[-1]) if args[-1] else 59
|
|
verbose_arg = int(args[-2])
|
|
load_lora_cpu = args[-3]
|
|
generation_args = args[:-3]
|
|
lora_list = [
|
|
None if item == "None" or item == "" else item
|
|
for item in [args[7], args[9], args[11], args[13], args[15]]
|
|
]
|
|
lora_status = [None] * 5
|
|
|
|
msg_load_lora = "Updating LoRAs in GPU..."
|
|
if load_lora_cpu:
|
|
msg_load_lora = "Updating LoRAs in CPU (Slow but saves GPU usage)..."
|
|
|
|
if lora_list != sd_gen.model.lora_memory and lora_list != [None] * 5:
|
|
yield None, msg_load_lora
|
|
|
|
|
|
if load_lora_cpu:
|
|
lora_status = sd_gen.model.lora_merge(
|
|
lora_A=lora_list[0], lora_scale_A=args[8],
|
|
lora_B=lora_list[1], lora_scale_B=args[10],
|
|
lora_C=lora_list[2], lora_scale_C=args[12],
|
|
lora_D=lora_list[3], lora_scale_D=args[14],
|
|
lora_E=lora_list[4], lora_scale_E=args[16],
|
|
)
|
|
print(lora_status)
|
|
|
|
if verbose_arg:
|
|
for status, lora in zip(lora_status, lora_list):
|
|
if status:
|
|
gr.Info(f"LoRA loaded in CPU: {lora}")
|
|
elif status is not None:
|
|
gr.Warning(f"Failed to load LoRA: {lora}")
|
|
|
|
if lora_status == [None] * 5 and sd_gen.model.lora_memory != [None] * 5 and load_lora_cpu:
|
|
lora_cache_msg = ", ".join(
|
|
str(x) for x in sd_gen.model.lora_memory if x is not None
|
|
)
|
|
gr.Info(f"LoRAs in cache: {lora_cache_msg}")
|
|
|
|
msg_request = f"Requesting {gpu_duration_arg}s. of GPU time"
|
|
gr.Info(msg_request)
|
|
print(msg_request)
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
yield from dynamic_gpu_duration(
|
|
sd_gen.generate_pipeline,
|
|
gpu_duration_arg,
|
|
*generation_args,
|
|
)
|
|
|
|
end_time = time.time()
|
|
|
|
if verbose_arg:
|
|
execution_time = end_time - start_time
|
|
msg_task_complete = (
|
|
f"GPU task complete in: {round(execution_time, 0) + 1} seconds"
|
|
)
|
|
gr.Info(msg_task_complete)
|
|
print(msg_task_complete)
|
|
|
|
|
|
dynamic_gpu_duration.zerogpu = True
|
|
sd_gen_generate_pipeline.zerogpu = True
|
|
|
|
from pathlib import Path
|
|
from PIL import Image
|
|
import random, json
|
|
from modutils import (safe_float, escape_lora_basename, to_lora_key, to_lora_path,
|
|
get_local_model_list, get_private_lora_model_lists, get_valid_lora_name,
|
|
get_valid_lora_path, get_valid_lora_wt, get_lora_info,
|
|
normalize_prompt_list, get_civitai_info, search_lora_on_civitai, translate_to_en)
|
|
|
|
sd_gen = GuiSD()
|
|
|
|
def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
|
|
model_name = load_diffusers_format_model[0], lora1 = None, lora1_wt = 1.0, lora2 = None, lora2_wt = 1.0,
|
|
lora3 = None, lora3_wt = 1.0, lora4 = None, lora4_wt = 1.0, lora5 = None, lora5_wt = 1.0,
|
|
sampler = "Euler a", vae = None, translate=True, progress=gr.Progress(track_tqdm=True)):
|
|
import PIL
|
|
import numpy as np
|
|
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
load_lora_cpu = False
|
|
verbose_info = False
|
|
gpu_duration = 59
|
|
|
|
images: list[tuple[PIL.Image.Image, str | None]] = []
|
|
info: str = ""
|
|
progress(0, desc="Preparing...")
|
|
|
|
if randomize_seed:
|
|
seed = random.randint(0, MAX_SEED)
|
|
|
|
generator = torch.Generator().manual_seed(seed).seed()
|
|
|
|
if translate:
|
|
prompt = translate_to_en(prompt)
|
|
negative_prompt = translate_to_en(prompt)
|
|
|
|
prompt, negative_prompt = insert_model_recom_prompt(prompt, negative_prompt, model_name)
|
|
progress(0.5, desc="Preparing...")
|
|
lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt = \
|
|
set_prompt_loras(prompt, model_name, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt)
|
|
lora1 = get_valid_lora_path(lora1)
|
|
lora2 = get_valid_lora_path(lora2)
|
|
lora3 = get_valid_lora_path(lora3)
|
|
lora4 = get_valid_lora_path(lora4)
|
|
lora5 = get_valid_lora_path(lora5)
|
|
progress(1, desc="Preparation completed. Starting inference preparation...")
|
|
|
|
sd_gen.load_new_model(model_name, vae, TASK_MODEL_LIST[0], progress)
|
|
images, info = sd_gen_generate_pipeline(prompt, negative_prompt, 1, num_inference_steps,
|
|
guidance_scale, True, generator, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt,
|
|
lora4, lora4_wt, lora5, lora5_wt, sampler,
|
|
height, width, model_name, vae, TASK_MODEL_LIST[0], None, "Canny", 512, 1024,
|
|
None, None, None, 0.35, 100, 200, 0.1, 0.1, 1.0, 0., 1., False, "Classic", None,
|
|
1.0, 100, 10, 30, 0.55, "Use same sampler", "", "",
|
|
False, True, 1, True, False, False, False, False, "./images", False, False, False, True, 1, 0.55,
|
|
False, False, False, True, False, "Use same sampler", False, "", "", 0.35, True, True, False, 4, 4, 32,
|
|
False, "", "", 0.35, True, True, False, 4, 4, 32,
|
|
True, None, None, "plus_face", "original", 0.7, None, None, "base", "style", 0.7, 0.0,
|
|
load_lora_cpu, verbose_info, gpu_duration
|
|
)
|
|
|
|
progress(1, desc="Inference completed.")
|
|
output_image = images[0][0] if images else None
|
|
|
|
return output_image
|
|
|
|
|
|
|
|
def _infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
|
|
model_name = load_diffusers_format_model[0], lora1 = None, lora1_wt = 1.0, lora2 = None, lora2_wt = 1.0,
|
|
lora3 = None, lora3_wt = 1.0, lora4 = None, lora4_wt = 1.0, lora5 = None, lora5_wt = 1.0,
|
|
sampler = "Euler a", vae = None, translate = True, progress=gr.Progress(track_tqdm=True)):
|
|
return gr.update(visible=True)
|
|
|
|
|
|
infer.zerogpu = True
|
|
_infer.zerogpu = True
|
|
|
|
|
|
def pass_result(result):
|
|
return result
|
|
|
|
|
|
def get_samplers():
|
|
return scheduler_names
|
|
|
|
|
|
def get_vaes():
|
|
return vae_model_list
|
|
|
|
|
|
show_diffusers_model_list_detail = False
|
|
cached_diffusers_model_tupled_list = get_tupled_model_list(load_diffusers_format_model)
|
|
def get_diffusers_model_list():
|
|
if show_diffusers_model_list_detail:
|
|
return cached_diffusers_model_tupled_list
|
|
else:
|
|
return load_diffusers_format_model
|
|
|
|
|
|
def enable_diffusers_model_detail(is_enable: bool = False, model_name: str = ""):
|
|
global show_diffusers_model_list_detail
|
|
show_diffusers_model_list_detail = is_enable
|
|
new_value = model_name
|
|
index = 0
|
|
if model_name in set(load_diffusers_format_model):
|
|
index = load_diffusers_format_model.index(model_name)
|
|
if is_enable:
|
|
new_value = cached_diffusers_model_tupled_list[index][1]
|
|
else:
|
|
new_value = load_diffusers_format_model[index]
|
|
return gr.update(value=is_enable), gr.update(value=new_value, choices=get_diffusers_model_list())
|
|
|
|
|
|
def get_t2i_model_info(repo_id: str):
|
|
from huggingface_hub import HfApi
|
|
api = HfApi()
|
|
try:
|
|
if " " in repo_id or not api.repo_exists(repo_id): return ""
|
|
model = api.model_info(repo_id=repo_id)
|
|
except Exception as e:
|
|
print(f"Error: Failed to get {repo_id}'s info. {e}")
|
|
return ""
|
|
if model.private or model.gated: return ""
|
|
tags = model.tags
|
|
info = []
|
|
url = f"https://huggingface.co/{repo_id}/"
|
|
if not 'diffusers' in tags: return ""
|
|
if 'diffusers:FluxPipeline' in tags:
|
|
info.append("FLUX.1")
|
|
elif 'diffusers:StableDiffusionXLPipeline' in tags:
|
|
info.append("SDXL")
|
|
elif 'diffusers:StableDiffusionPipeline' in tags:
|
|
info.append("SD1.5")
|
|
if model.card_data and model.card_data.tags:
|
|
info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
|
|
info.append(f"DLs: {model.downloads}")
|
|
info.append(f"likes: {model.likes}")
|
|
info.append(model.last_modified.strftime("lastmod: %Y-%m-%d"))
|
|
md = f"Model Info: {', '.join(info)}, [Model Repo]({url})"
|
|
return gr.update(value=md)
|
|
|
|
|
|
def load_model_prompt_dict():
|
|
import json
|
|
dict = {}
|
|
try:
|
|
with open('model_dict.json', encoding='utf-8') as f:
|
|
dict = json.load(f)
|
|
except Exception:
|
|
pass
|
|
return dict
|
|
|
|
|
|
model_prompt_dict = load_model_prompt_dict()
|
|
|
|
|
|
model_recom_prompt_enabled = True
|
|
animagine_ps = to_list("masterpiece, best quality, very aesthetic, absurdres")
|
|
animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
|
|
pony_ps = to_list("score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
|
|
pony_nps = to_list("source_pony, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
|
|
other_ps = to_list("anime artwork, anime style, studio anime, highly detailed, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed")
|
|
other_nps = to_list("photo, deformed, black and white, realism, disfigured, low contrast, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly")
|
|
default_ps = to_list("highly detailed, masterpiece, best quality, very aesthetic, absurdres")
|
|
default_nps = to_list("score_6, score_5, score_4, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
|
|
def insert_model_recom_prompt(prompt: str = "", neg_prompt: str = "", model_name: str = "None"):
|
|
if not model_recom_prompt_enabled or not model_name: return prompt, neg_prompt
|
|
prompts = to_list(prompt)
|
|
neg_prompts = to_list(neg_prompt)
|
|
prompts = list_sub(prompts, animagine_ps + pony_ps + other_ps)
|
|
neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps + other_nps)
|
|
last_empty_p = [""] if not prompts and type != "None" else []
|
|
last_empty_np = [""] if not neg_prompts and type != "None" else []
|
|
ps = []
|
|
nps = []
|
|
if model_name in model_prompt_dict.keys():
|
|
ps = to_list(model_prompt_dict[model_name]["prompt"])
|
|
nps = to_list(model_prompt_dict[model_name]["negative_prompt"])
|
|
else:
|
|
ps = default_ps
|
|
nps = default_nps
|
|
prompts = prompts + ps
|
|
neg_prompts = neg_prompts + nps
|
|
prompt = ", ".join(list_uniq(prompts) + last_empty_p)
|
|
neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
|
|
return prompt, neg_prompt
|
|
|
|
|
|
def enable_model_recom_prompt(is_enable: bool = True):
|
|
global model_recom_prompt_enabled
|
|
model_recom_prompt_enabled = is_enable
|
|
return is_enable
|
|
|
|
|
|
private_lora_dict = {}
|
|
try:
|
|
with open('lora_dict.json', encoding='utf-8') as f:
|
|
d = json.load(f)
|
|
for k, v in d.items():
|
|
private_lora_dict[escape_lora_basename(k)] = v
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
private_lora_model_list = get_private_lora_model_lists()
|
|
loras_dict = {"None": ["", "", "", "", ""], "": ["", "", "", "", ""]} | private_lora_dict.copy()
|
|
loras_url_to_path_dict = {}
|
|
civitai_lora_last_results = {}
|
|
all_lora_list = []
|
|
|
|
|
|
def get_all_lora_list():
|
|
global all_lora_list
|
|
loras = get_lora_model_list()
|
|
all_lora_list = loras.copy()
|
|
return loras
|
|
|
|
|
|
def get_all_lora_tupled_list():
|
|
global loras_dict
|
|
models = get_all_lora_list()
|
|
if not models: return []
|
|
tupled_list = []
|
|
for model in models:
|
|
|
|
basename = Path(model).stem
|
|
key = to_lora_key(model)
|
|
items = None
|
|
if key in loras_dict.keys():
|
|
items = loras_dict.get(key, None)
|
|
else:
|
|
items = get_civitai_info(model)
|
|
if items != None:
|
|
loras_dict[key] = items
|
|
name = basename
|
|
value = model
|
|
if items and items[2] != "":
|
|
if items[1] == "Pony":
|
|
name = f"{basename} (for {items[1]}🐴, {items[2]})"
|
|
else:
|
|
name = f"{basename} (for {items[1]}, {items[2]})"
|
|
tupled_list.append((name, value))
|
|
return tupled_list
|
|
|
|
|
|
def update_lora_dict(path: str):
|
|
global loras_dict
|
|
key = to_lora_key(path)
|
|
if key in loras_dict.keys(): return
|
|
items = get_civitai_info(path)
|
|
if items == None: return
|
|
loras_dict[key] = items
|
|
|
|
|
|
def download_lora(dl_urls: str):
|
|
global loras_url_to_path_dict
|
|
dl_path = ""
|
|
before = get_local_model_list(directory_loras)
|
|
urls = []
|
|
for url in [url.strip() for url in dl_urls.split(',')]:
|
|
local_path = f"{directory_loras}/{url.split('/')[-1]}"
|
|
if not Path(local_path).exists():
|
|
download_things(directory_loras, url, hf_token, CIVITAI_API_KEY)
|
|
urls.append(url)
|
|
after = get_local_model_list(directory_loras)
|
|
new_files = list_sub(after, before)
|
|
i = 0
|
|
for file in new_files:
|
|
path = Path(file)
|
|
if path.exists():
|
|
new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
|
|
path.resolve().rename(new_path.resolve())
|
|
loras_url_to_path_dict[urls[i]] = str(new_path)
|
|
update_lora_dict(str(new_path))
|
|
dl_path = str(new_path)
|
|
i += 1
|
|
return dl_path
|
|
|
|
|
|
def copy_lora(path: str, new_path: str):
|
|
import shutil
|
|
if path == new_path: return new_path
|
|
cpath = Path(path)
|
|
npath = Path(new_path)
|
|
if cpath.exists():
|
|
try:
|
|
shutil.copy(str(cpath.resolve()), str(npath.resolve()))
|
|
except Exception:
|
|
return None
|
|
update_lora_dict(str(npath))
|
|
return new_path
|
|
else:
|
|
return None
|
|
|
|
|
|
def download_my_lora(dl_urls: str, lora1: str, lora2: str, lora3: str, lora4: str, lora5: str):
|
|
path = download_lora(dl_urls)
|
|
if path:
|
|
if not lora1 or lora1 == "None":
|
|
lora1 = path
|
|
elif not lora2 or lora2 == "None":
|
|
lora2 = path
|
|
elif not lora3 or lora3 == "None":
|
|
lora3 = path
|
|
elif not lora4 or lora4 == "None":
|
|
lora4 = path
|
|
elif not lora5 or lora5 == "None":
|
|
lora5 = path
|
|
choices = get_all_lora_tupled_list()
|
|
return gr.update(value=lora1, choices=choices), gr.update(value=lora2, choices=choices), gr.update(value=lora3, choices=choices),\
|
|
gr.update(value=lora4, choices=choices), gr.update(value=lora5, choices=choices)
|
|
|
|
|
|
def set_prompt_loras(prompt, model_name, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt):
|
|
import re
|
|
lora1 = get_valid_lora_name(lora1, model_name)
|
|
lora2 = get_valid_lora_name(lora2, model_name)
|
|
lora3 = get_valid_lora_name(lora3, model_name)
|
|
lora4 = get_valid_lora_name(lora4, model_name)
|
|
lora5 = get_valid_lora_name(lora5, model_name)
|
|
if not "<lora" in prompt: return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
|
|
lora1_wt = get_valid_lora_wt(prompt, lora1, lora1_wt)
|
|
lora2_wt = get_valid_lora_wt(prompt, lora2, lora2_wt)
|
|
lora3_wt = get_valid_lora_wt(prompt, lora3, lora3_wt)
|
|
lora4_wt = get_valid_lora_wt(prompt, lora4, lora4_wt)
|
|
lora5_wt = get_valid_lora_wt(prompt, lora5, lora5_wt)
|
|
on1, label1, tag1, md1 = get_lora_info(lora1)
|
|
on2, label2, tag2, md2 = get_lora_info(lora2)
|
|
on3, label3, tag3, md3 = get_lora_info(lora3)
|
|
on4, label4, tag4, md4 = get_lora_info(lora4)
|
|
on5, label5, tag5, md5 = get_lora_info(lora5)
|
|
lora_paths = [lora1, lora2, lora3, lora4, lora5]
|
|
prompts = prompt.split(",") if prompt else []
|
|
for p in prompts:
|
|
p = str(p).strip()
|
|
if "<lora" in p:
|
|
result = re.findall(r'<lora:(.+?):(.+?)>', p)
|
|
if not result: continue
|
|
key = result[0][0]
|
|
wt = result[0][1]
|
|
path = to_lora_path(key)
|
|
if not key in loras_dict.keys() or not path:
|
|
path = get_valid_lora_name(path)
|
|
if not path or path == "None": continue
|
|
if path in lora_paths:
|
|
continue
|
|
elif not on1:
|
|
lora1 = path
|
|
lora_paths = [lora1, lora2, lora3, lora4, lora5]
|
|
lora1_wt = safe_float(wt)
|
|
on1 = True
|
|
elif not on2:
|
|
lora2 = path
|
|
lora_paths = [lora1, lora2, lora3, lora4, lora5]
|
|
lora2_wt = safe_float(wt)
|
|
on2 = True
|
|
elif not on3:
|
|
lora3 = path
|
|
lora_paths = [lora1, lora2, lora3, lora4, lora5]
|
|
lora3_wt = safe_float(wt)
|
|
on3 = True
|
|
elif not on4:
|
|
lora4 = path
|
|
lora_paths = [lora1, lora2, lora3, lora4, lora5]
|
|
lora4_wt = safe_float(wt)
|
|
on4, label4, tag4, md4 = get_lora_info(lora4)
|
|
elif not on5:
|
|
lora5 = path
|
|
lora_paths = [lora1, lora2, lora3, lora4, lora5]
|
|
lora5_wt = safe_float(wt)
|
|
on5 = True
|
|
return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
|
|
|
|
|
|
def apply_lora_prompt(prompt: str, lora_info: str):
|
|
if lora_info == "None": return gr.update(value=prompt)
|
|
tags = prompt.split(",") if prompt else []
|
|
prompts = normalize_prompt_list(tags)
|
|
lora_tag = lora_info.replace("/",",")
|
|
lora_tags = lora_tag.split(",") if str(lora_info) != "None" else []
|
|
lora_prompts = normalize_prompt_list(lora_tags)
|
|
empty = [""]
|
|
prompt = ", ".join(list_uniq(prompts + lora_prompts) + empty)
|
|
return gr.update(value=prompt)
|
|
|
|
|
|
def update_loras(prompt, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt):
|
|
import re
|
|
on1, label1, tag1, md1 = get_lora_info(lora1)
|
|
on2, label2, tag2, md2 = get_lora_info(lora2)
|
|
on3, label3, tag3, md3 = get_lora_info(lora3)
|
|
on4, label4, tag4, md4 = get_lora_info(lora4)
|
|
on5, label5, tag5, md5 = get_lora_info(lora5)
|
|
lora_paths = [lora1, lora2, lora3, lora4, lora5]
|
|
prompts = prompt.split(",") if prompt else []
|
|
output_prompts = []
|
|
for p in prompts:
|
|
p = str(p).strip()
|
|
if "<lora" in p:
|
|
result = re.findall(r'<lora:(.+?):(.+?)>', p)
|
|
if not result: continue
|
|
key = result[0][0]
|
|
wt = result[0][1]
|
|
path = to_lora_path(key)
|
|
if not key in loras_dict.keys() or not path: continue
|
|
if path in lora_paths:
|
|
output_prompts.append(f"<lora:{to_lora_key(path)}:{safe_float(wt):.2f}>")
|
|
elif p:
|
|
output_prompts.append(p)
|
|
lora_prompts = []
|
|
if on1: lora_prompts.append(f"<lora:{to_lora_key(lora1)}:{lora1_wt:.2f}>")
|
|
if on2: lora_prompts.append(f"<lora:{to_lora_key(lora2)}:{lora2_wt:.2f}>")
|
|
if on3: lora_prompts.append(f"<lora:{to_lora_key(lora3)}:{lora3_wt:.2f}>")
|
|
if on4: lora_prompts.append(f"<lora:{to_lora_key(lora4)}:{lora4_wt:.2f}>")
|
|
if on5: lora_prompts.append(f"<lora:{to_lora_key(lora5)}:{lora5_wt:.2f}>")
|
|
output_prompt = ", ".join(list_uniq(output_prompts + lora_prompts + [""]))
|
|
choices = get_all_lora_tupled_list()
|
|
return gr.update(value=output_prompt), gr.update(value=lora1, choices=choices), gr.update(value=lora1_wt),\
|
|
gr.update(value=tag1, label=label1, visible=on1), gr.update(visible=on1), gr.update(value=md1, visible=on1),\
|
|
gr.update(value=lora2, choices=choices), gr.update(value=lora2_wt),\
|
|
gr.update(value=tag2, label=label2, visible=on2), gr.update(visible=on2), gr.update(value=md2, visible=on2),\
|
|
gr.update(value=lora3, choices=choices), gr.update(value=lora3_wt),\
|
|
gr.update(value=tag3, label=label3, visible=on3), gr.update(visible=on3), gr.update(value=md3, visible=on3),\
|
|
gr.update(value=lora4, choices=choices), gr.update(value=lora4_wt),\
|
|
gr.update(value=tag4, label=label4, visible=on4), gr.update(visible=on4), gr.update(value=md4, visible=on4),\
|
|
gr.update(value=lora5, choices=choices), gr.update(value=lora5_wt),\
|
|
gr.update(value=tag5, label=label5, visible=on5), gr.update(visible=on5), gr.update(value=md5, visible=on5)
|
|
|
|
|
|
def search_civitai_lora(query, base_model, sort="Highest Rated", period="AllTime", tag=""):
|
|
global civitai_lora_last_results
|
|
items = search_lora_on_civitai(query, base_model, 100, sort, period, tag)
|
|
if not items: return gr.update(choices=[("", "")], value="", visible=False),\
|
|
gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
|
|
civitai_lora_last_results = {}
|
|
choices = []
|
|
for item in items:
|
|
base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model']
|
|
name = f"{item['name']} (for {base_model_name} / By: {item['creator']} / Tags: {', '.join(item['tags'])})"
|
|
value = item['dl_url']
|
|
choices.append((name, value))
|
|
civitai_lora_last_results[value] = item
|
|
if not choices: return gr.update(choices=[("", "")], value="", visible=False),\
|
|
gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
|
|
result = civitai_lora_last_results.get(choices[0][1], "None")
|
|
md = result['md'] if result else ""
|
|
return gr.update(choices=choices, value=choices[0][1], visible=True), gr.update(value=md, visible=True),\
|
|
gr.update(visible=True), gr.update(visible=True)
|
|
|
|
|
|
def select_civitai_lora(search_result):
|
|
if not "http" in search_result: return gr.update(value=""), gr.update(value="None", visible=True)
|
|
result = civitai_lora_last_results.get(search_result, "None")
|
|
md = result['md'] if result else ""
|
|
return gr.update(value=search_result), gr.update(value=md, visible=True)
|
|
|
|
|
|
def search_civitai_lora_json(query, base_model):
|
|
results = {}
|
|
items = search_lora_on_civitai(query, base_model)
|
|
if not items: return gr.update(value=results)
|
|
for item in items:
|
|
results[item['dl_url']] = item
|
|
return gr.update(value=results)
|
|
|
|
|
|
quality_prompt_list = [
|
|
{
|
|
"name": "None",
|
|
"prompt": "",
|
|
"negative_prompt": "lowres",
|
|
},
|
|
{
|
|
"name": "Animagine Common",
|
|
"prompt": "anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
|
|
"negative_prompt": "lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
|
|
},
|
|
{
|
|
"name": "Pony Anime Common",
|
|
"prompt": "source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres",
|
|
"negative_prompt": "source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends",
|
|
},
|
|
{
|
|
"name": "Pony Common",
|
|
"prompt": "source_anime, score_9, score_8_up, score_7_up",
|
|
"negative_prompt": "source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends",
|
|
},
|
|
{
|
|
"name": "Animagine Standard v3.0",
|
|
"prompt": "masterpiece, best quality",
|
|
"negative_prompt": "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name",
|
|
},
|
|
{
|
|
"name": "Animagine Standard v3.1",
|
|
"prompt": "masterpiece, best quality, very aesthetic, absurdres",
|
|
"negative_prompt": "lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
|
|
},
|
|
{
|
|
"name": "Animagine Light v3.1",
|
|
"prompt": "(masterpiece), best quality, very aesthetic, perfect face",
|
|
"negative_prompt": "(low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
|
|
},
|
|
{
|
|
"name": "Animagine Heavy v3.1",
|
|
"prompt": "(masterpiece), (best quality), (ultra-detailed), very aesthetic, illustration, disheveled hair, perfect composition, moist skin, intricate details",
|
|
"negative_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair, extra digit, fewer digits, cropped, worst quality, low quality, very displeasing",
|
|
},
|
|
]
|
|
|
|
|
|
style_list = [
|
|
{
|
|
"name": "None",
|
|
"prompt": "",
|
|
"negative_prompt": "",
|
|
},
|
|
{
|
|
"name": "Cinematic",
|
|
"prompt": "cinematic still, emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
|
|
"negative_prompt": "cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
|
|
},
|
|
{
|
|
"name": "Photographic",
|
|
"prompt": "cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
|
"negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
|
|
},
|
|
{
|
|
"name": "Anime",
|
|
"prompt": "anime artwork, anime style, vibrant, studio anime, highly detailed",
|
|
"negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
|
|
},
|
|
{
|
|
"name": "Manga",
|
|
"prompt": "manga style, vibrant, high-energy, detailed, iconic, Japanese comic style",
|
|
"negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
|
|
},
|
|
{
|
|
"name": "Digital Art",
|
|
"prompt": "concept art, digital artwork, illustrative, painterly, matte painting, highly detailed",
|
|
"negative_prompt": "photo, photorealistic, realism, ugly",
|
|
},
|
|
{
|
|
"name": "Pixel art",
|
|
"prompt": "pixel-art, low-res, blocky, pixel art style, 8-bit graphics",
|
|
"negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
|
|
},
|
|
{
|
|
"name": "Fantasy art",
|
|
"prompt": "ethereal fantasy concept art, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
|
|
"negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
|
|
},
|
|
{
|
|
"name": "Neonpunk",
|
|
"prompt": "neonpunk style, cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
|
|
"negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
|
|
},
|
|
{
|
|
"name": "3D Model",
|
|
"prompt": "professional 3d model, octane render, highly detailed, volumetric, dramatic lighting",
|
|
"negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
|
|
},
|
|
]
|
|
|
|
|
|
preset_styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|
|
preset_quality = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in quality_prompt_list}
|
|
|
|
|
|
def process_style_prompt(prompt: str, neg_prompt: str, styles_key: str = "None", quality_key: str = "None"):
|
|
def to_list(s):
|
|
return [x.strip() for x in s.split(",") if not s == ""]
|
|
|
|
def list_sub(a, b):
|
|
return [e for e in a if e not in b]
|
|
|
|
def list_uniq(l):
|
|
return sorted(set(l), key=l.index)
|
|
|
|
animagine_ps = to_list("anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres")
|
|
animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
|
|
pony_ps = to_list("source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
|
|
pony_nps = to_list("source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
|
|
prompts = to_list(prompt)
|
|
neg_prompts = to_list(neg_prompt)
|
|
|
|
all_styles_ps = []
|
|
all_styles_nps = []
|
|
for d in style_list:
|
|
all_styles_ps.extend(to_list(str(d.get("prompt", ""))))
|
|
all_styles_nps.extend(to_list(str(d.get("negative_prompt", ""))))
|
|
|
|
all_quality_ps = []
|
|
all_quality_nps = []
|
|
for d in quality_prompt_list:
|
|
all_quality_ps.extend(to_list(str(d.get("prompt", ""))))
|
|
all_quality_nps.extend(to_list(str(d.get("negative_prompt", ""))))
|
|
|
|
quality_ps = to_list(preset_quality[quality_key][0])
|
|
quality_nps = to_list(preset_quality[quality_key][1])
|
|
styles_ps = to_list(preset_styles[styles_key][0])
|
|
styles_nps = to_list(preset_styles[styles_key][1])
|
|
|
|
prompts = list_sub(prompts, animagine_ps + pony_ps + all_styles_ps + all_quality_ps)
|
|
neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps + all_styles_nps + all_quality_nps)
|
|
|
|
last_empty_p = [""] if not prompts and type != "None" and type != "Auto" and styles_key != "None" and quality_key != "None" else []
|
|
last_empty_np = [""] if not neg_prompts and type != "None" and type != "Auto" and styles_key != "None" and quality_key != "None" else []
|
|
|
|
if type == "Animagine":
|
|
prompts = prompts + animagine_ps
|
|
neg_prompts = neg_prompts + animagine_nps
|
|
elif type == "Pony":
|
|
prompts = prompts + pony_ps
|
|
neg_prompts = neg_prompts + pony_nps
|
|
|
|
prompts = prompts + styles_ps + quality_ps
|
|
neg_prompts = neg_prompts + styles_nps + quality_nps
|
|
|
|
prompt = ", ".join(list_uniq(prompts) + last_empty_p)
|
|
neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
|
|
|
|
return gr.update(value=prompt), gr.update(value=neg_prompt)
|
|
|
|
|
|
def save_images(images: list[Image.Image], metadatas: list[str]):
|
|
from PIL import PngImagePlugin
|
|
try:
|
|
output_images = []
|
|
for image, metadata in zip(images, metadatas):
|
|
info = PngImagePlugin.PngInfo()
|
|
info.add_text("metadata", metadata)
|
|
savefile = "image.png"
|
|
image.save(savefile, "PNG", pnginfo=info)
|
|
output_images.append(str(Path(savefile).resolve()))
|
|
return output_images
|
|
except Exception as e:
|
|
print(f"Failed to save image file: {e}")
|
|
raise Exception(f"Failed to save image file:") from e
|
|
|