John6666 commited on
Commit
d34ac72
·
verified ·
1 Parent(s): 4a5e288

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -165
app.py CHANGED
@@ -1,166 +1,213 @@
1
- import gradio as gr
2
- import json
3
- import logging
4
- import torch
5
- from PIL import Image
6
- import spaces
7
- from diffusers import DiffusionPipeline
8
- import copy
9
- import random
10
- import time
11
-
12
- # Load LoRAs from JSON file
13
- with open('loras.json', 'r') as f:
14
- loras = json.load(f)
15
-
16
- # Initialize the base model
17
- base_model = "black-forest-labs/FLUX.1-dev"
18
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
19
-
20
- MAX_SEED = 2**32-1
21
-
22
- class calculateDuration:
23
- def __init__(self, activity_name=""):
24
- self.activity_name = activity_name
25
-
26
- def __enter__(self):
27
- self.start_time = time.time()
28
- return self
29
-
30
- def __exit__(self, exc_type, exc_value, traceback):
31
- self.end_time = time.time()
32
- self.elapsed_time = self.end_time - self.start_time
33
- if self.activity_name:
34
- print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
35
- else:
36
- print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
37
-
38
-
39
- def update_selection(evt: gr.SelectData, width, height):
40
- selected_lora = loras[evt.index]
41
- new_placeholder = f"Type a prompt for {selected_lora['title']}"
42
- lora_repo = selected_lora["repo"]
43
- updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
44
- if "aspect" in selected_lora:
45
- if selected_lora["aspect"] == "portrait":
46
- width = 768
47
- height = 1024
48
- elif selected_lora["aspect"] == "landscape":
49
- width = 1024
50
- height = 768
51
- return (
52
- gr.update(placeholder=new_placeholder),
53
- updated_text,
54
- evt.index,
55
- width,
56
- height,
57
- )
58
-
59
- @spaces.GPU(duration=70)
60
- def generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress):
61
- pipe.to("cuda")
62
- generator = torch.Generator(device="cuda").manual_seed(seed)
63
-
64
- with calculateDuration("Generating image"):
65
- # Generate image
66
- image = pipe(
67
- prompt=f"{prompt} {trigger_word}",
68
- num_inference_steps=steps,
69
- guidance_scale=cfg_scale,
70
- width=width,
71
- height=height,
72
- generator=generator,
73
- joint_attention_kwargs={"scale": lora_scale},
74
- ).images[0]
75
- return image
76
-
77
- def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
78
- if selected_index is None:
79
- raise gr.Error("You must select a LoRA before proceeding.")
80
-
81
- selected_lora = loras[selected_index]
82
- lora_path = selected_lora["repo"]
83
- trigger_word = selected_lora["trigger_word"]
84
-
85
- # Load LoRA weights
86
- with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
87
- if "weights" in selected_lora:
88
- pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
89
- else:
90
- pipe.load_lora_weights(lora_path)
91
-
92
- # Set random seed for reproducibility
93
- with calculateDuration("Randomizing seed"):
94
- if randomize_seed:
95
- seed = random.randint(0, MAX_SEED)
96
-
97
- image = generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress)
98
- pipe.to("cpu")
99
- pipe.unload_lora_weights()
100
- return image, seed
101
-
102
- run_lora.zerogpu = True
103
-
104
- css = '''
105
- #gen_btn{height: 100%}
106
- #title{text-align: center}
107
- #title h1{font-size: 3em; display:inline-flex; align-items:center}
108
- #title img{width: 100px; margin-right: 0.5em}
109
- #gallery .grid-wrap{height: 10vh}
110
- '''
111
- with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
112
- title = gr.HTML(
113
- """<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA"> FLUX LoRA the Explorer</h1>""",
114
- elem_id="title",
115
- )
116
- selected_index = gr.State(None)
117
- with gr.Row():
118
- with gr.Column(scale=3):
119
- prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
120
- with gr.Column(scale=1, elem_id="gen_column"):
121
- generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
122
- with gr.Row():
123
- with gr.Column(scale=3):
124
- selected_info = gr.Markdown("")
125
- gallery = gr.Gallery(
126
- [(item["image"], item["title"]) for item in loras],
127
- label="LoRA Gallery",
128
- allow_preview=False,
129
- columns=3,
130
- elem_id="gallery"
131
- )
132
-
133
- with gr.Column(scale=4):
134
- result = gr.Image(label="Generated Image")
135
-
136
- with gr.Row():
137
- with gr.Accordion("Advanced Settings", open=False):
138
- with gr.Column():
139
- with gr.Row():
140
- cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
141
- steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
142
-
143
- with gr.Row():
144
- width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
145
- height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
146
-
147
- with gr.Row():
148
- randomize_seed = gr.Checkbox(True, label="Randomize seed")
149
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
150
- lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.95)
151
-
152
- gallery.select(
153
- update_selection,
154
- inputs=[width, height],
155
- outputs=[prompt, selected_info, selected_index, width, height]
156
- )
157
-
158
- gr.on(
159
- triggers=[generate_button.click, prompt.submit],
160
- fn=run_lora,
161
- inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
162
- outputs=[result, seed]
163
- )
164
-
165
- app.queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  app.launch()
 
1
+ import gradio as gr
2
+ import json
3
+ import logging
4
+ import torch
5
+ from PIL import Image
6
+ import spaces
7
+ from diffusers import DiffusionPipeline
8
+ import copy
9
+ import random
10
+ import time
11
+
12
+ # Load LoRAs from JSON file
13
+ with open('loras.json', 'r') as f:
14
+ loras = json.load(f)
15
+
16
+ # Initialize the base model
17
+ models = ["camenduru/FLUX.1-dev-diffusers", "black-forest-labs/FLUX.1-schnell"]
18
+ base_model = models[0]
19
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
20
+
21
+ MAX_SEED = 2**32-1
22
+
23
+ class calculateDuration:
24
+ def __init__(self, activity_name=""):
25
+ self.activity_name = activity_name
26
+
27
+ def __enter__(self):
28
+ self.start_time = time.time()
29
+ return self
30
+
31
+ def __exit__(self, exc_type, exc_value, traceback):
32
+ self.end_time = time.time()
33
+ self.elapsed_time = self.end_time - self.start_time
34
+ if self.activity_name:
35
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
36
+ else:
37
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
38
+
39
+
40
+ def update_selection(evt: gr.SelectData, width, height):
41
+ selected_lora = loras[evt.index]
42
+ new_placeholder = f"Type a prompt for {selected_lora['title']}"
43
+ lora_repo = selected_lora["repo"]
44
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
45
+ if "aspect" in selected_lora:
46
+ if selected_lora["aspect"] == "portrait":
47
+ width = 768
48
+ height = 1024
49
+ elif selected_lora["aspect"] == "landscape":
50
+ width = 1024
51
+ height = 768
52
+ return (
53
+ gr.update(placeholder=new_placeholder),
54
+ updated_text,
55
+ evt.index,
56
+ width,
57
+ height,
58
+ )
59
+
60
+ @spaces.GPU(duration=70)
61
+ def generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress):
62
+ pipe.to("cuda")
63
+ generator = torch.Generator(device="cuda").manual_seed(seed)
64
+
65
+ with calculateDuration("Generating image"):
66
+ # Generate image
67
+ image = pipe(
68
+ prompt=f"{prompt} {trigger_word}",
69
+ num_inference_steps=steps,
70
+ guidance_scale=cfg_scale,
71
+ width=width,
72
+ height=height,
73
+ generator=generator,
74
+ joint_attention_kwargs={"scale": lora_scale},
75
+ ).images[0]
76
+ return image
77
+
78
+ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
79
+ lora_scale, lora_repo, lora_weights, lora_trigger, progress=gr.Progress(track_tqdm=True)):
80
+ if selected_index is None and not lora_repo:
81
+ raise gr.Error("You must select a LoRA before proceeding.")
82
+
83
+ if selected_index is not None and not lora_repo:
84
+ selected_lora = loras[selected_index]
85
+ lora_path = selected_lora["repo"]
86
+ trigger_word = selected_lora["trigger_word"]
87
+ else: # override
88
+ selected_lora = loras[0]
89
+ lora_path = lora_repo
90
+ trigger_word = lora_trigger
91
+
92
+ # Load LoRA weights
93
+ with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
94
+ if lora_weights: # override
95
+ pipe.load_lora_weights(lora_path, weight_name=lora_weights)
96
+ elif "weights" in selected_lora:
97
+ pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
98
+ else:
99
+ pipe.load_lora_weights(lora_path)
100
+
101
+ # Set random seed for reproducibility
102
+ with calculateDuration("Randomizing seed"):
103
+ if randomize_seed:
104
+ seed = random.randint(0, MAX_SEED)
105
+
106
+ image = generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress)
107
+ pipe.to("cpu")
108
+ pipe.unload_lora_weights()
109
+ return image, seed
110
+
111
+ run_lora.zerogpu = True
112
+
113
+ def get_repo_safetensors(repo_id: str):
114
+ from huggingface_hub import HfApi
115
+ api = HfApi()
116
+ try:
117
+ if " " in repo_id or not api.repo_exists(repo_id): return gr.update(value="", choices=[])
118
+ files = api.list_repo_files(repo_id=repo_id)
119
+ except Exception as e:
120
+ print(f"Error: Failed to get {repo_id}'s info. ")
121
+ print(e)
122
+ return gr.update(choices=[])
123
+ files = [f for f in files if f.endswith(".safetensors")]
124
+ if len(files) == 0: return gr.update(value="", choices=[])
125
+ else: return gr.update(value=files[0], choices=files)
126
+
127
+ def change_base_model(repo_id: str):
128
+ from huggingface_hub import HfApi
129
+ global pipe
130
+ api = HfApi()
131
+ try:
132
+ if " " in repo_id or not api.repo_exists(repo_id): return
133
+ pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
134
+ except Exception as e:
135
+ print(e)
136
+
137
+ css = '''
138
+ #gen_btn{height: 100%}
139
+ #title{text-align: center}
140
+ #title h1{font-size: 3em; display:inline-flex; align-items:center}
141
+ #title img{width: 100px; margin-right: 0.5em}
142
+ #gallery .grid-wrap{height: 10vh}
143
+ '''
144
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
145
+ title = gr.HTML(
146
+ """<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA"> FLUX LoRA the Explorer</h1>""",
147
+ elem_id="title",
148
+ )
149
+ selected_index = gr.State(None)
150
+ with gr.Row():
151
+ with gr.Column(scale=3):
152
+ prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
153
+ with gr.Column(scale=1, elem_id="gen_column"):
154
+ generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
155
+ with gr.Row():
156
+ with gr.Column(scale=3):
157
+ selected_info = gr.Markdown("")
158
+ gallery = gr.Gallery(
159
+ [(item["image"], item["title"]) for item in loras],
160
+ label="LoRA Gallery",
161
+ allow_preview=False,
162
+ columns=3,
163
+ elem_id="gallery"
164
+ )
165
+
166
+ with gr.Column(scale=4):
167
+ result = gr.Image(label="Generated Image")
168
+
169
+ with gr.Row():
170
+ with gr.Accordion("Advanced Settings", open=False):
171
+ with gr.Column():
172
+
173
+ with gr.Row():
174
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
175
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
176
+
177
+ with gr.Row():
178
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
179
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
180
+
181
+ with gr.Row():
182
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
183
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
184
+
185
+ with gr.Row():
186
+ lora_repo = gr.Dropdown(label="LoRA Repo", choices=[], info="Input LoRA Repo ID", value="", allow_custom_value=True)
187
+ lora_weights = gr.Dropdown(label="LoRA Filename", choices=[], info="Optional", value="", allow_custom_value=True)
188
+ lora_trigger = gr.Textbox(label="LoRA Trigger Prompt", value="")
189
+ lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.95)
190
+
191
+ with gr.Row():
192
+ model_name = gr.Dropdown(label="Base Model", choices=models, value=models[0], allow_custom_value=True)
193
+
194
+ gallery.select(
195
+ update_selection,
196
+ inputs=[width, height],
197
+ outputs=[prompt, selected_info, selected_index, width, height]
198
+ )
199
+
200
+ gr.on(
201
+ triggers=[generate_button.click, prompt.submit],
202
+ fn=run_lora,
203
+ inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
204
+ lora_scale, lora_repo, lora_weights, lora_trigger],
205
+ outputs=[result, seed]
206
+ )
207
+
208
+ lora_repo.change(get_repo_safetensors, [lora_repo], [lora_weights])
209
+ model_name.change(change_base_model, [model_name], None)
210
+
211
+
212
+ app.queue()
213
  app.launch()