John6666 commited on
Commit
a3da8d7
·
verified ·
1 Parent(s): 9f2f8c6

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +115 -34
  2. env.py +5 -0
  3. modutils.py +63 -35
  4. requirements.txt +0 -1
app.py CHANGED
@@ -22,6 +22,7 @@ from stablepy import (
22
  SD15_TASKS,
23
  SDXL_TASKS,
24
  )
 
25
  #import urllib.parse
26
 
27
  PREPROCESSOR_CONTROLNET = {
@@ -393,6 +394,7 @@ class GuiSD:
393
  retain_task_model_in_cache=False,
394
  device="cpu",
395
  )
 
396
 
397
  def load_new_model(self, model_name, vae_model, task, progress=gr.Progress(track_tqdm=True)):
398
 
@@ -404,7 +406,7 @@ class GuiSD:
404
  if vae_model:
405
  vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
406
  if model_type != vae_type:
407
- gr.Info(msg_inc_vae)
408
 
409
  self.model.device = torch.device("cpu")
410
  dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
@@ -418,7 +420,7 @@ class GuiSD:
418
  )
419
  yield f"Model loaded: {model_name}"
420
 
421
- @spaces.GPU
422
  @torch.inference_mode()
423
  def generate_pipeline(
424
  self,
@@ -531,7 +533,7 @@ class GuiSD:
531
  vae_model = vae_model if vae_model != "None" else None
532
  loras_list = [lora1, lora2, lora3, lora4, lora5]
533
  vae_msg = f"VAE: {vae_model}" if vae_model else ""
534
- msg_lora = []
535
 
536
  print("Config model:", model_name, vae_model, loras_list)
537
 
@@ -539,7 +541,7 @@ class GuiSD:
539
  global lora_model_list
540
  lora_model_list = get_lora_model_list()
541
  lora1, lora_scale1, lora2, lora_scale2, lora3, lora_scale3, lora4, lora_scale4, lora5, lora_scale5 = \
542
- set_prompt_loras(prompt, syntax_weights, lora1, lora_scale1, lora2, lora_scale2, lora3,
543
  lora_scale3, lora4, lora_scale4, lora5, lora_scale5)
544
  prompt, neg_prompt = insert_model_recom_prompt(prompt, neg_prompt, model_name)
545
  ## END MOD
@@ -703,17 +705,24 @@ class GuiSD:
703
 
704
  #progress(0, desc="Preparation completed. Starting inference...")
705
 
706
- info_state = f"PROCESSING "
707
  for img, seed, image_path, metadata in self.model(**pipe_params):
708
  info_state += ">"
709
  if image_path:
710
- info_state = f"COMPLETED. Seeds: {str(seed)}"
711
  if vae_msg:
712
  info_state = info_state + "<br>" + vae_msg
 
 
 
 
 
 
 
713
  if msg_lora:
714
- info_state = info_state + "<br>" + "<br>".join(msg_lora)
715
 
716
- info_state = info_state + "<br>" + "GENERATION DATA:<br>" + "<br>-------<br>".join(metadata).replace("\n", "<br>")
717
 
718
  download_links = "<br>".join(
719
  [
@@ -721,7 +730,8 @@ class GuiSD:
721
  for i, path in enumerate(image_path)
722
  ]
723
  )
724
- if save_generated_images: info_state += f"<br>{download_links}"
 
725
 
726
  img = save_images(img, metadata)
727
 
@@ -735,32 +745,90 @@ def update_task_options(model_name, task_name):
735
 
736
  return gr.update(value=task_name, choices=new_choices)
737
 
738
- # def sd_gen_generate_pipeline(*args):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
739
 
740
- # # Load lora in CPU
741
- # status_lora = sd_gen.model.lora_merge(
742
- # lora_A=args[7] if args[7] != "None" else None, lora_scale_A=args[8],
743
- # lora_B=args[9] if args[9] != "None" else None, lora_scale_B=args[10],
744
- # lora_C=args[11] if args[11] != "None" else None, lora_scale_C=args[12],
745
- # lora_D=args[13] if args[13] != "None" else None, lora_scale_D=args[14],
746
- # lora_E=args[15] if args[15] != "None" else None, lora_scale_E=args[16],
747
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
748
 
749
- # lora_list = [args[7], args[9], args[11], args[13], args[15]]
750
- # print(status_lora)
751
- # for status, lora in zip(status_lora, lora_list):
752
- # if status:
753
- # gr.Info(f"LoRA loaded: {lora}")
754
- # elif status is not None:
755
- # gr.Warning(f"Failed to load LoRA: {lora}")
756
 
757
- # # if status_lora == [None] * 5 and self.model.lora_memory != [None] * 5:
758
- # # gr.Info(f"LoRAs in cache: {", ".join(str(x) for x in self.model.lora_memory if x is not None)}")
759
 
760
- # yield from sd_gen.generate_pipeline(*args)
761
 
 
 
 
 
 
 
 
762
 
763
- # sd_gen_generate_pipeline.zerogpu = True
 
 
 
 
 
 
 
 
 
 
764
  sd_gen = GuiSD()
765
 
766
  ## BEGIN MOD
@@ -869,6 +937,12 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
869
 
870
  actual_task_info = gr.HTML()
871
 
 
 
 
 
 
 
872
  with gr.Column(scale=1):
873
  with gr.Accordion("Generation settings", open=False, visible=True) as menu_gen:
874
  with gr.Row():
@@ -1015,9 +1089,13 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
1015
  lora5_copy_gui = gr.Button(value="Copy example to prompt", visible=False)
1016
  lora5_desc_gui = gr.Markdown(value="", visible=False)
1017
  with gr.Accordion("From URL", open=True, visible=True):
 
 
 
 
1018
  with gr.Row():
1019
  search_civitai_query_lora = gr.Textbox(label="Query", placeholder="oomuro sakurako...", lines=1)
1020
- search_civitai_basemodel_lora = gr.CheckboxGroup(label="Search LoRA for", choices=["Pony", "SD 1.5", "SDXL 1.0"], value=["Pony", "SDXL 1.0"])
1021
  search_civitai_button_lora = gr.Button("Search on Civitai")
1022
  search_civitai_desc_lora = gr.Markdown(value="", visible=False)
1023
  search_civitai_result_lora = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
@@ -1269,7 +1347,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
1269
  "Euler a",
1270
  1024,
1271
  1024,
1272
- "votepurchase/animagine-xl-3.1",
1273
  ],
1274
  ],
1275
  fn=sd_gen.generate_pipeline,
@@ -1407,9 +1485,9 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
1407
  lora4_copy_gui.click(apply_lora_prompt, [prompt_gui, lora4_info_gui], [prompt_gui], queue=False)
1408
  lora5_copy_gui.click(apply_lora_prompt, [prompt_gui, lora5_info_gui], [prompt_gui], queue=False)
1409
  gr.on(
1410
- triggers=[search_civitai_button_lora.click, search_civitai_query_lora.submit],
1411
  fn=search_civitai_lora,
1412
- inputs=[search_civitai_query_lora, search_civitai_basemodel_lora],
1413
  outputs=[search_civitai_result_lora, search_civitai_desc_lora, search_civitai_button_lora, search_civitai_query_lora],
1414
  queue=True,
1415
  scroll_to_output=True,
@@ -1463,7 +1541,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
1463
  queue=True,
1464
  show_progress="minimal",
1465
  ).success(
1466
- fn=sd_gen.generate_pipeline,
1467
  inputs=[
1468
  prompt_gui,
1469
  neg_prompt_gui,
@@ -1567,6 +1645,9 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', elem_id="main", fill_width=True, cs
1567
  mode_ip2,
1568
  scale_ip2,
1569
  pag_scale_gui,
 
 
 
1570
  ],
1571
  outputs=[result_images, actual_task_info],
1572
  queue=True,
 
22
  SD15_TASKS,
23
  SDXL_TASKS,
24
  )
25
+ import time
26
  #import urllib.parse
27
 
28
  PREPROCESSOR_CONTROLNET = {
 
394
  retain_task_model_in_cache=False,
395
  device="cpu",
396
  )
397
+ self.model.device = torch.device("cpu") #
398
 
399
  def load_new_model(self, model_name, vae_model, task, progress=gr.Progress(track_tqdm=True)):
400
 
 
406
  if vae_model:
407
  vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
408
  if model_type != vae_type:
409
+ gr.Warning(msg_inc_vae)
410
 
411
  self.model.device = torch.device("cpu")
412
  dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
 
420
  )
421
  yield f"Model loaded: {model_name}"
422
 
423
+ #@spaces.GPU
424
  @torch.inference_mode()
425
  def generate_pipeline(
426
  self,
 
533
  vae_model = vae_model if vae_model != "None" else None
534
  loras_list = [lora1, lora2, lora3, lora4, lora5]
535
  vae_msg = f"VAE: {vae_model}" if vae_model else ""
536
+ msg_lora = ""
537
 
538
  print("Config model:", model_name, vae_model, loras_list)
539
 
 
541
  global lora_model_list
542
  lora_model_list = get_lora_model_list()
543
  lora1, lora_scale1, lora2, lora_scale2, lora3, lora_scale3, lora4, lora_scale4, lora5, lora_scale5 = \
544
+ set_prompt_loras(prompt, syntax_weights, model_name, lora1, lora_scale1, lora2, lora_scale2, lora3,
545
  lora_scale3, lora4, lora_scale4, lora5, lora_scale5)
546
  prompt, neg_prompt = insert_model_recom_prompt(prompt, neg_prompt, model_name)
547
  ## END MOD
 
705
 
706
  #progress(0, desc="Preparation completed. Starting inference...")
707
 
708
+ info_state = "PROCESSING "
709
  for img, seed, image_path, metadata in self.model(**pipe_params):
710
  info_state += ">"
711
  if image_path:
712
+ info_state = f"COMPLETE. Seeds: {str(seed)}"
713
  if vae_msg:
714
  info_state = info_state + "<br>" + vae_msg
715
+
716
+ for status, lora in zip(self.model.lora_status, self.model.lora_memory):
717
+ if status:
718
+ msg_lora += f"<br>Loaded: {lora}"
719
+ elif status is not None:
720
+ msg_lora += f"<br>Error with: {lora}"
721
+
722
  if msg_lora:
723
+ info_state += msg_lora
724
 
725
+ info_state = info_state + "<br>" + "GENERATION DATA:<br>" + "<br>-------<br>".join(metadata).replace("\n", "<br>")
726
 
727
  download_links = "<br>".join(
728
  [
 
730
  for i, path in enumerate(image_path)
731
  ]
732
  )
733
+ if save_generated_images:
734
+ info_state += f"<br>{download_links}"
735
 
736
  img = save_images(img, metadata)
737
 
 
745
 
746
  return gr.update(value=task_name, choices=new_choices)
747
 
748
+ def dynamic_gpu_duration(func, duration, *args):
749
+
750
+ @spaces.GPU(duration=duration)
751
+ def wrapped_func():
752
+ yield from func(*args)
753
+
754
+ return wrapped_func()
755
+
756
+
757
+ @spaces.GPU
758
+ def dummy_gpu():
759
+ return None
760
+
761
+
762
+ def sd_gen_generate_pipeline(*args):
763
+
764
+ gpu_duration_arg = int(args[-1]) if args[-1] else 59
765
+ verbose_arg = int(args[-2])
766
+ load_lora_cpu = args[-3]
767
+ generation_args = args[:-3]
768
+ lora_list = [
769
+ None if item == "None" or item == "" else item
770
+ for item in [args[7], args[9], args[11], args[13], args[15]]
771
+ ]
772
+ lora_status = [None] * 5
773
+
774
+ msg_load_lora = "Updating LoRAs in GPU..."
775
+ if load_lora_cpu:
776
+ msg_load_lora = "Updating LoRAs in CPU (Slow but saves GPU usage)..."
777
+
778
+ if lora_list != sd_gen.model.lora_memory and lora_list != [None] * 5:
779
+ yield None, msg_load_lora
780
 
781
+ # Load lora in CPU
782
+ if load_lora_cpu:
783
+ lora_status = sd_gen.model.lora_merge(
784
+ lora_A=lora_list[0], lora_scale_A=args[8],
785
+ lora_B=lora_list[1], lora_scale_B=args[10],
786
+ lora_C=lora_list[2], lora_scale_C=args[12],
787
+ lora_D=lora_list[3], lora_scale_D=args[14],
788
+ lora_E=lora_list[4], lora_scale_E=args[16],
789
+ )
790
+ print(lora_status)
791
+
792
+ if verbose_arg:
793
+ for status, lora in zip(lora_status, lora_list):
794
+ if status:
795
+ gr.Info(f"LoRA loaded in CPU: {lora}")
796
+ elif status is not None:
797
+ gr.Warning(f"Failed to load LoRA: {lora}")
798
+
799
+ if lora_status == [None] * 5 and sd_gen.model.lora_memory != [None] * 5 and load_lora_cpu:
800
+ lora_cache_msg = ", ".join(
801
+ str(x) for x in sd_gen.model.lora_memory if x is not None
802
+ )
803
+ gr.Info(f"LoRAs in cache: {lora_cache_msg}")
804
 
805
+ msg_request = f"Requesting {gpu_duration_arg}s. of GPU time"
806
+ gr.Info(msg_request)
807
+ print(msg_request)
 
 
 
 
808
 
809
+ # yield from sd_gen.generate_pipeline(*generation_args)
 
810
 
811
+ start_time = time.time()
812
 
813
+ yield from dynamic_gpu_duration(
814
+ sd_gen.generate_pipeline,
815
+ gpu_duration_arg,
816
+ *generation_args,
817
+ )
818
+
819
+ end_time = time.time()
820
 
821
+ if verbose_arg:
822
+ execution_time = end_time - start_time
823
+ msg_task_complete = (
824
+ f"GPU task complete in: {round(execution_time, 0) + 1} seconds"
825
+ )
826
+ gr.Info(msg_task_complete)
827
+ print(msg_task_complete)
828
+
829
+
830
+ dynamic_gpu_duration.zerogpu = True
831
+ sd_gen_generate_pipeline.zerogpu = True
832
  sd_gen = GuiSD()
833
 
834
  ## BEGIN MOD
 
937
 
938
  actual_task_info = gr.HTML()
939
 
940
+ with gr.Row(equal_height=False, variant="default"):
941
+ gpu_duration_gui = gr.Number(minimum=5, maximum=240, value=59, show_label=False, container=False, info="GPU time duration (seconds)")
942
+ with gr.Column():
943
+ verbose_info_gui = gr.Checkbox(value=False, container=False, label="Status info")
944
+ load_lora_cpu_gui = gr.Checkbox(value=False, container=False, label="Load LoRAs on CPU (Save GPU time)")
945
+
946
  with gr.Column(scale=1):
947
  with gr.Accordion("Generation settings", open=False, visible=True) as menu_gen:
948
  with gr.Row():
 
1089
  lora5_copy_gui = gr.Button(value="Copy example to prompt", visible=False)
1090
  lora5_desc_gui = gr.Markdown(value="", visible=False)
1091
  with gr.Accordion("From URL", open=True, visible=True):
1092
+ with gr.Row():
1093
+ search_civitai_basemodel_lora = gr.CheckboxGroup(label="Search LoRA for", choices=["Pony", "SD 1.5", "SDXL 1.0", "Flux.1 D", "Flux.1 S"], value=["Pony", "SDXL 1.0"])
1094
+ search_civitai_sort_lora = gr.Radio(label="Sort", choices=["Highest Rated", "Most Downloaded", "Newest"], value="Highest Rated")
1095
+ search_civitai_period_lora = gr.Radio(label="Period", choices=["AllTime", "Year", "Month", "Week", "Day"], value="AllTime")
1096
  with gr.Row():
1097
  search_civitai_query_lora = gr.Textbox(label="Query", placeholder="oomuro sakurako...", lines=1)
1098
+ search_civitai_tag_lora = gr.Textbox(label="Tag", lines=1)
1099
  search_civitai_button_lora = gr.Button("Search on Civitai")
1100
  search_civitai_desc_lora = gr.Markdown(value="", visible=False)
1101
  search_civitai_result_lora = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
 
1347
  "Euler a",
1348
  1024,
1349
  1024,
1350
+ "cagliostrolab/animagine-xl-3.1",
1351
  ],
1352
  ],
1353
  fn=sd_gen.generate_pipeline,
 
1485
  lora4_copy_gui.click(apply_lora_prompt, [prompt_gui, lora4_info_gui], [prompt_gui], queue=False)
1486
  lora5_copy_gui.click(apply_lora_prompt, [prompt_gui, lora5_info_gui], [prompt_gui], queue=False)
1487
  gr.on(
1488
+ triggers=[search_civitai_button_lora.click, search_civitai_query_lora.submit, search_civitai_tag_lora.submit],
1489
  fn=search_civitai_lora,
1490
+ inputs=[search_civitai_query_lora, search_civitai_basemodel_lora, search_civitai_sort_lora, search_civitai_period_lora, search_civitai_tag_lora],
1491
  outputs=[search_civitai_result_lora, search_civitai_desc_lora, search_civitai_button_lora, search_civitai_query_lora],
1492
  queue=True,
1493
  scroll_to_output=True,
 
1541
  queue=True,
1542
  show_progress="minimal",
1543
  ).success(
1544
+ fn=sd_gen_generate_pipeline,
1545
  inputs=[
1546
  prompt_gui,
1547
  neg_prompt_gui,
 
1645
  mode_ip2,
1646
  scale_ip2,
1647
  pag_scale_gui,
1648
+ load_lora_cpu_gui,
1649
+ verbose_info_gui,
1650
+ gpu_duration_gui,
1651
  ],
1652
  outputs=[result_images, actual_task_info],
1653
  queue=True,
env.py CHANGED
@@ -72,6 +72,11 @@ load_diffusers_format_model = [
72
  "Raelina/Raemu-Flux",
73
  ]
74
 
 
 
 
 
 
75
  # List all Models for specified user
76
  HF_MODEL_USER_LIKES = ["votepurchase"] # sorted by number of likes
77
  HF_MODEL_USER_EX = ["John6666"] # sorted by a special rule
 
72
  "Raelina/Raemu-Flux",
73
  ]
74
 
75
+ DIFFUSERS_FORMAT_LORAS = [
76
+ "nerijs/animation2k-flux",
77
+ "XLabs-AI/flux-RealismLora",
78
+ ]
79
+
80
  # List all Models for specified user
81
  HF_MODEL_USER_LIKES = ["votepurchase"] # sorted by number of likes
82
  HF_MODEL_USER_EX = ["John6666"] # sorted by a special rule
modutils.py CHANGED
@@ -4,13 +4,21 @@ import gradio as gr
4
  from huggingface_hub import HfApi
5
  import os
6
  from pathlib import Path
 
7
 
8
 
9
  from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
10
- HF_MODEL_USER_EX, HF_MODEL_USER_LIKES,
11
  directory_loras, hf_read_token, HF_TOKEN, CIVITAI_API_KEY)
12
 
13
 
 
 
 
 
 
 
 
14
  def get_user_agent():
15
  return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
16
 
@@ -27,6 +35,11 @@ def list_sub(a, b):
27
  return [e for e in a if e not in b]
28
 
29
 
 
 
 
 
 
30
  from translatepy import Translator
31
  translator = Translator()
32
  def translate_to_en(input: str):
@@ -64,7 +77,7 @@ def download_things(directory, url, hf_token="", civitai_api_key=""):
64
  if hf_token:
65
  os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
66
  else:
67
- os.system (f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
68
  elif "civitai.com" in url:
69
  if "?" in url:
70
  url = url.split("?")[0]
@@ -100,7 +113,6 @@ def safe_float(input):
100
  return output
101
 
102
 
103
- from PIL import Image
104
  def save_images(images: list[Image.Image], metadatas: list[str]):
105
  from PIL import PngImagePlugin
106
  import uuid
@@ -245,10 +257,10 @@ model_id_list = get_model_id_list()
245
 
246
 
247
  def get_t2i_model_info(repo_id: str):
248
- api = HfApi()
249
  try:
250
- if " " in repo_id or not api.repo_exists(repo_id): return ""
251
- model = api.model_info(repo_id=repo_id)
252
  except Exception as e:
253
  print(f"Error: Failed to get {repo_id}'s info.")
254
  print(e)
@@ -258,9 +270,8 @@ def get_t2i_model_info(repo_id: str):
258
  info = []
259
  url = f"https://huggingface.co/{repo_id}/"
260
  if not 'diffusers' in tags: return ""
261
- if 'diffusers:FluxPipeline' in tags: info.append("FLUX.1")
262
- elif 'diffusers:StableDiffusionXLPipeline' in tags: info.append("SDXL")
263
- elif 'diffusers:StableDiffusionPipeline' in tags: info.append("SD1.5")
264
  if model.card_data and model.card_data.tags:
265
  info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
266
  info.append(f"DLs: {model.downloads}")
@@ -285,12 +296,8 @@ def get_tupled_model_list(model_list):
285
  tags = model.tags
286
  info = []
287
  if not 'diffusers' in tags: continue
288
- if 'diffusers:FluxPipeline' in tags:
289
- info.append("FLUX.1")
290
- if 'diffusers:StableDiffusionXLPipeline' in tags:
291
- info.append("SDXL")
292
- elif 'diffusers:StableDiffusionPipeline' in tags:
293
- info.append("SD1.5")
294
  if model.card_data and model.card_data.tags:
295
  info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
296
  if "pony" in info:
@@ -374,7 +381,7 @@ def get_civitai_info(path):
374
 
375
 
376
  def get_lora_model_list():
377
- loras = list_uniq(get_private_lora_model_lists() + get_local_model_list(directory_loras))
378
  loras.insert(0, "None")
379
  loras.insert(0, "")
380
  return loras
@@ -483,7 +490,7 @@ def download_my_lora(dl_urls: str, lora1: str, lora2: str, lora3: str, lora4: st
483
  gr.update(value=lora4, choices=choices), gr.update(value=lora5, choices=choices)
484
 
485
 
486
- def get_valid_lora_name(query: str):
487
  path = "None"
488
  if not query or query == "None": return "None"
489
  if to_lora_key(query) in loras_dict.keys(): return query
@@ -497,7 +504,7 @@ def get_valid_lora_name(query: str):
497
  dl_file = download_lora(query)
498
  if dl_file and Path(dl_file).exists(): return dl_file
499
  else:
500
- dl_file = find_similar_lora(query)
501
  if dl_file and Path(dl_file).exists(): return dl_file
502
  return "None"
503
 
@@ -521,14 +528,14 @@ def get_valid_lora_wt(prompt: str, lora_path: str, lora_wt: float):
521
  return wt
522
 
523
 
524
- def set_prompt_loras(prompt, prompt_syntax, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt):
525
  import re
526
  if not "Classic" in str(prompt_syntax): return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
527
- lora1 = get_valid_lora_name(lora1)
528
- lora2 = get_valid_lora_name(lora2)
529
- lora3 = get_valid_lora_name(lora3)
530
- lora4 = get_valid_lora_name(lora4)
531
- lora5 = get_valid_lora_name(lora5)
532
  if not "<lora" in prompt: return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
533
  lora1_wt = get_valid_lora_wt(prompt, lora1, lora1_wt)
534
  lora2_wt = get_valid_lora_wt(prompt, lora2, lora2_wt)
@@ -790,16 +797,17 @@ def get_civitai_info(path):
790
  return items
791
 
792
 
793
- def search_lora_on_civitai(query: str, allow_model: list[str] = ["Pony", "SDXL 1.0"], limit: int = 100):
 
794
  import requests
795
  from requests.adapters import HTTPAdapter
796
  from urllib3.util import Retry
797
- if not query: return None
798
  user_agent = get_user_agent()
799
  headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
800
  base_url = 'https://civitai.com/api/v1/models'
801
- params = {'query': query, 'types': ['LORA'], 'sort': 'Highest Rated', 'period': 'AllTime',
802
- 'nsfw': 'true', 'supportsGeneration ': 'true', 'limit': limit}
 
803
  session = requests.Session()
804
  retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
805
  session.mount("https://", HTTPAdapter(max_retries=retries))
@@ -828,9 +836,9 @@ def search_lora_on_civitai(query: str, allow_model: list[str] = ["Pony", "SDXL 1
828
  return items
829
 
830
 
831
- def search_civitai_lora(query, base_model):
832
  global civitai_lora_last_results
833
- items = search_lora_on_civitai(query, base_model)
834
  if not items: return gr.update(choices=[("", "")], value="", visible=False),\
835
  gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
836
  civitai_lora_last_results = {}
@@ -856,7 +864,27 @@ def select_civitai_lora(search_result):
856
  return gr.update(value=search_result), gr.update(value=md, visible=True)
857
 
858
 
859
- def find_similar_lora(q: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
860
  from rapidfuzz.process import extractOne
861
  from rapidfuzz.utils import default_process
862
  query = to_lora_key(q)
@@ -879,7 +907,7 @@ def find_similar_lora(q: str):
879
  print(f"Finding <lora:{query}:...> on Civitai...")
880
  civitai_query = Path(query).stem if Path(query).is_file() else query
881
  civitai_query = civitai_query.replace("_", " ").replace("-", " ")
882
- base_model = ["Pony", "SDXL 1.0"]
883
  items = search_lora_on_civitai(civitai_query, base_model, 1)
884
  if items:
885
  item = items[0]
@@ -1241,11 +1269,11 @@ def set_textual_inversion_prompt(textual_inversion_gui, prompt_gui, neg_prompt_g
1241
 
1242
  def get_model_pipeline(repo_id: str):
1243
  from huggingface_hub import HfApi
1244
- api = HfApi()
1245
  default = "StableDiffusionPipeline"
1246
  try:
1247
- if " " in repo_id or not api.repo_exists(repo_id): return default
1248
- model = api.model_info(repo_id=repo_id)
1249
  except Exception:
1250
  return default
1251
  if model.private or model.gated: return default
 
4
  from huggingface_hub import HfApi
5
  import os
6
  from pathlib import Path
7
+ from PIL import Image
8
 
9
 
10
  from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
11
+ HF_MODEL_USER_EX, HF_MODEL_USER_LIKES, DIFFUSERS_FORMAT_LORAS,
12
  directory_loras, hf_read_token, HF_TOKEN, CIVITAI_API_KEY)
13
 
14
 
15
+ MODEL_TYPE_DICT = {
16
+ "diffusers:StableDiffusionPipeline": "SD 1.5",
17
+ "diffusers:StableDiffusionXLPipeline": "SDXL",
18
+ "diffusers:FluxPipeline": "FLUX",
19
+ }
20
+
21
+
22
  def get_user_agent():
23
  return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
24
 
 
35
  return [e for e in a if e not in b]
36
 
37
 
38
+ def is_repo_name(s):
39
+ import re
40
+ return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
41
+
42
+
43
  from translatepy import Translator
44
  translator = Translator()
45
  def translate_to_en(input: str):
 
77
  if hf_token:
78
  os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
79
  else:
80
+ os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
81
  elif "civitai.com" in url:
82
  if "?" in url:
83
  url = url.split("?")[0]
 
113
  return output
114
 
115
 
 
116
  def save_images(images: list[Image.Image], metadatas: list[str]):
117
  from PIL import PngImagePlugin
118
  import uuid
 
257
 
258
 
259
  def get_t2i_model_info(repo_id: str):
260
+ api = HfApi(token=HF_TOKEN)
261
  try:
262
+ if not is_repo_name(repo_id): return ""
263
+ model = api.model_info(repo_id=repo_id, timeout=5.0)
264
  except Exception as e:
265
  print(f"Error: Failed to get {repo_id}'s info.")
266
  print(e)
 
270
  info = []
271
  url = f"https://huggingface.co/{repo_id}/"
272
  if not 'diffusers' in tags: return ""
273
+ for k, v in MODEL_TYPE_DICT.items():
274
+ if k in tags: info.append(v)
 
275
  if model.card_data and model.card_data.tags:
276
  info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
277
  info.append(f"DLs: {model.downloads}")
 
296
  tags = model.tags
297
  info = []
298
  if not 'diffusers' in tags: continue
299
+ for k, v in MODEL_TYPE_DICT.items():
300
+ if k in tags: info.append(v)
 
 
 
 
301
  if model.card_data and model.card_data.tags:
302
  info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
303
  if "pony" in info:
 
381
 
382
 
383
  def get_lora_model_list():
384
+ loras = list_uniq(get_private_lora_model_lists() + get_local_model_list(directory_loras) + DIFFUSERS_FORMAT_LORAS)
385
  loras.insert(0, "None")
386
  loras.insert(0, "")
387
  return loras
 
490
  gr.update(value=lora4, choices=choices), gr.update(value=lora5, choices=choices)
491
 
492
 
493
+ def get_valid_lora_name(query: str, model_name: str):
494
  path = "None"
495
  if not query or query == "None": return "None"
496
  if to_lora_key(query) in loras_dict.keys(): return query
 
504
  dl_file = download_lora(query)
505
  if dl_file and Path(dl_file).exists(): return dl_file
506
  else:
507
+ dl_file = find_similar_lora(query, model_name)
508
  if dl_file and Path(dl_file).exists(): return dl_file
509
  return "None"
510
 
 
528
  return wt
529
 
530
 
531
+ def set_prompt_loras(prompt, prompt_syntax, model_name, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt):
532
  import re
533
  if not "Classic" in str(prompt_syntax): return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
534
+ lora1 = get_valid_lora_name(lora1, model_name)
535
+ lora2 = get_valid_lora_name(lora2, model_name)
536
+ lora3 = get_valid_lora_name(lora3, model_name)
537
+ lora4 = get_valid_lora_name(lora4, model_name)
538
+ lora5 = get_valid_lora_name(lora5, model_name)
539
  if not "<lora" in prompt: return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
540
  lora1_wt = get_valid_lora_wt(prompt, lora1, lora1_wt)
541
  lora2_wt = get_valid_lora_wt(prompt, lora2, lora2_wt)
 
797
  return items
798
 
799
 
800
+ def search_lora_on_civitai(query: str, allow_model: list[str] = ["Pony", "SDXL 1.0"], limit: int = 100,
801
+ sort: str = "Highest Rated", period: str = "AllTime", tag: str = ""):
802
  import requests
803
  from requests.adapters import HTTPAdapter
804
  from urllib3.util import Retry
 
805
  user_agent = get_user_agent()
806
  headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
807
  base_url = 'https://civitai.com/api/v1/models'
808
+ params = {'types': ['LORA'], 'sort': sort, 'period': period, 'limit': limit, 'nsfw': 'true'}
809
+ if query: params["query"] = query
810
+ if tag: params["tag"] = tag
811
  session = requests.Session()
812
  retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
813
  session.mount("https://", HTTPAdapter(max_retries=retries))
 
836
  return items
837
 
838
 
839
+ def search_civitai_lora(query, base_model, sort="Highest Rated", period="AllTime", tag=""):
840
  global civitai_lora_last_results
841
+ items = search_lora_on_civitai(query, base_model, 100, sort, period, tag)
842
  if not items: return gr.update(choices=[("", "")], value="", visible=False),\
843
  gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
844
  civitai_lora_last_results = {}
 
864
  return gr.update(value=search_result), gr.update(value=md, visible=True)
865
 
866
 
867
+ LORA_BASE_MODEL_DICT = {
868
+ "diffusers:StableDiffusionPipeline": ["SD 1.5"],
869
+ "diffusers:StableDiffusionXLPipeline": ["Pony", "SDXL 1.0"],
870
+ "diffusers:FluxPipeline": ["Flux.1 D", "Flux.1 S"],
871
+ }
872
+
873
+
874
+ def get_lora_base_model(model_name: str):
875
+ api = HfApi(token=HF_TOKEN)
876
+ default = ["Pony", "SDXL 1.0"]
877
+ try:
878
+ model = api.model_info(repo_id=model_name, timeout=5.0)
879
+ tags = model.tags
880
+ for tag in tags:
881
+ if tag in LORA_BASE_MODEL_DICT.keys(): return LORA_BASE_MODEL_DICT.get(tag, default)
882
+ except Exception:
883
+ return default
884
+ return default
885
+
886
+
887
+ def find_similar_lora(q: str, model_name: str):
888
  from rapidfuzz.process import extractOne
889
  from rapidfuzz.utils import default_process
890
  query = to_lora_key(q)
 
907
  print(f"Finding <lora:{query}:...> on Civitai...")
908
  civitai_query = Path(query).stem if Path(query).is_file() else query
909
  civitai_query = civitai_query.replace("_", " ").replace("-", " ")
910
+ base_model = get_lora_base_model(model_name)
911
  items = search_lora_on_civitai(civitai_query, base_model, 1)
912
  if items:
913
  item = items[0]
 
1269
 
1270
  def get_model_pipeline(repo_id: str):
1271
  from huggingface_hub import HfApi
1272
+ api = HfApi(token=HF_TOKEN)
1273
  default = "StableDiffusionPipeline"
1274
  try:
1275
+ if not is_repo_name(repo_id): return default
1276
+ model = api.model_info(repo_id=repo_id, timeout=5.0)
1277
  except Exception:
1278
  return default
1279
  if model.private or model.gated: return default
requirements.txt CHANGED
@@ -2,7 +2,6 @@ git+https://github.com/R3gm/stablepy.git@flux_beta
2
  torch==2.2.0
3
  gdown
4
  opencv-python
5
- yt-dlp
6
  torchvision
7
  accelerate
8
  transformers
 
2
  torch==2.2.0
3
  gdown
4
  opencv-python
 
5
  torchvision
6
  accelerate
7
  transformers