alvdansen commited on
Commit
73ee311
Β·
verified Β·
1 Parent(s): bd1e951

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -185
app.py CHANGED
@@ -1,14 +1,44 @@
1
  import json
2
  import random
3
  import requests
4
- import os
5
- from PIL import Image
6
-
7
  import gradio as gr
8
  import numpy as np
9
  import spaces
10
  import torch
11
  from diffusers import DiffusionPipeline, LCMScheduler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def get_image(image_data):
14
  if isinstance(image_data, str):
@@ -21,20 +51,22 @@ def get_image(image_data):
21
  print(f"Unexpected image_data format: {type(image_data)}")
22
  return None
23
 
 
24
  if local_path and os.path.exists(local_path):
25
  try:
26
- Image.open(local_path).verify()
27
  return local_path
28
  except Exception as e:
29
  print(f"Error loading local image {local_path}: {e}")
30
 
 
31
  if hf_url:
32
  try:
33
  response = requests.get(hf_url)
34
  if response.status_code == 200:
35
  img = Image.open(requests.get(hf_url, stream=True).raw)
36
- img.verify()
37
- img.save(local_path)
38
  return local_path
39
  else:
40
  print(f"Failed to fetch image from URL {hf_url}. Status code: {response.status_code}")
@@ -44,51 +76,6 @@ def get_image(image_data):
44
  print(f"Failed to load image for {image_data}")
45
  return None
46
 
47
- with open("sdxl_lora.json", "r") as file:
48
- data = json.load(file)
49
- sdxl_loras_raw = [
50
- {
51
- "image": get_image(item["image"]),
52
- "title": item["title"],
53
- "repo": item["repo"],
54
- "trigger_word": item["trigger_word"],
55
- "weights": item["weights"],
56
- "is_pivotal": item.get("is_pivotal", False),
57
- "text_embedding_weights": item.get("text_embedding_weights", None),
58
- "likes": item.get("likes", 0),
59
- }
60
- for item in data
61
- ]
62
-
63
- sdxl_loras_raw = sorted(sdxl_loras_raw, key=lambda x: x["likes"], reverse=True)
64
-
65
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
66
- model_id = "stabilityai/stable-diffusion-xl-base-1.0"
67
-
68
- pipe = DiffusionPipeline.from_pretrained(model_id, variant="fp16")
69
-
70
- # Create LCMScheduler with default config
71
- lcm_scheduler = LCMScheduler.from_config(pipe.scheduler.config)
72
-
73
- # Remove the 'skip_prk_steps' if it exists in the config
74
- if hasattr(lcm_scheduler.config, 'skip_prk_steps'):
75
- delattr(lcm_scheduler.config, 'skip_prk_steps')
76
-
77
- pipe.scheduler = lcm_scheduler
78
- pipe.to(device=DEVICE, dtype=torch.float16)
79
-
80
- # Load Flash SDXL LoRA
81
- flash_sdxl_id = "jasperai/flash-sdxl"
82
- pipe.load_lora_weights(flash_sdxl_id, adapter_name="flash_lora")
83
-
84
- MAX_SEED = np.iinfo(np.int32).max
85
- MAX_IMAGE_SIZE = 1024
86
-
87
- def update_selection(selected_state: gr.SelectData, gr_sdxl_loras):
88
- lora_id = gr_sdxl_loras[selected_state.index]["repo"]
89
- trigger_word = gr_sdxl_loras[selected_state.index]["trigger_word"]
90
- return lora_id, trigger_word
91
-
92
  @spaces.GPU
93
  def infer(
94
  pre_prompt,
@@ -102,168 +89,150 @@ def infer(
102
  user_lora_weight,
103
  progress=gr.Progress(track_tqdm=True),
104
  ):
105
- try:
106
- # Load the user-selected LoRA
107
- new_adapter_id = user_lora_selector.replace("/", "_")
108
- pipe.load_lora_weights(user_lora_selector, adapter_name=new_adapter_id)
109
-
110
- # Set adapter weights
111
- pipe.set_adapters(["flash_lora", new_adapter_id], adapter_weights=[1.0, user_lora_weight])
112
- gr.Info("LoRA setup complete")
113
-
114
- if randomize_seed:
115
- seed = random.randint(0, MAX_SEED)
116
-
117
- generator = torch.Generator().manual_seed(seed)
118
-
119
- if pre_prompt != "":
120
- prompt = f"{pre_prompt} {prompt}"
121
-
122
- # Use Flash Diffusion settings
123
- image = pipe(
124
- prompt=prompt,
125
- negative_prompt=negative_prompt,
126
- guidance_scale=1.0, # Flash Diffusion typically uses guidance_scale=1
127
- num_inference_steps=4, # Flash Diffusion uses fewer steps
128
- generator=generator,
129
- ).images[0]
130
-
131
- return image
132
- except Exception as e:
133
- gr.Error(f"An error occurred: {str(e)}")
134
- return None
135
 
136
  css = """
137
- h1 {
 
 
 
 
 
 
 
 
 
 
138
  text-align: center;
139
- display:block;
140
  }
141
- p {
142
- text-align: justify;
143
- display:block;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  }
145
  """
146
 
147
  with gr.Blocks(css=css) as demo:
148
  gr.Markdown(
149
- f"""
150
- # ⚑ FlashDiffusion: FlashLoRA ⚑
151
- This is an interactive demo of [Flash Diffusion](https://gojasper.github.io/flash-diffusion-project/) **on top of** existing LoRAs.
152
-
153
- The distillation method proposed in [Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation](http://arxiv.org/abs/2406.02347) *by ClΓ©ment Chadebec, Onur Tasar, Eyal Benaroche and Benjamin Aubin* from Jasper Research.
154
- The LoRAs can be added **without** any retraining for similar results in most cases. Feel free to tweak the parameters and use your own LoRAs by giving a look at the [Github Repo](https://github.com/gojasper/flash-diffusion)
155
- """
156
- )
157
- gr.Markdown(
158
- "If you enjoy the space, please also promote *open-source* by giving a ⭐ to our repo [![GitHub Stars](https://img.shields.io/github/stars/gojasper/flash-diffusion?style=social)](https://github.com/gojasper/flash-diffusion)"
 
 
159
  )
160
 
161
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
162
  gr_lora_id = gr.State(value="")
163
 
164
  with gr.Row():
165
- with gr.Blocks():
166
- with gr.Column():
167
- user_lora_selector = gr.Textbox(
168
- label="Current Selected LoRA",
169
- max_lines=1,
170
- interactive=False,
171
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- user_lora_weight = gr.Slider(
174
- label="Selected LoRA Weight",
175
- minimum=0.5,
176
- maximum=3,
177
- step=0.1,
178
- value=1,
179
- )
180
-
181
- gallery = gr.Gallery(
182
- value=[(item["image"], item["title"]) for item in sdxl_loras_raw],
183
- label="SDXL LoRA Gallery",
184
- allow_preview=False,
185
- columns=3,
186
- elem_id="gallery",
187
- show_share_button=False,
188
- )
189
-
190
- with gr.Column():
191
  with gr.Row():
192
- prompt = gr.Text(
193
- label="Prompt",
194
- show_label=False,
195
- max_lines=1,
196
- placeholder="Enter your prompt",
197
- container=False,
198
- scale=5,
199
- )
200
 
201
- run_button = gr.Button("Run", scale=1)
202
-
203
- result = gr.Image(label="Result", show_label=False)
204
 
205
  with gr.Accordion("Advanced Settings", open=False):
206
- pre_prompt = gr.Text(
207
  label="Pre-Prompt",
208
- show_label=True,
209
- max_lines=1,
210
  placeholder="Pre Prompt from the LoRA config",
211
- container=True,
212
- scale=5,
213
  )
214
 
215
- seed = gr.Slider(
216
- label="Seed",
217
- minimum=0,
218
- maximum=MAX_SEED,
219
- step=1,
220
- value=0,
221
- )
222
-
223
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
224
-
225
  with gr.Row():
226
- num_inference_steps = gr.Slider(
227
- label="Number of inference steps",
228
- minimum=4,
229
- maximum=8,
230
  step=1,
231
- value=4,
232
  )
 
233
 
234
- with gr.Row():
235
- guidance_scale = gr.Slider(
236
- label="Guidance Scale",
237
- minimum=1,
238
- maximum=6,
239
- step=0.5,
240
- value=1,
241
- )
242
 
243
- hint_negative = gr.Markdown(
244
- """πŸ’‘ _Hint : Negative Prompt will only work with Guidance > 1 but the model was
245
- trained to be used with guidance = 1 (ie. without guidance).
246
- Can degrade the results, use cautiously._"""
 
 
247
  )
248
 
249
- negative_prompt = gr.Text(
250
  label="Negative Prompt",
251
- show_label=False,
252
- max_lines=1,
253
  placeholder="Enter a negative Prompt",
254
- container=False,
255
  )
256
 
257
  gr.on(
258
- [
259
- run_button.click,
260
- seed.change,
261
- randomize_seed.change,
262
- prompt.submit,
263
- negative_prompt.change,
264
- negative_prompt.submit,
265
- guidance_scale.change,
266
- ],
267
  fn=infer,
268
  inputs=[
269
  pre_prompt,
@@ -274,24 +243,30 @@ with gr.Blocks(css=css) as demo:
274
  negative_prompt,
275
  guidance_scale,
276
  user_lora_selector,
277
- user_lora_weight,
278
  ],
279
  outputs=[result],
280
  )
281
 
 
 
282
  gallery.select(
283
  fn=update_selection,
284
  inputs=[gr_sdxl_loras],
285
- outputs=[
286
- user_lora_selector,
287
- pre_prompt,
288
- ],
289
- show_progress="hidden",
290
  )
291
 
292
- gr.Markdown("**Disclaimer:**")
293
  gr.Markdown(
294
- "This demo is only for research purpose. Users are solely responsible for any content they create, and it is their obligation to ensure that it adheres to appropriate and ethical standards."
 
 
 
 
 
 
 
 
 
295
  )
296
 
297
  demo.queue().launch()
 
1
  import json
2
  import random
3
  import requests
 
 
 
4
  import gradio as gr
5
  import numpy as np
6
  import spaces
7
  import torch
8
  from diffusers import DiffusionPipeline, LCMScheduler
9
+ from PIL import Image
10
+ import os
11
+
12
+ # Load the JSON data
13
+ with open("sdxl_lora.json", "r") as file:
14
+ data = json.load(file)
15
+ sdxl_loras_raw = sorted(data, key=lambda x: x["likes"], reverse=True)
16
+
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+ model_id = "stabilityai/stable-diffusion-xl-base-1.0"
19
+
20
+ pipe = DiffusionPipeline.from_pretrained(model_id, variant="fp16")
21
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
22
+ pipe.to(device=DEVICE, dtype=torch.float16)
23
+
24
+ MAX_SEED = np.iinfo(np.int32).max
25
+ MAX_IMAGE_SIZE = 1024
26
+
27
+ def update_selection(selected_state: gr.SelectData, gr_sdxl_loras):
28
+ lora_id = gr_sdxl_loras[selected_state.index]["repo"]
29
+ trigger_word = gr_sdxl_loras[selected_state.index]["trigger_word"]
30
+ return lora_id, trigger_word
31
+
32
+ def load_lora_for_style(style_repo):
33
+ pipe.unload_lora_weights()
34
+ pipe.load_lora_weights("jasperai/flash-sdxl", adapter_name="lora")
35
+ pipe.set_adapters(["lora", new_adapter_id], adapter_weights=[1.0, user_lora_weight])
36
+
37
+ if new_adapter_id not in loaded_adapters["unet"]:
38
+ gr.Info("Swapping LoRA")
39
+ pipe.unload_lora_weights()
40
+ pipe.load_lora_weights(flash_sdxl_id, adapter_name="lora")
41
+ pipe.load_lora_weights(user_lora_selector, adapter_name=new_adapter_id)
42
 
43
  def get_image(image_data):
44
  if isinstance(image_data, str):
 
51
  print(f"Unexpected image_data format: {type(image_data)}")
52
  return None
53
 
54
+ # Try loading from local path first
55
  if local_path and os.path.exists(local_path):
56
  try:
57
+ Image.open(local_path).verify() # Verify that it's a valid image
58
  return local_path
59
  except Exception as e:
60
  print(f"Error loading local image {local_path}: {e}")
61
 
62
+ # If local path fails or doesn't exist, try URL
63
  if hf_url:
64
  try:
65
  response = requests.get(hf_url)
66
  if response.status_code == 200:
67
  img = Image.open(requests.get(hf_url, stream=True).raw)
68
+ img.verify() # Verify that it's a valid image
69
+ img.save(local_path) # Save for future use
70
  return local_path
71
  else:
72
  print(f"Failed to fetch image from URL {hf_url}. Status code: {response.status_code}")
 
76
  print(f"Failed to load image for {image_data}")
77
  return None
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  @spaces.GPU
80
  def infer(
81
  pre_prompt,
 
89
  user_lora_weight,
90
  progress=gr.Progress(track_tqdm=True),
91
  ):
92
+ load_lora_for_style(user_lora_selector)
93
+
94
+ if randomize_seed:
95
+ seed = random.randint(0, MAX_SEED)
96
+
97
+ generator = torch.Generator().manual_seed(seed)
98
+
99
+ if pre_prompt != "":
100
+ prompt = f"{pre_prompt} {prompt}"
101
+
102
+ image = pipe(
103
+ prompt=prompt,
104
+ negative_prompt=negative_prompt,
105
+ guidance_scale=guidance_scale,
106
+ num_inference_steps=num_inference_steps,
107
+ generator=generator,
108
+ ).images[0]
109
+
110
+ return image
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  css = """
113
+ body {
114
+ background-color: #1a1a1a;
115
+ color: #ffffff;
116
+ }
117
+ .container {
118
+ max-width: 900px;
119
+ margin: auto;
120
+ padding: 20px;
121
+ }
122
+ h1, h2 {
123
+ color: #4CAF50;
124
  text-align: center;
 
125
  }
126
+ .gallery {
127
+ display: flex;
128
+ flex-wrap: wrap;
129
+ justify-content: center;
130
+ }
131
+ .gallery img {
132
+ margin: 10px;
133
+ border-radius: 10px;
134
+ transition: transform 0.3s ease;
135
+ }
136
+ .gallery img:hover {
137
+ transform: scale(1.05);
138
+ }
139
+ .gradio-slider input[type="range"] {
140
+ background-color: #4CAF50;
141
+ }
142
+ .gradio-button {
143
+ background-color: #4CAF50 !important;
144
  }
145
  """
146
 
147
  with gr.Blocks(css=css) as demo:
148
  gr.Markdown(
149
+ """
150
+ # ⚑ FlashDiffusion: Araminta K's FlashLoRA Showcase ⚑
151
+
152
+ This interactive demo showcases [Araminta K's models](https://huggingface.co/alvdansen) using [Flash Diffusion](https://gojasper.github.io/flash-diffusion-project/) technology.
153
+
154
+ ## Acknowledgments
155
+ - Original Flash Diffusion technology by the Jasper AI team
156
+ - Based on the paper: [Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation](http://arxiv.org/abs/2406.02347) by ClΓ©ment Chadebec, Onur Tasar, Eyal Benaroche and Benjamin Aubin
157
+ - Models showcased here are created by Araminta K at Alvdansen Labs
158
+
159
+ Explore the power of FlashLoRA with Araminta K's unique artistic styles!
160
+ """
161
  )
162
 
163
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
164
  gr_lora_id = gr.State(value="")
165
 
166
  with gr.Row():
167
+ with gr.Column(scale=2):
168
+ gallery = gr.Gallery(
169
+ value=[(get_image(item["image"]), item["title"]) for item in sdxl_loras_raw if get_image(item["image"]) is not None],
170
+ label="SDXL LoRA Gallery",
171
+ show_label=False,
172
+ elem_id="gallery",
173
+ columns=3,
174
+ height=600,
175
+ )
176
+
177
+ user_lora_selector = gr.Textbox(
178
+ label="Current Selected LoRA",
179
+ interactive=False,
180
+ )
181
+
182
+ with gr.Column(scale=3):
183
+ prompt = gr.Textbox(
184
+ label="Prompt",
185
+ placeholder="Enter your prompt",
186
+ lines=3,
187
+ )
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  with gr.Row():
190
+ run_button = gr.Button("Run", variant="primary")
191
+ clear_button = gr.Button("Clear")
 
 
 
 
 
 
192
 
193
+ result = gr.Image(label="Result", height=512)
 
 
194
 
195
  with gr.Accordion("Advanced Settings", open=False):
196
+ pre_prompt = gr.Textbox(
197
  label="Pre-Prompt",
 
 
198
  placeholder="Pre Prompt from the LoRA config",
199
+ lines=2,
 
200
  )
201
 
 
 
 
 
 
 
 
 
 
 
202
  with gr.Row():
203
+ seed = gr.Slider(
204
+ label="Seed",
205
+ minimum=0,
206
+ maximum=MAX_SEED,
207
  step=1,
208
+ value=0,
209
  )
210
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
211
 
212
+ num_inference_steps = gr.Slider(
213
+ label="Number of inference steps",
214
+ minimum=4,
215
+ maximum=8,
216
+ step=1,
217
+ value=4,
218
+ )
 
219
 
220
+ guidance_scale = gr.Slider(
221
+ label="Guidance Scale",
222
+ minimum=1,
223
+ maximum=6,
224
+ step=0.5,
225
+ value=1,
226
  )
227
 
228
+ negative_prompt = gr.Textbox(
229
  label="Negative Prompt",
 
 
230
  placeholder="Enter a negative Prompt",
231
+ lines=2,
232
  )
233
 
234
  gr.on(
235
+ [run_button.click, prompt.submit],
 
 
 
 
 
 
 
 
236
  fn=infer,
237
  inputs=[
238
  pre_prompt,
 
243
  negative_prompt,
244
  guidance_scale,
245
  user_lora_selector,
246
+ gr.Slider(label="Selected LoRA Weight", minimum=0.5, maximum=3, step=0.1, value=1),
247
  ],
248
  outputs=[result],
249
  )
250
 
251
+ clear_button.click(lambda: "", outputs=[prompt, result])
252
+
253
  gallery.select(
254
  fn=update_selection,
255
  inputs=[gr_sdxl_loras],
256
+ outputs=[user_lora_selector, pre_prompt],
 
 
 
 
257
  )
258
 
 
259
  gr.Markdown(
260
+ """
261
+ ## Unleash Your Creativity!
262
+
263
+ This showcase brings together the speed of Flash Diffusion and the artistic flair of Araminta K's models.
264
+ Craft your prompts, adjust the settings, and watch as AI brings your ideas to life in stunning detail.
265
+
266
+ Remember to use this tool ethically and respect copyright and individual privacy.
267
+
268
+ Enjoy exploring these unique artistic styles!
269
+ """
270
  )
271
 
272
  demo.queue().launch()