import os
import json
import gradio as gr
from PIL import Image
from huggingface_hub import HfApi
from requests import HTTPError, Timeout


HF_LORA_PRIVATE_REPOS1 = ['John6666/loratest1', 'John6666/loratest3', 'John6666/loratest4', 'John6666/loratest6']
HF_LORA_PRIVATE_REPOS2 = ['John6666/loratest10', 'John6666/loratest']
HF_LORA_PRIVATE_REPOS = HF_LORA_PRIVATE_REPOS1 + HF_LORA_PRIVATE_REPOS2
HF_LORA_ESSENTIAL_PRIVATE_REPO = 'John6666/loratest1'
directory_loras = 'loras'
CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")


def get_user_agent():
    return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'


def change_interface_mode(mode: str):
    if mode == "Fast":
        return gr.update(open=False), gr.update(visible=True), gr.update(open=False), gr.update(open=False),\
        gr.update(visible=True), gr.update(open=False), gr.update(visible=True), gr.update(open=False),\
        gr.update(visible=True), gr.update(value="Fast")
    elif mode == "Simple": # t2i mode
        return gr.update(open=True), gr.update(visible=True), gr.update(open=False), gr.update(open=False),\
        gr.update(visible=True), gr.update(open=False), gr.update(visible=False), gr.update(open=True),\
        gr.update(visible=False), gr.update(value="Standard")
    elif mode == "LoRA": # t2i LoRA  mode
        return gr.update(open=True), gr.update(visible=True), gr.update(open=True), gr.update(open=False),\
        gr.update(visible=True), gr.update(open=True), gr.update(visible=True), gr.update(open=False),\
        gr.update(visible=False), gr.update(value="Standard")
    else: # Standard
        return gr.update(open=False), gr.update(visible=True), gr.update(open=False), gr.update(open=False),\
        gr.update(visible=True), gr.update(open=False), gr.update(visible=True), gr.update(open=False),\
        gr.update(visible=True), gr.update(value="Standard")


def get_model_list(directory_path):
    model_list = []
    valid_extensions = {'.ckpt', '.pt', '.pth', '.safetensors', '.bin'}

    for filename in os.listdir(directory_path):
        if os.path.splitext(filename)[1] in valid_extensions:
            name_without_extension = os.path.splitext(filename)[0]
            file_path = os.path.join(directory_path, filename)
            # model_list.append((name_without_extension, file_path))
            model_list.append(file_path)
#            print('\033[34mFILE: ' + file_path + '\033[0m')
    return model_list


def list_uniq(l):
        return sorted(set(l), key=l.index)


def list_sub(a, b):
    return [e for e in a if e not in b]


def normalize_prompt_list(tags):
    prompts = []
    for tag in tags:
        tag = str(tag).strip()
        if tag:
            prompts.append(tag)
    return prompts


def escape_lora_basename(basename: str):
    return basename.replace(".", "_").replace(" ", "_").replace(",", "")


def download_private_repo(repo_id, dir_path, is_replace):
    from huggingface_hub import snapshot_download
    hf_read_token = os.environ.get('HF_READ_TOKEN')
    if not hf_read_token: return
    try:
        snapshot_download(repo_id=repo_id, local_dir=dir_path, allow_patterns=['*.ckpt', '*.pt', '*.pth', '*.safetensors', '*.bin'], use_auth_token=hf_read_token)
    except Exception as e:
        return
    else:
        if is_replace:
            from pathlib import Path
            for file in Path(dir_path).glob("*"):
                if file.exists() and "." in file.stem or " " in file.stem and file.suffix in ['.ckpt', '.pt', '.pth', '.safetensors', '.bin']:
                    newpath = Path(f'{file.parent.name}/{escape_lora_basename(file.stem)}{file.suffix}')
                    file.resolve().rename(newpath.resolve())


private_model_path_repo_dict = {}


def get_private_model_list(repo_id, dir_path):
    global private_model_path_repo_dict
    api = HfApi()
    hf_read_token = os.environ.get('HF_READ_TOKEN')
    if not hf_read_token: return []
    try:
        files = api.list_repo_files(repo_id, token=hf_read_token)
    except Exception as e:
        return []
    else:
        model_list = []
        for file in files:
            from pathlib import Path
            path = Path(f"{dir_path}/{file}")
            if path.suffix in ['.ckpt', '.pt', '.pth', '.safetensors', '.bin']:
                model_list.append(str(path))
        for model in model_list:
            private_model_path_repo_dict[model] = repo_id
        return model_list


def get_private_lora_model_lists():
    models1 = []
    models2 = []
    for repo in HF_LORA_PRIVATE_REPOS1:
        models1.extend(get_private_model_list(repo, directory_loras))
    for repo in HF_LORA_PRIVATE_REPOS2:
        models2.extend(get_private_model_list(repo, directory_loras))
    models = list_uniq(models1 + sorted(models2))
    return models


def download_private_file(repo_id, path, is_replace):
    from huggingface_hub import hf_hub_download
    from pathlib import Path
    file = Path(path)
    newpath = Path(f'{file.parent.name}/{escape_lora_basename(file.stem)}{file.suffix}') if is_replace else file
    hf_read_token = os.environ.get('HF_READ_TOKEN')
    if not hf_read_token or newpath.exists(): return
    filename = file.name
    dirname = file.parent.name
    try:
        hf_hub_download(repo_id=repo_id, filename=filename, local_dir=dirname, use_auth_token=hf_read_token)
    except Exception as e:
        return
    else:
        if is_replace:
            file.resolve().rename(newpath.resolve())


def download_private_file_from_somewhere(path, is_replace):
    if not path in private_model_path_repo_dict.keys(): return
    repo_id = private_model_path_repo_dict.get(path, None)
    download_private_file(repo_id, path, is_replace)


def get_model_id_list():
    api = HfApi()
    model_ids = []
    try:
        models_vp = api.list_models(author="votepurchase", cardData=True, sort="likes")
        models_john = api.list_models(author="John6666", cardData=True, sort="last_modified")
    except Exception as e:
        return model_ids
    else:
        for model in models_vp:
            model_ids.append(model.id) if not model.private else ""
        anime_models = []
        real_models = []
        for model in models_john:
            if not model.private:
                anime_models.append(model.id) if 'anime' in model.tags else real_models.append(model.id)
        model_ids.extend(anime_models)
        model_ids.extend(real_models)
        return model_ids


def get_t2i_model_info(repo_id: str):
    api = HfApi()
    try:
        if " " in repo_id or not api.repo_exists(repo_id): return ""
        model = api.model_info(repo_id=repo_id)
    except (EnvironmentError, OSError, ValueError, HTTPError, Timeout) as e:
        return ""
    else:
        if model.private or model.gated: return ""
        tags = model.tags
        info = []
        url = f"https://huggingface.co/{repo_id}/"
        if not 'diffusers' in tags: return ""
        if 'diffusers:StableDiffusionXLPipeline' in tags:
            info.append("SDXL")
        elif 'diffusers:StableDiffusionPipeline' in tags:
            info.append("SD1.5")
        if model.card_data and model.card_data.tags:
            info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
        info.append(f"DLs: {model.downloads}")
        info.append(f"likes: {model.likes}")
        info.append(model.last_modified.strftime("lastmod: %Y-%m-%d"))
        md = f"Model Info: {', '.join(info)}, [Model Repo]({url})"
        return gr.update(value=md)


def get_tupled_model_list(model_list):
    if not model_list: return []
    tupled_list = []
    for repo_id in model_list:
        api = HfApi()
        try:
            if not api.repo_exists(repo_id): continue
            model = api.model_info(repo_id=repo_id)
        except Exception as e:
            continue
        else:
            if model.private or model.gated: continue
            tags = model.tags
            info = []
            if not 'diffusers' in tags: continue
            if 'diffusers:StableDiffusionXLPipeline' in tags:
                info.append("SDXL")
            elif 'diffusers:StableDiffusionPipeline' in tags:
                info.append("SD1.5")
            if model.card_data and model.card_data.tags:
                info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
            if "pony" in info:
                info.remove("pony")
                name = f"{repo_id} (Pony🐴, {', '.join(info)})"
            else:
                name = f"{repo_id} ({', '.join(info)})"
            tupled_list.append((name, repo_id))
    return tupled_list


def save_gallery_images(images):
    from datetime import datetime, timezone, timedelta
    japan_tz = timezone(timedelta(hours=9))
    dt_now = datetime.utcnow().replace(tzinfo=timezone.utc).astimezone(japan_tz)
    basename = dt_now.strftime('%Y%m%d_%H%M%S_')
    i = 1
    if not images: return images
    output_images = []
    output_paths = []
    for image in images:
        from pathlib import Path
        filename = basename + str(i) + ".png"
        oldpath = Path(image[0])
        newpath = oldpath.resolve().rename(Path(filename).resolve())
        output_paths.append(str(newpath))
        output_images.append((str(newpath), str(filename)))
        i += 1
    return gr.update(value=output_images), gr.update(value=output_paths), gr.update(visible=True),


optimization_list = {
    "None": [28, 7., 'Euler a', False, 'None', 1.],
    "Default": [28, 7., 'Euler a', False, 'None', 1.],
    "SPO": [28, 7., 'Euler a', True, 'loras/spo_sdxl_10ep_4k-data_lora_diffusers.safetensors', 1.],
    "DPO": [28, 7., 'Euler a', True, 'loras/sdxl-DPO-LoRA.safetensors', 1.],
    "DPO Turbo": [8, 2.5, 'LCM', True, 'loras/sd_xl_dpo_turbo_lora_v1-128dim.safetensors', 1.],
    "SDXL Turbo": [8, 2.5, 'LCM', True, 'loras/sd_xl_turbo_lora_v1.safetensors', 1.],
    "Hyper-SDXL 12step": [12, 5., 'TCD', True, 'loras/Hyper-SDXL-12steps-CFG-lora.safetensors', 1.],
    "Hyper-SDXL 8step": [8, 5., 'TCD', True, 'loras/Hyper-SDXL-8steps-CFG-lora.safetensors', 1.],
    "Hyper-SDXL 4step": [4, 0, 'TCD', True, 'loras/Hyper-SDXL-4steps-lora.safetensors', 1.],
    "Hyper-SDXL 2step": [2, 0, 'TCD', True, 'loras/Hyper-SDXL-2steps-lora.safetensors', 1.],
    "Hyper-SDXL 1step": [1, 0, 'TCD', True, 'loras/Hyper-SDXL-1steps-lora.safetensors', 1.],
    "PCM 16step": [16, 4., 'Euler a trailing', True, 'loras/pcm_sdxl_normalcfg_16step_converted.safetensors', 1.],
    "PCM 8step": [8, 4., 'Euler a trailing', True, 'loras/pcm_sdxl_normalcfg_8step_converted.safetensors', 1.],
    "PCM 4step": [4, 2., 'Euler a trailing', True, 'loras/pcm_sdxl_smallcfg_4step_converted.safetensors', 1.],
    "PCM 2step": [2, 1., 'Euler a trailing', True, 'loras/pcm_sdxl_smallcfg_2step_converted.safetensors', 1.],
}


def set_optimization(opt, steps_gui, cfg_gui, sampler_gui, clip_skip_gui, lora1_gui, lora_scale_1_gui):
    if not opt in list(optimization_list.keys()): opt = "None"
    def_steps_gui = 28
    def_cfg_gui = 7.
    steps = optimization_list.get(opt, "None")[0]
    cfg = optimization_list.get(opt, "None")[1]
    sampler = optimization_list.get(opt, "None")[2]
    clip_skip = optimization_list.get(opt, "None")[3]
    lora1 = optimization_list.get(opt, "None")[4]
    lora_scale_1 = optimization_list.get(opt, "None")[5]
    if opt == "None":
        steps = max(steps_gui, def_steps_gui)
        cfg = max(cfg_gui, def_cfg_gui)
        clip_skip = clip_skip_gui
    elif opt == "SPO" or opt == "DPO":
        steps = max(steps_gui, def_steps_gui)
        cfg = max(cfg_gui, def_cfg_gui)

    return gr.update(value=steps), gr.update(value=cfg), gr.update(value=sampler),\
          gr.update(value=clip_skip), gr.update(value=lora1), gr.update(value=lora_scale_1),


def set_lora_prompt(prompt_gui, prompt_syntax_gui, lora1_gui, lora_scale_1_gui, lora2_gui, lora_scale_2_gui,\
                     lora3_gui, lora_scale_3_gui, lora4_gui, lora_scale_4_gui, lora5_gui, lora_scale_5_gui):
    import os
    if not "Classic" in str(prompt_syntax_gui): return prompt_gui
    loras = []
    if lora1_gui and lora1_gui != "None":
        basename = os.path.splitext(os.path.basename(lora1_gui))[0]
        loras.append(f"<lora:{basename}:{lora_scale_1_gui:.2f}>")
    if lora2_gui and lora2_gui != "None":
        basename = os.path.splitext(os.path.basename(lora2_gui))[0]
        loras.append(f"<lora:{basename}:{lora_scale_2_gui:.2f}>")
    if lora3_gui and lora3_gui != "None":
        basename = os.path.splitext(os.path.basename(lora3_gui))[0]
        loras.append(f"<lora:{basename}:{lora_scale_3_gui:.2f}>")
    if lora4_gui and lora4_gui != "None":
        basename = os.path.splitext(os.path.basename(lora4_gui))[0]
        loras.append(f"<lora:{basename}:{lora_scale_4_gui:.2f}>")
    if lora5_gui and lora5_gui != "None":
        basename = os.path.splitext(os.path.basename(lora5_gui))[0]
        loras.append(f"<lora:{basename}:{lora_scale_5_gui:.2f}>")
    tags = prompt_gui.split(",") if prompt_gui else []
    prompts = []
    for tag in tags:
        tag = str(tag).strip()
        if tag and not "<lora" in tag:
            prompts.append(tag)
    empty = [""]
    prompt = ", ".join(prompts + loras + empty)
    return gr.update(value=prompt)


temp_dict = {}
lora_trigger_dict = {}
with open('lora_dict.json', encoding='utf-8') as f:
    temp_dict = json.load(f)
for k, v in temp_dict.items():
    lora_trigger_dict[escape_lora_basename(k)] = v


civitai_not_exists_list = []


def get_civitai_info(path):
    global civitai_not_exists_list
    import requests
    from urllib3.util import Retry
    from requests.adapters import HTTPAdapter
    if path in set(civitai_not_exists_list): return ["", "", "", "", ""]
    from pathlib import Path
    if not Path(path).exists(): return None
    user_agent = get_user_agent()
    headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
    base_url = 'https://civitai.com/api/v1/model-versions/by-hash/'
    params = {}
    session = requests.Session()
    retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
    session.mount("https://", HTTPAdapter(max_retries=retries))
    import hashlib
    with open(path, 'rb') as file:
        file_data = file.read()
    hash_sha256 = hashlib.sha256(file_data).hexdigest()
    url = base_url + hash_sha256
    try:
        r = session.get(url, params=params, headers=headers, stream=True, timeout=(3.0, 15))
    except (HTTPError, Timeout) as e:
        return ["", "", "", "", ""]
    else:
        if not r.ok: return None
        json = r.json()
        if not 'baseModel' in json:
            civitai_not_exists_list.append(path)
            return ["", "", "", "", ""]
        items = []
        items.append(" / ".join(json['trainedWords']))
        items.append(json['baseModel'])
        items.append(json['model']['name'])
        items.append(f"https://civitai.com/models/{json['modelId']}")
        items.append(json['images'][0]['url'])
        return items


def update_lora_dict(path):
    global lora_trigger_dict
    from pathlib import Path
    key = escape_lora_basename(Path(path).stem)
    if key in lora_trigger_dict.keys(): return
    items = get_civitai_info(path)
    if items == None: return
    lora_trigger_dict[key] = items


def get_lora_tupled_list(lora_model_list):
    global lora_trigger_dict
    from pathlib import Path
    if not lora_model_list: return []
    tupled_list = []
    local_models = set(get_model_list(directory_loras))
    for model in lora_model_list:
        if not model: continue
        basename = Path(model).stem
        key = escape_lora_basename(basename)
        items = None
        if key in lora_trigger_dict.keys():
            items = lora_trigger_dict.get(key, None)
        elif model in local_models:
            items = get_civitai_info(model)
            if items != None:
                lora_trigger_dict[key] = items
        name = basename
        value = model
        if items and items[2] != "":
            if items[1] == "Pony":
                name = f"{basename} (for {items[1]}🐴, {items[2]})"
            else:
                name = f"{basename} (for {items[1]}, {items[2]})"
        tupled_list.append((name, value))
    return tupled_list


def set_lora_trigger(lora_gui: str):
    from pathlib import Path
    if not lora_gui or lora_gui == "None": return gr.update(value="", visible=False), gr.update(visible=False),\
          gr.update(value="", visible=False), gr.update(value="None", visible=True)
    path = Path(lora_gui)
    new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
    if not new_path.stem in lora_trigger_dict.keys() and not str(path) in set(get_private_lora_model_lists() + get_model_list(directory_loras)):
        return gr.update(value="", visible=False), gr.update(visible=False),\
              gr.update(value="", visible=False), gr.update(value="", visible=True)
    if not new_path.exists():
        download_private_file_from_somewhere(str(path), True)
    basename = new_path.stem
    tag = ""
    label = f'Trigger: {basename}  /  Prompt:'
    value = "None"
    md = "None"
    flag = False
    items = lora_trigger_dict.get(basename, None)
    if items == None:
        items = get_civitai_info(str(new_path))
        if items != None:
            lora_trigger_dict[basename] = items
            flag = True
    if items and items[2] != "":
        tag = items[0]
        label = f'Trigger: {basename}  /  Prompt:'
        if items[1] == "Pony":
            label = f'Trigger: {basename}  /  Prompt (for Pony🐴):'
        if items[4]:
            md = f'<img src="{items[4]}" alt="thumbnail" width="150" height="240"><br>[LoRA Model URL]({items[3]})'
        elif items[3]:
            md = f'[LoRA Model URL]({items[3]})'
    if tag and flag:
        new_lora_model_list = list_uniq(get_private_lora_model_lists() + get_model_list(directory_loras))
        return gr.update(value=tag, label=label, visible=True), gr.update(visible=True),\
              gr.update(value=md, visible=True), gr.update(value=str(new_path), choices=get_lora_tupled_list(new_lora_model_list))
    elif tag:
        return gr.update(value=tag, label=label, visible=True), gr.update(visible=True),\
              gr.update(value=md, visible=True), gr.update(value=str(new_path))
    else:
        return gr.update(value=value, label=label, visible=True), gr.update(visible=True),\
              gr.update(value=md, visible=True), gr.update(visible=True)


def apply_lora_prompt(prompt_gui: str, lora_trigger_gui: str):
    if lora_trigger_gui == "None": return gr.update(value=prompt_gui)
    tags = prompt_gui.split(",") if prompt_gui else []
    prompts = normalize_prompt_list(tags)

    lora_tag = lora_trigger_gui.replace("/",",")
    lora_tags = lora_tag.split(",") if str(lora_trigger_gui) != "None" else []
    lora_prompts = normalize_prompt_list(lora_tags)
 
    empty = [""]
    prompt = ", ".join(list_uniq(prompts + lora_prompts) + empty)
    return gr.update(value=prompt)


def upload_file_lora(files):
    file_paths = [file.name for file in files]
    return gr.update(value=file_paths, visible=True), gr.update(visible=True)


def move_file_lora(filepaths):
    import shutil
    from pathlib import Path
    for file in filepaths:
        path = Path(shutil.move(Path(file).resolve(), Path(f"./{directory_loras}").resolve()))
        newpath = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
        path.resolve().rename(newpath.resolve())
        update_lora_dict(str(newpath))

    new_lora_model_list = list_uniq(get_private_lora_model_lists() + get_model_list(directory_loras))
    new_lora_model_list.insert(0, "None")
    
    return gr.update(
        choices=get_lora_tupled_list(new_lora_model_list), value=new_lora_model_list[-1]
    ), gr.update(
        choices=get_lora_tupled_list(new_lora_model_list)
    ), gr.update(
        choices=get_lora_tupled_list(new_lora_model_list)
    ), gr.update(
        choices=get_lora_tupled_list(new_lora_model_list)
    ), gr.update(
        choices=get_lora_tupled_list(new_lora_model_list)
    ),


def search_lora_on_civitai(query: str, allow_model: list[str]):
    import requests
    from urllib3.util import Retry
    from requests.adapters import HTTPAdapter
    if not query: return None
    user_agent = get_user_agent()
    headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
    base_url = 'https://civitai.com/api/v1/models'
    params = {'query': query, 'types': ['LORA'], 'sort': 'Highest Rated', 'period': 'AllTime',
              'nsfw': 'true', 'supportsGeneration ': 'true'}
    session = requests.Session()
    retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
    session.mount("https://", HTTPAdapter(max_retries=retries))
    try:
        r = session.get(base_url, params=params, headers=headers, stream=True, timeout=(3.0, 30))
    except (HTTPError, Timeout) as e:
        return None
    else:
        if not r.ok: return None
        json = r.json()
        if not 'items' in json: return None
        items = []
        for j in json['items']:
            for model in j['modelVersions']:
                item = {}
                if not model['baseModel'] in set(allow_model): continue
                item['name'] = j['name']
                item['creator'] = j['creator']['username']
                item['tags'] = j['tags']
                item['model_name'] = model['name']
                item['base_model'] = model['baseModel']
                item['dl_url'] = model['downloadUrl']
                item['md'] = f'<img src="{model["images"][0]["url"]}" alt="thumbnail" width="150" height="240"><br>[LoRA Model URL](https://civitai.com/models/{j["id"]})'
                items.append(item)
        return items


civitai_lora_last_results = {}


def search_civitai_lora(query, base_model):
    global civitai_lora_last_results
    items = search_lora_on_civitai(query, base_model)
    if not items: return gr.update(choices=[("", "")], value="", visible=False),\
          gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
    civitai_lora_last_results = {}
    choices = []
    for item in items:
        base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model']
        name = f"{item['name']} (for {base_model_name} / By: {item['creator']} / Tags: {', '.join(item['tags'])})"
        value = item['dl_url']
        choices.append((name, value))
        civitai_lora_last_results[value] = item['md']
    if not choices: return gr.update(choices=[("", "")], value="", visible=False),\
          gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
    md = civitai_lora_last_results.get(choices[0][1], "None")
    return gr.update(choices=choices, value=choices[0][1], visible=True), gr.update(value=md, visible=True),\
          gr.update(visible=True), gr.update(visible=True)


def select_civitai_lora(search_result):
    if not "http" in search_result: return gr.update(value=""), gr.update(value="None", visible=True)
    md = civitai_lora_last_results.get(search_result, "None")
    return gr.update(value=search_result), gr.update(value=md, visible=True)


quality_prompt_list = [
    {
        "name": "None",
        "prompt": "",
        "negative_prompt": "lowres",
    },
    {
        "name": "Animagine Common",
        "prompt": "anime artwork, anime style, key visual, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
        "negative_prompt": "lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
    },
    {
        "name": "Pony Anime Common",
        "prompt": "source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres",
        "negative_prompt": "source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends",
    },
    {
        "name": "Pony Common",
        "prompt": "source_anime, score_9, score_8_up, score_7_up",
        "negative_prompt": "source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends",
    },
    {
        "name": "Animagine Standard v3.0",
        "prompt": "masterpiece, best quality",
        "negative_prompt": "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name",
    },
    {
        "name": "Animagine Standard v3.1",
        "prompt": "masterpiece, best quality, very aesthetic, absurdres",
        "negative_prompt": "lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
    },
    {
        "name": "Animagine Light v3.1",
        "prompt": "(masterpiece), best quality, very aesthetic, perfect face",
        "negative_prompt": "(low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
    },
    {
        "name": "Animagine Heavy v3.1",
        "prompt": "(masterpiece), (best quality), (ultra-detailed), very aesthetic, illustration, disheveled hair, perfect composition, moist skin, intricate details",
        "negative_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair, extra digit, fewer digits, cropped, worst quality, low quality, very displeasing",
    },
]


style_list = [
    {
        "name": "None",
        "prompt": "",
        "negative_prompt": "",
    },
    {
        "name": "Cinematic",
        "prompt": "cinematic still, emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
        "negative_prompt": "cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
    },
    {
        "name": "Photographic",
        "prompt": "cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed",
        "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
    },
    {
        "name": "Anime",
        "prompt": "anime artwork, anime style, key visual, vibrant, studio anime, highly detailed",
        "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
    },
    {
        "name": "Manga",
        "prompt": "manga style, vibrant, high-energy, detailed, iconic, Japanese comic style",
        "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
    },
    {
        "name": "Digital Art",
        "prompt": "concept art, digital artwork, illustrative, painterly, matte painting, highly detailed",
        "negative_prompt": "photo, photorealistic, realism, ugly",
    },
    {
        "name": "Pixel art",
        "prompt": "pixel-art, low-res, blocky, pixel art style, 8-bit graphics",
        "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
    },
    {
        "name": "Fantasy art",
        "prompt": "ethereal fantasy concept art, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
        "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
    },
    {
        "name": "Neonpunk",
        "prompt": "neonpunk style, cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
        "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
    },
    {
        "name": "3D Model",
        "prompt": "professional 3d model, octane render, highly detailed, volumetric, dramatic lighting",
        "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
    },
]


# [sampler_gui, steps_gui, cfg_gui, clip_skip_gui, img_width_gui, img_height_gui, optimization_gui]
preset_sampler_setting = {
    "None": ["Euler a", 28, 7., True, 1024, 1024, "None"],
    "Anime 3:4 Fast": ["LCM", 8, 2.5, True, 896, 1152, "DPO Turbo"],
    "Anime 3:4 Standard": ["Euler a", 28, 7., True, 896, 1152, "None"],
    "Anime 3:4 Heavy": ["Euler a", 40, 7., True, 896, 1152, "None"],
    "Anime 1:1 Fast": ["LCM", 8, 2.5, True, 1024, 1024, "DPO Turbo"],
    "Anime 1:1 Standard": ["Euler a", 28, 7., True, 1024, 1024, "None"],
    "Anime 1:1 Heavy": ["Euler a", 40, 7., True, 1024, 1024, "None"],
    "Photo 3:4 Fast": ["LCM", 8, 2.5, False, 896, 1152, "DPO Turbo"],
    "Photo 3:4 Standard": ["DPM++ 2M Karras", 28, 7., False, 896, 1152, "None"],
    "Photo 3:4 Heavy": ["DPM++ 2M Karras", 40, 7., False, 896, 1152, "None"],
    "Photo 1:1 Fast": ["LCM", 8, 2.5, False, 1024, 1024, "DPO Turbo"],
    "Photo 1:1 Standard": ["DPM++ 2M Karras", 28, 7., False, 1024, 1024, "None"],
    "Photo 1:1 Heavy": ["DPM++ 2M Karras", 40, 7., False, 1024, 1024, "None"],
}


def set_sampler_settings(sampler_setting):
    if not sampler_setting in list(preset_sampler_setting.keys()) or sampler_setting == "None":
        return gr.update(value="Euler a"), gr.update(value=28), gr.update(value=7.), gr.update(value=True),\
              gr.update(value=1024), gr.update(value=1024), gr.update(value="None")
    v = preset_sampler_setting.get(sampler_setting, ["Euler a", 28, 7., True, 1024, 1024])
    # sampler, steps, cfg, clip_skip, width, height, optimization
    return gr.update(value=v[0]), gr.update(value=v[1]), gr.update(value=v[2]), gr.update(value=v[3]),\
          gr.update(value=v[4]), gr.update(value=v[5]), gr.update(value=v[6])


preset_styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
preset_quality = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in quality_prompt_list}


def process_style_prompt(prompt: str, neg_prompt: str, styles_key: str = "None", quality_key: str = "None", type: str = "None"):
    def to_list(s):
        return [x.strip() for x in s.split(",") if not s == ""]
    
    def list_sub(a, b):
        return [e for e in a if e not in b]
    
    def list_uniq(l):
        return sorted(set(l), key=l.index)

    animagine_ps = to_list("anime artwork, anime style, key visual, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres")
    animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
    pony_ps = to_list("source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
    pony_nps = to_list("source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
    prompts = to_list(prompt)
    neg_prompts = to_list(neg_prompt)

    all_styles_ps = []
    all_styles_nps = []
    for d in style_list:
        all_styles_ps.extend(to_list(str(d.get("prompt", ""))))
        all_styles_nps.extend(to_list(str(d.get("negative_prompt", ""))))

    all_quality_ps = []
    all_quality_nps = []
    for d in quality_prompt_list:
        all_quality_ps.extend(to_list(str(d.get("prompt", ""))))
        all_quality_nps.extend(to_list(str(d.get("negative_prompt", ""))))

    quality_ps = to_list(preset_quality[quality_key][0])
    quality_nps = to_list(preset_quality[quality_key][1])
    styles_ps = to_list(preset_styles[styles_key][0])
    styles_nps = to_list(preset_styles[styles_key][1])

    prompts = list_sub(prompts, animagine_ps + pony_ps + all_styles_ps + all_quality_ps)
    neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps + all_styles_nps + all_quality_nps)

    last_empty_p = [""] if not prompts and type != "None" and styles_key != "None" and quality_key != "None" else []
    last_empty_np = [""] if not neg_prompts and type != "None" and styles_key != "None" and quality_key != "None" else []

    if type == "Animagine":
        prompts = prompts + animagine_ps
        neg_prompts = neg_prompts + animagine_nps
    elif type == "Pony":
        prompts = prompts + pony_ps
        neg_prompts = neg_prompts + pony_nps

    prompts = prompts + styles_ps + quality_ps
    neg_prompts = neg_prompts + styles_nps + quality_nps

    prompt = ", ".join(list_uniq(prompts) + last_empty_p)
    neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)

    return prompt, neg_prompt


def set_quick_presets(genre:str = "None", type:str = "None", speed:str = "None", aspect:str = "None"):
    quality = "None"
    style = "None"
    sampler = "None"
    opt = "None"

    if genre == "Anime":
        style = "Anime"
        if aspect == "1:1":
            if speed == "Heavy":
                sampler = "Anime 1:1 Heavy"
            elif speed == "Fast":
                sampler = "Anime 1:1 Fast"
            else:
                sampler = "Anime 1:1 Standard"
        elif aspect == "3:4":
            if speed == "Heavy":
                sampler = "Anime 3:4 Heavy"
            elif speed == "Fast":
                sampler = "Anime 3:4 Fast"
            else:
                sampler = "Anime 3:4 Standard"
        if type == "Pony":
            quality = "Pony Anime Common"
        else:
            quality = "Animagine Common"
    elif genre == "Photo":
        style = "Photographic"
        if aspect == "1:1":
            if speed == "Heavy":
                sampler = "Photo 1:1 Heavy"
            elif speed == "Fast":
                sampler = "Photo 1:1 Fast"
            else:
                sampler = "Photo 1:1 Standard"
        elif aspect == "3:4":
            if speed == "Heavy":
                sampler = "Photo 3:4 Heavy"
            elif speed == "Fast":
                sampler = "Photo 3:4 Fast"
            else:
                sampler = "Photo 3:4 Standard"
        if type == "Pony":
            quality = "Pony Common"
        else:
            quality = "None"

    if speed == "Fast":
        opt = "DPO Turbo"
        if genre == "Anime" and type != "Pony": quality = "Animagine Light v3.1"

    return gr.update(value=quality), gr.update(value=style), gr.update(value=sampler), gr.update(value=opt)


textual_inversion_dict = {}
with open('textual_inversion_dict.json', encoding='utf-8') as f:
    textual_inversion_dict = json.load(f)


textual_inversion_file_token_list = []


def get_tupled_embed_list(embed_list):
    from pathlib import Path
    global textual_inversion_file_list
    tupled_list = []
    for file in embed_list:
        token = textual_inversion_dict.get(Path(file).name, [Path(file).stem.replace(",",""), False])[0]
        tupled_list.append((token, file))
        textual_inversion_file_token_list.append(token)
    return tupled_list


def set_textual_inversion_prompt(textual_inversion_gui, prompt_gui, neg_prompt_gui, prompt_syntax_gui):
    ti_tags = list(textual_inversion_dict.values()) + textual_inversion_file_token_list
    tags = prompt_gui.split(",") if prompt_gui else []
    prompts = []
    for tag in tags:
        tag = str(tag).strip()
        if tag and not tag in ti_tags:
            prompts.append(tag)

    ntags = neg_prompt_gui.split(",") if neg_prompt_gui else []
    neg_prompts = []
    for tag in ntags:
        tag = str(tag).strip()
        if tag and not tag in ti_tags:
            neg_prompts.append(tag)

    ti_prompts = []
    ti_neg_prompts = []
    for ti in textual_inversion_gui:
        from pathlib import Path
        tokens = textual_inversion_dict.get(Path(ti).name, [Path(ti).stem.replace(",",""), False])
        is_positive = tokens[1] == True or "positive" in Path(ti).parent.name
        if is_positive: # positive prompt
            ti_prompts.append(tokens[0])
        else: # negative prompt (default)
            ti_neg_prompts.append(tokens[0])
 
    empty = [""]
    prompt = ", ".join(prompts + ti_prompts + empty)
    neg_prompt = ", ".join(neg_prompts + ti_neg_prompts + empty)

    return gr.update(value=prompt), gr.update(value=neg_prompt),


def get_model_pipeline(repo_id: str):
    from huggingface_hub import HfApi
    api = HfApi()
    default = "StableDiffusionPipeline"
    try:
        if " " in repo_id or not api.repo_exists(repo_id): return default
        model = api.model_info(repo_id=repo_id)
    except Exception as e:
        return default
    else:
        if model.private or model.gated: return default
        tags = model.tags
        if not 'diffusers' in tags: return default
        if 'diffusers:StableDiffusionXLPipeline' in tags:
            return "StableDiffusionXLPipeline"
        elif 'diffusers:StableDiffusionPipeline' in tags:
            return "StableDiffusionPipeline"
        else:
            return default


def load_model_prompt_dict():
    import json
    dict = {}
    with open('model_dict.json', encoding='utf-8') as f:
        dict = json.load(f)
    return dict


model_prompt_dict = load_model_prompt_dict()


def insert_model_recom_prompt(prompt: str = "", neg_prompt: str = "", model_name: str = "None"):
    def to_list(s):
        return [x.strip() for x in s.split(",") if not s == ""]
    
    def list_sub(a, b):
        return [e for e in a if e not in b]
    
    def list_uniq(l):
        return sorted(set(l), key=l.index)

    if not model_name or not model_name in model_prompt_dict.keys(): return prompt, neg_prompt
    animagine_ps = to_list("anime artwork, anime style, key visual, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres")
    animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
    pony_ps = to_list("source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
    pony_nps = to_list("source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
    other_ps = to_list("anime artwork, anime style, key visual, vibrant, studio anime, highly detailed, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed")
    other_nps = to_list("photo, deformed, black and white, realism, disfigured, low contrast, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly")
    prompts = to_list(prompt)
    neg_prompts = to_list(neg_prompt)
    prompts = list_sub(prompts, animagine_ps + pony_ps + other_ps)
    neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps + other_nps)
    last_empty_p = [""] if not prompts and type != "None" else []
    last_empty_np = [""] if not neg_prompts and type != "None" else []
    ps = to_list(model_prompt_dict[model_name]["prompt"])
    nps = to_list(model_prompt_dict[model_name]["negative_prompt"])
    prompts = prompts + ps
    neg_prompts = neg_prompts + nps
    prompt = ", ".join(list_uniq(prompts) + last_empty_p)
    neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
    return prompt, neg_prompt