John6666 commited on
Commit
35b1cf8
·
verified ·
1 Parent(s): 2e4c5d9

Upload 12 files

Browse files
Files changed (5) hide show
  1. app.py +8 -4
  2. env.py +3 -1
  3. flux.py +17 -16
  4. mod.py +1 -1
  5. modutils.py +93 -42
app.py CHANGED
@@ -505,15 +505,19 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css, delete_cache
505
  lora_md[i] = gr.Markdown(value="", visible=False)
506
  lora_num[i] = gr.Number(i, visible=False)
507
  with gr.Accordion("From URL", open=True, visible=True):
 
 
 
 
508
  with gr.Row():
509
  lora_search_civitai_query = gr.Textbox(label="Query", placeholder="flux", lines=1)
 
510
  lora_search_civitai_submit = gr.Button("Search on Civitai")
511
- lora_search_civitai_basemodel = gr.CheckboxGroup(label="Search LoRA for", choices=["Flux.1 D", "Flux.1 S"], value=["Flux.1 D", "Flux.1 S"])
512
  with gr.Row():
513
  lora_search_civitai_json = gr.JSON(value={}, visible=False)
514
  lora_search_civitai_desc = gr.Markdown(value="", visible=False)
515
  lora_search_civitai_result = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
516
- lora_download_url = gr.Textbox(label="URL", placeholder="http://...my_lora_url.safetensors", lines=1)
517
  with gr.Row():
518
  lora_download = [None] * num_loras
519
  for i in range(num_loras):
@@ -591,9 +595,9 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css, delete_cache
591
  prompt_enhance.click(enhance_prompt, [prompt], [prompt], queue=False, show_api=False)
592
 
593
  gr.on(
594
- triggers=[lora_search_civitai_submit.click, lora_search_civitai_query.submit],
595
  fn=search_civitai_lora,
596
- inputs=[lora_search_civitai_query, lora_search_civitai_basemodel],
597
  outputs=[lora_search_civitai_result, lora_search_civitai_desc, lora_search_civitai_submit, lora_search_civitai_query],
598
  scroll_to_output=True,
599
  queue=True,
 
505
  lora_md[i] = gr.Markdown(value="", visible=False)
506
  lora_num[i] = gr.Number(i, visible=False)
507
  with gr.Accordion("From URL", open=True, visible=True):
508
+ with gr.Row():
509
+ lora_search_civitai_basemodel = gr.CheckboxGroup(label="Search LoRA for", choices=["Flux.1 D", "Flux.1 S"], value=["Flux.1 D", "Flux.1 S"])
510
+ lora_search_civitai_sort = gr.Radio(label="Sort", choices=["Highest Rated", "Most Downloaded", "Newest"], value="Highest Rated")
511
+ lora_search_civitai_period = gr.Radio(label="Period", choices=["AllTime", "Year", "Month", "Week", "Day"], value="AllTime")
512
  with gr.Row():
513
  lora_search_civitai_query = gr.Textbox(label="Query", placeholder="flux", lines=1)
514
+ lora_search_civitai_tag = gr.Textbox(label="Tag", lines=1)
515
  lora_search_civitai_submit = gr.Button("Search on Civitai")
 
516
  with gr.Row():
517
  lora_search_civitai_json = gr.JSON(value={}, visible=False)
518
  lora_search_civitai_desc = gr.Markdown(value="", visible=False)
519
  lora_search_civitai_result = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
520
+ lora_download_url = gr.Textbox(label="LoRA URL", placeholder="https://civitai.com/api/download/models/28907", lines=1)
521
  with gr.Row():
522
  lora_download = [None] * num_loras
523
  for i in range(num_loras):
 
595
  prompt_enhance.click(enhance_prompt, [prompt], [prompt], queue=False, show_api=False)
596
 
597
  gr.on(
598
+ triggers=[lora_search_civitai_submit.click, lora_search_civitai_query.submit, lora_search_civitai_tag.submit],
599
  fn=search_civitai_lora,
600
+ inputs=[lora_search_civitai_query, lora_search_civitai_basemodel, lora_search_civitai_sort, lora_search_civitai_period, lora_search_civitai_tag],
601
  outputs=[lora_search_civitai_result, lora_search_civitai_desc, lora_search_civitai_submit, lora_search_civitai_query],
602
  scroll_to_output=True,
603
  queue=True,
env.py CHANGED
@@ -2,7 +2,7 @@ import os
2
 
3
 
4
  CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
5
- hf_token = os.environ.get("HF_TOKEN")
6
  hf_read_token = os.environ.get('HF_READ_TOKEN') # only use for private repo
7
 
8
 
@@ -67,6 +67,7 @@ HF_MODEL_USER_LIKES = [] # sorted by number of likes
67
  HF_MODEL_USER_EX = [] # sorted by a special rule
68
 
69
 
 
70
  # - **Download Models**
71
  download_model_list = [
72
  ]
@@ -79,6 +80,7 @@ download_vae_list = [
79
  download_lora_list = [
80
  ]
81
 
 
82
 
83
  directory_models = 'models'
84
  os.makedirs(directory_models, exist_ok=True)
 
2
 
3
 
4
  CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
5
+ HF_TOKEN = os.environ.get("HF_TOKEN")
6
  hf_read_token = os.environ.get('HF_READ_TOKEN') # only use for private repo
7
 
8
 
 
67
  HF_MODEL_USER_EX = [] # sorted by a special rule
68
 
69
 
70
+
71
  # - **Download Models**
72
  download_model_list = [
73
  ]
 
80
  download_lora_list = [
81
  ]
82
 
83
+ DIFFUSERS_FORMAT_LORAS = []
84
 
85
  directory_models = 'models'
86
  os.makedirs(directory_models, exist_ok=True)
flux.py CHANGED
@@ -11,14 +11,15 @@ warnings.filterwarnings(action="ignore", category=FutureWarning, module="diffuse
11
  warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
12
  warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
13
  from pathlib import Path
14
- from env import (hf_token, hf_read_token, # to use only for private repos
 
15
  CIVITAI_API_KEY, HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2, HF_LORA_ESSENTIAL_PRIVATE_REPO,
16
  HF_VAE_PRIVATE_REPO, directory_models, directory_loras, directory_vaes,
17
  download_model_list, download_lora_list, download_vae_list)
18
  from modutils import (to_list, list_uniq, list_sub, get_lora_model_list, download_private_repo,
19
  safe_float, escape_lora_basename, to_lora_key, to_lora_path, get_local_model_list,
20
  get_private_lora_model_lists, get_valid_lora_name, get_valid_lora_path, get_valid_lora_wt,
21
- get_lora_info, normalize_prompt_list, get_civitai_info, search_lora_on_civitai)
22
 
23
 
24
  def download_things(directory, url, hf_token="", civitai_api_key=""):
@@ -38,7 +39,7 @@ def download_things(directory, url, hf_token="", civitai_api_key=""):
38
  if hf_token:
39
  os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
40
  else:
41
- os.system (f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
42
  elif "civitai.com" in url:
43
  if "?" in url:
44
  url = url.split("?")[0]
@@ -94,14 +95,18 @@ vae_model_list = get_model_list(directory_vaes)
94
  vae_model_list.insert(0, "None")
95
 
96
 
 
 
 
 
 
97
  def get_t2i_model_info(repo_id: str):
98
- from huggingface_hub import HfApi
99
- api = HfApi()
100
  try:
101
- if " " in repo_id or not api.repo_exists(repo_id): return ""
102
- model = api.model_info(repo_id=repo_id)
103
  except Exception as e:
104
- print(f"Error: Failed to get {repo_id}'s info. ")
105
  print(e)
106
  return ""
107
  if model.private or model.gated: return ""
@@ -109,12 +114,8 @@ def get_t2i_model_info(repo_id: str):
109
  info = []
110
  url = f"https://huggingface.co/{repo_id}/"
111
  if not 'diffusers' in tags: return ""
112
- if 'diffusers:FluxPipeline' in tags:
113
- info.append("FLUX.1")
114
- elif 'diffusers:StableDiffusionXLPipeline' in tags:
115
- info.append("SDXL")
116
- elif 'diffusers:StableDiffusionPipeline' in tags:
117
- info.append("SD1.5")
118
  if model.card_data and model.card_data.tags:
119
  info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
120
  info.append(f"DLs: {model.downloads}")
@@ -246,9 +247,9 @@ def update_loras(prompt, lora, lora_wt):
246
  gr.update(value=tag, label=label, visible=on), gr.update(value=md, visible=on)
247
 
248
 
249
- def search_civitai_lora(query, base_model):
250
  global civitai_lora_last_results
251
- items = search_lora_on_civitai(query, base_model)
252
  if not items: return gr.update(choices=[("", "")], value="", visible=False),\
253
  gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
254
  civitai_lora_last_results = {}
 
11
  warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
12
  warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
13
  from pathlib import Path
14
+ from huggingface_hub import HfApi
15
+ from env import (HF_TOKEN, hf_read_token, # to use only for private repos
16
  CIVITAI_API_KEY, HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2, HF_LORA_ESSENTIAL_PRIVATE_REPO,
17
  HF_VAE_PRIVATE_REPO, directory_models, directory_loras, directory_vaes,
18
  download_model_list, download_lora_list, download_vae_list)
19
  from modutils import (to_list, list_uniq, list_sub, get_lora_model_list, download_private_repo,
20
  safe_float, escape_lora_basename, to_lora_key, to_lora_path, get_local_model_list,
21
  get_private_lora_model_lists, get_valid_lora_name, get_valid_lora_path, get_valid_lora_wt,
22
+ get_lora_info, normalize_prompt_list, get_civitai_info, search_lora_on_civitai, MODEL_TYPE_DICT)
23
 
24
 
25
  def download_things(directory, url, hf_token="", civitai_api_key=""):
 
39
  if hf_token:
40
  os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
41
  else:
42
+ os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
43
  elif "civitai.com" in url:
44
  if "?" in url:
45
  url = url.split("?")[0]
 
95
  vae_model_list.insert(0, "None")
96
 
97
 
98
+ def is_repo_name(s):
99
+ import re
100
+ return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
101
+
102
+
103
  def get_t2i_model_info(repo_id: str):
104
+ api = HfApi(token=HF_TOKEN)
 
105
  try:
106
+ if not is_repo_name(repo_id): return ""
107
+ model = api.model_info(repo_id=repo_id, timeout=5.0)
108
  except Exception as e:
109
+ print(f"Error: Failed to get {repo_id}'s info.")
110
  print(e)
111
  return ""
112
  if model.private or model.gated: return ""
 
114
  info = []
115
  url = f"https://huggingface.co/{repo_id}/"
116
  if not 'diffusers' in tags: return ""
117
+ for k, v in MODEL_TYPE_DICT.items():
118
+ if k in tags: info.append(v)
 
 
 
 
119
  if model.card_data and model.card_data.tags:
120
  info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
121
  info.append(f"DLs: {model.downloads}")
 
247
  gr.update(value=tag, label=label, visible=on), gr.update(value=md, visible=on)
248
 
249
 
250
+ def search_civitai_lora(query, base_model, sort="Highest Rated", period="AllTime", tag=""):
251
  global civitai_lora_last_results
252
+ items = search_lora_on_civitai(query, base_model, 100, sort, period, tag)
253
  if not items: return gr.update(choices=[("", "")], value="", visible=False),\
254
  gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
255
  civitai_lora_last_results = {}
mod.py CHANGED
@@ -347,7 +347,7 @@ def enhance_prompt(input_prompt):
347
 
348
  def save_image(image, savefile, modelname, prompt, height, width, steps, cfg, seed):
349
  import uuid
350
- from PIL import Image, PngImagePlugin
351
  import json
352
  try:
353
  if savefile is None: savefile = f"{modelname.split('/')[-1]}_{str(uuid.uuid4())}.png"
 
347
 
348
  def save_image(image, savefile, modelname, prompt, height, width, steps, cfg, seed):
349
  import uuid
350
+ from PIL import PngImagePlugin
351
  import json
352
  try:
353
  if savefile is None: savefile = f"{modelname.split('/')[-1]}_{str(uuid.uuid4())}.png"
modutils.py CHANGED
@@ -4,11 +4,19 @@ import gradio as gr
4
  from huggingface_hub import HfApi
5
  import os
6
  from pathlib import Path
 
7
 
8
 
9
  from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
10
- HF_MODEL_USER_EX, HF_MODEL_USER_LIKES,
11
- directory_loras, hf_read_token, hf_token, CIVITAI_API_KEY)
 
 
 
 
 
 
 
12
 
13
 
14
  def get_user_agent():
@@ -27,6 +35,11 @@ def list_sub(a, b):
27
  return [e for e in a if e not in b]
28
 
29
 
 
 
 
 
 
30
  from translatepy import Translator
31
  translator = Translator()
32
  def translate_to_en(input: str):
@@ -64,7 +77,7 @@ def download_things(directory, url, hf_token="", civitai_api_key=""):
64
  if hf_token:
65
  os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
66
  else:
67
- os.system (f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
68
  elif "civitai.com" in url:
69
  if "?" in url:
70
  url = url.split("?")[0]
@@ -100,6 +113,23 @@ def safe_float(input):
100
  return output
101
 
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  def save_gallery_images(images, progress=gr.Progress(track_tqdm=True)):
104
  from datetime import datetime, timezone, timedelta
105
  progress(0, desc="Updating gallery...")
@@ -209,11 +239,16 @@ def get_model_id_list():
209
  model_ids.append(model.id) if not model.private else ""
210
  anime_models = []
211
  real_models = []
 
 
212
  for model in models_ex:
213
- if not model.private and not model.gated and "diffusers:FluxPipeline" not in model.tags:
214
- anime_models.append(model.id) if "anime" in model.tags else real_models.append(model.id)
 
215
  model_ids.extend(anime_models)
216
  model_ids.extend(real_models)
 
 
217
  model_id_list = model_ids.copy()
218
  return model_ids
219
 
@@ -222,10 +257,10 @@ model_id_list = get_model_id_list()
222
 
223
 
224
  def get_t2i_model_info(repo_id: str):
225
- api = HfApi()
226
  try:
227
- if " " in repo_id or not api.repo_exists(repo_id): return ""
228
- model = api.model_info(repo_id=repo_id)
229
  except Exception as e:
230
  print(f"Error: Failed to get {repo_id}'s info.")
231
  print(e)
@@ -235,9 +270,8 @@ def get_t2i_model_info(repo_id: str):
235
  info = []
236
  url = f"https://huggingface.co/{repo_id}/"
237
  if not 'diffusers' in tags: return ""
238
- if 'diffusers:FluxPipeline' in tags: info.append("FLUX.1")
239
- elif 'diffusers:StableDiffusionXLPipeline' in tags: info.append("SDXL")
240
- elif 'diffusers:StableDiffusionPipeline' in tags: info.append("SD1.5")
241
  if model.card_data and model.card_data.tags:
242
  info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
243
  info.append(f"DLs: {model.downloads}")
@@ -262,12 +296,8 @@ def get_tupled_model_list(model_list):
262
  tags = model.tags
263
  info = []
264
  if not 'diffusers' in tags: continue
265
- if 'diffusers:FluxPipeline' in tags:
266
- info.append("FLUX.1")
267
- if 'diffusers:StableDiffusionXLPipeline' in tags:
268
- info.append("SDXL")
269
- elif 'diffusers:StableDiffusionPipeline' in tags:
270
- info.append("SD1.5")
271
  if model.card_data and model.card_data.tags:
272
  info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
273
  if "pony" in info:
@@ -351,7 +381,7 @@ def get_civitai_info(path):
351
 
352
 
353
  def get_lora_model_list():
354
- loras = list_uniq(get_private_lora_model_lists() + get_local_model_list(directory_loras))
355
  loras.insert(0, "None")
356
  loras.insert(0, "")
357
  return loras
@@ -408,7 +438,7 @@ def download_lora(dl_urls: str):
408
  for url in [url.strip() for url in dl_urls.split(',')]:
409
  local_path = f"{directory_loras}/{url.split('/')[-1]}"
410
  if not Path(local_path).exists():
411
- download_things(directory_loras, url, hf_token, CIVITAI_API_KEY)
412
  urls.append(url)
413
  after = get_local_model_list(directory_loras)
414
  new_files = list_sub(after, before)
@@ -460,7 +490,7 @@ def download_my_lora(dl_urls: str, lora1: str, lora2: str, lora3: str, lora4: st
460
  gr.update(value=lora4, choices=choices), gr.update(value=lora5, choices=choices)
461
 
462
 
463
- def get_valid_lora_name(query: str):
464
  path = "None"
465
  if not query or query == "None": return "None"
466
  if to_lora_key(query) in loras_dict.keys(): return query
@@ -474,7 +504,7 @@ def get_valid_lora_name(query: str):
474
  dl_file = download_lora(query)
475
  if dl_file and Path(dl_file).exists(): return dl_file
476
  else:
477
- dl_file = find_similar_lora(query)
478
  if dl_file and Path(dl_file).exists(): return dl_file
479
  return "None"
480
 
@@ -498,14 +528,14 @@ def get_valid_lora_wt(prompt: str, lora_path: str, lora_wt: float):
498
  return wt
499
 
500
 
501
- def set_prompt_loras(prompt, prompt_syntax, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt):
502
  import re
503
  if not "Classic" in str(prompt_syntax): return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
504
- lora1 = get_valid_lora_name(lora1)
505
- lora2 = get_valid_lora_name(lora2)
506
- lora3 = get_valid_lora_name(lora3)
507
- lora4 = get_valid_lora_name(lora4)
508
- lora5 = get_valid_lora_name(lora5)
509
  if not "<lora" in prompt: return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
510
  lora1_wt = get_valid_lora_wt(prompt, lora1, lora1_wt)
511
  lora2_wt = get_valid_lora_wt(prompt, lora2, lora2_wt)
@@ -670,7 +700,7 @@ def get_my_lora(link_url):
670
  before = get_local_model_list(directory_loras)
671
  for url in [url.strip() for url in link_url.split(',')]:
672
  if not Path(f"{directory_loras}/{url.split('/')[-1]}").exists():
673
- download_things(directory_loras, url, hf_token, CIVITAI_API_KEY)
674
  after = get_local_model_list(directory_loras)
675
  new_files = list_sub(after, before)
676
  for file in new_files:
@@ -727,8 +757,7 @@ def move_file_lora(filepaths):
727
 
728
 
729
  def get_civitai_info(path):
730
- global civitai_not_exists_list
731
- global loras_url_to_path_dict
732
  import requests
733
  from requests.adapters import HTTPAdapter
734
  from urllib3.util import Retry
@@ -768,16 +797,18 @@ def get_civitai_info(path):
768
  return items
769
 
770
 
771
- def search_lora_on_civitai(query: str, allow_model: list[str] = ["Pony", "SDXL 1.0"], limit: int = 100):
 
772
  import requests
773
  from requests.adapters import HTTPAdapter
774
  from urllib3.util import Retry
775
- if not query: return None
776
  user_agent = get_user_agent()
777
  headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
778
  base_url = 'https://civitai.com/api/v1/models'
779
- params = {'query': query, 'types': ['LORA'], 'sort': 'Highest Rated', 'period': 'AllTime',
780
- 'nsfw': 'true', 'supportsGeneration ': 'true', 'limit': limit}
 
781
  session = requests.Session()
782
  retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
783
  session.mount("https://", HTTPAdapter(max_retries=retries))
@@ -806,9 +837,9 @@ def search_lora_on_civitai(query: str, allow_model: list[str] = ["Pony", "SDXL 1
806
  return items
807
 
808
 
809
- def search_civitai_lora(query, base_model):
810
  global civitai_lora_last_results
811
- items = search_lora_on_civitai(query, base_model)
812
  if not items: return gr.update(choices=[("", "")], value="", visible=False),\
813
  gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
814
  civitai_lora_last_results = {}
@@ -834,7 +865,27 @@ def select_civitai_lora(search_result):
834
  return gr.update(value=search_result), gr.update(value=md, visible=True)
835
 
836
 
837
- def find_similar_lora(q: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838
  from rapidfuzz.process import extractOne
839
  from rapidfuzz.utils import default_process
840
  query = to_lora_key(q)
@@ -857,7 +908,7 @@ def find_similar_lora(q: str):
857
  print(f"Finding <lora:{query}:...> on Civitai...")
858
  civitai_query = Path(query).stem if Path(query).is_file() else query
859
  civitai_query = civitai_query.replace("_", " ").replace("-", " ")
860
- base_model = ["Pony", "SDXL 1.0"]
861
  items = search_lora_on_civitai(civitai_query, base_model, 1)
862
  if items:
863
  item = items[0]
@@ -1219,12 +1270,12 @@ def set_textual_inversion_prompt(textual_inversion_gui, prompt_gui, neg_prompt_g
1219
 
1220
  def get_model_pipeline(repo_id: str):
1221
  from huggingface_hub import HfApi
1222
- api = HfApi()
1223
  default = "StableDiffusionPipeline"
1224
  try:
1225
- if " " in repo_id or not api.repo_exists(repo_id): return default
1226
- model = api.model_info(repo_id=repo_id)
1227
- except Exception as e:
1228
  return default
1229
  if model.private or model.gated: return default
1230
  tags = model.tags
 
4
  from huggingface_hub import HfApi
5
  import os
6
  from pathlib import Path
7
+ from PIL import Image
8
 
9
 
10
  from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
11
+ HF_MODEL_USER_EX, HF_MODEL_USER_LIKES, DIFFUSERS_FORMAT_LORAS,
12
+ directory_loras, hf_read_token, HF_TOKEN, CIVITAI_API_KEY)
13
+
14
+
15
+ MODEL_TYPE_DICT = {
16
+ "diffusers:StableDiffusionPipeline": "SD 1.5",
17
+ "diffusers:StableDiffusionXLPipeline": "SDXL",
18
+ "diffusers:FluxPipeline": "FLUX",
19
+ }
20
 
21
 
22
  def get_user_agent():
 
35
  return [e for e in a if e not in b]
36
 
37
 
38
+ def is_repo_name(s):
39
+ import re
40
+ return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
41
+
42
+
43
  from translatepy import Translator
44
  translator = Translator()
45
  def translate_to_en(input: str):
 
77
  if hf_token:
78
  os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
79
  else:
80
+ os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
81
  elif "civitai.com" in url:
82
  if "?" in url:
83
  url = url.split("?")[0]
 
113
  return output
114
 
115
 
116
+ def save_images(images: list[Image.Image], metadatas: list[str]):
117
+ from PIL import PngImagePlugin
118
+ import uuid
119
+ try:
120
+ output_images = []
121
+ for image, metadata in zip(images, metadatas):
122
+ info = PngImagePlugin.PngInfo()
123
+ info.add_text("metadata", metadata)
124
+ savefile = f"{str(uuid.uuid4())}.png"
125
+ image.save(savefile, "PNG", pnginfo=info)
126
+ output_images.append(str(Path(savefile).resolve()))
127
+ return output_images
128
+ except Exception as e:
129
+ print(f"Failed to save image file: {e}")
130
+ raise Exception(f"Failed to save image file:") from e
131
+
132
+
133
  def save_gallery_images(images, progress=gr.Progress(track_tqdm=True)):
134
  from datetime import datetime, timezone, timedelta
135
  progress(0, desc="Updating gallery...")
 
239
  model_ids.append(model.id) if not model.private else ""
240
  anime_models = []
241
  real_models = []
242
+ anime_models_flux = []
243
+ real_models_flux = []
244
  for model in models_ex:
245
+ if not model.private and not model.gated:
246
+ if "diffusers:FluxPipeline" in model.tags: anime_models_flux.append(model.id) if "anime" in model.tags else real_models_flux.append(model.id)
247
+ else: anime_models.append(model.id) if "anime" in model.tags else real_models.append(model.id)
248
  model_ids.extend(anime_models)
249
  model_ids.extend(real_models)
250
+ model_ids.extend(anime_models_flux)
251
+ model_ids.extend(real_models_flux)
252
  model_id_list = model_ids.copy()
253
  return model_ids
254
 
 
257
 
258
 
259
  def get_t2i_model_info(repo_id: str):
260
+ api = HfApi(token=HF_TOKEN)
261
  try:
262
+ if not is_repo_name(repo_id): return ""
263
+ model = api.model_info(repo_id=repo_id, timeout=5.0)
264
  except Exception as e:
265
  print(f"Error: Failed to get {repo_id}'s info.")
266
  print(e)
 
270
  info = []
271
  url = f"https://huggingface.co/{repo_id}/"
272
  if not 'diffusers' in tags: return ""
273
+ for k, v in MODEL_TYPE_DICT.items():
274
+ if k in tags: info.append(v)
 
275
  if model.card_data and model.card_data.tags:
276
  info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
277
  info.append(f"DLs: {model.downloads}")
 
296
  tags = model.tags
297
  info = []
298
  if not 'diffusers' in tags: continue
299
+ for k, v in MODEL_TYPE_DICT.items():
300
+ if k in tags: info.append(v)
 
 
 
 
301
  if model.card_data and model.card_data.tags:
302
  info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
303
  if "pony" in info:
 
381
 
382
 
383
  def get_lora_model_list():
384
+ loras = list_uniq(get_private_lora_model_lists() + get_local_model_list(directory_loras) + DIFFUSERS_FORMAT_LORAS)
385
  loras.insert(0, "None")
386
  loras.insert(0, "")
387
  return loras
 
438
  for url in [url.strip() for url in dl_urls.split(',')]:
439
  local_path = f"{directory_loras}/{url.split('/')[-1]}"
440
  if not Path(local_path).exists():
441
+ download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
442
  urls.append(url)
443
  after = get_local_model_list(directory_loras)
444
  new_files = list_sub(after, before)
 
490
  gr.update(value=lora4, choices=choices), gr.update(value=lora5, choices=choices)
491
 
492
 
493
+ def get_valid_lora_name(query: str, model_name: str):
494
  path = "None"
495
  if not query or query == "None": return "None"
496
  if to_lora_key(query) in loras_dict.keys(): return query
 
504
  dl_file = download_lora(query)
505
  if dl_file and Path(dl_file).exists(): return dl_file
506
  else:
507
+ dl_file = find_similar_lora(query, model_name)
508
  if dl_file and Path(dl_file).exists(): return dl_file
509
  return "None"
510
 
 
528
  return wt
529
 
530
 
531
+ def set_prompt_loras(prompt, prompt_syntax, model_name, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt):
532
  import re
533
  if not "Classic" in str(prompt_syntax): return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
534
+ lora1 = get_valid_lora_name(lora1, model_name)
535
+ lora2 = get_valid_lora_name(lora2, model_name)
536
+ lora3 = get_valid_lora_name(lora3, model_name)
537
+ lora4 = get_valid_lora_name(lora4, model_name)
538
+ lora5 = get_valid_lora_name(lora5, model_name)
539
  if not "<lora" in prompt: return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
540
  lora1_wt = get_valid_lora_wt(prompt, lora1, lora1_wt)
541
  lora2_wt = get_valid_lora_wt(prompt, lora2, lora2_wt)
 
700
  before = get_local_model_list(directory_loras)
701
  for url in [url.strip() for url in link_url.split(',')]:
702
  if not Path(f"{directory_loras}/{url.split('/')[-1]}").exists():
703
+ download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
704
  after = get_local_model_list(directory_loras)
705
  new_files = list_sub(after, before)
706
  for file in new_files:
 
757
 
758
 
759
  def get_civitai_info(path):
760
+ global civitai_not_exists_list, loras_url_to_path_dict
 
761
  import requests
762
  from requests.adapters import HTTPAdapter
763
  from urllib3.util import Retry
 
797
  return items
798
 
799
 
800
+ def search_lora_on_civitai(query: str, allow_model: list[str] = ["Pony", "SDXL 1.0"], limit: int = 100,
801
+ sort: str = "Highest Rated", period: str = "AllTime", tag: str = ""):
802
  import requests
803
  from requests.adapters import HTTPAdapter
804
  from urllib3.util import Retry
805
+
806
  user_agent = get_user_agent()
807
  headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
808
  base_url = 'https://civitai.com/api/v1/models'
809
+ params = {'types': ['LORA'], 'sort': sort, 'period': period, 'limit': limit, 'nsfw': 'true'}
810
+ if query: params["query"] = query
811
+ if tag: params["tag"] = tag
812
  session = requests.Session()
813
  retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
814
  session.mount("https://", HTTPAdapter(max_retries=retries))
 
837
  return items
838
 
839
 
840
+ def search_civitai_lora(query, base_model, sort="Highest Rated", period="AllTime", tag=""):
841
  global civitai_lora_last_results
842
+ items = search_lora_on_civitai(query, base_model, 100, sort, period, tag)
843
  if not items: return gr.update(choices=[("", "")], value="", visible=False),\
844
  gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
845
  civitai_lora_last_results = {}
 
865
  return gr.update(value=search_result), gr.update(value=md, visible=True)
866
 
867
 
868
+ LORA_BASE_MODEL_DICT = {
869
+ "diffusers:StableDiffusionPipeline": ["SD 1.5"],
870
+ "diffusers:StableDiffusionXLPipeline": ["Pony", "SDXL 1.0"],
871
+ "diffusers:FluxPipeline": ["Flux.1 D", "Flux.1 S"],
872
+ }
873
+
874
+
875
+ def get_lora_base_model(model_name: str):
876
+ api = HfApi(token=HF_TOKEN)
877
+ default = ["Pony", "SDXL 1.0"]
878
+ try:
879
+ model = api.model_info(repo_id=model_name, timeout=5.0)
880
+ tags = model.tags
881
+ for tag in tags:
882
+ if tag in LORA_BASE_MODEL_DICT.keys(): return LORA_BASE_MODEL_DICT.get(tag, default)
883
+ except Exception:
884
+ return default
885
+ return default
886
+
887
+
888
+ def find_similar_lora(q: str, model_name: str):
889
  from rapidfuzz.process import extractOne
890
  from rapidfuzz.utils import default_process
891
  query = to_lora_key(q)
 
908
  print(f"Finding <lora:{query}:...> on Civitai...")
909
  civitai_query = Path(query).stem if Path(query).is_file() else query
910
  civitai_query = civitai_query.replace("_", " ").replace("-", " ")
911
+ base_model = get_lora_base_model(model_name)
912
  items = search_lora_on_civitai(civitai_query, base_model, 1)
913
  if items:
914
  item = items[0]
 
1270
 
1271
  def get_model_pipeline(repo_id: str):
1272
  from huggingface_hub import HfApi
1273
+ api = HfApi(token=HF_TOKEN)
1274
  default = "StableDiffusionPipeline"
1275
  try:
1276
+ if not is_repo_name(repo_id): return default
1277
+ model = api.model_info(repo_id=repo_id, timeout=5.0)
1278
+ except Exception:
1279
  return default
1280
  if model.private or model.gated: return default
1281
  tags = model.tags