alvdansen commited on
Commit
85da865
Β·
verified Β·
1 Parent(s): 78842a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -40
app.py CHANGED
@@ -7,6 +7,7 @@ import spaces
7
  import torch
8
  from diffusers import DiffusionPipeline, LCMScheduler
9
 
 
10
  with open("sdxl_lora.json", "r") as file:
11
  data = json.load(file)
12
  sdxl_loras_raw = [
@@ -31,24 +32,19 @@ model_id = "stabilityai/stable-diffusion-xl-base-1.0"
31
 
32
  pipe = DiffusionPipeline.from_pretrained(model_id, variant="fp16")
33
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
34
- pipe.load_lora_weights("alvdansen/flash-lora-araminta-k-styles", adapter_name="lora")
35
  pipe.to(device=DEVICE, dtype=torch.float16)
36
 
37
-
38
  MAX_SEED = np.iinfo(np.int32).max
39
  MAX_IMAGE_SIZE = 1024
40
 
41
-
42
- def update_selection(
43
- selected_state: gr.SelectData,
44
- gr_sdxl_loras,
45
- ):
46
-
47
  lora_id = gr_sdxl_loras[selected_state.index]["repo"]
48
  trigger_word = gr_sdxl_loras[selected_state.index]["trigger_word"]
49
-
50
  return lora_id, trigger_word
51
 
 
 
 
52
 
53
  @spaces.GPU
54
  def infer(
@@ -63,19 +59,8 @@ def infer(
63
  user_lora_weight,
64
  progress=gr.Progress(track_tqdm=True),
65
  ):
66
- flash_sdxl_id = "jasperai/flash-sdxl"
67
-
68
- new_adapter_id = user_lora_selector.replace("/", "_")
69
- loaded_adapters = pipe.get_list_adapters()
70
-
71
- if new_adapter_id not in loaded_adapters["unet"]:
72
- gr.Info("Swapping LoRA")
73
- pipe.unload_lora_weights()
74
- pipe.load_lora_weights(flash_sdxl_id, adapter_name="lora")
75
- pipe.load_lora_weights(user_lora_selector, adapter_name=new_adapter_id)
76
-
77
- pipe.set_adapters(["lora", new_adapter_id], adapter_weights=[1.0, user_lora_weight])
78
- gr.Info("LoRA setup done")
79
 
80
  if randomize_seed:
81
  seed = random.randint(0, MAX_SEED)
@@ -95,19 +80,15 @@ def infer(
95
 
96
  return image
97
 
98
-
99
  css = """
100
-
101
  h1 {
102
  text-align: center;
103
  display:block;
104
  }
105
-
106
  p {
107
  text-align: justify;
108
  display:block;
109
  }
110
-
111
  """
112
 
113
  if torch.cuda.is_available():
@@ -116,11 +97,9 @@ else:
116
  power_device = "CPU"
117
 
118
  with gr.Blocks(css=css) as demo:
119
-
120
  gr.Markdown(
121
  f"""
122
  # ⚑ FlashDiffusion: FlashLoRA ⚑
123
-
124
  This is an interactive demo of [Flash Diffusion](https://gojasper.github.io/flash-diffusion-project/) **on top of** existing LoRAs.
125
 
126
  The distillation method proposed in [Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation](http://arxiv.org/abs/2406.02347) *by ClΓ©ment Chadebec, Onur Tasar, Eyal Benaroche and Benjamin Aubin* from Jasper Research.
@@ -137,11 +116,8 @@ with gr.Blocks(css=css) as demo:
137
  gr_lora_id = gr.State(value="")
138
 
139
  with gr.Row():
140
-
141
  with gr.Blocks():
142
-
143
  with gr.Column():
144
-
145
  user_lora_selector = gr.Textbox(
146
  label="Current Selected LoRA",
147
  max_lines=1,
@@ -166,9 +142,7 @@ with gr.Blocks(css=css) as demo:
166
  )
167
 
168
  with gr.Column():
169
-
170
  with gr.Row():
171
-
172
  prompt = gr.Text(
173
  label="Prompt",
174
  show_label=False,
@@ -183,7 +157,6 @@ with gr.Blocks(css=css) as demo:
183
  result = gr.Image(label="Result", show_label=False)
184
 
185
  with gr.Accordion("Advanced Settings", open=False):
186
-
187
  pre_prompt = gr.Text(
188
  label="Pre-Prompt",
189
  show_label=True,
@@ -204,7 +177,6 @@ with gr.Blocks(css=css) as demo:
204
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
205
 
206
  with gr.Row():
207
-
208
  num_inference_steps = gr.Slider(
209
  label="Number of inference steps",
210
  minimum=4,
@@ -214,7 +186,6 @@ with gr.Blocks(css=css) as demo:
214
  )
215
 
216
  with gr.Row():
217
-
218
  guidance_scale = gr.Slider(
219
  label="Guidance Scale",
220
  minimum=1,
@@ -242,7 +213,6 @@ with gr.Blocks(css=css) as demo:
242
  run_button.click,
243
  seed.change,
244
  randomize_seed.change,
245
- # prompt.change,
246
  prompt.submit,
247
  negative_prompt.change,
248
  negative_prompt.submit,
@@ -261,7 +231,6 @@ with gr.Blocks(css=css) as demo:
261
  user_lora_weight,
262
  ],
263
  outputs=[result],
264
- # show_progress="full",
265
  )
266
 
267
  gallery.select(
@@ -279,5 +248,4 @@ with gr.Blocks(css=css) as demo:
279
  "This demo is only for research purpose. Users are solely responsible for any content they create, and it is their obligation to ensure that it adheres to appropriate and ethical standards."
280
  )
281
 
282
-
283
- demo.queue().launch()
 
7
  import torch
8
  from diffusers import DiffusionPipeline, LCMScheduler
9
 
10
+ # Load the JSON data
11
  with open("sdxl_lora.json", "r") as file:
12
  data = json.load(file)
13
  sdxl_loras_raw = [
 
32
 
33
  pipe = DiffusionPipeline.from_pretrained(model_id, variant="fp16")
34
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
 
35
  pipe.to(device=DEVICE, dtype=torch.float16)
36
 
 
37
  MAX_SEED = np.iinfo(np.int32).max
38
  MAX_IMAGE_SIZE = 1024
39
 
40
+ def update_selection(selected_state: gr.SelectData, gr_sdxl_loras):
 
 
 
 
 
41
  lora_id = gr_sdxl_loras[selected_state.index]["repo"]
42
  trigger_word = gr_sdxl_loras[selected_state.index]["trigger_word"]
 
43
  return lora_id, trigger_word
44
 
45
+ def load_lora_for_style(style_repo):
46
+ pipe.unload_lora_weights() # Unload any previously loaded weights
47
+ pipe.load_lora_weights(style_repo, adapter_name="lora")
48
 
49
  @spaces.GPU
50
  def infer(
 
59
  user_lora_weight,
60
  progress=gr.Progress(track_tqdm=True),
61
  ):
62
+ # Load the appropriate LoRA weights
63
+ load_lora_for_style(user_lora_selector)
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  if randomize_seed:
66
  seed = random.randint(0, MAX_SEED)
 
80
 
81
  return image
82
 
 
83
  css = """
 
84
  h1 {
85
  text-align: center;
86
  display:block;
87
  }
 
88
  p {
89
  text-align: justify;
90
  display:block;
91
  }
 
92
  """
93
 
94
  if torch.cuda.is_available():
 
97
  power_device = "CPU"
98
 
99
  with gr.Blocks(css=css) as demo:
 
100
  gr.Markdown(
101
  f"""
102
  # ⚑ FlashDiffusion: FlashLoRA ⚑
 
103
  This is an interactive demo of [Flash Diffusion](https://gojasper.github.io/flash-diffusion-project/) **on top of** existing LoRAs.
104
 
105
  The distillation method proposed in [Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation](http://arxiv.org/abs/2406.02347) *by ClΓ©ment Chadebec, Onur Tasar, Eyal Benaroche and Benjamin Aubin* from Jasper Research.
 
116
  gr_lora_id = gr.State(value="")
117
 
118
  with gr.Row():
 
119
  with gr.Blocks():
 
120
  with gr.Column():
 
121
  user_lora_selector = gr.Textbox(
122
  label="Current Selected LoRA",
123
  max_lines=1,
 
142
  )
143
 
144
  with gr.Column():
 
145
  with gr.Row():
 
146
  prompt = gr.Text(
147
  label="Prompt",
148
  show_label=False,
 
157
  result = gr.Image(label="Result", show_label=False)
158
 
159
  with gr.Accordion("Advanced Settings", open=False):
 
160
  pre_prompt = gr.Text(
161
  label="Pre-Prompt",
162
  show_label=True,
 
177
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
178
 
179
  with gr.Row():
 
180
  num_inference_steps = gr.Slider(
181
  label="Number of inference steps",
182
  minimum=4,
 
186
  )
187
 
188
  with gr.Row():
 
189
  guidance_scale = gr.Slider(
190
  label="Guidance Scale",
191
  minimum=1,
 
213
  run_button.click,
214
  seed.change,
215
  randomize_seed.change,
 
216
  prompt.submit,
217
  negative_prompt.change,
218
  negative_prompt.submit,
 
231
  user_lora_weight,
232
  ],
233
  outputs=[result],
 
234
  )
235
 
236
  gallery.select(
 
248
  "This demo is only for research purpose. Users are solely responsible for any content they create, and it is their obligation to ensure that it adheres to appropriate and ethical standards."
249
  )
250
 
251
+ demo.queue().launch()