File size: 7,523 Bytes
d34ac72
 
 
 
 
 
 
 
 
 
09fa6ac
d34ac72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09fa6ac
 
b397bfd
d34ac72
 
 
 
 
 
 
 
 
 
 
 
b397bfd
 
 
d34ac72
 
 
 
 
 
 
 
 
 
 
 
 
b397bfd
09fa6ac
d34ac72
 
 
 
09fa6ac
d34ac72
 
 
 
 
 
 
 
 
 
09fa6ac
d34ac72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6802e8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import gradio as gr
import json
import logging
import torch
from PIL import Image
import spaces
from diffusers import DiffusionPipeline
import copy
import random
import time
from mod import models, clear_cache, get_repo_safetensors, change_base_model

# Load LoRAs from JSON file
with open('loras.json', 'r') as f:
    loras = json.load(f)

# Initialize the base model
base_model = models[0]
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)

MAX_SEED = 2**32-1

class calculateDuration:
    def __init__(self, activity_name=""):
        self.activity_name = activity_name

    def __enter__(self):
        self.start_time = time.time()
        return self
    
    def __exit__(self, exc_type, exc_value, traceback):
        self.end_time = time.time()
        self.elapsed_time = self.end_time - self.start_time
        if self.activity_name:
            print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
        else:
            print(f"Elapsed time: {self.elapsed_time:.6f} seconds")


def update_selection(evt: gr.SelectData, width, height):
    selected_lora = loras[evt.index]
    new_placeholder = f"Type a prompt for {selected_lora['title']}"
    lora_repo = selected_lora["repo"]
    updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
    if "aspect" in selected_lora:
        if selected_lora["aspect"] == "portrait":
            width = 768
            height = 1024
        elif selected_lora["aspect"] == "landscape":
            width = 1024
            height = 768
    return (
        gr.update(placeholder=new_placeholder),
        updated_text,
        evt.index,
        width,
        height,
    )

@spaces.GPU(duration=70)
def generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress):
    pipe.to("cuda")
    generator = torch.Generator(device="cuda").manual_seed(seed)
    
    with calculateDuration("Generating image"):
        # Generate image
        image = pipe(
            prompt=f"{prompt} {trigger_word}",
            num_inference_steps=steps,
            guidance_scale=cfg_scale,
            width=width,
            height=height,
            generator=generator,
            joint_attention_kwargs={"scale": lora_scale},
        ).images[0]
    return image

def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,

              lora_scale, lora_repo, lora_weights, lora_trigger, progress=gr.Progress(track_tqdm=True)):
    if selected_index is None and not lora_repo:
        gr.Info("LoRA isn't selected.")
    #    raise gr.Error("You must select a LoRA before proceeding.")

    if selected_index is not None and not lora_repo:
        selected_lora = loras[selected_index]
        lora_path = selected_lora["repo"]
        trigger_word = selected_lora["trigger_word"]
    else: # override
        selected_lora = loras[0]
        lora_path = lora_repo
        trigger_word = lora_trigger

    # Load LoRA weights
    with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
        if selected_index is None and not lora_repo: # override
            pass
        elif lora_weights: # override
            pipe.load_lora_weights(lora_path, weight_name=lora_weights)
        elif "weights" in selected_lora:
            pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
        else:
            pipe.load_lora_weights(lora_path)
        
    # Set random seed for reproducibility
    with calculateDuration("Randomizing seed"):
        if randomize_seed:
            seed = random.randint(0, MAX_SEED)
    
    image = generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress)
    pipe.to("cpu")
    if selected_index is not None or lora_repo: pipe.unload_lora_weights()
    clear_cache()
    return image, seed  

run_lora.zerogpu = True



css = '''

#gen_btn{height: 100%}

#title{text-align: center}

#title h1{font-size: 3em; display:inline-flex; align-items:center}

#title img{width: 100px; margin-right: 0.5em}

#gallery .grid-wrap{height: 10vh}

'''
with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
    title = gr.HTML(
        """<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA"> FLUX LoRA the Explorer Mod</h1>""",
        elem_id="title",
    )
    selected_index = gr.State(None)
    with gr.Row():
        with gr.Column(scale=3):
            prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
        with gr.Column(scale=1, elem_id="gen_column"):
            generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
    with gr.Row():
        with gr.Column(scale=3):
            selected_info = gr.Markdown("")
            gallery = gr.Gallery(
                [(item["image"], item["title"]) for item in loras],
                label="LoRA Gallery",
                allow_preview=False,
                columns=3,
                elem_id="gallery"
            )
            
        with gr.Column(scale=4):
            result = gr.Image(label="Generated Image")

    with gr.Row():
        with gr.Accordion("Advanced Settings", open=False):
            with gr.Column():
                
                with gr.Row():
                    cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
                    steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
                
                with gr.Row():
                    width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
                    height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
                
                with gr.Row():
                    randomize_seed = gr.Checkbox(True, label="Randomize seed")
                    seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)

                with gr.Row():
                    lora_repo = gr.Dropdown(label="LoRA Repo", choices=[], info="Input LoRA Repo ID", value="", allow_custom_value=True)
                    lora_weights = gr.Dropdown(label="LoRA Filename", choices=[], info="Optional", value="", allow_custom_value=True)
                    lora_trigger = gr.Textbox(label="LoRA Trigger Prompt", value="")
                    lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.95)

                with gr.Row():
                    model_name = gr.Dropdown(label="Base Model", choices=models, value=models[0], allow_custom_value=True)

    gallery.select(
        update_selection,
        inputs=[width, height],
        outputs=[prompt, selected_info, selected_index, width, height]
    )

    gr.on(
        triggers=[generate_button.click, prompt.submit],
        fn=run_lora,
        inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
                 lora_scale, lora_repo, lora_weights, lora_trigger],
        outputs=[result, seed]
    )

    lora_repo.change(get_repo_safetensors, [lora_repo], [lora_weights])
    model_name.change(change_base_model, [model_name], None)


app.queue()
app.launch()