diff --git "a/app.py" "b/app.py"
--- "a/app.py"
+++ "b/app.py"
@@ -1,1456 +1,1269 @@
-import spaces
-import gradio as gr
-import json
-import torch
-from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image, AutoPipelineForInpainting
-from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
-from diffusers.utils import load_image
-from diffusers import FluxControlNetPipeline, FluxControlNetModel, FluxMultiControlNetModel, FluxControlNetImg2ImgPipeline, FluxTransformer2DModel, FluxControlNetInpaintPipeline, FluxInpaintPipeline
-from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download, HfApi
-import os
-import copy
-import random
-import time
-import requests
-import pandas as pd
-from pathlib import Path
-
-from env import models, num_loras, num_cns, HF_TOKEN, single_file_base_models
-from mod import (clear_cache, get_repo_safetensors, is_repo_name, is_repo_exists, get_model_trigger,
- description_ui, compose_lora_json, is_valid_lora, fuse_loras, save_image, preprocess_i2i_image,
- get_trigger_word, enhance_prompt, set_control_union_image,
- get_control_union_mode, set_control_union_mode, get_control_params, translate_to_en)
-from modutils import (search_civitai_lora, select_civitai_lora, search_civitai_lora_json,
- download_my_lora_flux, get_all_lora_tupled_list, apply_lora_prompt_flux,
- update_loras_flux, update_civitai_selection, get_civitai_tag, CIVITAI_SORT, CIVITAI_PERIOD,
- get_t2i_model_info, download_hf_file, save_image_history)
-from tagger.tagger import predict_tags_wd, compose_prompt_to_copy
-from tagger.fl2flux import predict_tags_fl2_flux
-
-CUSTOM_PLACEHOLDER = os.path.join(os.getcwd(), "custom.png")
-
-#Load prompts for randomization
-df = pd.read_csv('prompts.csv', header=None)
-prompt_values = df.values.flatten()
-
-# Load LoRAs from JSON file
-with open('loras.json', 'r') as f:
- loras = json.load(f)
-
-# Initialize the base model
-base_model = models[0]
-controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
-#controlnet_model_union_repo = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
-dtype = torch.bfloat16
-#dtype = torch.float8_e4m3fn
-#device = "cuda" if torch.cuda.is_available() else "cpu"
-taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype, token=HF_TOKEN)
-good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN)
-pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN)
-pipe_i2i = AutoPipelineForImage2Image.from_pretrained(base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
- tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
-pipe_ip = AutoPipelineForInpainting.from_pretrained(base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
- tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
-controlnet_union = None
-controlnet = None
-last_model = models[0]
-last_cn_on = False
-#controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype)
-#controlnet = FluxMultiControlNetModel([controlnet_union])
-#controlnet.config = controlnet_union.config
-
-MAX_SEED = 2**32-1
-
-# Funktion für Tests
-def run_test(input_text, debug_log):
- try:
- # Eingabe auf bekannte Tests prüfen
- if input_text == "get_custom_image":
- result = get_custom_image()
- else:
- result = f"Unbekannter Test: {input_text}"
- # Ergebnis ins Debug-Log schreiben
- updated_log = debug_log + f"\nTest '{input_text}': {result}"
- return updated_log
- except Exception as e:
- # Fehler ebenfalls ins Debug-Log schreiben
- updated_log = debug_log + f"\nFehler beim Test '{input_text}': {str(e)}"
- return updated_log
-
-# Hilfsfunktion zum Anhängen von Logs
-def append_debug_log(log_text, current_logs=""):
- """Fügt einen neuen Log-Eintrag hinzu."""
- updated_logs = current_logs + f"\n{log_text}"
- return updated_logs
-
-# Gradio Blocks definieren
-with gr.Blocks() as app:
- # Debug-Log-Feld hinzufügen
- debug_log = gr.Textbox(
- label="Debug Log",
- interactive=False,
- lines=10,
- placeholder="Hier erscheinen Debug-Informationen...",
- type="text" # Nur Text wird akzeptiert
- )
-
- # Test-Input-Feld und Button
- with gr.Row():
- test_input = gr.Textbox(
- label="Test Input",
- placeholder="Gib den Namen einer Funktion ein, z.B. 'get_custom_image'.",
- )
- test_button = gr.Button("Run Test")
-
- # Test-Button mit der Funktion verbinden
- test_button.click(
- fn=run_test,
- inputs=[test_input, debug_log],
- outputs=debug_log
- )
-
- # Ein Beispiel-Funktionalität: Dummy-Echo
- with gr.Row():
- input_box = gr.Textbox(label="Input")
- output_box = gr.Textbox(label="Output")
- dummy_button = gr.Button("Dummy Test")
-
- def dummy_function(text):
- return f"Echo: {text}"
-
- dummy_button.click(dummy_function, inputs=[input_box], outputs=[output_box])
-
-# App starten
-app.launch()
-
-
-
-def unload_lora():
- global pipe, pipe_i2i, pipe_ip
- try:
- #pipe.unfuse_lora()
- pipe.unload_lora_weights()
- #pipe_i2i.unfuse_lora()
- pipe_i2i.unload_lora_weights()
- pipe_ip.unload_lora_weights()
- except Exception as e:
- print(e)
-
-def download_file_mod(url, directory=os.getcwd()):
- path = download_hf_file(directory, url, hf_token=HF_TOKEN)
- if not path: raise Exception(f"Download error: {url}")
- return path
-
-# https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union
-# https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
-# https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux
-#@spaces.GPU()
-def change_base_model(repo_id: str, cn_on: bool, disable_model_cache: bool, model_type: str, progress=gr.Progress(track_tqdm=True)):
- global pipe, pipe_i2i, pipe_ip, taef1, good_vae, controlnet_union, controlnet, last_model, last_cn_on, dtype
- safetensors_file = None
- single_file_base_model = single_file_base_models.get(model_type, models[0])
- try:
- #if not disable_model_cache and (repo_id == last_model and cn_on is last_cn_on) or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(visible=True)
- if not disable_model_cache and (repo_id == last_model and cn_on is last_cn_on) or ((not is_repo_name(repo_id) or not is_repo_exists(repo_id)) and not ".safetensors" in repo_id): return gr.update()
- unload_lora()
- pipe.to("cpu")
- pipe_i2i.to("cpu")
- pipe_ip.to("cpu")
- good_vae.to("cpu")
- taef1.to("cpu")
- if controlnet is not None: controlnet.to("cpu")
- if controlnet_union is not None: controlnet_union.to("cpu")
- clear_cache()
- if cn_on:
- progress(0, desc=f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
- print(f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
- controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype, token=HF_TOKEN)
- controlnet = FluxMultiControlNetModel([controlnet_union])
- controlnet.config = controlnet_union.config
- if ".safetensors" in repo_id:
- safetensors_file = download_file_mod(repo_id)
- transformer = FluxTransformer2DModel.from_single_file(safetensors_file, subfolder="transformer", torch_dtype=dtype, config=single_file_base_model)
- pipe = FluxControlNetPipeline.from_pretrained(single_file_base_model, transformer=transformer, controlnet=controlnet, torch_dtype=dtype, token=HF_TOKEN)
- pipe_i2i = FluxControlNetImg2ImgPipeline.from_pretrained(single_file_base_model, controlnet=controlnet, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
- tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
- pipe_ip = FluxControlNetInpaintPipeline.from_pretrained(single_file_base_model, controlnet=controlnet, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
- tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
- else:
- pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=dtype, token=HF_TOKEN)
- pipe_i2i = FluxControlNetImg2ImgPipeline.from_pretrained(repo_id, controlnet=controlnet, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
- tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
- pipe_ip = FluxControlNetInpaintPipeline.from_pretrained(repo_id, controlnet=controlnet, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
- tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
- last_model = repo_id
- last_cn_on = cn_on
- progress(1, desc=f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
- print(f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
- else:
- progress(0, desc=f"Loading model: {repo_id}")
- print(f"Loading model: {repo_id}")
- if ".safetensors" in repo_id:
- safetensors_file = download_file_mod(repo_id)
- transformer = FluxTransformer2DModel.from_single_file(safetensors_file, subfolder="transformer", torch_dtype=dtype, config=single_file_base_model)
- pipe = DiffusionPipeline.from_pretrained(single_file_base_model, transformer=transformer, torch_dtype=dtype, token=HF_TOKEN)
- pipe_i2i = AutoPipelineForImage2Image.from_pretrained(single_file_base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
- tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
- pipe_ip = AutoPipelineForInpainting.from_pretrained(single_file_base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
- tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
- else:
- pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype, token=HF_TOKEN)
- pipe_i2i = AutoPipelineForImage2Image.from_pretrained(repo_id, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
- tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
- pipe_ip = AutoPipelineForInpainting.from_pretrained(repo_id, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
- tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
- last_model = repo_id
- last_cn_on = cn_on
- progress(1, desc=f"Model loaded: {repo_id}")
- print(f"Model loaded: {repo_id}")
- except Exception as e:
- print(f"Model load Error: {repo_id} {e}")
- raise gr.Error(f"Model load Error: {repo_id} {e}") from e
- finally:
- if safetensors_file and Path(safetensors_file).exists(): Path(safetensors_file).unlink()
- return gr.update()
-
-change_base_model.zerogpu = True
-
-def is_repo_public(repo_id: str):
- api = HfApi()
- try:
- if api.repo_exists(repo_id=repo_id, token=False): return True
- else: return False
- except Exception as e:
- print(f"Error: Failed to connect {repo_id}. {e}")
- return False
-
-class calculateDuration:
- def __init__(self, activity_name=""):
- self.activity_name = activity_name
-
- def __enter__(self):
- self.start_time = time.time()
- return self
-
- def __exit__(self, exc_type, exc_value, traceback):
- self.end_time = time.time()
- self.elapsed_time = self.end_time - self.start_time
- if self.activity_name:
- print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
- else:
- print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
-
-def download_file(url, directory=None):
- if directory is None:
- directory = os.getcwd() # Use current working directory if not specified
-
- # Get the filename from the URL
- filename = url.split('/')[-1]
-
- # Full path for the downloaded file
- filepath = os.path.join(directory, filename)
-
- # Download the file
- response = requests.get(url)
- response.raise_for_status() # Raise an exception for bad status codes
-
- # Write the content to the file
- with open(filepath, 'wb') as file:
- file.write(response.content)
-
- return filepath
-
-def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
- selected_index = evt.index
- selected_indices = selected_indices or []
- if selected_index in selected_indices:
- selected_indices.remove(selected_index)
- else:
- if len(selected_indices) < 2:
- selected_indices.append(selected_index)
- else:
- gr.Warning("You can select up to 2 LoRAs, remove one to select a new one.")
- return gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), width, height, gr.update(), gr.update()
-
- selected_info_1 = "Select a LoRA 1"
- selected_info_2 = "Select a LoRA 2"
- lora_scale_1 = 1.15
- lora_scale_2 = 1.15
- lora_image_1 = None
- lora_image_2 = None
- if len(selected_indices) >= 1:
- lora1 = loras_state[selected_indices[0]]
- selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
- lora_image_1 = lora1['image']
- if len(selected_indices) >= 2:
- lora2 = loras_state[selected_indices[1]]
- selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
- lora_image_2 = lora2['image']
-
- if selected_indices:
- last_selected_lora = loras_state[selected_indices[-1]]
- new_placeholder = f"Type a prompt for {last_selected_lora['title']}"
- else:
- new_placeholder = "Type a prompt"
-
- return gr.update(placeholder=new_placeholder), selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, width, height, lora_image_1, lora_image_2
-
-def remove_lora_1(selected_indices, loras_state):
- if len(selected_indices) >= 1:
- selected_indices.pop(0)
- selected_info_1 = "Select LoRA 1"
- selected_info_2 = "Select LoRA 2"
- lora_scale_1 = 1.15
- lora_scale_2 = 1.15
- lora_image_1 = None
- lora_image_2 = None
- if len(selected_indices) >= 1:
- lora1 = loras_state[selected_indices[0]]
- selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
- lora_image_1 = lora1['image']
- if len(selected_indices) >= 2:
- lora2 = loras_state[selected_indices[1]]
- selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
- lora_image_2 = lora2['image']
- return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
-
-def remove_lora_2(selected_indices, loras_state):
- if len(selected_indices) >= 2:
- selected_indices.pop(1)
- selected_info_1 = "Select LoRA 1"
- selected_info_2 = "Select LoRA 2"
- lora_scale_1 = 1.15
- lora_scale_2 = 1.15
- lora_image_1 = None
- lora_image_2 = None
- if len(selected_indices) >= 1:
- lora1 = loras_state[selected_indices[0]]
- selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
- lora_image_1 = lora1['image']
- if len(selected_indices) >= 2:
- lora2 = loras_state[selected_indices[1]]
- selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
- lora_image_2 = lora2['image']
- return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
-
-def randomize_loras(selected_indices, loras_state):
- if len(loras_state) < 2:
- raise gr.Error("Not enough LoRAs to randomize.")
- selected_indices = random.sample(range(len(loras_state)), 2)
- lora1 = loras_state[selected_indices[0]]
- lora2 = loras_state[selected_indices[1]]
- selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
- selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
- lora_scale_1 = 1.15
- lora_scale_2 = 1.15
- lora_image_1 = lora1['image']
- lora_image_2 = lora2['image']
- random_prompt = random.choice(prompt_values)
- return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2, random_prompt
-
-def download_loras_images(loras_json_orig: list[dict]):
- """
- Optimierte Funktion zur Handhabung von Bild-URLs aus Repositories mit Fallback-Logik.
- """
- default_placeholder = "/path/to/default-placeholder.png" # Platzhalterbild für fehlende Bilder
- loras_json = []
-
- for lora in loras_json_orig:
- repo = lora.get("repo", None)
- image_url = lora.get("image", None)
-
- # Standardwerte und Fallbacks
- lora["title"] = lora.get("title", "Unknown LoRA")
- lora["trigger_word"] = lora.get("trigger_word", "")
- resolved_image_url = None
-
- # 1. Prüfen und Laden des Repository-Bildes
- if repo:
- repo_image_url = f"https://huggingface.co/{repo}/resolve/main/{image_url}" if image_url else None
- try:
- if repo_image_url and requests.head(repo_image_url).status_code == 200:
- resolved_image_url = repo_image_url
- except Exception as e:
- print(f"Fehler beim Laden des Repo-Bildes: {repo_image_url}: {e}")
-
- # 2. Fallback: Laden des Bildes aus der JSON-URL (Hotlink)
- if not resolved_image_url and image_url:
- try:
- if requests.head(image_url).status_code == 200:
- resolved_image_url = image_url
- except Exception as e:
- print(f"Fehler beim Laden des Hotlink-Bildes: {image_url}: {e}")
-
- # 3. Fallback: Platzhalterbild verwenden
- lora["image"] = resolved_image_url if resolved_image_url else default_placeholder
- loras_json.append(lora)
-
- return loras_json
-
-
-def handle_gallery_click(evt: gr.SelectData, loras_state):
- """
- Behandelt Klicks auf Galerie-Elemente.
- Zeigt das angeklickte Bild in der Großansicht an und blendet die Galerie aus.
- """
- selected_index = evt.index
- selected_lora = loras_state[selected_index]
-
- # Daten des ausgewählten LoRA-Elements abrufen
- large_image = selected_lora.get("image", "/path/to/default-placeholder.png")
- title = selected_lora.get("title", "Unknown LoRA")
-
- # Galerie ausblenden, Großansicht und Select-Button einblenden
- return (
- gr.update(visible=False), # Galerie ausblenden
- gr.update(value=large_image, visible=True), # Großansicht anzeigen
- gr.update(visible=True) # Select-Button anzeigen
- )
-
-def toggle_large_view(selected_indices, loras_state):
- """
- Schaltet von der Großansicht zurück zur Galerie-Ansicht.
- """
- # Großansicht ausblenden, Galerie einblenden
- return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
-
-def select_lora(selected_indices, loras_state):
- """
- Fügt die ausgewählte LoRA hinzu und kehrt zur Galerie zurück.
- """
- # Logik, um die LoRA hinzuzufügen (falls benötigt)
- # Beispiel: Update von `selected_indices` oder Änderungen an `loras_state`
-
- # Zurück zur Galerie-Ansicht
- return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
-
-
-
-
-
-
-def add_custom_lora(custom_lora, selected_indices, current_loras, gallery, debug_log):
- logs = debug_log
- try:
- logs = append_debug_log(f"Adding custom LoRA: {custom_lora}", logs)
- if custom_lora:
- title, repo, path, trigger_word, image = check_custom_model(custom_lora)
- logs = append_debug_log(f"Loaded custom LoRA: {repo}", logs)
-
- if image is not None and "http" in image and not is_repo_public(repo):
- try:
- image = download_file_mod(image)
- except Exception as e:
- logs = append_debug_log(f"Error downloading image: {e}", logs)
- image = get_custom_image() # Fallback verwenden
-
- existing_item_index = next((index for (index, item) in enumerate(current_loras) if item['repo'] == repo), None)
- if existing_item_index is None:
- new_item = {
- "image": image or get_custom_image(), # Fallback erneut prüfen
- "title": title,
- "repo": repo,
- "weights": path,
- "trigger_word": trigger_word
- }
- current_loras.append(new_item)
- logs = append_debug_log(f"Added new LoRA: {title}", logs)
-
- gallery_items = [(item["image"], item["title"]) for item in current_loras]
- return current_loras, gr.update(value=gallery_items), gr.update(value=logs)
- else:
- logs = append_debug_log("No custom LoRA provided.", logs)
- return current_loras, gallery, gr.update(value=logs)
- except Exception as e:
- logs = append_debug_log(f"Error in add_custom_lora: {e}", logs)
- return current_loras, gallery, gr.update(value=logs)
-
-
-
-
-
-def update_gallery_with_loras(selected_indices, loras_state, gallery):
- """
- Aktualisiert die Galerie basierend auf der Auswahl. Implementiert die Vorschau-Logik.
- """
- if not selected_indices:
- # Galerieansicht: Keine Auswahl
- gallery_items = [(lora["image"], lora["title"]) for lora in loras_state]
- return gr.update(value=gallery_items), gr.update(visible=False), gr.update(visible=True)
-
- # Vorschauansicht: Ein Bild wurde ausgewählt
- selected_lora = loras_state[selected_indices[0]] # Nur das erste ausgewählte Bild
- preview_image = selected_lora["image"]
- preview_title = selected_lora["title"]
- preview_trigger_word = selected_lora.get("trigger_word", "")
- preview_button_visible = True
-
- # Micro-Thumbnails erstellen
- micro_thumbnails = [(lora["image"], "") for lora in loras_state]
-
- return (
- gr.update(value=[(preview_image, preview_title)], visible=True),
- gr.update(value=micro_thumbnails, visible=True),
- gr.update(visible=False), # Galerie deaktivieren
- gr.update(value=preview_trigger_word, visible=True),
- gr.update(visible=preview_button_visible),
- )
-
-
-def get_custom_image():
- """
- Liefert ein Bild für den Fallback.
- Prüft zuerst das Repository, dann die URL, und setzt sonst ein Platzhalterbild.
- """
- placeholder_path = "custom.png" # Pfad zum Platzhalterbild im Hauptverzeichnis
-
- try:
- # Prüfen, ob das Platzhalterbild existiert
- if os.path.exists(placeholder_path):
- return placeholder_path
- else:
- raise FileNotFoundError(f"Platzhalterbild nicht gefunden: {placeholder_path}")
- except Exception as e:
- print(f"Error in get_custom_image: {e}")
- # Sicherer Fallback, falls das Platzhalterbild fehlt
- return "/path/to/default-placeholder.png"
-
-
-def remove_custom_lora(selected_indices, current_loras, gallery):
- if current_loras:
- custom_lora_repo = current_loras[-1]['repo']
- # Remove from loras list
- current_loras = current_loras[:-1]
- # Remove from selected_indices if selected
- custom_lora_index = len(current_loras)
- if custom_lora_index in selected_indices:
- selected_indices.remove(custom_lora_index)
- # Update gallery
- gallery_items = [(item["image"], item["title"]) for item in current_loras]
- # Update selected_info and images
- selected_info_1 = "Select a LoRA 1"
- selected_info_2 = "Select a LoRA 2"
- lora_scale_1 = 1.15
- lora_scale_2 = 1.15
- lora_image_1 = None
- lora_image_2 = None
- if len(selected_indices) >= 1:
- lora1 = current_loras[selected_indices[0]]
- selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
- lora_image_1 = lora1['image']
- if len(selected_indices) >= 2:
- lora2 = current_loras[selected_indices[1]]
- selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
- lora_image_2 = lora2['image']
- return (
- current_loras,
- gr.update(value=gallery_items),
- selected_info_1,
- selected_info_2,
- selected_indices,
- lora_scale_1,
- lora_scale_2,
- lora_image_1,
- lora_image_2
- )
-
-@spaces.GPU(duration=70)
-@torch.inference_mode()
-def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, cn_on, progress=gr.Progress(track_tqdm=True)):
- global pipe, taef1, good_vae, controlnet, controlnet_union
- try:
- good_vae.to("cuda")
- taef1.to("cuda")
- generator = torch.Generator(device="cuda").manual_seed(int(float(seed)))
-
- with calculateDuration("Generating image"):
- # Generate image
- modes, images, scales = get_control_params()
- if not cn_on or len(modes) == 0:
- pipe.to("cuda")
- pipe.vae = taef1
- pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
- progress(0, desc="Start Inference.")
- for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
- prompt=prompt_mash,
- num_inference_steps=steps,
- guidance_scale=cfg_scale,
- width=width,
- height=height,
- generator=generator,
- joint_attention_kwargs={"scale": 1.0},
- output_type="pil",
- good_vae=good_vae,
- ):
- yield img
- else:
- pipe.to("cuda")
- pipe.vae = good_vae
- if controlnet_union is not None: controlnet_union.to("cuda")
- if controlnet is not None: controlnet.to("cuda")
- pipe.enable_model_cpu_offload()
- progress(0, desc="Start Inference with ControlNet.")
- for img in pipe(
- prompt=prompt_mash,
- control_image=images,
- control_mode=modes,
- num_inference_steps=steps,
- guidance_scale=cfg_scale,
- width=width,
- height=height,
- controlnet_conditioning_scale=scales,
- generator=generator,
- joint_attention_kwargs={"scale": 1.0},
- ).images:
- yield img
- except Exception as e:
- print(e)
- raise gr.Error(f"Inference Error: {e}") from e
-
-@spaces.GPU(duration=70)
-@torch.inference_mode()
-def generate_image_to_image(prompt_mash, image_input_path_dict, image_strength, is_inpaint, blur_mask, blur_factor, steps, cfg_scale, width, height, seed, cn_on, progress=gr.Progress(track_tqdm=True)):
- global pipe_i2i, pipe_ip, good_vae, controlnet, controlnet_union
- try:
- good_vae.to("cuda")
- generator = torch.Generator(device="cuda").manual_seed(int(float(seed)))
- image_input_path = image_input_path_dict['background']
- mask_path = image_input_path_dict['layers'][0]
-
- with calculateDuration("Generating image"):
- # Generate image
- modes, images, scales = get_control_params()
- if not cn_on or len(modes) == 0:
- if is_inpaint: # Inpainting
- pipe_ip.to("cuda")
- pipe_ip.vae = good_vae
- image_input = load_image(image_input_path)
- mask_input = load_image(mask_path)
- if blur_mask: mask_input = pipe_ip.mask_processor.blur(mask_input, blur_factor=blur_factor)
- progress(0, desc="Start Inpainting Inference.")
- final_image = pipe_ip(
- prompt=prompt_mash,
- image=image_input,
- mask_image=mask_input,
- strength=image_strength,
- num_inference_steps=steps,
- guidance_scale=cfg_scale,
- width=width,
- height=height,
- generator=generator,
- joint_attention_kwargs={"scale": 1.0},
- output_type="pil",
- ).images[0]
- return final_image
- else:
- pipe_i2i.to("cuda")
- pipe_i2i.vae = good_vae
- image_input = load_image(image_input_path)
- progress(0, desc="Start I2I Inference.")
- final_image = pipe_i2i(
- prompt=prompt_mash,
- image=image_input,
- strength=image_strength,
- num_inference_steps=steps,
- guidance_scale=cfg_scale,
- width=width,
- height=height,
- generator=generator,
- joint_attention_kwargs={"scale": 1.0},
- output_type="pil",
- ).images[0]
- return final_image
- else:
- if is_inpaint: # Inpainting
- pipe_ip.to("cuda")
- pipe_ip.vae = good_vae
- image_input = load_image(image_input_path)
- mask_input = load_image(mask_path)
- if blur_mask: mask_input = pipe_ip.mask_processor.blur(mask_input, blur_factor=blur_factor)
- if controlnet_union is not None: controlnet_union.to("cuda")
- if controlnet is not None: controlnet.to("cuda")
- pipe_ip.enable_model_cpu_offload()
- progress(0, desc="Start Inpainting Inference with ControlNet.")
- final_image = pipe_ip(
- prompt=prompt_mash,
- control_image=images,
- control_mode=modes,
- image=image_input,
- mask_image=mask_input,
- strength=image_strength,
- num_inference_steps=steps,
- guidance_scale=cfg_scale,
- width=width,
- height=height,
- controlnet_conditioning_scale=scales,
- generator=generator,
- joint_attention_kwargs={"scale": 1.0},
- output_type="pil",
- ).images[0]
- return final_image
- else:
- pipe_i2i.to("cuda")
- pipe_i2i.vae = good_vae
- image_input = load_image(image_input_path['background'])
- if controlnet_union is not None: controlnet_union.to("cuda")
- if controlnet is not None: controlnet.to("cuda")
- pipe_i2i.enable_model_cpu_offload()
- progress(0, desc="Start I2I Inference with ControlNet.")
- final_image = pipe_i2i(
- prompt=prompt_mash,
- control_image=images,
- control_mode=modes,
- image=image_input,
- strength=image_strength,
- num_inference_steps=steps,
- guidance_scale=cfg_scale,
- width=width,
- height=height,
- controlnet_conditioning_scale=scales,
- generator=generator,
- joint_attention_kwargs={"scale": 1.0},
- output_type="pil",
- ).images[0]
- return final_image
- except Exception as e:
- print(e)
- raise gr.Error(f"I2I Inference Error: {e}") from e
-
-def run_lora(prompt, image_input, image_strength, task_type, blur_mask, blur_factor, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2,
- randomize_seed, seed, width, height, loras_state, lora_json, cn_on, translate_on, progress=gr.Progress(track_tqdm=True)):
- global pipe, pipe_i2i, pipe_ip
- if not selected_indices and not is_valid_lora(lora_json):
- gr.Info("LoRA isn't selected.")
- # raise gr.Error("You must select a LoRA before proceeding.")
- progress(0, desc="Preparing Inference.")
-
- selected_loras = [loras_state[idx] for idx in selected_indices]
-
- if task_type == "Inpainting":
- is_inpaint = True
- is_i2i = True
- elif task_type == "Image-to-Image":
- is_inpaint = False
- is_i2i = True
- else: # "Text-to-Image"
- is_inpaint = False
- is_i2i = False
-
- if translate_on: prompt = translate_to_en(prompt)
-
- # Build the prompt with trigger words
- prepends = []
- appends = []
- for lora in selected_loras:
- trigger_word = lora.get('trigger_word', '')
- if trigger_word:
- if lora.get("trigger_position") == "prepend":
- prepends.append(trigger_word)
- else:
- appends.append(trigger_word)
- prompt_mash = " ".join(prepends + [prompt] + appends)
- print("Prompt Mash: ", prompt_mash) #
-
- # Unload previous LoRA weights
- with calculateDuration("Unloading LoRA"):
- unload_lora()
-
- print(pipe.get_active_adapters()) #
- print(pipe_i2i.get_active_adapters()) #
- print(pipe_ip.get_active_adapters()) #
-
- clear_cache() #
-
- # Build the prompt for External LoRAs
- prompt_mash = prompt_mash + get_model_trigger(last_model)
- lora_names = []
- lora_weights = []
- if is_valid_lora(lora_json): # Load External LoRA weights
- with calculateDuration("Loading External LoRA weights"):
- if is_inpaint:
- pipe_ip, lora_names, lora_weights = fuse_loras(pipe_ip, lora_json)
- elif is_i2i:
- pipe_i2i, lora_names, lora_weights = fuse_loras(pipe_i2i, lora_json)
- else: pipe, lora_names, lora_weights = fuse_loras(pipe, lora_json)
- trigger_word = get_trigger_word(lora_json)
- prompt_mash = f"{prompt_mash} {trigger_word}"
- print("Prompt Mash: ", prompt_mash) #
-
- # Load LoRA weights with respective scales
- if selected_indices:
- with calculateDuration("Loading LoRA weights"):
- for idx, lora in enumerate(selected_loras):
- lora_name = f"lora_{idx}"
- lora_names.append(lora_name)
- print(f"Lora Name: {lora_name}")
- lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2)
- lora_path = lora['repo']
- weight_name = lora.get("weights")
- print(f"Lora Path: {lora_path}")
- if is_inpaint:
- pipe_ip.load_lora_weights(
- lora_path,
- weight_name=weight_name if weight_name else None,
- low_cpu_mem_usage=False,
- adapter_name=lora_name,
- token=HF_TOKEN
- )
- elif is_i2i:
- pipe_i2i.load_lora_weights(
- lora_path,
- weight_name=weight_name if weight_name else None,
- low_cpu_mem_usage=False,
- adapter_name=lora_name,
- token=HF_TOKEN
- )
- else:
- pipe.load_lora_weights(
- lora_path,
- weight_name=weight_name if weight_name else None,
- low_cpu_mem_usage=False,
- adapter_name=lora_name,
- token=HF_TOKEN
- )
- print("Loaded LoRAs:", lora_names)
- if selected_indices or is_valid_lora(lora_json):
- if is_inpaint:
- pipe_ip.set_adapters(lora_names, adapter_weights=lora_weights)
- elif is_i2i:
- pipe_i2i.set_adapters(lora_names, adapter_weights=lora_weights)
- else:
- pipe.set_adapters(lora_names, adapter_weights=lora_weights)
-
- print(pipe.get_active_adapters()) #
- print(pipe_i2i.get_active_adapters()) #
- print(pipe_ip.get_active_adapters()) #
-
- # Set random seed for reproducibility
- with calculateDuration("Randomizing seed"):
- if randomize_seed:
- seed = random.randint(0, MAX_SEED)
-
- # Generate image
- progress(0, desc="Running Inference.")
- if is_i2i:
- final_image = generate_image_to_image(prompt_mash, image_input, image_strength, is_inpaint, blur_mask, blur_factor, steps, cfg_scale, width, height, seed, cn_on)
- yield save_image(final_image, None, last_model, prompt_mash, height, width, steps, cfg_scale, seed), seed, gr.update(visible=False)
- else:
- image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, cn_on)
- # Consume the generator to get the final image
- final_image = None
- step_counter = 0
- for image in image_generator:
- step_counter+=1
- final_image = image
- progress_bar = f'
'
- yield image, seed, gr.update(value=progress_bar, visible=True)
- yield save_image(final_image, None, last_model, prompt_mash, height, width, steps, cfg_scale, seed), seed, gr.update(value=progress_bar, visible=False)
-
-run_lora.zerogpu = True
-
-def get_huggingface_safetensors(link):
- split_link = link.split("/")
- if len(split_link) == 2:
- model_card = ModelCard.load(link, token=HF_TOKEN)
- base_model = model_card.data.get("base_model")
- print(f"Base model: {base_model}")
- if base_model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
- #raise Exception("Not a FLUX LoRA!")
- gr.Warning("Not a FLUX LoRA?")
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
- trigger_word = model_card.data.get("instance_prompt", "")
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
- fs = HfFileSystem(token=HF_TOKEN)
- safetensors_name = None
- try:
- list_of_files = fs.ls(link, detail=False)
- for file in list_of_files:
- if file.endswith(".safetensors"):
- safetensors_name = file.split("/")[-1]
- if not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
- image_elements = file.split("/")
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
- except Exception as e:
- print(e)
- raise gr.Error("Invalid Hugging Face repository with a *.safetensors LoRA")
- if not safetensors_name:
- raise gr.Error("No *.safetensors file found in the repository")
- return split_link[1], link, safetensors_name, trigger_word, image_url
- else:
- raise gr.Error("Invalid Hugging Face repository link")
-
-def check_custom_model(link):
- if link.endswith(".safetensors"):
- # Treat as direct link to the LoRA weights
- title = os.path.basename(link)
- repo = link
- path = None # No specific weight name
- trigger_word = ""
- image_url = None
- return title, repo, path, trigger_word, image_url
- elif link.startswith("https://"):
- if "huggingface.co" in link:
- link_split = link.split("huggingface.co/")
- return get_huggingface_safetensors(link_split[1])
- else:
- raise Exception("Unsupported URL")
- else:
- # Assume it's a Hugging Face model path
- return get_huggingface_safetensors(link)
-
-def update_history(new_image, history):
- """Updates the history gallery with the new image."""
- if history is None:
- history = []
- history.insert(0, new_image)
- return history
-
-loras = download_loras_images(loras)
-
-css = '''
-#gen_column{align-self: stretch}
-#gen_btn{height: 100%}
-#title{text-align: center}
-#title h1{font-size: 3em; display:inline-flex; align-items:center}
-#title img{width: 100px; margin-right: 0.25em}
-#gallery .grid-wrap{height: 5vh}
-#lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
-.custom_lora_card{margin-bottom: 1em}
-.card_internal{display: flex;height: 100px;margin-top: .5em}
-.card_internal img{margin-right: 1em}
-.styler{--form-gap-width: 0px !important}
-#progress{height:30px}
-#progress .generating{display:none}
-.progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
-.progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
-#component-8, .button_total{height: 100%; align-self: stretch;}
-#loaded_loras [data-testid="block-info"]{font-size:80%}
-#custom_lora_structure{background: var(--block-background-fill)}
-#custom_lora_btn{margin-top: auto;margin-bottom: 11px}
-#random_btn{font-size: 300%}
-#component-11{align-self: stretch;}
-.info { align-items: center; text-align: center; }
-.desc [src$='#float'] { float: right; margin: 20px; }
-'''
-with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', fill_width=True, css=css, delete_cache=(60, 3600)) as app:
-
- debug_log = gr.Textbox(
- label="Debug Log",
- interactive=False,
- lines=10,
- placeholder="Hier erscheinen Debug-Informationen...",
- type="text" # Stelle sicher, dass sie nur Text akzeptiert
- )
-
- # Test-Input-Feld und Button
- with gr.Row():
- test_input = gr.Textbox(
- label="Test Input",
- placeholder="Gib den Namen einer Funktion ein, z.B. 'get_custom_image'.",
- )
- test_button = gr.Button("Run Test")
-
-
- with gr.Tab("FLUX LoRA the Explorer"):
- title = gr.HTML(
- """FLUX LoRA Explorer Mod Reloaded
""",
- elem_id="title",
- )
- loras_state = gr.State(loras)
- selected_indices = gr.State([])
- with gr.Row():
- with gr.Column(scale=3):
- with gr.Group():
- with gr.Accordion("Generate Prompt from Image", open=False):
- tagger_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"], height=256)
- with gr.Accordion(label="Advanced options", open=False):
- tagger_general_threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.01, interactive=True)
- tagger_character_threshold = gr.Slider(label="Character threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.01, interactive=True)
- neg_prompt = gr.Text(label="Negative Prompt", lines=1, max_lines=8, placeholder="", visible=False)
- v2_character = gr.Textbox(label="Character", placeholder="hatsune miku", scale=2, visible=False)
- v2_series = gr.Textbox(label="Series", placeholder="vocaloid", scale=2, visible=False)
- v2_copy = gr.Button(value="Copy to clipboard", size="sm", interactive=False, visible=False)
- tagger_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-Flux"], label="Algorithms", value=["Use WD Tagger"])
- tagger_generate_from_image = gr.Button(value="Generate Prompt from Image")
- prompt = gr.Textbox(label="Prompt", lines=1, max_lines=8, placeholder="Type a prompt", show_copy_button=True)
- with gr.Row():
- prompt_enhance = gr.Button(value="Enhance your prompt", variant="secondary")
- auto_trans = gr.Checkbox(label="Auto translate to English", value=False, elem_classes="info")
- with gr.Column(scale=1, elem_id="gen_column"):
- generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn", elem_classes=["button_total"])
- with gr.Row(elem_id="loaded_loras"):
- with gr.Column(scale=1, min_width=25):
- randomize_button = gr.Button("🎲", variant="secondary", scale=1, elem_id="random_btn")
- with gr.Column(scale=8):
- with gr.Row():
- with gr.Column(scale=0, min_width=50):
- lora_image_1 = gr.Image(label="LoRA 1 Image", interactive=False, min_width=50, width=50, show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False, height=50)
- with gr.Column(scale=3, min_width=100):
- selected_info_1 = gr.Markdown("Select a LoRA 1")
- with gr.Column(scale=5, min_width=50):
- lora_scale_1 = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=3, step=0.01, value=1.15)
- with gr.Row():
- remove_button_1 = gr.Button("Remove", size="sm")
- with gr.Column(scale=8):
- with gr.Row():
- with gr.Column(scale=0, min_width=50):
- lora_image_2 = gr.Image(label="LoRA 2 Image", interactive=False, min_width=50, width=50, show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False, height=50)
- with gr.Column(scale=3, min_width=100):
- selected_info_2 = gr.Markdown("Select a LoRA 2")
- with gr.Column(scale=5, min_width=50):
- lora_scale_2 = gr.Slider(label="LoRA 2 Scale", minimum=0, maximum=3, step=0.01, value=1.15)
- with gr.Row():
- remove_button_2 = gr.Button("Remove", size="sm")
- with gr.Row():
- with gr.Column():
- selected_info = gr.Markdown("")
- # Galerie-Komponente
- gallery = gr.Gallery(
- label="LoRA Gallery",
- value=[(lora["image"], lora["title"]) for lora in loras], # Loras initial
- columns=4,
- interactive=True # Galerie interaktiv
- )
-
- # Großansicht für das ausgewählte Bild
- large_view = gr.Image(
- label="Selected Image",
- visible=False, # Standardmäßig nicht sichtbar
- interactive=False # Keine Interaktivität
- )
-
- # Select-Button, um das ausgewählte Bild zu übernehmen
- select_button = gr.Button(
- "Select",
- visible=False # Nur sichtbar, wenn ein Bild ausgewählt ist
- )
-
- # Event-Handler: Klick auf ein Galerie-Bild
- gallery.select(
- handle_gallery_click, # Funktion zum Verarbeiten des Galerie-Klicks
- inputs=[loras_state], # Eingabe: State
- outputs=[gallery, large_view, select_button] # Ausgabe: Galerie, Großansicht, Button
- )
-
-
- with gr.Group():
- with gr.Row(elem_id="custom_lora_structure"):
- custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path or *.safetensors public URL", placeholder="multimodalart/vintage-ads-flux", scale=3, min_width=150)
- add_custom_lora_button = gr.Button("Add Custom LoRA", elem_id="custom_lora_btn", scale=2, min_width=150)
- remove_custom_lora_button = gr.Button("Remove Custom LoRA", visible=False)
- gr.Markdown("[Check the list of FLUX LoRAs](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
- with gr.Column():
- progress_bar = gr.Markdown(elem_id="progress",visible=False)
- result = gr.Image(label="Generated Image", format="png", type="filepath", show_share_button=False, interactive=False)
- with gr.Accordion("History", open=False):
- history_gallery = gr.Gallery(label="History", columns=4, rows=1, object_fit="contain", interactive=False, format="png",
- show_share_button=False, show_download_button=True)
- history_files = gr.Files(interactive=False, visible=False)
- history_clear_button = gr.Button(value="Clear History", variant="secondary")
- history_clear_button.click(lambda: ([], []), None, [history_gallery, history_files], queue=False, show_api=False)
- with gr.Group():
- with gr.Row():
- model_name = gr.Dropdown(label="Base Model", info="You can enter a huggingface model repo_id or path of single safetensors file to want to use.",
- choices=models, value=models[0], allow_custom_value=True, min_width=320, scale=5)
- model_type = gr.Radio(label="Model type", info="Model type of single safetensors file",
- choices=list(single_file_base_models.keys()), value=list(single_file_base_models.keys())[0], scale=1)
- model_info = gr.Markdown(elem_classes="info")
-
- with gr.Row():
- with gr.Accordion("Advanced Settings", open=False):
- with gr.Row():
- with gr.Column():
- #input_image = gr.Image(label="Input image", type="filepath", height=256, sources=["upload", "clipboard"], show_share_button=False)
- input_image = gr.ImageEditor(label='Input image', type='filepath', sources=["upload", "clipboard"], image_mode='RGB', show_share_button=False, show_fullscreen_button=False,
- layers=False, brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed", default_size=32), value=None,
- canvas_size=(384, 384), width=384, height=512)
- with gr.Column():
- task_type = gr.Radio(label="Task", choices=["Text-to-Image", "Image-to-Image", "Inpainting"], value="Text-to-Image")
- image_strength = gr.Slider(label="Strength", info="Lower means more image influence in I2I, opposite in Inpaint", minimum=0.01, maximum=1.0, step=0.01, value=0.75)
- blur_mask = gr.Checkbox(label="Blur mask", value=False)
- blur_factor = gr.Slider(label="Blur factor", minimum=0, maximum=50, step=1, value=33)
- input_image_preprocess = gr.Checkbox(True, label="Preprocess Input image")
- with gr.Column():
- with gr.Row():
- width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
- height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
- cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
- steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
- with gr.Row():
- randomize_seed = gr.Checkbox(True, label="Randomize seed")
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
- disable_model_cache = gr.Checkbox(False, label="Disable model caching")
- with gr.Accordion("External LoRA", open=True):
- with gr.Column():
- deselect_lora_button = gr.Button("Remove External LoRAs", variant="secondary")
- lora_repo_json = gr.JSON(value=[{}] * num_loras, visible=False)
- lora_repo = [None] * num_loras
- lora_weights = [None] * num_loras
- lora_trigger = [None] * num_loras
- lora_wt = [None] * num_loras
- lora_info = [None] * num_loras
- lora_copy = [None] * num_loras
- lora_md = [None] * num_loras
- lora_num = [None] * num_loras
- with gr.Row():
- for i in range(num_loras):
- with gr.Column():
- lora_repo[i] = gr.Dropdown(label=f"LoRA {int(i+1)} Repo", choices=get_all_lora_tupled_list(), info="Input LoRA Repo ID", value="", allow_custom_value=True, min_width=320)
- with gr.Row():
- lora_weights[i] = gr.Dropdown(label=f"LoRA {int(i+1)} Filename", choices=[], info="Optional", value="", allow_custom_value=True)
- lora_trigger[i] = gr.Textbox(label=f"LoRA {int(i+1)} Trigger Prompt", lines=1, max_lines=4, value="")
- lora_wt[i] = gr.Slider(label=f"LoRA {int(i+1)} Scale", minimum=-3, maximum=3, step=0.01, value=1.00)
- with gr.Row():
- lora_info[i] = gr.Textbox(label="", info="Example of prompt:", value="", show_copy_button=True, interactive=False, visible=False)
- lora_copy[i] = gr.Button(value="Copy example to prompt", visible=False)
- lora_md[i] = gr.Markdown(value="", visible=False)
- lora_num[i] = gr.Number(i, visible=False)
- with gr.Accordion("From URL", open=True, visible=True):
- with gr.Row():
- lora_search_civitai_basemodel = gr.CheckboxGroup(label="Search LoRA for", choices=["Flux.1 D", "Flux.1 S"], value=["Flux.1 D"])
- lora_search_civitai_sort = gr.Radio(label="Sort", choices=CIVITAI_SORT, value="Most Downloaded")
- lora_search_civitai_period = gr.Radio(label="Period", choices=CIVITAI_PERIOD, value="Month")
- with gr.Row():
- lora_search_civitai_query = gr.Textbox(label="Query", placeholder="flux", lines=1)
- lora_search_civitai_tag = gr.Dropdown(label="Tag", choices=get_civitai_tag(), value=get_civitai_tag()[0], allow_custom_value=True)
- lora_search_civitai_user = gr.Textbox(label="Username", lines=1)
- lora_search_civitai_submit = gr.Button("Search on Civitai")
- with gr.Row():
- lora_search_civitai_json = gr.JSON(value={}, visible=False)
- lora_search_civitai_desc = gr.Markdown(value="", visible=False, elem_classes="desc")
- with gr.Accordion("Select from Gallery", open=False):
- lora_search_civitai_gallery = gr.Gallery([], label="Results", allow_preview=False, columns=5, show_share_button=False, interactive=False)
- lora_search_civitai_result = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
- lora_download_url = gr.Textbox(label="LoRA URL", placeholder="https://civitai.com/api/download/models/28907", lines=1)
- with gr.Row():
- lora_download = [None] * num_loras
- for i in range(num_loras):
- lora_download[i] = gr.Button(f"Get and set LoRA to {int(i+1)}")
- with gr.Accordion("ControlNet (extremely slow)", open=True, visible=False):
- with gr.Column():
- cn_on = gr.Checkbox(False, label="Use ControlNet")
- cn_mode = [None] * num_cns
- cn_scale = [None] * num_cns
- cn_image = [None] * num_cns
- cn_image_ref = [None] * num_cns
- cn_res = [None] * num_cns
- cn_num = [None] * num_cns
- with gr.Row():
- for i in range(num_cns):
- with gr.Column():
- cn_mode[i] = gr.Radio(label=f"ControlNet {int(i+1)} Mode", choices=get_control_union_mode(), value=get_control_union_mode()[0])
- with gr.Row():
- cn_scale[i] = gr.Slider(label=f"ControlNet {int(i+1)} Weight", minimum=0.0, maximum=1.0, step=0.01, value=0.75)
- cn_res[i] = gr.Slider(label=f"ControlNet {int(i+1)} Preprocess resolution", minimum=128, maximum=512, value=384, step=1)
- cn_num[i] = gr.Number(i, visible=False)
- with gr.Row():
- cn_image_ref[i] = gr.Image(label="Image Reference", type="pil", format="png", height=256, sources=["upload", "clipboard"], show_share_button=False)
- cn_image[i] = gr.Image(label="Control Image", type="pil", format="png", height=256, show_share_button=False, interactive=False)
-
- gallery.select(
- update_selection,
- inputs=[selected_indices, loras_state, width, height],
- outputs=[prompt, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, width, height, lora_image_1, lora_image_2])
- remove_button_1.click(
- remove_lora_1,
- inputs=[selected_indices, loras_state],
- outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
- )
- remove_button_2.click(
- remove_lora_2,
- inputs=[selected_indices, loras_state],
- outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
- )
- randomize_button.click(
- randomize_loras,
- inputs=[selected_indices, loras_state],
- outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2, prompt]
- )
- add_custom_lora_button.click(
- add_custom_lora,
- inputs=[custom_lora, selected_indices, loras_state, gallery, debug_log],
- outputs=[loras_state, gallery, debug_log]
- )
- remove_custom_lora_button.click(
- remove_custom_lora,
- inputs=[selected_indices, loras_state, gallery],
- outputs=[loras_state, gallery, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
- )
- gr.on(
- triggers=[generate_button.click, prompt.submit],
- fn=change_base_model,
- inputs=[model_name, cn_on, disable_model_cache, model_type],
- outputs=[result],
- queue=True,
- show_api=False,
- trigger_mode="once",
- ).success(
- fn=run_lora,
- inputs=[prompt, input_image, image_strength, task_type, blur_mask, blur_factor, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2,
- randomize_seed, seed, width, height, loras_state, lora_repo_json, cn_on, auto_trans],
- outputs=[result, seed, progress_bar],
- queue=True,
- show_api=True,
- #).then( # Update the history gallery
- # fn=lambda x, history: update_history(x, history),
- # inputs=[result, history_gallery],
- # outputs=history_gallery,
- ).success(save_image_history, [result, history_gallery, history_files, model_name], [history_gallery, history_files], queue=False, show_api=False)
-
- input_image.clear(lambda: gr.update(value="Text-to-Image"), None, [task_type], queue=False, show_api=False)
- input_image.upload(preprocess_i2i_image, [input_image, input_image_preprocess, height, width], [input_image], queue=False, show_api=False)\
- .success(lambda: gr.update(value="Image-to-Image"), None, [task_type], queue=False, show_api=False)
- gr.on(
- triggers=[model_name.change, cn_on.change],
- fn=get_t2i_model_info,
- inputs=[model_name],
- outputs=[model_info],
- queue=False,
- show_api=False,
- trigger_mode="once",
- )#.then(change_base_model, [model_name, cn_on, disable_model_cache, model_type], [result], queue=True, show_api=False)
- prompt_enhance.click(enhance_prompt, [prompt], [prompt], queue=False, show_api=False)
-
- gr.on(
- triggers=[lora_search_civitai_submit.click, lora_search_civitai_query.submit],
- fn=search_civitai_lora,
- inputs=[lora_search_civitai_query, lora_search_civitai_basemodel, lora_search_civitai_sort, lora_search_civitai_period,
- lora_search_civitai_tag, lora_search_civitai_user, lora_search_civitai_gallery],
- outputs=[lora_search_civitai_result, lora_search_civitai_desc, lora_search_civitai_submit, lora_search_civitai_query, lora_search_civitai_gallery],
- scroll_to_output=True,
- queue=True,
- show_api=False,
- )
- lora_search_civitai_json.change(search_civitai_lora_json, [lora_search_civitai_query, lora_search_civitai_basemodel], [lora_search_civitai_json], queue=True, show_api=True) # fn for api
- lora_search_civitai_result.change(select_civitai_lora, [lora_search_civitai_result], [lora_download_url, lora_search_civitai_desc], scroll_to_output=True, queue=False, show_api=False)
- lora_search_civitai_gallery.select(update_civitai_selection, None, [lora_search_civitai_result], queue=False, show_api=False)
-
- for i, l in enumerate(lora_repo):
- deselect_lora_button.click(lambda: ("", 1.0), None, [lora_repo[i], lora_wt[i]], queue=False, show_api=False)
- gr.on(
- triggers=[lora_download[i].click],
- fn=download_my_lora_flux,
- inputs=[lora_download_url, lora_repo[i]],
- outputs=[lora_repo[i]],
- scroll_to_output=True,
- queue=True,
- show_api=False,
- )
- gr.on(
- triggers=[lora_repo[i].change, lora_wt[i].change],
- fn=update_loras_flux,
- inputs=[prompt, lora_repo[i], lora_wt[i]],
- outputs=[prompt, lora_repo[i], lora_wt[i], lora_info[i], lora_md[i]],
- queue=False,
- trigger_mode="once",
- show_api=False,
- ).success(get_repo_safetensors, [lora_repo[i]], [lora_weights[i]], queue=False, show_api=False
- ).success(apply_lora_prompt_flux, [lora_info[i]], [lora_trigger[i]], queue=False, show_api=False
- ).success(compose_lora_json, [lora_repo_json, lora_num[i], lora_repo[i], lora_wt[i], lora_weights[i], lora_trigger[i]], [lora_repo_json], queue=False, show_api=False)
-
- for i, m in enumerate(cn_mode):
- gr.on(
- triggers=[cn_mode[i].change, cn_scale[i].change],
- fn=set_control_union_mode,
- inputs=[cn_num[i], cn_mode[i], cn_scale[i]],
- outputs=[cn_on],
- queue=True,
- show_api=False,
- ).success(set_control_union_image, [cn_num[i], cn_mode[i], cn_image_ref[i], height, width, cn_res[i]], [cn_image[i]], queue=False, show_api=False)
- cn_image_ref[i].upload(set_control_union_image, [cn_num[i], cn_mode[i], cn_image_ref[i], height, width, cn_res[i]], [cn_image[i]], queue=False, show_api=False)
-
- tagger_generate_from_image.click(lambda: ("", "", ""), None, [v2_series, v2_character, prompt], queue=False, show_api=False,
- ).success(
- predict_tags_wd,
- [tagger_image, prompt, tagger_algorithms, tagger_general_threshold, tagger_character_threshold],
- [v2_series, v2_character, prompt, v2_copy],
- show_api=False,
- ).success(predict_tags_fl2_flux, [tagger_image, prompt, tagger_algorithms], [prompt], show_api=False,
- ).success(compose_prompt_to_copy, [v2_character, v2_series, prompt], [prompt], queue=False, show_api=False)
-
- with gr.Tab("FLUX Prompt Generator"):
- from prompt import (PromptGenerator, HuggingFaceInferenceNode, florence_caption,
- ARTFORM, PHOTO_TYPE, ROLES, HAIRSTYLES, LIGHTING, COMPOSITION, POSE, BACKGROUND,
- PHOTOGRAPHY_STYLES, DEVICE, PHOTOGRAPHER, ARTIST, DIGITAL_ARTFORM, PLACE,
- FEMALE_DEFAULT_TAGS, MALE_DEFAULT_TAGS, FEMALE_BODY_TYPES, MALE_BODY_TYPES,
- FEMALE_CLOTHING, MALE_CLOTHING, FEMALE_ADDITIONAL_DETAILS, MALE_ADDITIONAL_DETAILS, pg_title)
-
- prompt_generator = PromptGenerator()
- huggingface_node = HuggingFaceInferenceNode()
-
- gr.HTML(pg_title)
-
- with gr.Row():
- with gr.Column(scale=2):
- with gr.Accordion("Basic Settings"):
- pg_custom = gr.Textbox(label="Custom Input Prompt (optional)")
- pg_subject = gr.Textbox(label="Subject (optional)")
- pg_gender = gr.Radio(["female", "male"], label="Gender", value="female")
-
- # Add the radio button for global option selection
- pg_global_option = gr.Radio(
- ["Disabled", "Random", "No Figure Rand"],
- label="Set all options to:",
- value="Disabled"
- )
-
- with gr.Accordion("Artform and Photo Type", open=False):
- pg_artform = gr.Dropdown(["disabled", "random"] + ARTFORM, label="Artform", value="disabled")
- pg_photo_type = gr.Dropdown(["disabled", "random"] + PHOTO_TYPE, label="Photo Type", value="disabled")
-
- with gr.Accordion("Character Details", open=False):
- pg_body_types = gr.Dropdown(["disabled", "random"] + FEMALE_BODY_TYPES + MALE_BODY_TYPES, label="Body Types", value="disabled")
- pg_default_tags = gr.Dropdown(["disabled", "random"] + FEMALE_DEFAULT_TAGS + MALE_DEFAULT_TAGS, label="Default Tags", value="disabled")
- pg_roles = gr.Dropdown(["disabled", "random"] + ROLES, label="Roles", value="disabled")
- pg_hairstyles = gr.Dropdown(["disabled", "random"] + HAIRSTYLES, label="Hairstyles", value="disabled")
- pg_clothing = gr.Dropdown(["disabled", "random"] + FEMALE_CLOTHING + MALE_CLOTHING, label="Clothing", value="disabled")
-
- with gr.Accordion("Scene Details", open=False):
- pg_place = gr.Dropdown(["disabled", "random"] + PLACE, label="Place", value="disabled")
- pg_lighting = gr.Dropdown(["disabled", "random"] + LIGHTING, label="Lighting", value="disabled")
- pg_composition = gr.Dropdown(["disabled", "random"] + COMPOSITION, label="Composition", value="disabled")
- pg_pose = gr.Dropdown(["disabled", "random"] + POSE, label="Pose", value="disabled")
- pg_background = gr.Dropdown(["disabled", "random"] + BACKGROUND, label="Background", value="disabled")
-
- with gr.Accordion("Style and Artist", open=False):
- pg_additional_details = gr.Dropdown(["disabled", "random"] + FEMALE_ADDITIONAL_DETAILS + MALE_ADDITIONAL_DETAILS, label="Additional Details", value="disabled")
- pg_photography_styles = gr.Dropdown(["disabled", "random"] + PHOTOGRAPHY_STYLES, label="Photography Styles", value="disabled")
- pg_device = gr.Dropdown(["disabled", "random"] + DEVICE, label="Device", value="disabled")
- pg_photographer = gr.Dropdown(["disabled", "random"] + PHOTOGRAPHER, label="Photographer", value="disabled")
- pg_artist = gr.Dropdown(["disabled", "random"] + ARTIST, label="Artist", value="disabled")
- pg_digital_artform = gr.Dropdown(["disabled", "random"] + DIGITAL_ARTFORM, label="Digital Artform", value="disabled")
-
- pg_generate_button = gr.Button("Generate Prompt")
-
- with gr.Column(scale=2):
- with gr.Accordion("Image and Caption", open=False):
- pg_input_image = gr.Image(label="Input Image (optional)")
- pg_caption_output = gr.Textbox(label="Generated Caption", lines=3)
- pg_create_caption_button = gr.Button("Create Caption")
- pg_add_caption_button = gr.Button("Add Caption to Prompt")
-
- with gr.Accordion("Prompt Generation", open=True):
- pg_output = gr.Textbox(label="Generated Prompt / Input Text", lines=4)
- pg_t5xxl_output = gr.Textbox(label="T5XXL Output", visible=True)
- pg_clip_l_output = gr.Textbox(label="CLIP L Output", visible=True)
- pg_clip_g_output = gr.Textbox(label="CLIP G Output", visible=True)
-
- with gr.Column(scale=2):
- with gr.Accordion("Prompt Generation with LLM", open=False):
- pg_happy_talk = gr.Checkbox(label="Happy Talk", value=True)
- pg_compress = gr.Checkbox(label="Compress", value=True)
- pg_compression_level = gr.Radio(["soft", "medium", "hard"], label="Compression Level", value="hard")
- pg_poster = gr.Checkbox(label="Poster", value=False)
- pg_custom_base_prompt = gr.Textbox(label="Custom Base Prompt", lines=5)
- pg_generate_text_button = gr.Button("Generate Prompt with LLM (Llama 3.1 70B)")
- pg_text_output = gr.Textbox(label="Generated Text", lines=10)
-
- def create_caption(image):
- if image is not None:
- return florence_caption(image)
- return ""
-
- pg_create_caption_button.click(
- create_caption,
- inputs=[pg_input_image],
- outputs=[pg_caption_output]
- )
-
- def generate_prompt_with_dynamic_seed(*args):
- # Generate a new random seed
- dynamic_seed = random.randint(0, 1000000)
-
- # Call the generate_prompt function with the dynamic seed
- result = prompt_generator.generate_prompt(dynamic_seed, *args)
-
- # Return the result along with the used seed
- return [dynamic_seed] + list(result)
-
- pg_generate_button.click(
- generate_prompt_with_dynamic_seed,
- inputs=[pg_custom, pg_subject, pg_gender, pg_artform, pg_photo_type, pg_body_types, pg_default_tags, pg_roles, pg_hairstyles,
- pg_additional_details, pg_photography_styles, pg_device, pg_photographer, pg_artist, pg_digital_artform,
- pg_place, pg_lighting, pg_clothing, pg_composition, pg_pose, pg_background, pg_input_image],
- outputs=[gr.Number(label="Used Seed", visible=False), pg_output, gr.Number(visible=False), pg_t5xxl_output, pg_clip_l_output, pg_clip_g_output]
- ) #
-
- pg_add_caption_button.click(
- prompt_generator.add_caption_to_prompt,
- inputs=[pg_output, pg_caption_output],
- outputs=[pg_output]
- )
-
- pg_generate_text_button.click(
- huggingface_node.generate,
- inputs=[pg_output, pg_happy_talk, pg_compress, pg_compression_level, pg_poster, pg_custom_base_prompt],
- outputs=pg_text_output
- )
-
- def update_all_options(choice):
- updates = {}
- if choice == "Disabled":
- for dropdown in [
- pg_artform, pg_photo_type, pg_body_types, pg_default_tags, pg_roles, pg_hairstyles, pg_clothing,
- pg_place, pg_lighting, pg_composition, pg_pose, pg_background, pg_additional_details,
- pg_photography_styles, pg_device, pg_photographer, pg_artist, pg_digital_artform
- ]:
- updates[dropdown] = gr.update(value="disabled")
- elif choice == "Random":
- for dropdown in [
- pg_artform, pg_photo_type, pg_body_types, pg_default_tags, pg_roles, pg_hairstyles, pg_clothing,
- pg_place, pg_lighting, pg_composition, pg_pose, pg_background, pg_additional_details,
- pg_photography_styles, pg_device, pg_photographer, pg_artist, pg_digital_artform
- ]:
- updates[dropdown] = gr.update(value="random")
- else: # No Figure Random
- for dropdown in [pg_photo_type, pg_body_types, pg_default_tags, pg_roles, pg_hairstyles, pg_clothing, pg_pose, pg_additional_details]:
- updates[dropdown] = gr.update(value="disabled")
- for dropdown in [pg_artform, pg_place, pg_lighting, pg_composition, pg_background, pg_photography_styles, pg_device, pg_photographer, pg_artist, pg_digital_artform]:
- updates[dropdown] = gr.update(value="random")
- return updates
-
- pg_global_option.change(
- update_all_options,
- inputs=[pg_global_option],
- outputs=[
- pg_artform, pg_photo_type, pg_body_types, pg_default_tags, pg_roles, pg_hairstyles, pg_clothing,
- pg_place, pg_lighting, pg_composition, pg_pose, pg_background, pg_additional_details,
- pg_photography_styles, pg_device, pg_photographer, pg_artist, pg_digital_artform
- ]
- )
-
- with gr.Tab("PNG Info"):
- def extract_exif_data(image):
- if image is None: return ""
-
- try:
- metadata_keys = ['parameters', 'metadata', 'prompt', 'Comment']
-
- for key in metadata_keys:
- if key in image.info:
- return image.info[key]
-
- return str(image.info)
-
- except Exception as e:
- return f"Error extracting metadata: {str(e)}"
-
- with gr.Row():
- with gr.Column():
- image_metadata = gr.Image(label="Image with metadata", type="pil", sources=["upload"])
-
- with gr.Column():
- result_metadata = gr.Textbox(label="Metadata", show_label=True, show_copy_button=True, interactive=False, container=True, max_lines=99)
-
- image_metadata.change(
- fn=extract_exif_data,
- inputs=[image_metadata],
- outputs=[result_metadata],
- )
-
- description_ui()
- gr.LoginButton()
- gr.DuplicateButton(value="Duplicate Space for private use (This demo does not work on CPU. Requires GPU Space)")
-
-app.queue()
-app.launch(ssr_mode=False)
+import spaces
+import gradio as gr
+import json
+import torch
+from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image, AutoPipelineForInpainting
+from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
+from diffusers.utils import load_image
+from diffusers import FluxControlNetPipeline, FluxControlNetModel, FluxMultiControlNetModel, FluxControlNetImg2ImgPipeline, FluxTransformer2DModel, FluxControlNetInpaintPipeline, FluxInpaintPipeline
+from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download, HfApi
+import os
+import copy
+import random
+import time
+import requests
+import pandas as pd
+from pathlib import Path
+
+from env import models, num_loras, num_cns, HF_TOKEN, single_file_base_models
+from mod import (clear_cache, get_repo_safetensors, is_repo_name, is_repo_exists, get_model_trigger,
+ description_ui, compose_lora_json, is_valid_lora, fuse_loras, save_image, preprocess_i2i_image,
+ get_trigger_word, enhance_prompt, set_control_union_image,
+ get_control_union_mode, set_control_union_mode, get_control_params, translate_to_en)
+from modutils import (search_civitai_lora, select_civitai_lora, search_civitai_lora_json,
+ download_my_lora_flux, get_all_lora_tupled_list, apply_lora_prompt_flux,
+ update_loras_flux, update_civitai_selection, get_civitai_tag, CIVITAI_SORT, CIVITAI_PERIOD,
+ get_t2i_model_info, download_hf_file, save_image_history)
+from tagger.tagger import predict_tags_wd, compose_prompt_to_copy
+from tagger.fl2flux import predict_tags_fl2_flux
+
+#Load prompts for randomization
+df = pd.read_csv('prompts.csv', header=None)
+prompt_values = df.values.flatten()
+
+# Load LoRAs from JSON file
+with open('loras.json', 'r') as f:
+ loras = json.load(f)
+
+# Initialize the base model
+base_model = models[0]
+controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
+#controlnet_model_union_repo = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
+dtype = torch.bfloat16
+#dtype = torch.float8_e4m3fn
+#device = "cuda" if torch.cuda.is_available() else "cpu"
+taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype, token=HF_TOKEN)
+good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, token=HF_TOKEN)
+pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, token=HF_TOKEN)
+pipe_i2i = AutoPipelineForImage2Image.from_pretrained(base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
+ tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
+pipe_ip = AutoPipelineForInpainting.from_pretrained(base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
+ tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
+controlnet_union = None
+controlnet = None
+last_model = models[0]
+last_cn_on = False
+#controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype)
+#controlnet = FluxMultiControlNetModel([controlnet_union])
+#controlnet.config = controlnet_union.config
+
+MAX_SEED = 2**32-1
+
+def unload_lora():
+ global pipe, pipe_i2i, pipe_ip
+ try:
+ #pipe.unfuse_lora()
+ pipe.unload_lora_weights()
+ #pipe_i2i.unfuse_lora()
+ pipe_i2i.unload_lora_weights()
+ pipe_ip.unload_lora_weights()
+ except Exception as e:
+ print(e)
+
+def download_file_mod(url, directory=os.getcwd()):
+ path = download_hf_file(directory, url, hf_token=HF_TOKEN)
+ if not path: raise Exception(f"Download error: {url}")
+ return path
+
+# https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union
+# https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
+# https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux
+#@spaces.GPU()
+def change_base_model(repo_id: str, cn_on: bool, disable_model_cache: bool, model_type: str, progress=gr.Progress(track_tqdm=True)):
+ global pipe, pipe_i2i, pipe_ip, taef1, good_vae, controlnet_union, controlnet, last_model, last_cn_on, dtype
+ safetensors_file = None
+ single_file_base_model = single_file_base_models.get(model_type, models[0])
+ try:
+ #if not disable_model_cache and (repo_id == last_model and cn_on is last_cn_on) or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(visible=True)
+ if not disable_model_cache and (repo_id == last_model and cn_on is last_cn_on) or ((not is_repo_name(repo_id) or not is_repo_exists(repo_id)) and not ".safetensors" in repo_id): return gr.update()
+ unload_lora()
+ pipe.to("cpu")
+ pipe_i2i.to("cpu")
+ pipe_ip.to("cpu")
+ good_vae.to("cpu")
+ taef1.to("cpu")
+ if controlnet is not None: controlnet.to("cpu")
+ if controlnet_union is not None: controlnet_union.to("cpu")
+ clear_cache()
+ if cn_on:
+ progress(0, desc=f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
+ print(f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
+ controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype, token=HF_TOKEN)
+ controlnet = FluxMultiControlNetModel([controlnet_union])
+ controlnet.config = controlnet_union.config
+ if ".safetensors" in repo_id:
+ safetensors_file = download_file_mod(repo_id)
+ transformer = FluxTransformer2DModel.from_single_file(safetensors_file, subfolder="transformer", torch_dtype=dtype, config=single_file_base_model)
+ pipe = FluxControlNetPipeline.from_pretrained(single_file_base_model, transformer=transformer, controlnet=controlnet, torch_dtype=dtype, token=HF_TOKEN)
+ pipe_i2i = FluxControlNetImg2ImgPipeline.from_pretrained(single_file_base_model, controlnet=controlnet, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
+ tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
+ pipe_ip = FluxControlNetInpaintPipeline.from_pretrained(single_file_base_model, controlnet=controlnet, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
+ tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
+ else:
+ pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=dtype, token=HF_TOKEN)
+ pipe_i2i = FluxControlNetImg2ImgPipeline.from_pretrained(repo_id, controlnet=controlnet, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
+ tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
+ pipe_ip = FluxControlNetInpaintPipeline.from_pretrained(repo_id, controlnet=controlnet, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
+ tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
+ last_model = repo_id
+ last_cn_on = cn_on
+ progress(1, desc=f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
+ print(f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
+ else:
+ progress(0, desc=f"Loading model: {repo_id}")
+ print(f"Loading model: {repo_id}")
+ if ".safetensors" in repo_id:
+ safetensors_file = download_file_mod(repo_id)
+ transformer = FluxTransformer2DModel.from_single_file(safetensors_file, subfolder="transformer", torch_dtype=dtype, config=single_file_base_model)
+ pipe = DiffusionPipeline.from_pretrained(single_file_base_model, transformer=transformer, torch_dtype=dtype, token=HF_TOKEN)
+ pipe_i2i = AutoPipelineForImage2Image.from_pretrained(single_file_base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
+ tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
+ pipe_ip = AutoPipelineForInpainting.from_pretrained(single_file_base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
+ tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
+ else:
+ pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype, token=HF_TOKEN)
+ pipe_i2i = AutoPipelineForImage2Image.from_pretrained(repo_id, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
+ tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
+ pipe_ip = AutoPipelineForInpainting.from_pretrained(repo_id, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
+ tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype, token=HF_TOKEN)
+ last_model = repo_id
+ last_cn_on = cn_on
+ progress(1, desc=f"Model loaded: {repo_id}")
+ print(f"Model loaded: {repo_id}")
+ except Exception as e:
+ print(f"Model load Error: {repo_id} {e}")
+ raise gr.Error(f"Model load Error: {repo_id} {e}") from e
+ finally:
+ if safetensors_file and Path(safetensors_file).exists(): Path(safetensors_file).unlink()
+ return gr.update()
+
+change_base_model.zerogpu = True
+
+def is_repo_public(repo_id: str):
+ api = HfApi()
+ try:
+ if api.repo_exists(repo_id=repo_id, token=False): return True
+ else: return False
+ except Exception as e:
+ print(f"Error: Failed to connect {repo_id}. {e}")
+ return False
+
+class calculateDuration:
+ def __init__(self, activity_name=""):
+ self.activity_name = activity_name
+
+ def __enter__(self):
+ self.start_time = time.time()
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.end_time = time.time()
+ self.elapsed_time = self.end_time - self.start_time
+ if self.activity_name:
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
+ else:
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
+
+def download_file(url, directory=None):
+ if directory is None:
+ directory = os.getcwd() # Use current working directory if not specified
+
+ # Get the filename from the URL
+ filename = url.split('/')[-1]
+
+ # Full path for the downloaded file
+ filepath = os.path.join(directory, filename)
+
+ # Download the file
+ response = requests.get(url)
+ response.raise_for_status() # Raise an exception for bad status codes
+
+ # Write the content to the file
+ with open(filepath, 'wb') as file:
+ file.write(response.content)
+
+ return filepath
+
+def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
+ selected_index = evt.index
+ selected_indices = selected_indices or []
+ if selected_index in selected_indices:
+ selected_indices.remove(selected_index)
+ else:
+ if len(selected_indices) < 2:
+ selected_indices.append(selected_index)
+ else:
+ gr.Warning("You can select up to 2 LoRAs, remove one to select a new one.")
+ return gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), width, height, gr.update(), gr.update()
+
+ selected_info_1 = "Select a LoRA 1"
+ selected_info_2 = "Select a LoRA 2"
+ lora_scale_1 = 1.15
+ lora_scale_2 = 1.15
+ lora_image_1 = None
+ lora_image_2 = None
+ if len(selected_indices) >= 1:
+ lora1 = loras_state[selected_indices[0]]
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
+ lora_image_1 = lora1['image']
+ if len(selected_indices) >= 2:
+ lora2 = loras_state[selected_indices[1]]
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
+ lora_image_2 = lora2['image']
+
+ if selected_indices:
+ last_selected_lora = loras_state[selected_indices[-1]]
+ new_placeholder = f"Type a prompt for {last_selected_lora['title']}"
+ else:
+ new_placeholder = "Type a prompt"
+
+ return gr.update(placeholder=new_placeholder), selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, width, height, lora_image_1, lora_image_2
+
+def remove_lora_1(selected_indices, loras_state):
+ if len(selected_indices) >= 1:
+ selected_indices.pop(0)
+ selected_info_1 = "Select LoRA 1"
+ selected_info_2 = "Select LoRA 2"
+ lora_scale_1 = 1.15
+ lora_scale_2 = 1.15
+ lora_image_1 = None
+ lora_image_2 = None
+ if len(selected_indices) >= 1:
+ lora1 = loras_state[selected_indices[0]]
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
+ lora_image_1 = lora1['image']
+ if len(selected_indices) >= 2:
+ lora2 = loras_state[selected_indices[1]]
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
+ lora_image_2 = lora2['image']
+ return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
+
+def remove_lora_2(selected_indices, loras_state):
+ if len(selected_indices) >= 2:
+ selected_indices.pop(1)
+ selected_info_1 = "Select LoRA 1"
+ selected_info_2 = "Select LoRA 2"
+ lora_scale_1 = 1.15
+ lora_scale_2 = 1.15
+ lora_image_1 = None
+ lora_image_2 = None
+ if len(selected_indices) >= 1:
+ lora1 = loras_state[selected_indices[0]]
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
+ lora_image_1 = lora1['image']
+ if len(selected_indices) >= 2:
+ lora2 = loras_state[selected_indices[1]]
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
+ lora_image_2 = lora2['image']
+ return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
+
+def randomize_loras(selected_indices, loras_state):
+ if len(loras_state) < 2:
+ raise gr.Error("Not enough LoRAs to randomize.")
+ selected_indices = random.sample(range(len(loras_state)), 2)
+ lora1 = loras_state[selected_indices[0]]
+ lora2 = loras_state[selected_indices[1]]
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
+ lora_scale_1 = 1.15
+ lora_scale_2 = 1.15
+ lora_image_1 = lora1['image']
+ lora_image_2 = lora2['image']
+ random_prompt = random.choice(prompt_values)
+ return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2, random_prompt
+
+def download_loras_images(loras_json_orig: list[dict]):
+ api = HfApi(token=HF_TOKEN)
+ loras_json = []
+ for lora in loras_json_orig:
+ repo = lora.get("repo", None)
+ if repo is None or not api.repo_exists(repo_id=repo, token=HF_TOKEN):
+ print(f"LoRA '{repo}' is not exsit.")
+ continue
+ if "title" not in lora.keys() or "trigger_word" not in lora.keys() or "image" not in lora.keys():
+ title, _repo, _path, trigger_word, image_def = check_custom_model(repo)
+ if "title" not in lora.keys(): lora["title"] = title
+ if "trigger_word" not in lora.keys(): lora["trigger_word"] = trigger_word
+ if "image" not in lora.keys(): lora["image"] = image_def
+ image = lora.get("image", None)
+ try:
+ if not is_repo_public(repo) and image is not None and "http" in image and repo in image: image = download_file_mod(image)
+ lora["image"] = image if image else "/home/user/app/custom.png"
+ except Exception as e:
+ print(f"Failed to download LoRA '{repo}''s image '{image if image else ''}'. {e}")
+ lora["image"] = "/home/user/app/custom.png"
+ loras_json.append(lora)
+ return loras_json
+
+def add_custom_lora(custom_lora, selected_indices, current_loras, gallery):
+ if custom_lora:
+ try:
+ title, repo, path, trigger_word, image = check_custom_model(custom_lora)
+ if image is not None and "http" in image and not is_repo_public(repo) and repo in image:
+ try:
+ image = download_file_mod(image)
+ except Exception as e:
+ print(e)
+ image = None
+ print(f"Loaded custom LoRA: {repo}")
+ existing_item_index = next((index for (index, item) in enumerate(current_loras) if item['repo'] == repo), None)
+ if existing_item_index is None:
+ if repo.endswith(".safetensors") and repo.startswith("http"):
+ #repo = download_file(repo)
+ repo = download_file_mod(repo)
+ new_item = {
+ "image": image if image else "/home/user/app/custom.png",
+ "title": title,
+ "repo": repo,
+ "weights": path,
+ "trigger_word": trigger_word
+ }
+ print(f"New LoRA: {new_item}")
+ existing_item_index = len(current_loras)
+ current_loras.append(new_item)
+
+ # Update gallery
+ gallery_items = [(item["image"], item["title"]) for item in current_loras]
+ # Update selected_indices if there's room
+ if len(selected_indices) < 2:
+ selected_indices.append(existing_item_index)
+ else:
+ gr.Warning("You can select up to 2 LoRAs, remove one to select a new one.")
+
+ # Update selected_info and images
+ selected_info_1 = "Select a LoRA 1"
+ selected_info_2 = "Select a LoRA 2"
+ lora_scale_1 = 1.15
+ lora_scale_2 = 1.15
+ lora_image_1 = None
+ lora_image_2 = None
+ if len(selected_indices) >= 1:
+ lora1 = current_loras[selected_indices[0]]
+ selected_info_1 = f"### LoRA 1 Selected: {lora1['title']} ✨"
+ lora_image_1 = lora1['image'] if lora1['image'] else None
+ if len(selected_indices) >= 2:
+ lora2 = current_loras[selected_indices[1]]
+ selected_info_2 = f"### LoRA 2 Selected: {lora2['title']} ✨"
+ lora_image_2 = lora2['image'] if lora2['image'] else None
+ print("Finished adding custom LoRA")
+ return (
+ current_loras,
+ gr.update(value=gallery_items),
+ selected_info_1,
+ selected_info_2,
+ selected_indices,
+ lora_scale_1,
+ lora_scale_2,
+ lora_image_1,
+ lora_image_2
+ )
+ except Exception as e:
+ print(e)
+ gr.Warning(str(e))
+ return current_loras, gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update()
+ else:
+ return current_loras, gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update()
+
+def remove_custom_lora(selected_indices, current_loras, gallery):
+ if current_loras:
+ custom_lora_repo = current_loras[-1]['repo']
+ # Remove from loras list
+ current_loras = current_loras[:-1]
+ # Remove from selected_indices if selected
+ custom_lora_index = len(current_loras)
+ if custom_lora_index in selected_indices:
+ selected_indices.remove(custom_lora_index)
+ # Update gallery
+ gallery_items = [(item["image"], item["title"]) for item in current_loras]
+ # Update selected_info and images
+ selected_info_1 = "Select a LoRA 1"
+ selected_info_2 = "Select a LoRA 2"
+ lora_scale_1 = 1.15
+ lora_scale_2 = 1.15
+ lora_image_1 = None
+ lora_image_2 = None
+ if len(selected_indices) >= 1:
+ lora1 = current_loras[selected_indices[0]]
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
+ lora_image_1 = lora1['image']
+ if len(selected_indices) >= 2:
+ lora2 = current_loras[selected_indices[1]]
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
+ lora_image_2 = lora2['image']
+ return (
+ current_loras,
+ gr.update(value=gallery_items),
+ selected_info_1,
+ selected_info_2,
+ selected_indices,
+ lora_scale_1,
+ lora_scale_2,
+ lora_image_1,
+ lora_image_2
+ )
+
+@spaces.GPU(duration=70)
+@torch.inference_mode()
+def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, cn_on, progress=gr.Progress(track_tqdm=True)):
+ global pipe, taef1, good_vae, controlnet, controlnet_union
+ try:
+ good_vae.to("cuda")
+ taef1.to("cuda")
+ generator = torch.Generator(device="cuda").manual_seed(int(float(seed)))
+
+ with calculateDuration("Generating image"):
+ # Generate image
+ modes, images, scales = get_control_params()
+ if not cn_on or len(modes) == 0:
+ pipe.to("cuda")
+ pipe.vae = taef1
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
+ progress(0, desc="Start Inference.")
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
+ prompt=prompt_mash,
+ num_inference_steps=steps,
+ guidance_scale=cfg_scale,
+ width=width,
+ height=height,
+ generator=generator,
+ joint_attention_kwargs={"scale": 1.0},
+ output_type="pil",
+ good_vae=good_vae,
+ ):
+ yield img
+ else:
+ pipe.to("cuda")
+ pipe.vae = good_vae
+ if controlnet_union is not None: controlnet_union.to("cuda")
+ if controlnet is not None: controlnet.to("cuda")
+ pipe.enable_model_cpu_offload()
+ progress(0, desc="Start Inference with ControlNet.")
+ for img in pipe(
+ prompt=prompt_mash,
+ control_image=images,
+ control_mode=modes,
+ num_inference_steps=steps,
+ guidance_scale=cfg_scale,
+ width=width,
+ height=height,
+ controlnet_conditioning_scale=scales,
+ generator=generator,
+ joint_attention_kwargs={"scale": 1.0},
+ ).images:
+ yield img
+ except Exception as e:
+ print(e)
+ raise gr.Error(f"Inference Error: {e}") from e
+
+@spaces.GPU(duration=70)
+@torch.inference_mode()
+def generate_image_to_image(prompt_mash, image_input_path_dict, image_strength, is_inpaint, blur_mask, blur_factor, steps, cfg_scale, width, height, seed, cn_on, progress=gr.Progress(track_tqdm=True)):
+ global pipe_i2i, pipe_ip, good_vae, controlnet, controlnet_union
+ try:
+ good_vae.to("cuda")
+ generator = torch.Generator(device="cuda").manual_seed(int(float(seed)))
+ image_input_path = image_input_path_dict['background']
+ mask_path = image_input_path_dict['layers'][0]
+
+ with calculateDuration("Generating image"):
+ # Generate image
+ modes, images, scales = get_control_params()
+ if not cn_on or len(modes) == 0:
+ if is_inpaint: # Inpainting
+ pipe_ip.to("cuda")
+ pipe_ip.vae = good_vae
+ image_input = load_image(image_input_path)
+ mask_input = load_image(mask_path)
+ if blur_mask: mask_input = pipe_ip.mask_processor.blur(mask_input, blur_factor=blur_factor)
+ progress(0, desc="Start Inpainting Inference.")
+ final_image = pipe_ip(
+ prompt=prompt_mash,
+ image=image_input,
+ mask_image=mask_input,
+ strength=image_strength,
+ num_inference_steps=steps,
+ guidance_scale=cfg_scale,
+ width=width,
+ height=height,
+ generator=generator,
+ joint_attention_kwargs={"scale": 1.0},
+ output_type="pil",
+ ).images[0]
+ return final_image
+ else:
+ pipe_i2i.to("cuda")
+ pipe_i2i.vae = good_vae
+ image_input = load_image(image_input_path)
+ progress(0, desc="Start I2I Inference.")
+ final_image = pipe_i2i(
+ prompt=prompt_mash,
+ image=image_input,
+ strength=image_strength,
+ num_inference_steps=steps,
+ guidance_scale=cfg_scale,
+ width=width,
+ height=height,
+ generator=generator,
+ joint_attention_kwargs={"scale": 1.0},
+ output_type="pil",
+ ).images[0]
+ return final_image
+ else:
+ if is_inpaint: # Inpainting
+ pipe_ip.to("cuda")
+ pipe_ip.vae = good_vae
+ image_input = load_image(image_input_path)
+ mask_input = load_image(mask_path)
+ if blur_mask: mask_input = pipe_ip.mask_processor.blur(mask_input, blur_factor=blur_factor)
+ if controlnet_union is not None: controlnet_union.to("cuda")
+ if controlnet is not None: controlnet.to("cuda")
+ pipe_ip.enable_model_cpu_offload()
+ progress(0, desc="Start Inpainting Inference with ControlNet.")
+ final_image = pipe_ip(
+ prompt=prompt_mash,
+ control_image=images,
+ control_mode=modes,
+ image=image_input,
+ mask_image=mask_input,
+ strength=image_strength,
+ num_inference_steps=steps,
+ guidance_scale=cfg_scale,
+ width=width,
+ height=height,
+ controlnet_conditioning_scale=scales,
+ generator=generator,
+ joint_attention_kwargs={"scale": 1.0},
+ output_type="pil",
+ ).images[0]
+ return final_image
+ else:
+ pipe_i2i.to("cuda")
+ pipe_i2i.vae = good_vae
+ image_input = load_image(image_input_path['background'])
+ if controlnet_union is not None: controlnet_union.to("cuda")
+ if controlnet is not None: controlnet.to("cuda")
+ pipe_i2i.enable_model_cpu_offload()
+ progress(0, desc="Start I2I Inference with ControlNet.")
+ final_image = pipe_i2i(
+ prompt=prompt_mash,
+ control_image=images,
+ control_mode=modes,
+ image=image_input,
+ strength=image_strength,
+ num_inference_steps=steps,
+ guidance_scale=cfg_scale,
+ width=width,
+ height=height,
+ controlnet_conditioning_scale=scales,
+ generator=generator,
+ joint_attention_kwargs={"scale": 1.0},
+ output_type="pil",
+ ).images[0]
+ return final_image
+ except Exception as e:
+ print(e)
+ raise gr.Error(f"I2I Inference Error: {e}") from e
+
+def run_lora(prompt, image_input, image_strength, task_type, blur_mask, blur_factor, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2,
+ randomize_seed, seed, width, height, loras_state, lora_json, cn_on, translate_on, progress=gr.Progress(track_tqdm=True)):
+ global pipe, pipe_i2i, pipe_ip
+ if not selected_indices and not is_valid_lora(lora_json):
+ gr.Info("LoRA isn't selected.")
+ # raise gr.Error("You must select a LoRA before proceeding.")
+ progress(0, desc="Preparing Inference.")
+
+ selected_loras = [loras_state[idx] for idx in selected_indices]
+
+ if task_type == "Inpainting":
+ is_inpaint = True
+ is_i2i = True
+ elif task_type == "Image-to-Image":
+ is_inpaint = False
+ is_i2i = True
+ else: # "Text-to-Image"
+ is_inpaint = False
+ is_i2i = False
+
+ if translate_on: prompt = translate_to_en(prompt)
+
+ # Build the prompt with trigger words
+ prepends = []
+ appends = []
+ for lora in selected_loras:
+ trigger_word = lora.get('trigger_word', '')
+ if trigger_word:
+ if lora.get("trigger_position") == "prepend":
+ prepends.append(trigger_word)
+ else:
+ appends.append(trigger_word)
+ prompt_mash = " ".join(prepends + [prompt] + appends)
+ print("Prompt Mash: ", prompt_mash) #
+
+ # Unload previous LoRA weights
+ with calculateDuration("Unloading LoRA"):
+ unload_lora()
+
+ print(pipe.get_active_adapters()) #
+ print(pipe_i2i.get_active_adapters()) #
+ print(pipe_ip.get_active_adapters()) #
+
+ clear_cache() #
+
+ # Build the prompt for External LoRAs
+ prompt_mash = prompt_mash + get_model_trigger(last_model)
+ lora_names = []
+ lora_weights = []
+ if is_valid_lora(lora_json): # Load External LoRA weights
+ with calculateDuration("Loading External LoRA weights"):
+ if is_inpaint:
+ pipe_ip, lora_names, lora_weights = fuse_loras(pipe_ip, lora_json)
+ elif is_i2i:
+ pipe_i2i, lora_names, lora_weights = fuse_loras(pipe_i2i, lora_json)
+ else: pipe, lora_names, lora_weights = fuse_loras(pipe, lora_json)
+ trigger_word = get_trigger_word(lora_json)
+ prompt_mash = f"{prompt_mash} {trigger_word}"
+ print("Prompt Mash: ", prompt_mash) #
+
+ # Load LoRA weights with respective scales
+ if selected_indices:
+ with calculateDuration("Loading LoRA weights"):
+ for idx, lora in enumerate(selected_loras):
+ lora_name = f"lora_{idx}"
+ lora_names.append(lora_name)
+ print(f"Lora Name: {lora_name}")
+ lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2)
+ lora_path = lora['repo']
+ weight_name = lora.get("weights")
+ print(f"Lora Path: {lora_path}")
+ if is_inpaint:
+ pipe_ip.load_lora_weights(
+ lora_path,
+ weight_name=weight_name if weight_name else None,
+ low_cpu_mem_usage=False,
+ adapter_name=lora_name,
+ token=HF_TOKEN
+ )
+ elif is_i2i:
+ pipe_i2i.load_lora_weights(
+ lora_path,
+ weight_name=weight_name if weight_name else None,
+ low_cpu_mem_usage=False,
+ adapter_name=lora_name,
+ token=HF_TOKEN
+ )
+ else:
+ pipe.load_lora_weights(
+ lora_path,
+ weight_name=weight_name if weight_name else None,
+ low_cpu_mem_usage=False,
+ adapter_name=lora_name,
+ token=HF_TOKEN
+ )
+ print("Loaded LoRAs:", lora_names)
+ if selected_indices or is_valid_lora(lora_json):
+ if is_inpaint:
+ pipe_ip.set_adapters(lora_names, adapter_weights=lora_weights)
+ elif is_i2i:
+ pipe_i2i.set_adapters(lora_names, adapter_weights=lora_weights)
+ else:
+ pipe.set_adapters(lora_names, adapter_weights=lora_weights)
+
+ print(pipe.get_active_adapters()) #
+ print(pipe_i2i.get_active_adapters()) #
+ print(pipe_ip.get_active_adapters()) #
+
+ # Set random seed for reproducibility
+ with calculateDuration("Randomizing seed"):
+ if randomize_seed:
+ seed = random.randint(0, MAX_SEED)
+
+ # Generate image
+ progress(0, desc="Running Inference.")
+ if is_i2i:
+ final_image = generate_image_to_image(prompt_mash, image_input, image_strength, is_inpaint, blur_mask, blur_factor, steps, cfg_scale, width, height, seed, cn_on)
+ yield save_image(final_image, None, last_model, prompt_mash, height, width, steps, cfg_scale, seed), seed, gr.update(visible=False)
+ else:
+ image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, cn_on)
+ # Consume the generator to get the final image
+ final_image = None
+ step_counter = 0
+ for image in image_generator:
+ step_counter+=1
+ final_image = image
+ progress_bar = f''
+ yield image, seed, gr.update(value=progress_bar, visible=True)
+ yield save_image(final_image, None, last_model, prompt_mash, height, width, steps, cfg_scale, seed), seed, gr.update(value=progress_bar, visible=False)
+
+run_lora.zerogpu = True
+
+def get_huggingface_safetensors(link):
+ split_link = link.split("/")
+ if len(split_link) == 2:
+ model_card = ModelCard.load(link, token=HF_TOKEN)
+ base_model = model_card.data.get("base_model")
+ print(f"Base model: {base_model}")
+ if base_model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
+ #raise Exception("Not a FLUX LoRA!")
+ gr.Warning("Not a FLUX LoRA?")
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
+ trigger_word = model_card.data.get("instance_prompt", "")
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
+ fs = HfFileSystem(token=HF_TOKEN)
+ safetensors_name = None
+ try:
+ list_of_files = fs.ls(link, detail=False)
+ for file in list_of_files:
+ if file.endswith(".safetensors"):
+ safetensors_name = file.split("/")[-1]
+ if not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
+ image_elements = file.split("/")
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
+ except Exception as e:
+ print(e)
+ raise gr.Error("Invalid Hugging Face repository with a *.safetensors LoRA")
+ if not safetensors_name:
+ raise gr.Error("No *.safetensors file found in the repository")
+ return split_link[1], link, safetensors_name, trigger_word, image_url
+ else:
+ raise gr.Error("Invalid Hugging Face repository link")
+
+def check_custom_model(link):
+ if link.endswith(".safetensors"):
+ # Treat as direct link to the LoRA weights
+ title = os.path.basename(link)
+ repo = link
+ path = None # No specific weight name
+ trigger_word = ""
+ image_url = None
+ return title, repo, path, trigger_word, image_url
+ elif link.startswith("https://"):
+ if "huggingface.co" in link:
+ link_split = link.split("huggingface.co/")
+ return get_huggingface_safetensors(link_split[1])
+ else:
+ raise Exception("Unsupported URL")
+ else:
+ # Assume it's a Hugging Face model path
+ return get_huggingface_safetensors(link)
+
+def update_history(new_image, history):
+ """Updates the history gallery with the new image."""
+ if history is None:
+ history = []
+ history.insert(0, new_image)
+ return history
+
+loras = download_loras_images(loras)
+
+css = '''
+#gen_column{align-self: stretch}
+#gen_btn{height: 100%}
+#title{text-align: center}
+#title h1{font-size: 3em; display:inline-flex; align-items:center}
+#title img{width: 100px; margin-right: 0.25em}
+#gallery .grid-wrap{height: 5vh}
+#lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
+.custom_lora_card{margin-bottom: 1em}
+.card_internal{display: flex;height: 100px;margin-top: .5em}
+.card_internal img{margin-right: 1em}
+.styler{--form-gap-width: 0px !important}
+#progress{height:30px}
+#progress .generating{display:none}
+.progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
+.progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
+#component-8, .button_total{height: 100%; align-self: stretch;}
+#loaded_loras [data-testid="block-info"]{font-size:80%}
+#custom_lora_structure{background: var(--block-background-fill)}
+#custom_lora_btn{margin-top: auto;margin-bottom: 11px}
+#random_btn{font-size: 300%}
+#component-11{align-self: stretch;}
+.info { align-items: center; text-align: center; }
+.desc [src$='#float'] { float: right; margin: 20px; }
+'''
+with gr.Blocks(theme='NoCrypt/miku@>=1.2.2', fill_width=True, css=css, delete_cache=(60, 3600)) as app:
+ with gr.Tab("FLUX LoRA the Explorer"):
+ title = gr.HTML(
+ """FLUX LoRA the Explorer Mod
""",
+ elem_id="title",
+ )
+ loras_state = gr.State(loras)
+ selected_indices = gr.State([])
+ with gr.Row():
+ with gr.Column(scale=3):
+ with gr.Group():
+ with gr.Accordion("Generate Prompt from Image", open=False):
+ tagger_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"], height=256)
+ with gr.Accordion(label="Advanced options", open=False):
+ tagger_general_threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.01, interactive=True)
+ tagger_character_threshold = gr.Slider(label="Character threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.01, interactive=True)
+ neg_prompt = gr.Text(label="Negative Prompt", lines=1, max_lines=8, placeholder="", visible=False)
+ v2_character = gr.Textbox(label="Character", placeholder="hatsune miku", scale=2, visible=False)
+ v2_series = gr.Textbox(label="Series", placeholder="vocaloid", scale=2, visible=False)
+ v2_copy = gr.Button(value="Copy to clipboard", size="sm", interactive=False, visible=False)
+ tagger_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-Flux"], label="Algorithms", value=["Use WD Tagger"])
+ tagger_generate_from_image = gr.Button(value="Generate Prompt from Image")
+ prompt = gr.Textbox(label="Prompt", lines=1, max_lines=8, placeholder="Type a prompt", show_copy_button=True)
+ with gr.Row():
+ prompt_enhance = gr.Button(value="Enhance your prompt", variant="secondary")
+ auto_trans = gr.Checkbox(label="Auto translate to English", value=False, elem_classes="info")
+ with gr.Column(scale=1, elem_id="gen_column"):
+ generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn", elem_classes=["button_total"])
+ with gr.Row(elem_id="loaded_loras"):
+ with gr.Column(scale=1, min_width=25):
+ randomize_button = gr.Button("🎲", variant="secondary", scale=1, elem_id="random_btn")
+ with gr.Column(scale=8):
+ with gr.Row():
+ with gr.Column(scale=0, min_width=50):
+ lora_image_1 = gr.Image(label="LoRA 1 Image", interactive=False, min_width=50, width=50, show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False, height=50)
+ with gr.Column(scale=3, min_width=100):
+ selected_info_1 = gr.Markdown("Select a LoRA 1")
+ with gr.Column(scale=5, min_width=50):
+ lora_scale_1 = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=3, step=0.01, value=1.15)
+ with gr.Row():
+ remove_button_1 = gr.Button("Remove", size="sm")
+ with gr.Column(scale=8):
+ with gr.Row():
+ with gr.Column(scale=0, min_width=50):
+ lora_image_2 = gr.Image(label="LoRA 2 Image", interactive=False, min_width=50, width=50, show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False, height=50)
+ with gr.Column(scale=3, min_width=100):
+ selected_info_2 = gr.Markdown("Select a LoRA 2")
+ with gr.Column(scale=5, min_width=50):
+ lora_scale_2 = gr.Slider(label="LoRA 2 Scale", minimum=0, maximum=3, step=0.01, value=1.15)
+ with gr.Row():
+ remove_button_2 = gr.Button("Remove", size="sm")
+ with gr.Row():
+ with gr.Column():
+ selected_info = gr.Markdown("")
+ gallery = gr.Gallery([(item["image"], item["title"]) for item in loras], label="LoRA Gallery", allow_preview=False,
+ columns=4, elem_id="gallery", show_share_button=False, interactive=False)
+ with gr.Group():
+ with gr.Row(elem_id="custom_lora_structure"):
+ custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path or *.safetensors public URL", placeholder="multimodalart/vintage-ads-flux", scale=3, min_width=150)
+ add_custom_lora_button = gr.Button("Add Custom LoRA", elem_id="custom_lora_btn", scale=2, min_width=150)
+ remove_custom_lora_button = gr.Button("Remove Custom LoRA", visible=False)
+ gr.Markdown("[Check the list of FLUX LoRAs](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
+ with gr.Column():
+ progress_bar = gr.Markdown(elem_id="progress",visible=False)
+ result = gr.Image(label="Generated Image", format="png", type="filepath", show_share_button=False, interactive=False)
+ with gr.Accordion("History", open=False):
+ history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False, format="png",
+ show_share_button=False, show_download_button=True)
+ history_files = gr.Files(interactive=False, visible=False)
+ history_clear_button = gr.Button(value="Clear History", variant="secondary")
+ history_clear_button.click(lambda: ([], []), None, [history_gallery, history_files], queue=False, show_api=False)
+ with gr.Group():
+ with gr.Row():
+ model_name = gr.Dropdown(label="Base Model", info="You can enter a huggingface model repo_id or path of single safetensors file to want to use.",
+ choices=models, value=models[0], allow_custom_value=True, min_width=320, scale=5)
+ model_type = gr.Radio(label="Model type", info="Model type of single safetensors file",
+ choices=list(single_file_base_models.keys()), value=list(single_file_base_models.keys())[0], scale=1)
+ model_info = gr.Markdown(elem_classes="info")
+
+ with gr.Row():
+ with gr.Accordion("Advanced Settings", open=False):
+ with gr.Row():
+ with gr.Column():
+ #input_image = gr.Image(label="Input image", type="filepath", height=256, sources=["upload", "clipboard"], show_share_button=False)
+ input_image = gr.ImageEditor(label='Input image', type='filepath', sources=["upload", "clipboard"], image_mode='RGB', show_share_button=False, show_fullscreen_button=False,
+ layers=False, brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed", default_size=32), value=None,
+ canvas_size=(384, 384), width=384, height=512)
+ with gr.Column():
+ task_type = gr.Radio(label="Task", choices=["Text-to-Image", "Image-to-Image", "Inpainting"], value="Text-to-Image")
+ image_strength = gr.Slider(label="Strength", info="Lower means more image influence in I2I, opposite in Inpaint", minimum=0.01, maximum=1.0, step=0.01, value=0.75)
+ blur_mask = gr.Checkbox(label="Blur mask", value=False)
+ blur_factor = gr.Slider(label="Blur factor", minimum=0, maximum=50, step=1, value=33)
+ input_image_preprocess = gr.Checkbox(True, label="Preprocess Input image")
+ with gr.Column():
+ with gr.Row():
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
+ with gr.Row():
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
+ disable_model_cache = gr.Checkbox(False, label="Disable model caching")
+ with gr.Accordion("External LoRA", open=True):
+ with gr.Column():
+ deselect_lora_button = gr.Button("Remove External LoRAs", variant="secondary")
+ lora_repo_json = gr.JSON(value=[{}] * num_loras, visible=False)
+ lora_repo = [None] * num_loras
+ lora_weights = [None] * num_loras
+ lora_trigger = [None] * num_loras
+ lora_wt = [None] * num_loras
+ lora_info = [None] * num_loras
+ lora_copy = [None] * num_loras
+ lora_md = [None] * num_loras
+ lora_num = [None] * num_loras
+ with gr.Row():
+ for i in range(num_loras):
+ with gr.Column():
+ lora_repo[i] = gr.Dropdown(label=f"LoRA {int(i+1)} Repo", choices=get_all_lora_tupled_list(), info="Input LoRA Repo ID", value="", allow_custom_value=True, min_width=320)
+ with gr.Row():
+ lora_weights[i] = gr.Dropdown(label=f"LoRA {int(i+1)} Filename", choices=[], info="Optional", value="", allow_custom_value=True)
+ lora_trigger[i] = gr.Textbox(label=f"LoRA {int(i+1)} Trigger Prompt", lines=1, max_lines=4, value="")
+ lora_wt[i] = gr.Slider(label=f"LoRA {int(i+1)} Scale", minimum=-3, maximum=3, step=0.01, value=1.00)
+ with gr.Row():
+ lora_info[i] = gr.Textbox(label="", info="Example of prompt:", value="", show_copy_button=True, interactive=False, visible=False)
+ lora_copy[i] = gr.Button(value="Copy example to prompt", visible=False)
+ lora_md[i] = gr.Markdown(value="", visible=False)
+ lora_num[i] = gr.Number(i, visible=False)
+ with gr.Accordion("From URL", open=True, visible=True):
+ with gr.Row():
+ lora_search_civitai_basemodel = gr.CheckboxGroup(label="Search LoRA for", choices=["Flux.1 D", "Flux.1 S"], value=["Flux.1 D"])
+ lora_search_civitai_sort = gr.Radio(label="Sort", choices=CIVITAI_SORT, value="Most Downloaded")
+ lora_search_civitai_period = gr.Radio(label="Period", choices=CIVITAI_PERIOD, value="Month")
+ with gr.Row():
+ lora_search_civitai_query = gr.Textbox(label="Query", placeholder="flux", lines=1)
+ lora_search_civitai_tag = gr.Dropdown(label="Tag", choices=get_civitai_tag(), value=get_civitai_tag()[0], allow_custom_value=True)
+ lora_search_civitai_user = gr.Textbox(label="Username", lines=1)
+ lora_search_civitai_submit = gr.Button("Search on Civitai")
+ with gr.Row():
+ lora_search_civitai_json = gr.JSON(value={}, visible=False)
+ lora_search_civitai_desc = gr.Markdown(value="", visible=False, elem_classes="desc")
+ with gr.Accordion("Select from Gallery", open=False):
+ lora_search_civitai_gallery = gr.Gallery([], label="Results", allow_preview=False, columns=5, show_share_button=False, interactive=False)
+ lora_search_civitai_result = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
+ lora_download_url = gr.Textbox(label="LoRA URL", placeholder="https://civitai.com/api/download/models/28907", lines=1)
+ with gr.Row():
+ lora_download = [None] * num_loras
+ for i in range(num_loras):
+ lora_download[i] = gr.Button(f"Get and set LoRA to {int(i+1)}")
+ with gr.Accordion("ControlNet (extremely slow)", open=True, visible=False):
+ with gr.Column():
+ cn_on = gr.Checkbox(False, label="Use ControlNet")
+ cn_mode = [None] * num_cns
+ cn_scale = [None] * num_cns
+ cn_image = [None] * num_cns
+ cn_image_ref = [None] * num_cns
+ cn_res = [None] * num_cns
+ cn_num = [None] * num_cns
+ with gr.Row():
+ for i in range(num_cns):
+ with gr.Column():
+ cn_mode[i] = gr.Radio(label=f"ControlNet {int(i+1)} Mode", choices=get_control_union_mode(), value=get_control_union_mode()[0])
+ with gr.Row():
+ cn_scale[i] = gr.Slider(label=f"ControlNet {int(i+1)} Weight", minimum=0.0, maximum=1.0, step=0.01, value=0.75)
+ cn_res[i] = gr.Slider(label=f"ControlNet {int(i+1)} Preprocess resolution", minimum=128, maximum=512, value=384, step=1)
+ cn_num[i] = gr.Number(i, visible=False)
+ with gr.Row():
+ cn_image_ref[i] = gr.Image(label="Image Reference", type="pil", format="png", height=256, sources=["upload", "clipboard"], show_share_button=False)
+ cn_image[i] = gr.Image(label="Control Image", type="pil", format="png", height=256, show_share_button=False, interactive=False)
+
+ gallery.select(
+ update_selection,
+ inputs=[selected_indices, loras_state, width, height],
+ outputs=[prompt, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, width, height, lora_image_1, lora_image_2])
+ remove_button_1.click(
+ remove_lora_1,
+ inputs=[selected_indices, loras_state],
+ outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
+ )
+ remove_button_2.click(
+ remove_lora_2,
+ inputs=[selected_indices, loras_state],
+ outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
+ )
+ randomize_button.click(
+ randomize_loras,
+ inputs=[selected_indices, loras_state],
+ outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2, prompt]
+ )
+ add_custom_lora_button.click(
+ add_custom_lora,
+ inputs=[custom_lora, selected_indices, loras_state, gallery],
+ outputs=[loras_state, gallery, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
+ )
+ remove_custom_lora_button.click(
+ remove_custom_lora,
+ inputs=[selected_indices, loras_state, gallery],
+ outputs=[loras_state, gallery, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
+ )
+ gr.on(
+ triggers=[generate_button.click, prompt.submit],
+ fn=change_base_model,
+ inputs=[model_name, cn_on, disable_model_cache, model_type],
+ outputs=[result],
+ queue=True,
+ show_api=False,
+ trigger_mode="once",
+ ).success(
+ fn=run_lora,
+ inputs=[prompt, input_image, image_strength, task_type, blur_mask, blur_factor, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2,
+ randomize_seed, seed, width, height, loras_state, lora_repo_json, cn_on, auto_trans],
+ outputs=[result, seed, progress_bar],
+ queue=True,
+ show_api=True,
+ #).then( # Update the history gallery
+ # fn=lambda x, history: update_history(x, history),
+ # inputs=[result, history_gallery],
+ # outputs=history_gallery,
+ ).success(save_image_history, [result, history_gallery, history_files, model_name], [history_gallery, history_files], queue=False, show_api=False)
+
+ input_image.clear(lambda: gr.update(value="Text-to-Image"), None, [task_type], queue=False, show_api=False)
+ input_image.upload(preprocess_i2i_image, [input_image, input_image_preprocess, height, width], [input_image], queue=False, show_api=False)\
+ .success(lambda: gr.update(value="Image-to-Image"), None, [task_type], queue=False, show_api=False)
+ gr.on(
+ triggers=[model_name.change, cn_on.change],
+ fn=get_t2i_model_info,
+ inputs=[model_name],
+ outputs=[model_info],
+ queue=False,
+ show_api=False,
+ trigger_mode="once",
+ )#.then(change_base_model, [model_name, cn_on, disable_model_cache, model_type], [result], queue=True, show_api=False)
+ prompt_enhance.click(enhance_prompt, [prompt], [prompt], queue=False, show_api=False)
+
+ gr.on(
+ triggers=[lora_search_civitai_submit.click, lora_search_civitai_query.submit],
+ fn=search_civitai_lora,
+ inputs=[lora_search_civitai_query, lora_search_civitai_basemodel, lora_search_civitai_sort, lora_search_civitai_period,
+ lora_search_civitai_tag, lora_search_civitai_user, lora_search_civitai_gallery],
+ outputs=[lora_search_civitai_result, lora_search_civitai_desc, lora_search_civitai_submit, lora_search_civitai_query, lora_search_civitai_gallery],
+ scroll_to_output=True,
+ queue=True,
+ show_api=False,
+ )
+ lora_search_civitai_json.change(search_civitai_lora_json, [lora_search_civitai_query, lora_search_civitai_basemodel], [lora_search_civitai_json], queue=True, show_api=True) # fn for api
+ lora_search_civitai_result.change(select_civitai_lora, [lora_search_civitai_result], [lora_download_url, lora_search_civitai_desc], scroll_to_output=True, queue=False, show_api=False)
+ lora_search_civitai_gallery.select(update_civitai_selection, None, [lora_search_civitai_result], queue=False, show_api=False)
+
+ for i, l in enumerate(lora_repo):
+ deselect_lora_button.click(lambda: ("", 1.0), None, [lora_repo[i], lora_wt[i]], queue=False, show_api=False)
+ gr.on(
+ triggers=[lora_download[i].click],
+ fn=download_my_lora_flux,
+ inputs=[lora_download_url, lora_repo[i]],
+ outputs=[lora_repo[i]],
+ scroll_to_output=True,
+ queue=True,
+ show_api=False,
+ )
+ gr.on(
+ triggers=[lora_repo[i].change, lora_wt[i].change],
+ fn=update_loras_flux,
+ inputs=[prompt, lora_repo[i], lora_wt[i]],
+ outputs=[prompt, lora_repo[i], lora_wt[i], lora_info[i], lora_md[i]],
+ queue=False,
+ trigger_mode="once",
+ show_api=False,
+ ).success(get_repo_safetensors, [lora_repo[i]], [lora_weights[i]], queue=False, show_api=False
+ ).success(apply_lora_prompt_flux, [lora_info[i]], [lora_trigger[i]], queue=False, show_api=False
+ ).success(compose_lora_json, [lora_repo_json, lora_num[i], lora_repo[i], lora_wt[i], lora_weights[i], lora_trigger[i]], [lora_repo_json], queue=False, show_api=False)
+
+ for i, m in enumerate(cn_mode):
+ gr.on(
+ triggers=[cn_mode[i].change, cn_scale[i].change],
+ fn=set_control_union_mode,
+ inputs=[cn_num[i], cn_mode[i], cn_scale[i]],
+ outputs=[cn_on],
+ queue=True,
+ show_api=False,
+ ).success(set_control_union_image, [cn_num[i], cn_mode[i], cn_image_ref[i], height, width, cn_res[i]], [cn_image[i]], queue=False, show_api=False)
+ cn_image_ref[i].upload(set_control_union_image, [cn_num[i], cn_mode[i], cn_image_ref[i], height, width, cn_res[i]], [cn_image[i]], queue=False, show_api=False)
+
+ tagger_generate_from_image.click(lambda: ("", "", ""), None, [v2_series, v2_character, prompt], queue=False, show_api=False,
+ ).success(
+ predict_tags_wd,
+ [tagger_image, prompt, tagger_algorithms, tagger_general_threshold, tagger_character_threshold],
+ [v2_series, v2_character, prompt, v2_copy],
+ show_api=False,
+ ).success(predict_tags_fl2_flux, [tagger_image, prompt, tagger_algorithms], [prompt], show_api=False,
+ ).success(compose_prompt_to_copy, [v2_character, v2_series, prompt], [prompt], queue=False, show_api=False)
+
+ with gr.Tab("FLUX Prompt Generator"):
+ from prompt import (PromptGenerator, HuggingFaceInferenceNode, florence_caption,
+ ARTFORM, PHOTO_TYPE, ROLES, HAIRSTYLES, LIGHTING, COMPOSITION, POSE, BACKGROUND,
+ PHOTOGRAPHY_STYLES, DEVICE, PHOTOGRAPHER, ARTIST, DIGITAL_ARTFORM, PLACE,
+ FEMALE_DEFAULT_TAGS, MALE_DEFAULT_TAGS, FEMALE_BODY_TYPES, MALE_BODY_TYPES,
+ FEMALE_CLOTHING, MALE_CLOTHING, FEMALE_ADDITIONAL_DETAILS, MALE_ADDITIONAL_DETAILS, pg_title)
+
+ prompt_generator = PromptGenerator()
+ huggingface_node = HuggingFaceInferenceNode()
+
+ gr.HTML(pg_title)
+
+ with gr.Row():
+ with gr.Column(scale=2):
+ with gr.Accordion("Basic Settings"):
+ pg_custom = gr.Textbox(label="Custom Input Prompt (optional)")
+ pg_subject = gr.Textbox(label="Subject (optional)")
+ pg_gender = gr.Radio(["female", "male"], label="Gender", value="female")
+
+ # Add the radio button for global option selection
+ pg_global_option = gr.Radio(
+ ["Disabled", "Random", "No Figure Rand"],
+ label="Set all options to:",
+ value="Disabled"
+ )
+
+ with gr.Accordion("Artform and Photo Type", open=False):
+ pg_artform = gr.Dropdown(["disabled", "random"] + ARTFORM, label="Artform", value="disabled")
+ pg_photo_type = gr.Dropdown(["disabled", "random"] + PHOTO_TYPE, label="Photo Type", value="disabled")
+
+ with gr.Accordion("Character Details", open=False):
+ pg_body_types = gr.Dropdown(["disabled", "random"] + FEMALE_BODY_TYPES + MALE_BODY_TYPES, label="Body Types", value="disabled")
+ pg_default_tags = gr.Dropdown(["disabled", "random"] + FEMALE_DEFAULT_TAGS + MALE_DEFAULT_TAGS, label="Default Tags", value="disabled")
+ pg_roles = gr.Dropdown(["disabled", "random"] + ROLES, label="Roles", value="disabled")
+ pg_hairstyles = gr.Dropdown(["disabled", "random"] + HAIRSTYLES, label="Hairstyles", value="disabled")
+ pg_clothing = gr.Dropdown(["disabled", "random"] + FEMALE_CLOTHING + MALE_CLOTHING, label="Clothing", value="disabled")
+
+ with gr.Accordion("Scene Details", open=False):
+ pg_place = gr.Dropdown(["disabled", "random"] + PLACE, label="Place", value="disabled")
+ pg_lighting = gr.Dropdown(["disabled", "random"] + LIGHTING, label="Lighting", value="disabled")
+ pg_composition = gr.Dropdown(["disabled", "random"] + COMPOSITION, label="Composition", value="disabled")
+ pg_pose = gr.Dropdown(["disabled", "random"] + POSE, label="Pose", value="disabled")
+ pg_background = gr.Dropdown(["disabled", "random"] + BACKGROUND, label="Background", value="disabled")
+
+ with gr.Accordion("Style and Artist", open=False):
+ pg_additional_details = gr.Dropdown(["disabled", "random"] + FEMALE_ADDITIONAL_DETAILS + MALE_ADDITIONAL_DETAILS, label="Additional Details", value="disabled")
+ pg_photography_styles = gr.Dropdown(["disabled", "random"] + PHOTOGRAPHY_STYLES, label="Photography Styles", value="disabled")
+ pg_device = gr.Dropdown(["disabled", "random"] + DEVICE, label="Device", value="disabled")
+ pg_photographer = gr.Dropdown(["disabled", "random"] + PHOTOGRAPHER, label="Photographer", value="disabled")
+ pg_artist = gr.Dropdown(["disabled", "random"] + ARTIST, label="Artist", value="disabled")
+ pg_digital_artform = gr.Dropdown(["disabled", "random"] + DIGITAL_ARTFORM, label="Digital Artform", value="disabled")
+
+ pg_generate_button = gr.Button("Generate Prompt")
+
+ with gr.Column(scale=2):
+ with gr.Accordion("Image and Caption", open=False):
+ pg_input_image = gr.Image(label="Input Image (optional)")
+ pg_caption_output = gr.Textbox(label="Generated Caption", lines=3)
+ pg_create_caption_button = gr.Button("Create Caption")
+ pg_add_caption_button = gr.Button("Add Caption to Prompt")
+
+ with gr.Accordion("Prompt Generation", open=True):
+ pg_output = gr.Textbox(label="Generated Prompt / Input Text", lines=4)
+ pg_t5xxl_output = gr.Textbox(label="T5XXL Output", visible=True)
+ pg_clip_l_output = gr.Textbox(label="CLIP L Output", visible=True)
+ pg_clip_g_output = gr.Textbox(label="CLIP G Output", visible=True)
+
+ with gr.Column(scale=2):
+ with gr.Accordion("Prompt Generation with LLM", open=False):
+ pg_happy_talk = gr.Checkbox(label="Happy Talk", value=True)
+ pg_compress = gr.Checkbox(label="Compress", value=True)
+ pg_compression_level = gr.Radio(["soft", "medium", "hard"], label="Compression Level", value="hard")
+ pg_poster = gr.Checkbox(label="Poster", value=False)
+ pg_custom_base_prompt = gr.Textbox(label="Custom Base Prompt", lines=5)
+ pg_generate_text_button = gr.Button("Generate Prompt with LLM (Llama 3.1 70B)")
+ pg_text_output = gr.Textbox(label="Generated Text", lines=10)
+
+ def create_caption(image):
+ if image is not None:
+ return florence_caption(image)
+ return ""
+
+ pg_create_caption_button.click(
+ create_caption,
+ inputs=[pg_input_image],
+ outputs=[pg_caption_output]
+ )
+
+ def generate_prompt_with_dynamic_seed(*args):
+ # Generate a new random seed
+ dynamic_seed = random.randint(0, 1000000)
+
+ # Call the generate_prompt function with the dynamic seed
+ result = prompt_generator.generate_prompt(dynamic_seed, *args)
+
+ # Return the result along with the used seed
+ return [dynamic_seed] + list(result)
+
+ pg_generate_button.click(
+ generate_prompt_with_dynamic_seed,
+ inputs=[pg_custom, pg_subject, pg_gender, pg_artform, pg_photo_type, pg_body_types, pg_default_tags, pg_roles, pg_hairstyles,
+ pg_additional_details, pg_photography_styles, pg_device, pg_photographer, pg_artist, pg_digital_artform,
+ pg_place, pg_lighting, pg_clothing, pg_composition, pg_pose, pg_background, pg_input_image],
+ outputs=[gr.Number(label="Used Seed", visible=False), pg_output, gr.Number(visible=False), pg_t5xxl_output, pg_clip_l_output, pg_clip_g_output]
+ ) #
+
+ pg_add_caption_button.click(
+ prompt_generator.add_caption_to_prompt,
+ inputs=[pg_output, pg_caption_output],
+ outputs=[pg_output]
+ )
+
+ pg_generate_text_button.click(
+ huggingface_node.generate,
+ inputs=[pg_output, pg_happy_talk, pg_compress, pg_compression_level, pg_poster, pg_custom_base_prompt],
+ outputs=pg_text_output
+ )
+
+ def update_all_options(choice):
+ updates = {}
+ if choice == "Disabled":
+ for dropdown in [
+ pg_artform, pg_photo_type, pg_body_types, pg_default_tags, pg_roles, pg_hairstyles, pg_clothing,
+ pg_place, pg_lighting, pg_composition, pg_pose, pg_background, pg_additional_details,
+ pg_photography_styles, pg_device, pg_photographer, pg_artist, pg_digital_artform
+ ]:
+ updates[dropdown] = gr.update(value="disabled")
+ elif choice == "Random":
+ for dropdown in [
+ pg_artform, pg_photo_type, pg_body_types, pg_default_tags, pg_roles, pg_hairstyles, pg_clothing,
+ pg_place, pg_lighting, pg_composition, pg_pose, pg_background, pg_additional_details,
+ pg_photography_styles, pg_device, pg_photographer, pg_artist, pg_digital_artform
+ ]:
+ updates[dropdown] = gr.update(value="random")
+ else: # No Figure Random
+ for dropdown in [pg_photo_type, pg_body_types, pg_default_tags, pg_roles, pg_hairstyles, pg_clothing, pg_pose, pg_additional_details]:
+ updates[dropdown] = gr.update(value="disabled")
+ for dropdown in [pg_artform, pg_place, pg_lighting, pg_composition, pg_background, pg_photography_styles, pg_device, pg_photographer, pg_artist, pg_digital_artform]:
+ updates[dropdown] = gr.update(value="random")
+ return updates
+
+ pg_global_option.change(
+ update_all_options,
+ inputs=[pg_global_option],
+ outputs=[
+ pg_artform, pg_photo_type, pg_body_types, pg_default_tags, pg_roles, pg_hairstyles, pg_clothing,
+ pg_place, pg_lighting, pg_composition, pg_pose, pg_background, pg_additional_details,
+ pg_photography_styles, pg_device, pg_photographer, pg_artist, pg_digital_artform
+ ]
+ )
+
+ with gr.Tab("PNG Info"):
+ def extract_exif_data(image):
+ if image is None: return ""
+
+ try:
+ metadata_keys = ['parameters', 'metadata', 'prompt', 'Comment']
+
+ for key in metadata_keys:
+ if key in image.info:
+ return image.info[key]
+
+ return str(image.info)
+
+ except Exception as e:
+ return f"Error extracting metadata: {str(e)}"
+
+ with gr.Row():
+ with gr.Column():
+ image_metadata = gr.Image(label="Image with metadata", type="pil", sources=["upload"])
+
+ with gr.Column():
+ result_metadata = gr.Textbox(label="Metadata", show_label=True, show_copy_button=True, interactive=False, container=True, max_lines=99)
+
+ image_metadata.change(
+ fn=extract_exif_data,
+ inputs=[image_metadata],
+ outputs=[result_metadata],
+ )
+
+ description_ui()
+ gr.LoginButton()
+ gr.DuplicateButton(value="Duplicate Space for private use (This demo does not work on CPU. Requires GPU Space)")
+
+app.queue()
+app.launch(ssr_mode=False)