Spaces:
Running
on
Zero
Running
on
Zero
Upload app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5 |
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image
|
6 |
from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
|
7 |
from diffusers.utils import load_image
|
8 |
-
from diffusers import FluxControlNetPipeline, FluxControlNetModel, FluxMultiControlNetModel, FluxControlNetImg2ImgPipeline
|
9 |
from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download, HfApi
|
10 |
import os
|
11 |
import copy
|
@@ -13,8 +13,9 @@ import random
|
|
13 |
import time
|
14 |
import requests
|
15 |
import pandas as pd
|
|
|
16 |
|
17 |
-
from env import models, num_loras, num_cns, HF_TOKEN
|
18 |
from mod import (clear_cache, get_repo_safetensors, is_repo_name, is_repo_exists, get_model_trigger,
|
19 |
description_ui, compose_lora_json, is_valid_lora, fuse_loras, save_image, preprocess_i2i_image,
|
20 |
get_trigger_word, enhance_prompt, set_control_union_image,
|
@@ -41,11 +42,11 @@ controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
|
|
41 |
dtype = torch.bfloat16
|
42 |
#dtype = torch.float8_e4m3fn
|
43 |
#device = "cuda" if torch.cuda.is_available() else "cpu"
|
44 |
-
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
|
45 |
-
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype)
|
46 |
-
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1)
|
47 |
pipe_i2i = AutoPipelineForImage2Image.from_pretrained(base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
|
48 |
-
tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype)
|
49 |
controlnet_union = None
|
50 |
controlnet = None
|
51 |
last_model = models[0]
|
@@ -70,10 +71,13 @@ def unload_lora():
|
|
70 |
# https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
|
71 |
# https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux
|
72 |
#@spaces.GPU()
|
73 |
-
def change_base_model(repo_id: str, cn_on: bool, disable_model_cache: bool, progress=gr.Progress(track_tqdm=True)):
|
74 |
global pipe, pipe_i2i, taef1, good_vae, controlnet_union, controlnet, last_model, last_cn_on, dtype
|
|
|
|
|
75 |
try:
|
76 |
-
if not disable_model_cache and (repo_id == last_model and cn_on is last_cn_on) or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(visible=True)
|
|
|
77 |
unload_lora()
|
78 |
pipe.to("cpu")
|
79 |
pipe_i2i.to("cpu")
|
@@ -85,12 +89,19 @@ def change_base_model(repo_id: str, cn_on: bool, disable_model_cache: bool, prog
|
|
85 |
if cn_on:
|
86 |
progress(0, desc=f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
|
87 |
print(f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
|
88 |
-
controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype)
|
89 |
controlnet = FluxMultiControlNetModel([controlnet_union])
|
90 |
controlnet.config = controlnet_union.config
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
last_model = repo_id
|
95 |
last_cn_on = cn_on
|
96 |
progress(1, desc=f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
|
@@ -98,16 +109,25 @@ def change_base_model(repo_id: str, cn_on: bool, disable_model_cache: bool, prog
|
|
98 |
else:
|
99 |
progress(0, desc=f"Loading model: {repo_id}")
|
100 |
print(f"Loading model: {repo_id}")
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
last_model = repo_id
|
105 |
last_cn_on = cn_on
|
106 |
progress(1, desc=f"Model loaded: {repo_id}")
|
107 |
print(f"Model loaded: {repo_id}")
|
108 |
except Exception as e:
|
109 |
-
print(f"Model load Error: {e}")
|
110 |
-
raise gr.Error(f"Model load Error: {e}") from e
|
|
|
|
|
111 |
return gr.update(visible=True)
|
112 |
|
113 |
change_base_model.zerogpu = True
|
@@ -260,7 +280,6 @@ def add_custom_lora(custom_lora, selected_indices, current_loras, gallery):
|
|
260 |
except Exception as e:
|
261 |
print(e)
|
262 |
image = None
|
263 |
-
#image = image if is_repo_public(repo) else None
|
264 |
print(f"Loaded custom LoRA: {repo}")
|
265 |
existing_item_index = next((index for (index, item) in enumerate(current_loras) if item['repo'] == repo), None)
|
266 |
if existing_item_index is None:
|
@@ -521,20 +540,27 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
|
|
521 |
for idx, lora in enumerate(selected_loras):
|
522 |
lora_name = f"lora_{idx}"
|
523 |
lora_names.append(lora_name)
|
|
|
524 |
lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2)
|
525 |
lora_path = lora['repo']
|
526 |
weight_name = lora.get("weights")
|
527 |
print(f"Lora Path: {lora_path}")
|
528 |
if image_input is not None:
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
|
|
|
|
|
|
533 |
else:
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
|
|
|
|
|
|
538 |
print("Loaded LoRAs:", lora_names)
|
539 |
if image_input is not None:
|
540 |
pipe_i2i.set_adapters(lora_names, adapter_weights=lora_weights)
|
@@ -722,8 +748,13 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', fill_width=True, css=css, delete_ca
|
|
722 |
with gr.Accordion("History", open=False):
|
723 |
history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
|
724 |
with gr.Group():
|
725 |
-
|
|
|
|
|
|
|
|
|
726 |
model_info = gr.Markdown(elem_classes="info")
|
|
|
727 |
with gr.Row():
|
728 |
with gr.Accordion("Advanced Settings", open=False):
|
729 |
with gr.Row():
|
@@ -756,7 +787,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', fill_width=True, css=css, delete_ca
|
|
756 |
with gr.Row():
|
757 |
for i in range(num_loras):
|
758 |
with gr.Column():
|
759 |
-
lora_repo[i] = gr.Dropdown(label=f"LoRA {int(i+1)} Repo", choices=get_all_lora_tupled_list(), info="Input LoRA Repo ID", value="", allow_custom_value=True)
|
760 |
with gr.Row():
|
761 |
lora_weights[i] = gr.Dropdown(label=f"LoRA {int(i+1)} Filename", choices=[], info="Optional", value="", allow_custom_value=True)
|
762 |
lora_trigger[i] = gr.Textbox(label=f"LoRA {int(i+1)} Trigger Prompt", lines=1, max_lines=4, value="")
|
@@ -837,7 +868,7 @@ with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', fill_width=True, css=css, delete_ca
|
|
837 |
gr.on(
|
838 |
triggers=[generate_button.click, prompt.submit],
|
839 |
fn=change_base_model,
|
840 |
-
inputs=[model_name, cn_on, disable_model_cache],
|
841 |
outputs=[result],
|
842 |
queue=True,
|
843 |
show_api=False,
|
|
|
5 |
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image
|
6 |
from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
|
7 |
from diffusers.utils import load_image
|
8 |
+
from diffusers import FluxControlNetPipeline, FluxControlNetModel, FluxMultiControlNetModel, FluxControlNetImg2ImgPipeline, FluxTransformer2DModel
|
9 |
from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download, HfApi
|
10 |
import os
|
11 |
import copy
|
|
|
13 |
import time
|
14 |
import requests
|
15 |
import pandas as pd
|
16 |
+
from pathlib import Path
|
17 |
|
18 |
+
from env import models, num_loras, num_cns, HF_TOKEN, single_file_base_models
|
19 |
from mod import (clear_cache, get_repo_safetensors, is_repo_name, is_repo_exists, get_model_trigger,
|
20 |
description_ui, compose_lora_json, is_valid_lora, fuse_loras, save_image, preprocess_i2i_image,
|
21 |
get_trigger_word, enhance_prompt, set_control_union_image,
|
|
|
42 |
dtype = torch.bfloat16
|
43 |
#dtype = torch.float8_e4m3fn
|
44 |
#device = "cuda" if torch.cuda.is_available() else "cpu"
|
45 |
+
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype, token=HF_TOKEN)
|
46 |
+
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN)
|
47 |
+
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN)
|
48 |
pipe_i2i = AutoPipelineForImage2Image.from_pretrained(base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
|
49 |
+
tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
|
50 |
controlnet_union = None
|
51 |
controlnet = None
|
52 |
last_model = models[0]
|
|
|
71 |
# https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
|
72 |
# https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux
|
73 |
#@spaces.GPU()
|
74 |
+
def change_base_model(repo_id: str, cn_on: bool, disable_model_cache: bool, model_type: str, progress=gr.Progress(track_tqdm=True)):
|
75 |
global pipe, pipe_i2i, taef1, good_vae, controlnet_union, controlnet, last_model, last_cn_on, dtype
|
76 |
+
safetensors_file = None
|
77 |
+
single_file_base_model = single_file_base_models.get(model_type, models[0])
|
78 |
try:
|
79 |
+
#if not disable_model_cache and (repo_id == last_model and cn_on is last_cn_on) or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(visible=True)
|
80 |
+
if not disable_model_cache and (repo_id == last_model and cn_on is last_cn_on) or ((not is_repo_name(repo_id) or not is_repo_exists(repo_id)) and not ".safetensors" in repo_id): return gr.update(visible=True)
|
81 |
unload_lora()
|
82 |
pipe.to("cpu")
|
83 |
pipe_i2i.to("cpu")
|
|
|
89 |
if cn_on:
|
90 |
progress(0, desc=f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
|
91 |
print(f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
|
92 |
+
controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype, token=HF_TOKEN)
|
93 |
controlnet = FluxMultiControlNetModel([controlnet_union])
|
94 |
controlnet.config = controlnet_union.config
|
95 |
+
if ".safetensors" in repo_id:
|
96 |
+
safetensors_file = download_file_mod(repo_id)
|
97 |
+
transformer = FluxTransformer2DModel.from_single_file(safetensors_file, subfolder="transformer", torch_dtype=dtype, config=single_file_base_model)
|
98 |
+
pipe = FluxControlNetPipeline.from_pretrained(single_file_base_model, transformer=transformer, controlnet=controlnet, torch_dtype=dtype, token=HF_TOKEN)
|
99 |
+
pipe_i2i = FluxControlNetImg2ImgPipeline.from_pretrained(single_file_base_model, controlnet=controlnet, vae=None, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
|
100 |
+
tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
|
101 |
+
else:
|
102 |
+
pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=dtype, token=HF_TOKEN)
|
103 |
+
pipe_i2i = FluxControlNetImg2ImgPipeline.from_pretrained(repo_id, controlnet=controlnet, vae=None, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
|
104 |
+
tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
|
105 |
last_model = repo_id
|
106 |
last_cn_on = cn_on
|
107 |
progress(1, desc=f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
|
|
|
109 |
else:
|
110 |
progress(0, desc=f"Loading model: {repo_id}")
|
111 |
print(f"Loading model: {repo_id}")
|
112 |
+
if ".safetensors" in repo_id:
|
113 |
+
safetensors_file = download_file_mod(repo_id)
|
114 |
+
transformer = FluxTransformer2DModel.from_single_file(safetensors_file, subfolder="transformer", torch_dtype=dtype, config=single_file_base_model)
|
115 |
+
pipe = DiffusionPipeline.from_pretrained(single_file_base_model, transformer=transformer, torch_dtype=dtype, token=HF_TOKEN)
|
116 |
+
pipe_i2i = AutoPipelineForImage2Image.from_pretrained(single_file_base_model, vae=None, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
|
117 |
+
tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
|
118 |
+
else:
|
119 |
+
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype, token=HF_TOKEN)
|
120 |
+
pipe_i2i = AutoPipelineForImage2Image.from_pretrained(repo_id, vae=None, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
|
121 |
+
tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
|
122 |
last_model = repo_id
|
123 |
last_cn_on = cn_on
|
124 |
progress(1, desc=f"Model loaded: {repo_id}")
|
125 |
print(f"Model loaded: {repo_id}")
|
126 |
except Exception as e:
|
127 |
+
print(f"Model load Error: {repo_id} {e}")
|
128 |
+
raise gr.Error(f"Model load Error: {repo_id} {e}") from e
|
129 |
+
finally:
|
130 |
+
if safetensors_file and Path(safetensors_file).exists(): Path(safetensors_file).unlink()
|
131 |
return gr.update(visible=True)
|
132 |
|
133 |
change_base_model.zerogpu = True
|
|
|
280 |
except Exception as e:
|
281 |
print(e)
|
282 |
image = None
|
|
|
283 |
print(f"Loaded custom LoRA: {repo}")
|
284 |
existing_item_index = next((index for (index, item) in enumerate(current_loras) if item['repo'] == repo), None)
|
285 |
if existing_item_index is None:
|
|
|
540 |
for idx, lora in enumerate(selected_loras):
|
541 |
lora_name = f"lora_{idx}"
|
542 |
lora_names.append(lora_name)
|
543 |
+
print(f"Lora Name: {lora_name}")
|
544 |
lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2)
|
545 |
lora_path = lora['repo']
|
546 |
weight_name = lora.get("weights")
|
547 |
print(f"Lora Path: {lora_path}")
|
548 |
if image_input is not None:
|
549 |
+
pipe_i2i.load_lora_weights(
|
550 |
+
lora_path,
|
551 |
+
weight_name=weight_name if weight_name else None,
|
552 |
+
low_cpu_mem_usage=True,
|
553 |
+
adapter_name=lora_name,
|
554 |
+
token=HF_TOKEN
|
555 |
+
)
|
556 |
else:
|
557 |
+
pipe.load_lora_weights(
|
558 |
+
lora_path,
|
559 |
+
weight_name=weight_name if weight_name else None,
|
560 |
+
low_cpu_mem_usage=True,
|
561 |
+
adapter_name=lora_name,
|
562 |
+
token=HF_TOKEN
|
563 |
+
)
|
564 |
print("Loaded LoRAs:", lora_names)
|
565 |
if image_input is not None:
|
566 |
pipe_i2i.set_adapters(lora_names, adapter_weights=lora_weights)
|
|
|
748 |
with gr.Accordion("History", open=False):
|
749 |
history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
|
750 |
with gr.Group():
|
751 |
+
with gr.Row():
|
752 |
+
model_name = gr.Dropdown(label="Base Model", info="You can enter a huggingface model repo_id or path of single safetensors file to want to use.",
|
753 |
+
choices=models, value=models[0], allow_custom_value=True, min_width=320, scale=5)
|
754 |
+
model_type = gr.Radio(label="Model type", info="Model type of single safetensors file",
|
755 |
+
choices=list(single_file_base_models.keys()), value=list(single_file_base_models.keys())[0], scale=1)
|
756 |
model_info = gr.Markdown(elem_classes="info")
|
757 |
+
|
758 |
with gr.Row():
|
759 |
with gr.Accordion("Advanced Settings", open=False):
|
760 |
with gr.Row():
|
|
|
787 |
with gr.Row():
|
788 |
for i in range(num_loras):
|
789 |
with gr.Column():
|
790 |
+
lora_repo[i] = gr.Dropdown(label=f"LoRA {int(i+1)} Repo", choices=get_all_lora_tupled_list(), info="Input LoRA Repo ID", value="", allow_custom_value=True, min_width=320)
|
791 |
with gr.Row():
|
792 |
lora_weights[i] = gr.Dropdown(label=f"LoRA {int(i+1)} Filename", choices=[], info="Optional", value="", allow_custom_value=True)
|
793 |
lora_trigger[i] = gr.Textbox(label=f"LoRA {int(i+1)} Trigger Prompt", lines=1, max_lines=4, value="")
|
|
|
868 |
gr.on(
|
869 |
triggers=[generate_button.click, prompt.submit],
|
870 |
fn=change_base_model,
|
871 |
+
inputs=[model_name, cn_on, disable_model_cache, model_type],
|
872 |
outputs=[result],
|
873 |
queue=True,
|
874 |
show_api=False,
|