|
import gradio as gr
|
|
import os
|
|
from pathlib import Path
|
|
import gc
|
|
import re
|
|
import shutil
|
|
from utils import set_token, get_download_file, list_uniq
|
|
from stkey import read_safetensors_key, read_safetensors_metadata, validate_keys, write_safetensors_key
|
|
|
|
|
|
TEMP_DIR = "."
|
|
KEYS_DIR = "keys"
|
|
KEYS_FILES = [f"{KEYS_DIR}/sdxl_keys.txt"]
|
|
DEFAULT_KEYS_FILE = f"{KEYS_DIR}/sdxl_keys.txt"
|
|
|
|
|
|
def update_keys_files():
|
|
global KEYS_FILES
|
|
files = []
|
|
for file in Path(KEYS_DIR).glob("*.txt"):
|
|
files.append(str(file))
|
|
KEYS_FILES = files
|
|
|
|
|
|
update_keys_files()
|
|
|
|
|
|
def upload_keys_file(path: str):
|
|
global KEYS_FILES
|
|
newpath = str(Path(KEYS_DIR, Path(path).stem + ".txt"))
|
|
if not Path(newpath).exists(): shutil.copy(str(Path(path)), newpath)
|
|
update_keys_files()
|
|
return gr.update(choices=KEYS_FILES)
|
|
|
|
|
|
def parse_urls(s):
|
|
url_pattern = "https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+"
|
|
try:
|
|
urls = re.findall(url_pattern, s)
|
|
return list(urls)
|
|
except Exception:
|
|
return []
|
|
|
|
|
|
def to_urls(l: list[str]):
|
|
return "\n".join(l)
|
|
|
|
|
|
def uniq_urls(s):
|
|
return to_urls(list_uniq(parse_urls(s)))
|
|
|
|
|
|
def download_file(dl_url, civitai_key, progress=gr.Progress(track_tqdm=True)):
|
|
download_dir = TEMP_DIR
|
|
progress(0, desc=f"Start downloading... {dl_url}")
|
|
output_filename = get_download_file(download_dir, dl_url, civitai_key)
|
|
return output_filename
|
|
|
|
|
|
def get_stkey(filename: str, is_validate: bool=True, rfile: str=KEYS_FILES[0], progress=gr.Progress(track_tqdm=True)):
|
|
paths = []
|
|
metadata = {}
|
|
keys = []
|
|
missing = []
|
|
added = []
|
|
try:
|
|
progress(0, desc=f"Loading keys... {filename}")
|
|
keys = read_safetensors_key(filename)
|
|
if len(keys) == 0: raise Exception("No keys found.")
|
|
progress(0.5, desc=f"Checking keys... {filename}")
|
|
if write_safetensors_key(keys, str(Path(filename).stem + ".txt"), is_validate, rfile):
|
|
paths.append(str(Path(filename).stem + ".txt"))
|
|
paths.append(str(Path(filename).stem + "_missing.txt"))
|
|
paths.append(str(Path(filename).stem + "_added.txt"))
|
|
missing, added = validate_keys(keys, rfile)
|
|
metadata = read_safetensors_metadata(filename)
|
|
except Exception as e:
|
|
print(f"Error: Failed check {filename}. {e}")
|
|
gr.Warning(f"Error: Failed check {filename}. {e}")
|
|
finally:
|
|
Path(filename).unlink()
|
|
return paths, metadata, keys, missing, added
|
|
|
|
|
|
def stkey_gr(dl_url: str, civitai_key: str, hf_token: str, urls: list[str], files: list[str],
|
|
is_validate=True, rfile=KEYS_FILES[0], progress=gr.Progress(track_tqdm=True)):
|
|
if not hf_token: hf_token = os.environ.get("HF_TOKEN")
|
|
set_token(hf_token)
|
|
if not civitai_key: civitai_key = os.environ.get("CIVITAI_API_KEY")
|
|
dl_urls = parse_urls(dl_url)
|
|
if not urls: urls = []
|
|
if not files: files = []
|
|
metadata = {}
|
|
keys = []
|
|
missing = []
|
|
added = []
|
|
for u in dl_urls:
|
|
file = download_file(u, civitai_key)
|
|
if not Path(file).exists() or not Path(file).is_file(): continue
|
|
paths, metadata, keys, missing, added = get_stkey(file, is_validate, rfile)
|
|
if len(paths) != 0: files.extend(paths)
|
|
progress(1, desc="Processing...")
|
|
gc.collect()
|
|
return gr.update(value=urls, choices=urls), gr.update(value=files), gr.update(visible=False), metadata, keys, missing, added
|
|
|
|
|