Spaces:
Runtime error
Runtime error
Upload 5 files
Browse files- app.py +5 -2
- mod.py +31 -4
- requirements.txt +6 -1
app.py
CHANGED
@@ -11,7 +11,8 @@ import random
|
|
11 |
import time
|
12 |
|
13 |
from mod import (models, clear_cache, get_repo_safetensors, change_base_model,
|
14 |
-
description_ui, num_loras, compose_lora_json, is_valid_lora, fuse_loras,
|
|
|
15 |
from flux import (search_civitai_lora, select_civitai_lora, search_civitai_lora_json,
|
16 |
download_my_lora, get_all_lora_tupled_list, apply_lora_prompt,
|
17 |
update_loras)
|
@@ -241,6 +242,7 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
|
|
241 |
tagger_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use CogFlorence-2.1-Large", "Use Florence-2-Flux"], label="Algorithms", value=["Use WD Tagger"])
|
242 |
tagger_generate_from_image = gr.Button(value="Generate Prompt from Image")
|
243 |
prompt = gr.Textbox(label="Prompt", lines=1, max_lines=8, placeholder="Type a prompt")
|
|
|
244 |
with gr.Column(scale=1, elem_id="gen_column"):
|
245 |
generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
|
246 |
with gr.Row():
|
@@ -306,8 +308,8 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
|
|
306 |
with gr.Accordion("From URL", open=True, visible=True):
|
307 |
with gr.Row():
|
308 |
lora_search_civitai_query = gr.Textbox(label="Query", placeholder="flux", lines=1)
|
309 |
-
lora_search_civitai_basemodel = gr.CheckboxGroup(label="Search LoRA for", choices=["Flux.1 D", "Flux.1 S"], value=["Flux.1 D", "Flux.1 S"])
|
310 |
lora_search_civitai_submit = gr.Button("Search on Civitai")
|
|
|
311 |
lora_search_civitai_result = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
|
312 |
lora_search_civitai_json = gr.JSON(value={}, visible=False)
|
313 |
lora_search_civitai_desc = gr.Markdown(value="", visible=False)
|
@@ -344,6 +346,7 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
|
|
344 |
)
|
345 |
|
346 |
model_name.change(change_base_model, [model_name], [result])
|
|
|
347 |
|
348 |
gr.on(
|
349 |
triggers=[lora_search_civitai_submit.click, lora_search_civitai_query.submit],
|
|
|
11 |
import time
|
12 |
|
13 |
from mod import (models, clear_cache, get_repo_safetensors, change_base_model,
|
14 |
+
description_ui, num_loras, compose_lora_json, is_valid_lora, fuse_loras,
|
15 |
+
get_trigger_word, pipe, enhance_prompt)
|
16 |
from flux import (search_civitai_lora, select_civitai_lora, search_civitai_lora_json,
|
17 |
download_my_lora, get_all_lora_tupled_list, apply_lora_prompt,
|
18 |
update_loras)
|
|
|
242 |
tagger_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use CogFlorence-2.1-Large", "Use Florence-2-Flux"], label="Algorithms", value=["Use WD Tagger"])
|
243 |
tagger_generate_from_image = gr.Button(value="Generate Prompt from Image")
|
244 |
prompt = gr.Textbox(label="Prompt", lines=1, max_lines=8, placeholder="Type a prompt")
|
245 |
+
prompt_enhance = gr.Button(value="Enhance your prompt", variant="secondary")
|
246 |
with gr.Column(scale=1, elem_id="gen_column"):
|
247 |
generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
|
248 |
with gr.Row():
|
|
|
308 |
with gr.Accordion("From URL", open=True, visible=True):
|
309 |
with gr.Row():
|
310 |
lora_search_civitai_query = gr.Textbox(label="Query", placeholder="flux", lines=1)
|
|
|
311 |
lora_search_civitai_submit = gr.Button("Search on Civitai")
|
312 |
+
lora_search_civitai_basemodel = gr.CheckboxGroup(label="Search LoRA for", choices=["Flux.1 D", "Flux.1 S"], value=["Flux.1 D", "Flux.1 S"])
|
313 |
lora_search_civitai_result = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
|
314 |
lora_search_civitai_json = gr.JSON(value={}, visible=False)
|
315 |
lora_search_civitai_desc = gr.Markdown(value="", visible=False)
|
|
|
346 |
)
|
347 |
|
348 |
model_name.change(change_base_model, [model_name], [result])
|
349 |
+
prompt_enhance.click(enhance_prompt, [prompt], [prompt])
|
350 |
|
351 |
gr.on(
|
352 |
triggers=[lora_search_civitai_submit.click, lora_search_civitai_query.submit],
|
mod.py
CHANGED
@@ -7,6 +7,7 @@ import gc
|
|
7 |
import subprocess
|
8 |
|
9 |
|
|
|
10 |
subprocess.run('pip cache purge', shell=True)
|
11 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
torch.set_grad_enabled(False)
|
@@ -61,7 +62,7 @@ def get_repo_safetensors(repo_id: str):
|
|
61 |
if not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(value="", choices=[])
|
62 |
files = api.list_repo_files(repo_id=repo_id)
|
63 |
except Exception as e:
|
64 |
-
print(f"Error: Failed to get {repo_id}'s info.
|
65 |
print(e)
|
66 |
return gr.update(choices=[])
|
67 |
files = [f for f in files if f.endswith(".safetensors")]
|
@@ -138,8 +139,7 @@ def fuse_loras(pipe, lorajson: list[dict]):
|
|
138 |
#pipe.unload_lora_weights()
|
139 |
|
140 |
|
141 |
-
|
142 |
-
fuse_loras.zerogpu = True
|
143 |
|
144 |
|
145 |
def description_ui():
|
@@ -148,4 +148,31 @@ def description_ui():
|
|
148 |
- Mod of [multimodalart/flux-lora-the-explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer),
|
149 |
[gokaygokay/FLUX-Prompt-Generator](https://huggingface.co/spaces/gokaygokay/FLUX-Prompt-Generator).
|
150 |
"""
|
151 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
import subprocess
|
8 |
|
9 |
|
10 |
+
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
11 |
subprocess.run('pip cache purge', shell=True)
|
12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
torch.set_grad_enabled(False)
|
|
|
62 |
if not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(value="", choices=[])
|
63 |
files = api.list_repo_files(repo_id=repo_id)
|
64 |
except Exception as e:
|
65 |
+
print(f"Error: Failed to get {repo_id}'s info.")
|
66 |
print(e)
|
67 |
return gr.update(choices=[])
|
68 |
files = [f for f in files if f.endswith(".safetensors")]
|
|
|
139 |
#pipe.unload_lora_weights()
|
140 |
|
141 |
|
142 |
+
|
|
|
143 |
|
144 |
|
145 |
def description_ui():
|
|
|
148 |
- Mod of [multimodalart/flux-lora-the-explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer),
|
149 |
[gokaygokay/FLUX-Prompt-Generator](https://huggingface.co/spaces/gokaygokay/FLUX-Prompt-Generator).
|
150 |
"""
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
|
155 |
+
def load_prompt_enhancer():
|
156 |
+
try:
|
157 |
+
model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
|
158 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
159 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).eval().to(device=device)
|
160 |
+
enhancer_flux = pipeline('text2text-generation', model=model, tokenizer=tokenizer, repetition_penalty=1.5, device=device)
|
161 |
+
except Exception as e:
|
162 |
+
print(e)
|
163 |
+
enhancer_flux = None
|
164 |
+
return enhancer_flux
|
165 |
+
|
166 |
+
|
167 |
+
enhancer_flux = load_prompt_enhancer()
|
168 |
+
|
169 |
+
|
170 |
+
def enhance_prompt(input_prompt):
|
171 |
+
result = enhancer_flux("enhance prompt: " + input_prompt, max_length = 256)
|
172 |
+
enhanced_text = result[0]['generated_text']
|
173 |
+
return enhanced_text
|
174 |
+
|
175 |
+
|
176 |
+
load_prompt_enhancer.zerogpu = True
|
177 |
+
change_base_model.zerogpu = True
|
178 |
+
fuse_loras.zerogpu = True
|
requirements.txt
CHANGED
@@ -1,7 +1,12 @@
|
|
1 |
torch
|
|
|
|
|
|
|
2 |
git+https://github.com/huggingface/diffusers
|
3 |
spaces
|
4 |
transformers
|
5 |
peft
|
6 |
sentencepiece
|
7 |
-
timm
|
|
|
|
|
|
1 |
torch
|
2 |
+
torchvision
|
3 |
+
huggingface_hub
|
4 |
+
accelerate
|
5 |
git+https://github.com/huggingface/diffusers
|
6 |
spaces
|
7 |
transformers
|
8 |
peft
|
9 |
sentencepiece
|
10 |
+
timm
|
11 |
+
xformers
|
12 |
+
einops
|