John6666 commited on
Commit
5be865f
·
verified ·
1 Parent(s): 77c826e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -29
app.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image
6
  from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
7
  from diffusers.utils import load_image
8
- from diffusers import FluxControlNetPipeline, FluxControlNetModel, FluxMultiControlNetModel, FluxControlNetImg2ImgPipeline
9
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download, HfApi
10
  import os
11
  import copy
@@ -13,8 +13,9 @@ import random
13
  import time
14
  import requests
15
  import pandas as pd
 
16
 
17
- from env import models, num_loras, num_cns, HF_TOKEN
18
  from mod import (clear_cache, get_repo_safetensors, is_repo_name, is_repo_exists, get_model_trigger,
19
  description_ui, compose_lora_json, is_valid_lora, fuse_loras, save_image, preprocess_i2i_image,
20
  get_trigger_word, enhance_prompt, set_control_union_image,
@@ -41,11 +42,11 @@ controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
41
  dtype = torch.bfloat16
42
  #dtype = torch.float8_e4m3fn
43
  #device = "cuda" if torch.cuda.is_available() else "cpu"
44
- taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
45
- good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype)
46
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1)
47
  pipe_i2i = AutoPipelineForImage2Image.from_pretrained(base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
48
- tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype)
49
  controlnet_union = None
50
  controlnet = None
51
  last_model = models[0]
@@ -70,10 +71,13 @@ def unload_lora():
70
  # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
71
  # https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux
72
  #@spaces.GPU()
73
- def change_base_model(repo_id: str, cn_on: bool, disable_model_cache: bool, progress=gr.Progress(track_tqdm=True)):
74
  global pipe, pipe_i2i, taef1, good_vae, controlnet_union, controlnet, last_model, last_cn_on, dtype
 
 
75
  try:
76
- if not disable_model_cache and (repo_id == last_model and cn_on is last_cn_on) or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(visible=True)
 
77
  unload_lora()
78
  pipe.to("cpu")
79
  pipe_i2i.to("cpu")
@@ -85,12 +89,19 @@ def change_base_model(repo_id: str, cn_on: bool, disable_model_cache: bool, prog
85
  if cn_on:
86
  progress(0, desc=f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
87
  print(f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
88
- controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype)
89
  controlnet = FluxMultiControlNetModel([controlnet_union])
90
  controlnet.config = controlnet_union.config
91
- pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=dtype)
92
- pipe_i2i = FluxControlNetImg2ImgPipeline.from_pretrained(repo_id, controlnet=controlnet, vae=None, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
93
- tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype)
 
 
 
 
 
 
 
94
  last_model = repo_id
95
  last_cn_on = cn_on
96
  progress(1, desc=f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
@@ -98,16 +109,25 @@ def change_base_model(repo_id: str, cn_on: bool, disable_model_cache: bool, prog
98
  else:
99
  progress(0, desc=f"Loading model: {repo_id}")
100
  print(f"Loading model: {repo_id}")
101
- pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype)
102
- pipe_i2i = AutoPipelineForImage2Image.from_pretrained(repo_id, vae=None, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
103
- tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype)
 
 
 
 
 
 
 
104
  last_model = repo_id
105
  last_cn_on = cn_on
106
  progress(1, desc=f"Model loaded: {repo_id}")
107
  print(f"Model loaded: {repo_id}")
108
  except Exception as e:
109
- print(f"Model load Error: {e}")
110
- raise gr.Error(f"Model load Error: {e}") from e
 
 
111
  return gr.update(visible=True)
112
 
113
  change_base_model.zerogpu = True
@@ -260,7 +280,6 @@ def add_custom_lora(custom_lora, selected_indices, current_loras, gallery):
260
  except Exception as e:
261
  print(e)
262
  image = None
263
- #image = image if is_repo_public(repo) else None
264
  print(f"Loaded custom LoRA: {repo}")
265
  existing_item_index = next((index for (index, item) in enumerate(current_loras) if item['repo'] == repo), None)
266
  if existing_item_index is None:
@@ -521,20 +540,27 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
521
  for idx, lora in enumerate(selected_loras):
522
  lora_name = f"lora_{idx}"
523
  lora_names.append(lora_name)
 
524
  lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2)
525
  lora_path = lora['repo']
526
  weight_name = lora.get("weights")
527
  print(f"Lora Path: {lora_path}")
528
  if image_input is not None:
529
- if weight_name:
530
- pipe_i2i.load_lora_weights(lora_path, weight_name=weight_name, low_cpu_mem_usage=True, adapter_name=lora_name, token=HF_TOKEN)
531
- else:
532
- pipe_i2i.load_lora_weights(lora_path, low_cpu_mem_usage=True, adapter_name=lora_name, token=HF_TOKEN)
 
 
 
533
  else:
534
- if weight_name:
535
- pipe.load_lora_weights(lora_path, weight_name=weight_name, low_cpu_mem_usage=True, adapter_name=lora_name, token=HF_TOKEN)
536
- else:
537
- pipe.load_lora_weights(lora_path, low_cpu_mem_usage=True, adapter_name=lora_name, token=HF_TOKEN)
 
 
 
538
  print("Loaded LoRAs:", lora_names)
539
  if image_input is not None:
540
  pipe_i2i.set_adapters(lora_names, adapter_weights=lora_weights)
@@ -722,8 +748,13 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', fill_width=True, css=css, delete_ca
722
  with gr.Accordion("History", open=False):
723
  history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
724
  with gr.Group():
725
- model_name = gr.Dropdown(label="Base Model", info="You can enter a huggingface model repo_id to want to use.", choices=models, value=models[0], allow_custom_value=True)
 
 
 
 
726
  model_info = gr.Markdown(elem_classes="info")
 
727
  with gr.Row():
728
  with gr.Accordion("Advanced Settings", open=False):
729
  with gr.Row():
@@ -756,7 +787,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', fill_width=True, css=css, delete_ca
756
  with gr.Row():
757
  for i in range(num_loras):
758
  with gr.Column():
759
- lora_repo[i] = gr.Dropdown(label=f"LoRA {int(i+1)} Repo", choices=get_all_lora_tupled_list(), info="Input LoRA Repo ID", value="", allow_custom_value=True)
760
  with gr.Row():
761
  lora_weights[i] = gr.Dropdown(label=f"LoRA {int(i+1)} Filename", choices=[], info="Optional", value="", allow_custom_value=True)
762
  lora_trigger[i] = gr.Textbox(label=f"LoRA {int(i+1)} Trigger Prompt", lines=1, max_lines=4, value="")
@@ -837,7 +868,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', fill_width=True, css=css, delete_ca
837
  gr.on(
838
  triggers=[generate_button.click, prompt.submit],
839
  fn=change_base_model,
840
- inputs=[model_name, cn_on, disable_model_cache],
841
  outputs=[result],
842
  queue=True,
843
  show_api=False,
 
5
  from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image
6
  from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
7
  from diffusers.utils import load_image
8
+ from diffusers import FluxControlNetPipeline, FluxControlNetModel, FluxMultiControlNetModel, FluxControlNetImg2ImgPipeline, FluxTransformer2DModel
9
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download, HfApi
10
  import os
11
  import copy
 
13
  import time
14
  import requests
15
  import pandas as pd
16
+ from pathlib import Path
17
 
18
+ from env import models, num_loras, num_cns, HF_TOKEN, single_file_base_models
19
  from mod import (clear_cache, get_repo_safetensors, is_repo_name, is_repo_exists, get_model_trigger,
20
  description_ui, compose_lora_json, is_valid_lora, fuse_loras, save_image, preprocess_i2i_image,
21
  get_trigger_word, enhance_prompt, set_control_union_image,
 
42
  dtype = torch.bfloat16
43
  #dtype = torch.float8_e4m3fn
44
  #device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype, token=HF_TOKEN)
46
+ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN)
47
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN)
48
  pipe_i2i = AutoPipelineForImage2Image.from_pretrained(base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
49
+ tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
50
  controlnet_union = None
51
  controlnet = None
52
  last_model = models[0]
 
71
  # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
72
  # https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux
73
  #@spaces.GPU()
74
+ def change_base_model(repo_id: str, cn_on: bool, disable_model_cache: bool, model_type: str, progress=gr.Progress(track_tqdm=True)):
75
  global pipe, pipe_i2i, taef1, good_vae, controlnet_union, controlnet, last_model, last_cn_on, dtype
76
+ safetensors_file = None
77
+ single_file_base_model = single_file_base_models.get(model_type, models[0])
78
  try:
79
+ #if not disable_model_cache and (repo_id == last_model and cn_on is last_cn_on) or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(visible=True)
80
+ if not disable_model_cache and (repo_id == last_model and cn_on is last_cn_on) or ((not is_repo_name(repo_id) or not is_repo_exists(repo_id)) and not ".safetensors" in repo_id): return gr.update(visible=True)
81
  unload_lora()
82
  pipe.to("cpu")
83
  pipe_i2i.to("cpu")
 
89
  if cn_on:
90
  progress(0, desc=f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
91
  print(f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
92
+ controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype, token=HF_TOKEN)
93
  controlnet = FluxMultiControlNetModel([controlnet_union])
94
  controlnet.config = controlnet_union.config
95
+ if ".safetensors" in repo_id:
96
+ safetensors_file = download_file_mod(repo_id)
97
+ transformer = FluxTransformer2DModel.from_single_file(safetensors_file, subfolder="transformer", torch_dtype=dtype, config=single_file_base_model)
98
+ pipe = FluxControlNetPipeline.from_pretrained(single_file_base_model, transformer=transformer, controlnet=controlnet, torch_dtype=dtype, token=HF_TOKEN)
99
+ pipe_i2i = FluxControlNetImg2ImgPipeline.from_pretrained(single_file_base_model, controlnet=controlnet, vae=None, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
100
+ tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
101
+ else:
102
+ pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=dtype, token=HF_TOKEN)
103
+ pipe_i2i = FluxControlNetImg2ImgPipeline.from_pretrained(repo_id, controlnet=controlnet, vae=None, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
104
+ tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
105
  last_model = repo_id
106
  last_cn_on = cn_on
107
  progress(1, desc=f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
 
109
  else:
110
  progress(0, desc=f"Loading model: {repo_id}")
111
  print(f"Loading model: {repo_id}")
112
+ if ".safetensors" in repo_id:
113
+ safetensors_file = download_file_mod(repo_id)
114
+ transformer = FluxTransformer2DModel.from_single_file(safetensors_file, subfolder="transformer", torch_dtype=dtype, config=single_file_base_model)
115
+ pipe = DiffusionPipeline.from_pretrained(single_file_base_model, transformer=transformer, torch_dtype=dtype, token=HF_TOKEN)
116
+ pipe_i2i = AutoPipelineForImage2Image.from_pretrained(single_file_base_model, vae=None, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
117
+ tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
118
+ else:
119
+ pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype, token=HF_TOKEN)
120
+ pipe_i2i = AutoPipelineForImage2Image.from_pretrained(repo_id, vae=None, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
121
+ tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
122
  last_model = repo_id
123
  last_cn_on = cn_on
124
  progress(1, desc=f"Model loaded: {repo_id}")
125
  print(f"Model loaded: {repo_id}")
126
  except Exception as e:
127
+ print(f"Model load Error: {repo_id} {e}")
128
+ raise gr.Error(f"Model load Error: {repo_id} {e}") from e
129
+ finally:
130
+ if safetensors_file and Path(safetensors_file).exists(): Path(safetensors_file).unlink()
131
  return gr.update(visible=True)
132
 
133
  change_base_model.zerogpu = True
 
280
  except Exception as e:
281
  print(e)
282
  image = None
 
283
  print(f"Loaded custom LoRA: {repo}")
284
  existing_item_index = next((index for (index, item) in enumerate(current_loras) if item['repo'] == repo), None)
285
  if existing_item_index is None:
 
540
  for idx, lora in enumerate(selected_loras):
541
  lora_name = f"lora_{idx}"
542
  lora_names.append(lora_name)
543
+ print(f"Lora Name: {lora_name}")
544
  lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2)
545
  lora_path = lora['repo']
546
  weight_name = lora.get("weights")
547
  print(f"Lora Path: {lora_path}")
548
  if image_input is not None:
549
+ pipe_i2i.load_lora_weights(
550
+ lora_path,
551
+ weight_name=weight_name if weight_name else None,
552
+ low_cpu_mem_usage=True,
553
+ adapter_name=lora_name,
554
+ token=HF_TOKEN
555
+ )
556
  else:
557
+ pipe.load_lora_weights(
558
+ lora_path,
559
+ weight_name=weight_name if weight_name else None,
560
+ low_cpu_mem_usage=True,
561
+ adapter_name=lora_name,
562
+ token=HF_TOKEN
563
+ )
564
  print("Loaded LoRAs:", lora_names)
565
  if image_input is not None:
566
  pipe_i2i.set_adapters(lora_names, adapter_weights=lora_weights)
 
748
  with gr.Accordion("History", open=False):
749
  history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
750
  with gr.Group():
751
+ with gr.Row():
752
+ model_name = gr.Dropdown(label="Base Model", info="You can enter a huggingface model repo_id or path of single safetensors file to want to use.",
753
+ choices=models, value=models[0], allow_custom_value=True, min_width=320, scale=5)
754
+ model_type = gr.Radio(label="Model type", info="Model type of single safetensors file",
755
+ choices=list(single_file_base_models.keys()), value=list(single_file_base_models.keys())[0], scale=1)
756
  model_info = gr.Markdown(elem_classes="info")
757
+
758
  with gr.Row():
759
  with gr.Accordion("Advanced Settings", open=False):
760
  with gr.Row():
 
787
  with gr.Row():
788
  for i in range(num_loras):
789
  with gr.Column():
790
+ lora_repo[i] = gr.Dropdown(label=f"LoRA {int(i+1)} Repo", choices=get_all_lora_tupled_list(), info="Input LoRA Repo ID", value="", allow_custom_value=True, min_width=320)
791
  with gr.Row():
792
  lora_weights[i] = gr.Dropdown(label=f"LoRA {int(i+1)} Filename", choices=[], info="Optional", value="", allow_custom_value=True)
793
  lora_trigger[i] = gr.Textbox(label=f"LoRA {int(i+1)} Trigger Prompt", lines=1, max_lines=4, value="")
 
868
  gr.on(
869
  triggers=[generate_button.click, prompt.submit],
870
  fn=change_base_model,
871
+ inputs=[model_name, cn_on, disable_model_cache, model_type],
872
  outputs=[result],
873
  queue=True,
874
  show_api=False,