|
import spaces
|
|
import os
|
|
import gradio as gr
|
|
import json
|
|
import logging
|
|
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
|
import diffusers
|
|
diffusers.utils.logging.set_verbosity(40)
|
|
import warnings
|
|
warnings.filterwarnings(action="ignore", category=FutureWarning, module="diffusers")
|
|
warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
|
|
warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
|
|
from pathlib import Path
|
|
from huggingface_hub import HfApi
|
|
from env import (HF_TOKEN, hf_read_token,
|
|
CIVITAI_API_KEY, HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2, HF_LORA_ESSENTIAL_PRIVATE_REPO,
|
|
HF_VAE_PRIVATE_REPO, directory_models, directory_loras, directory_vaes,
|
|
download_model_list, download_lora_list, download_vae_list)
|
|
from modutils import (to_list, list_uniq, list_sub, get_lora_model_list, download_private_repo,
|
|
safe_float, escape_lora_basename, to_lora_key, to_lora_path, get_local_model_list, download_things,
|
|
get_private_lora_model_lists, get_valid_lora_name, get_valid_lora_path, get_valid_lora_wt,
|
|
get_lora_info, normalize_prompt_list, get_civitai_info, search_lora_on_civitai, MODEL_TYPE_DICT)
|
|
|
|
|
|
|
|
download_model = ", ".join(download_model_list)
|
|
|
|
download_vae = ", ".join(download_vae_list)
|
|
|
|
download_lora = ", ".join(download_lora_list)
|
|
|
|
|
|
|
|
|
|
CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
|
|
hf_token = os.environ.get("HF_TOKEN")
|
|
|
|
|
|
for url in [url.strip() for url in download_model.split(',')]:
|
|
if not os.path.exists(f"./models/{url.split('/')[-1]}"):
|
|
download_things(directory_models, url, hf_token, CIVITAI_API_KEY)
|
|
for url in [url.strip() for url in download_vae.split(',')]:
|
|
if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
|
|
download_things(directory_vaes, url, hf_token, CIVITAI_API_KEY)
|
|
for url in [url.strip() for url in download_lora.split(',')]:
|
|
if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
|
|
download_things(directory_loras, url, hf_token, CIVITAI_API_KEY)
|
|
|
|
lora_model_list = get_lora_model_list()
|
|
vae_model_list = get_local_model_list(directory_vaes)
|
|
vae_model_list.insert(0, "None")
|
|
|
|
|
|
private_lora_dict = {"": ["", "", "", "", ""]}
|
|
try:
|
|
with open('lora_dict.json', encoding='utf-8') as f:
|
|
d = json.load(f)
|
|
for k, v in d.items():
|
|
private_lora_dict[escape_lora_basename(k)] = v
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
private_lora_model_list = get_private_lora_model_lists()
|
|
loras_dict = {"None": ["", "", "", "", ""], "": ["", "", "", "", ""]} | private_lora_dict.copy()
|
|
loras_url_to_path_dict = {}
|
|
civitai_lora_last_results = {}
|
|
all_lora_list = []
|
|
|
|
|
|
def get_all_lora_list():
|
|
global all_lora_list
|
|
loras = get_lora_model_list()
|
|
all_lora_list = loras.copy()
|
|
return loras
|
|
|
|
|
|
def get_all_lora_tupled_list():
|
|
global loras_dict
|
|
models = get_all_lora_list()
|
|
if not models: return []
|
|
tupled_list = []
|
|
for model in models:
|
|
|
|
basename = Path(model).stem
|
|
key = to_lora_key(model)
|
|
items = None
|
|
if key in loras_dict.keys():
|
|
items = loras_dict.get(key, None)
|
|
else:
|
|
items = get_civitai_info(model)
|
|
if items != None:
|
|
loras_dict[key] = items
|
|
name = basename
|
|
value = model
|
|
if items and items[2] != "":
|
|
if items[1] == "Pony":
|
|
name = f"{basename} (for {items[1]}🐴, {items[2]})"
|
|
else:
|
|
name = f"{basename} (for {items[1]}, {items[2]})"
|
|
tupled_list.append((name, value))
|
|
return tupled_list
|
|
|
|
|
|
def update_lora_dict(path: str):
|
|
global loras_dict
|
|
key = to_lora_key(path)
|
|
if key in loras_dict.keys(): return
|
|
items = get_civitai_info(path)
|
|
if items == None: return
|
|
loras_dict[key] = items
|
|
|
|
|
|
def download_lora(dl_urls: str):
|
|
global loras_url_to_path_dict
|
|
dl_path = ""
|
|
before = get_local_model_list(directory_loras)
|
|
urls = []
|
|
for url in [url.strip() for url in dl_urls.split(',')]:
|
|
local_path = f"{directory_loras}/{url.split('/')[-1]}"
|
|
if not Path(local_path).exists():
|
|
download_things(directory_loras, url, hf_token, CIVITAI_API_KEY)
|
|
urls.append(url)
|
|
after = get_local_model_list(directory_loras)
|
|
new_files = list_sub(after, before)
|
|
for i, file in enumerate(new_files):
|
|
path = Path(file)
|
|
if path.exists():
|
|
new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
|
|
path.resolve().rename(new_path.resolve())
|
|
loras_url_to_path_dict[urls[i]] = str(new_path)
|
|
update_lora_dict(str(new_path))
|
|
dl_path = str(new_path)
|
|
return dl_path
|
|
|
|
|
|
def copy_lora(path: str, new_path: str):
|
|
import shutil
|
|
if path == new_path: return new_path
|
|
cpath = Path(path)
|
|
npath = Path(new_path)
|
|
if cpath.exists():
|
|
try:
|
|
shutil.copy(str(cpath.resolve()), str(npath.resolve()))
|
|
except Exception:
|
|
return None
|
|
update_lora_dict(str(npath))
|
|
return new_path
|
|
else:
|
|
return None
|
|
|
|
|
|
def download_my_lora(dl_urls: str, lora):
|
|
path = download_lora(dl_urls)
|
|
if path: lora = path
|
|
choices = get_all_lora_tupled_list()
|
|
return gr.update(value=lora, choices=choices)
|
|
|
|
|
|
def apply_lora_prompt(lora_info: str):
|
|
if lora_info == "None": return ""
|
|
lora_tag = lora_info.replace("/",",")
|
|
lora_tags = lora_tag.split(",") if str(lora_info) != "None" else []
|
|
lora_prompts = normalize_prompt_list(lora_tags)
|
|
prompt = ", ".join(list_uniq(lora_prompts))
|
|
return prompt
|
|
|
|
|
|
def update_loras(prompt, lora, lora_wt):
|
|
on, label, tag, md = get_lora_info(lora)
|
|
choices = get_all_lora_tupled_list()
|
|
return gr.update(value=prompt), gr.update(value=lora, choices=choices), gr.update(value=lora_wt),\
|
|
gr.update(value=tag, label=label, visible=on), gr.update(value=md, visible=on)
|
|
|
|
|
|
def search_civitai_lora(query, base_model, sort="Highest Rated", period="AllTime", tag=""):
|
|
global civitai_lora_last_results
|
|
items = search_lora_on_civitai(query, base_model, 100, sort, period, tag)
|
|
if not items: return gr.update(choices=[("", "")], value="", visible=False),\
|
|
gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
|
|
civitai_lora_last_results = {}
|
|
choices = []
|
|
for item in items:
|
|
base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model']
|
|
name = f"{item['name']} (for {base_model_name} / By: {item['creator']} / Tags: {', '.join(item['tags'])})"
|
|
value = item['dl_url']
|
|
choices.append((name, value))
|
|
civitai_lora_last_results[value] = item
|
|
if not choices: return gr.update(choices=[("", "")], value="", visible=False),\
|
|
gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
|
|
result = civitai_lora_last_results.get(choices[0][1], "None")
|
|
md = result['md'] if result else ""
|
|
return gr.update(choices=choices, value=choices[0][1], visible=True), gr.update(value=md, visible=True),\
|
|
gr.update(visible=True), gr.update(visible=True)
|
|
|
|
|
|
def select_civitai_lora(search_result):
|
|
if not "http" in search_result: return gr.update(value=""), gr.update(value="None", visible=True)
|
|
result = civitai_lora_last_results.get(search_result, "None")
|
|
md = result['md'] if result else ""
|
|
return gr.update(value=search_result), gr.update(value=md, visible=True)
|
|
|
|
|
|
def search_civitai_lora_json(query, base_model):
|
|
results = {}
|
|
items = search_lora_on_civitai(query, base_model)
|
|
if not items: return gr.update(value=results)
|
|
for item in items:
|
|
results[item['dl_url']] = item
|
|
return gr.update(value=results)
|
|
|
|
|