jiuface commited on
Commit
2c50a6c
·
1 Parent(s): d067625

pure load lora

Browse files
Files changed (1) hide show
  1. app.py +61 -353
app.py CHANGED
@@ -14,9 +14,12 @@ from io import BytesIO
14
  # from diffusers.models.attention_processor import AttentionProcessor
15
  from diffusers.models.attention_processor import AttnProcessor2_0
16
  import torch.nn.functional as F
17
-
 
 
18
  import re
19
  import json
 
20
  # 登录 Hugging Face Hub
21
  HF_TOKEN = os.environ.get("HF_TOKEN")
22
  login(token=HF_TOKEN)
@@ -49,262 +52,16 @@ class calculateDuration:
49
  else:
50
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
51
 
52
- # 定义位置、偏移和区域的映射
53
- valid_locations = { # x, y in 90*90
54
- 'in the center': (45, 45),
55
- 'on the left': (15, 45),
56
- 'on the right': (75, 45),
57
- 'on the top': (45, 15),
58
- 'on the bottom': (45, 75),
59
- 'on the top-left': (15, 15),
60
- 'on the top-right': (75, 15),
61
- 'on the bottom-left': (15, 75),
62
- 'on the bottom-right': (75, 75)
63
- }
64
-
65
- valid_offsets = { # x, y in 90*90
66
- 'no offset': (0, 0),
67
- 'slightly to the left': (-10, 0),
68
- 'slightly to the right': (10, 0),
69
- 'slightly to the upper': (0, -10),
70
- 'slightly to the lower': (0, 10),
71
- 'slightly to the upper-left': (-10, -10),
72
- 'slightly to the upper-right': (10, -10),
73
- 'slightly to the lower-left': (-10, 10),
74
- 'slightly to the lower-right': (10, 10)
75
- }
76
-
77
- valid_areas = { # w, h in 90*90
78
- "a small square area": (50, 50),
79
- "a small vertical area": (40, 60),
80
- "a small horizontal area": (60, 40),
81
- "a medium-sized square area": (60, 60),
82
- "a medium-sized vertical area": (50, 80),
83
- "a medium-sized horizontal area": (80, 50),
84
- "a large square area": (70, 70),
85
- "a large vertical area": (60, 90),
86
- "a large horizontal area": (90, 60)
87
- }
88
-
89
- # 解析角色位置的函数
90
- def parse_character_position(character_position):
91
- # 定义正则表达式模式
92
- location_pattern = '|'.join(re.escape(key) for key in valid_locations.keys())
93
- offset_pattern = '|'.join(re.escape(key) for key in valid_offsets.keys())
94
- area_pattern = '|'.join(re.escape(key) for key in valid_areas.keys())
95
-
96
- # 提取位置
97
- location_match = re.search(location_pattern, character_position, re.IGNORECASE)
98
- location = location_match.group(0) if location_match else 'in the center'
99
-
100
- # 提取偏移
101
- offset_match = re.search(offset_pattern, character_position, re.IGNORECASE)
102
- offset = offset_match.group(0) if offset_match else 'no offset'
103
-
104
- # 提取区域
105
- area_match = re.search(area_pattern, character_position, re.IGNORECASE)
106
- area = area_match.group(0) if area_match else 'a medium-sized square area'
107
-
108
- return {
109
- 'location': location,
110
- 'offset': offset,
111
- 'area': area
112
- }
113
-
114
- # 创建掩码的函数
115
- def create_attention_mask(image_width, image_height, location, offset, area):
116
- # 图像在生成时通常会被缩放为 90x90,因此先定义一个基础尺寸
117
- base_size = 90
118
-
119
- # 获取位置坐标
120
- loc_x, loc_y = valid_locations.get(location, (45, 45))
121
- # 获取偏移量
122
- offset_x, offset_y = valid_offsets.get(offset, (0, 0))
123
- # 获取区域大小
124
- area_width, area_height = valid_areas.get(area, (60, 60))
125
-
126
- # 计算最终位置
127
- final_x = loc_x + offset_x
128
- final_y = loc_y + offset_y
129
-
130
- # 将坐标和尺寸映射到实际图像尺寸
131
- scale_x = image_width / base_size
132
- scale_y = image_height / base_size
133
-
134
- center_x = final_x * scale_x
135
- center_y = final_y * scale_y
136
- width = area_width * scale_x
137
- height = area_height * scale_y
138
-
139
- # 计算左上角和右下角坐标
140
- x_start = int(max(center_x - width / 2, 0))
141
- y_start = int(max(center_y - height / 2, 0))
142
- x_end = int(min(center_x + width / 2, image_width))
143
- y_end = int(min(center_y + height / 2, image_height))
144
-
145
- # 创建掩码
146
- mask = torch.zeros((image_height, image_width), dtype=torch.float32, device="cuda")
147
- mask[y_start:y_end, x_start:x_end] = 1.0
148
-
149
- # 展平成一维
150
- mask_flat = mask.view(-1) # 形状为 (image_height * image_width,)
151
- return mask_flat
152
-
153
- # 自定义注意力处理器
154
-
155
- class CustomCrossAttentionProcessor(AttnProcessor2_0):
156
- def __init__(self, masks, adapter_names):
157
- super().__init__()
158
- self.masks = masks # 列表,包含每个角色的掩码 (shape: [key_length])
159
- self.adapter_names = adapter_names # 列表,包含每个角色的 LoRA 适配器名称
160
-
161
- def __call__(
162
- self,
163
- attn,
164
- hidden_states,
165
- encoder_hidden_states=None,
166
- attention_mask=None,
167
- temb=None,
168
- **kwargs,
169
- ):
170
- """
171
- 自定义的注意力处理器,用于在注意力计算中应用角色掩码。
172
-
173
- 参数:
174
- attn: 注意力模块实例。
175
- hidden_states: 输入的隐藏状态 (query)。
176
- encoder_hidden_states: 编码器的隐藏状态 (key/value)。
177
- attention_mask: 注意力掩码。
178
- temb: 时间嵌入(可能不需要)。
179
- **kwargs: 其他参数。
180
-
181
- 返回:
182
- ��理后的隐藏状态。
183
- """
184
- # 获取当前的 adapter_name
185
- adapter_name = getattr(attn, 'adapter_name', None)
186
- if adapter_name is None or adapter_name not in self.adapter_names:
187
- # 如果没有 adapter_name,或者不在我们的列表中,直接执行父类的 __call__ 方法
188
- return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, temb, **kwargs)
189
-
190
- # 查找 adapter_name 对应的索引
191
- idx = self.adapter_names.index(adapter_name)
192
- mask = self.masks[idx] # 获取对应的掩码 (shape: [key_length])
193
-
194
- # 以下是 AttnProcessor2_0 的实现,我们在适当的位置加入自定义的掩码逻辑
195
-
196
- residual = hidden_states
197
- if attn.spatial_norm is not None:
198
- hidden_states = attn.spatial_norm(hidden_states, temb)
199
-
200
- input_ndim = hidden_states.ndim
201
-
202
- if input_ndim == 4:
203
- batch_size, channel, height, width = hidden_states.shape
204
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
205
- else:
206
- batch_size, sequence_length, _ = hidden_states.shape
207
-
208
- if encoder_hidden_states is None:
209
- encoder_hidden_states = hidden_states
210
- else:
211
- # 如果有 encoder_hidden_states,获取其形状
212
- encoder_batch_size, key_length, _ = encoder_hidden_states.shape
213
-
214
- if attention_mask is not None:
215
- # 处理 attention_mask,如果需要的话
216
- attention_mask = attn.prepare_attention_mask(attention_mask, key_length, batch_size)
217
- # attention_mask 的形状应为 (batch_size, attn.heads, query_length, key_length)
218
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
219
- else:
220
- # 如果没有 attention_mask,我们创建一个全 0 的掩码
221
- attention_mask = torch.zeros(
222
- batch_size, attn.heads, 1, key_length, device=hidden_states.device, dtype=hidden_states.dtype
223
- )
224
-
225
- if attn.group_norm is not None:
226
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
227
-
228
- query = attn.to_q(hidden_states)
229
-
230
- if attn.norm_cross:
231
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
232
-
233
- key = attn.to_k(encoder_hidden_states)
234
- value = attn.to_v(encoder_hidden_states)
235
-
236
- inner_dim = key.shape[-1]
237
- head_dim = inner_dim // attn.heads
238
-
239
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
240
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
241
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
242
-
243
- if attn.norm_q is not None:
244
- query = attn.norm_q(query)
245
- if attn.norm_k is not None:
246
- key = attn.norm_k(key)
247
-
248
- # 计算原始的注意力得分
249
- # 我们需要在计算注意力得分前应用掩码
250
- # 但由于 PyTorch 的 scaled_dot_product_attention 接受 attention_mask 参数,我们需要调整我们的掩码
251
-
252
- # 创建自定义的 attention_mask
253
- # mask 的形状为 [key_length],需要调整为 (batch_size, attn.heads, 1, key_length)
254
- custom_attention_mask = mask.view(1, 1, 1, -1).to(hidden_states.device, dtype=hidden_states.dtype)
255
- # 将有效位置设为 0,被掩蔽的位置设为 -1e9(对于 float16,使用 -65504)
256
- mask_value = -65504.0 if hidden_states.dtype == torch.float16 else -1e9
257
- custom_attention_mask = (1.0 - custom_attention_mask) * mask_value # 有效位置为 0,无效位置为 -1e9
258
-
259
- # 将自定义掩码添加到 attention_mask
260
- attention_mask = attention_mask + custom_attention_mask
261
-
262
- # 计算注意力
263
- hidden_states = F.scaled_dot_product_attention(
264
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
265
- )
266
-
267
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
268
- hidden_states = hidden_states.to(query.dtype)
269
-
270
- # linear proj
271
- hidden_states = attn.to_out[0](hidden_states)
272
- # dropout
273
- hidden_states = attn.to_out[1](hidden_states)
274
-
275
- if input_ndim == 4:
276
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
277
-
278
- if attn.residual_connection:
279
- hidden_states = hidden_states + residual
280
-
281
- hidden_states = hidden_states / attn.rescale_output_factor
282
-
283
- return hidden_states
284
-
285
-
286
- # 替换注意力处理器的函数
287
- def replace_attention_processors(pipe, masks, adapter_names):
288
- custom_processor = CustomCrossAttentionProcessor(masks, adapter_names)
289
- for name, module in pipe.transformer.named_modules():
290
- if hasattr(module, 'attn'):
291
- module.attn.adapter_name = getattr(module, 'adapter_name', None)
292
- module.attn.processor = custom_processor
293
- if hasattr(module, 'cross_attn'):
294
- module.cross_attn.adapter_name = getattr(module, 'adapter_name', None)
295
- module.cross_attn.processor = custom_processor
296
-
297
  # 生成图像的函数
298
-
299
- def generate_image_with_embeddings(prompt_embeds, pooled_prompt_embeds, steps, seed, cfg_scale, width, height, progress):
 
300
  pipe.to(device)
301
  generator = torch.Generator(device=device).manual_seed(seed)
302
-
303
  with calculateDuration("Generating image"):
304
  # Generate image
305
  generated_image = pipe(
306
- prompt_embeds=prompt_embeds,
307
- pooled_prompt_embeds=pooled_prompt_embeds,
308
  num_inference_steps=steps,
309
  guidance_scale=cfg_scale,
310
  width=width,
@@ -315,111 +72,67 @@ def generate_image_with_embeddings(prompt_embeds, pooled_prompt_embeds, steps, s
315
  progress(99, "Generate success!")
316
  return generated_image
317
 
318
- # 主函数
319
- @spaces.GPU
320
- @torch.inference_mode()
321
- def run_lora(prompt_bg, character_prompts_json, character_positions_json, lora_strings_json, prompt_details, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
322
-
323
- # 解析角色提示词、位置和 LoRA 字符串
324
- try:
325
- character_prompts = json.loads(character_prompts_json)
326
- character_positions = json.loads(character_positions_json)
327
- lora_strings = json.loads(lora_strings_json)
328
- except json.JSONDecodeError as e:
329
- raise ValueError(f"Invalid JSON input: {e}")
330
-
331
- # 确保提示词、位置和 LoRA 字符串的数量一致
332
- if len(character_prompts) != len(character_positions) or len(character_prompts) != len(lora_strings):
333
- raise ValueError("The number of character prompts, positions, and LoRA strings must be the same.")
334
-
335
- # 角色的数量
336
- num_characters = len(character_prompts)
 
337
 
338
- # Load LoRA weights
339
- with calculateDuration("Loading LoRA weights"):
340
- pipe.unload_lora_weights()
341
- adapter_names = []
342
- for lora_info in lora_strings:
343
- lora_repo = lora_info.get("repo")
344
- weights = lora_info.get("weights")
345
- adapter_name = lora_info.get("adapter_name")
346
- if lora_repo and weights and adapter_name:
347
- # 调用 pipe.load_lora_weights() 方法加载权重
348
- pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
349
- adapter_names.append(adapter_name)
350
- # 将 adapter_name 设置为模型的属性
351
- setattr(pipe.transformer, 'adapter_name', adapter_name)
352
 
353
- else:
354
- raise ValueError("Invalid LoRA string format. Each item must have 'repo', 'weights', and 'adapter_name' keys.")
355
- adapter_weights = [lora_scale] * len(adapter_names)
356
- # 调用 pipeline.set_adapters 方法设置 adapter 和对应权重
357
- pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
358
 
359
- # 确保 adapter_names 的数量与角色数量匹配
360
- if len(adapter_names) != num_characters:
361
- raise ValueError("The number of LoRA adapters must match the number of characters.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
 
363
  # Set random seed for reproducibility
364
  if randomize_seed:
365
  with calculateDuration("Set random seed"):
366
  seed = random.randint(0, MAX_SEED)
367
 
368
- with calculateDuration("Encoding prompts"):
369
- # 编码背景提示词
370
- # 使用 tokenizer_2 和 text_encoder_2
371
- bg_text_input_2 = pipe.tokenizer_2(prompt_bg, return_tensors="pt").to(device)
372
- bg_prompt_embeds = pipe.text_encoder_2(bg_text_input_2.input_ids.to(device))[0]
373
-
374
- # 使用 tokenizer 和 text_encoder
375
- bg_text_input = pipe.tokenizer(prompt_bg, return_tensors="pt").to(device)
376
- bg_pooled_embeds = pipe.text_encoder(bg_text_input.input_ids.to(device)).pooler_output
377
-
378
- # 编码角色提示词
379
- character_prompt_embeds = []
380
- character_pooled_embeds = []
381
- for prompt in character_prompts:
382
- # 使用 tokenizer_2 和 text_encoder_2
383
- char_text_input_2 = pipe.tokenizer_2(prompt, return_tensors="pt").to(device)
384
- char_prompt_embeds = pipe.text_encoder_2(char_text_input_2.input_ids.to(device))[0]
385
- # 使用 tokenizer 和 text_encoder
386
- char_text_input = pipe.tokenizer(prompt, return_tensors="pt").to(device)
387
- char_pooled_embeds = pipe.text_encoder(char_text_input.input_ids.to(device)).pooler_output
388
-
389
- character_prompt_embeds.append(char_prompt_embeds)
390
- character_pooled_embeds.append(char_pooled_embeds)
391
-
392
- # 编码互动细节提示词
393
- details_text_input_2 = pipe.tokenizer_2(prompt_details, return_tensors="pt").to(device)
394
- details_prompt_embeds = pipe.text_encoder_2(details_text_input_2.input_ids.to(device))[0]
395
-
396
- details_text_input = pipe.tokenizer(prompt_details, return_tensors="pt").to(device)
397
- details_pooled_embeds = pipe.text_encoder(details_text_input.input_ids.to(device)).pooler_output
398
-
399
- # 合并背景和互动细节的嵌入
400
- prompt_embeds = torch.cat([bg_prompt_embeds, details_prompt_embeds], dim=1)
401
- pooled_prompt_embeds = torch.cat([bg_pooled_embeds, details_pooled_embeds], dim=-1)
402
-
403
- # 解析角色位置
404
- character_infos = []
405
- for position_str in character_positions:
406
- info = parse_character_position(position_str)
407
- character_infos.append(info)
408
-
409
- # 创建角色的掩码
410
- masks = []
411
- for info in character_infos:
412
- mask = create_attention_mask(width, height, info['location'], info['offset'], info['area'])
413
- masks.append(mask)
414
-
415
- # 替换注意力处理器
416
- replace_attention_processors(pipe, masks, adapter_names)
417
-
418
  # Generate image
419
- final_image = generate_image_with_embeddings(prompt_embeddings, pooled_prompt_embeds, steps, seed, cfg_scale, width, height, progress)
420
 
421
- # 您可以在此处添加上传图片的代码
422
- result = {"status": "success", "message": "Image generated"}
 
 
 
 
 
423
 
424
  progress(100, "Completed!")
425
 
@@ -439,11 +152,9 @@ with gr.Blocks(css=css) as demo:
439
 
440
  with gr.Column():
441
 
442
- prompt_bg = gr.Text(label="Background Prompt", placeholder="Enter background/scene prompt", lines=2)
443
- character_prompts = gr.Text(label="Character Prompts (JSON List)", placeholder='["Character 1 prompt", "Character 2 prompt"]', lines=5)
444
- character_positions = gr.Text(label="Character Positions (JSON List)", placeholder='["Character 1 position", "Character 2 position"]', lines=5)
445
  lora_strings_json = gr.Text(label="LoRA Strings (JSON List)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1"}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2"}]', lines=5)
446
- prompt_details = gr.Text(label="Interaction Details", placeholder="Enter interaction details between characters", lines=2)
447
  run_button = gr.Button("Run", scale=0)
448
 
449
  with gr.Accordion("Advanced Settings", open=False):
@@ -474,11 +185,8 @@ with gr.Blocks(css=css) as demo:
474
  json_text = gr.Text(label="Result JSON")
475
 
476
  inputs = [
477
- prompt_bg,
478
- character_prompts,
479
- character_positions,
480
  lora_strings_json,
481
- prompt_details,
482
  cfg_scale,
483
  steps,
484
  randomize_seed,
 
14
  # from diffusers.models.attention_processor import AttentionProcessor
15
  from diffusers.models.attention_processor import AttnProcessor2_0
16
  import torch.nn.functional as F
17
+ import time
18
+ import boto3
19
+ from io import BytesIO
20
  import re
21
  import json
22
+
23
  # 登录 Hugging Face Hub
24
  HF_TOKEN = os.environ.get("HF_TOKEN")
25
  login(token=HF_TOKEN)
 
52
  else:
53
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  # 生成图像的函数
56
+ @spaces.GPU
57
+ @torch.inference_mode()
58
+ def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
59
  pipe.to(device)
60
  generator = torch.Generator(device=device).manual_seed(seed)
 
61
  with calculateDuration("Generating image"):
62
  # Generate image
63
  generated_image = pipe(
64
+ prompt=prompt,
 
65
  num_inference_steps=steps,
66
  guidance_scale=cfg_scale,
67
  width=width,
 
72
  progress(99, "Generate success!")
73
  return generated_image
74
 
75
+
76
+ def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
77
+ print("upload_image_to_r2", account_id, access_key, secret_key, bucket_name)
78
+ connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
79
+
80
+ s3 = boto3.client(
81
+ 's3',
82
+ endpoint_url=connectionUrl,
83
+ region_name='auto',
84
+ aws_access_key_id=access_key,
85
+ aws_secret_access_key=secret_key
86
+ )
87
+
88
+ current_time = datetime.now().strftime("%Y/%m/%d/%H%M%S")
89
+ image_file = f"generated_images/{current_time}_{random.randint(0, MAX_SEED)}.png"
90
+ buffer = BytesIO()
91
+ image.save(buffer, "PNG")
92
+ buffer.seek(0)
93
+ s3.upload_fileobj(buffer, bucket_name, image_file)
94
+ print("upload finish", image_file)
95
 
96
+ return image_file
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
99
 
100
+ # Load LoRA weights
101
+ if lora_strings_json:
102
+ try:
103
+ lora_strings_json = json.loads(lora_strings_json)
104
+ except:
105
+ lora_strings_json = None
106
+ if lora_strings_json:
107
+ with calculateDuration("Loading LoRA weights"):
108
+ pipe.unload_lora_weights()
109
+ adapter_names = []
110
+ for lora_info in lora_strings:
111
+ lora_repo = lora_info.get("repo")
112
+ weights = lora_info.get("weights")
113
+ adapter_name = lora_info.get("adapter_name")
114
+ if lora_repo and weights and adapter_name:
115
+ # 加载 LoRA 权重
116
+ pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
117
+ adapter_names.append(adapter_name)
118
+ adapter_weights = [lora_scale] * len(adapter_names)
119
+ pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
120
 
121
  # Set random seed for reproducibility
122
  if randomize_seed:
123
  with calculateDuration("Set random seed"):
124
  seed = random.randint(0, MAX_SEED)
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  # Generate image
127
+ final_image = generate_image(prompt, steps, seed, cfg_scale, width, height, progress)
128
 
129
+ if final_image:
130
+ if upload_to_r2:
131
+ with calculateDuration("Upload image"):
132
+ url = upload_image_to_r2(final_image, account_id, access_key, secret_key, bucket)
133
+ result = {"status": "success", "message": "upload image success", "url": url}
134
+ else:
135
+ result = {"status": "success", "message": "Image generated but not uploaded"}
136
 
137
  progress(100, "Completed!")
138
 
 
152
 
153
  with gr.Column():
154
 
155
+ prompt = gr.Text(label="Prompt", placeholder="Enter prompt", lines=2)
 
 
156
  lora_strings_json = gr.Text(label="LoRA Strings (JSON List)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1"}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2"}]', lines=5)
157
+
158
  run_button = gr.Button("Run", scale=0)
159
 
160
  with gr.Accordion("Advanced Settings", open=False):
 
185
  json_text = gr.Text(label="Result JSON")
186
 
187
  inputs = [
188
+ prompt,
 
 
189
  lora_strings_json,
 
190
  cfg_scale,
191
  steps,
192
  randomize_seed,