jiuface commited on
Commit
8d7d2d7
·
1 Parent(s): f93e467
Files changed (1) hide show
  1. app.py +157 -62
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import gradio as gr
3
  import numpy as np
4
  import random
 
5
  import torch
6
  import json
7
  import logging
@@ -10,12 +11,17 @@ from huggingface_hub import login
10
  import time
11
  from datetime import datetime
12
  from io import BytesIO
13
- from diffusers.models.attention_processor import AttentionProcessor
 
 
 
14
  import re
15
  import json
16
  # 登录 Hugging Face Hub
17
  HF_TOKEN = os.environ.get("HF_TOKEN")
18
  login(token=HF_TOKEN)
 
 
19
 
20
  # 初始化
21
  dtype = torch.float16 # 您可以根据需要调整数据类型
@@ -145,79 +151,160 @@ def create_attention_mask(image_width, image_height, location, offset, area):
145
  return mask_flat
146
 
147
  # 自定义注意力处理器
148
- class CustomCrossAttentionProcessor(AttentionProcessor):
149
- def __init__(self, masks, embeddings, adapter_names):
 
150
  super().__init__()
151
- self.masks = masks # 列表,包含每个角色的掩码
152
- self.embeddings = embeddings # 列表,包含每个角色的嵌入
153
  self.adapter_names = adapter_names # 列表,包含每个角色的 LoRA 适配器名称
154
 
155
- def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  # 获取当前的 adapter_name
157
  adapter_name = getattr(attn, 'adapter_name', None)
158
  if adapter_name is None or adapter_name not in self.adapter_names:
159
- # 如果没有 adapter_name,直接执行默认的注意力计算
160
- return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
161
 
162
  # 查找 adapter_name 对应的索引
163
  idx = self.adapter_names.index(adapter_name)
164
- mask = self.masks[idx]
165
-
166
- # 标准的注意力计算
167
- batch_size, sequence_length, _ = hidden_states.shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  query = attn.to_q(hidden_states)
 
 
 
 
170
  key = attn.to_k(encoder_hidden_states)
171
  value = attn.to_v(encoder_hidden_states)
172
 
173
- # 重塑以适应多头注意力
174
- query = query.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2)
175
- key = key.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2)
176
- value = value.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- # 计算注意力得分
179
- attention_scores = torch.matmul(query, key.transpose(-1, -2)) * attn.scale
180
 
181
- # 应用掩码调整注意力得分
182
- # mask 调整为与 attention_scores 兼容的形状
183
- # 假设 key_len mask 的长度一致
184
- mask_expanded = mask.unsqueeze(0).unsqueeze(0).unsqueeze(0) # (1, 1, 1, key_len)
185
- # 将掩码应用于 attention_scores
186
- attention_scores += mask_expanded * 1e6 # 增强对应位置的注意力
187
 
188
- # 计算注意力概率
189
- attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
190
 
191
- # 计算上下文向量
192
- context = torch.matmul(attention_probs, value)
 
 
193
 
194
- # 重塑回原始形状
195
- context = context.transpose(1, 2).reshape(batch_size, -1, attn.heads * attn.head_dim)
 
 
 
 
 
196
 
197
- # 输出投影
198
- hidden_states = attn.to_out(context)
199
  return hidden_states
200
 
 
201
  # 替换注意力处理器的函数
202
- def replace_attention_processors(pipe, masks, embeddings, adapter_names):
203
- custom_processor = CustomCrossAttentionProcessor(masks, embeddings, adapter_names)
204
- for name, module in pipe.unet.named_modules():
205
- if hasattr(module, 'attn2'):
206
- # 设置 adapter_name 为模块的属性
207
- module.attn2.adapter_name = getattr(module, 'adapter_name', None)
208
- module.attn2.processor = custom_processor
 
 
209
 
210
  # 生成图像的函数
211
- @spaces.GPU
212
- @torch.inference_mode()
213
- def generate_image_with_embeddings(prompt_embeddings, steps, seed, cfg_scale, width, height, progress):
214
- pipe.to("cuda")
215
- generator = torch.Generator(device="cuda").manual_seed(seed)
216
 
217
  with calculateDuration("Generating image"):
218
  # Generate image
219
  generated_image = pipe(
220
- prompt_embeds=prompt_embeddings,
 
221
  num_inference_steps=steps,
222
  guidance_scale=cfg_scale,
223
  width=width,
@@ -229,7 +316,8 @@ def generate_image_with_embeddings(prompt_embeddings, steps, seed, cfg_scale, wi
229
  return generated_image
230
 
231
  # 主函数
232
-
 
233
  def run_lora(prompt_bg, character_prompts_json, character_positions_json, lora_strings_json, prompt_details, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
234
 
235
  # 解析角色提示词、位置和 LoRA 字符串
@@ -260,7 +348,8 @@ def run_lora(prompt_bg, character_prompts_json, character_positions_json, lora_s
260
  pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
261
  adapter_names.append(adapter_name)
262
  # 将 adapter_name 设置为模型的属性
263
- setattr(pipe.unet, 'adapter_name', adapter_name)
 
264
  else:
265
  raise ValueError("Invalid LoRA string format. Each item must have 'repo', 'weights', and 'adapter_name' keys.")
266
  adapter_weights = [lora_scale] * len(adapter_names)
@@ -279,22 +368,28 @@ def run_lora(prompt_bg, character_prompts_json, character_positions_json, lora_s
279
  # 编码提示词
280
  with calculateDuration("Encoding prompts"):
281
  # 编码背景提示词
282
- bg_text_input = pipe.tokenizer(prompt_bg, return_tensors="pt").to("cuda")
283
- bg_embeddings = pipe.text_encoder(bg_text_input.input_ids.to(device))[0]
284
-
 
285
  # 编码角色提示词
286
- character_embeddings = []
 
287
  for prompt in character_prompts:
288
- char_text_input = pipe.tokenizer(prompt, return_tensors="pt").to("cuda")
289
- char_embeddings = pipe.text_encoder(char_text_input.input_ids.to(device))[0]
290
- character_embeddings.append(char_embeddings)
291
-
 
 
292
  # 编码互动细节提示词
293
- details_text_input = pipe.tokenizer(prompt_details, return_tensors="pt").to("cuda")
294
- details_embeddings = pipe.text_encoder(details_text_input.input_ids.to(device))[0]
295
-
 
296
  # 合并背景和互动细节的嵌入
297
- prompt_embeddings = torch.cat([bg_embeddings, details_embeddings], dim=1)
 
298
 
299
  # 解析角色位置
300
  character_infos = []
@@ -309,10 +404,10 @@ def run_lora(prompt_bg, character_prompts_json, character_positions_json, lora_s
309
  masks.append(mask)
310
 
311
  # 替换注意力处理器
312
- replace_attention_processors(pipe, masks, character_embeddings, adapter_names)
313
 
314
  # Generate image
315
- final_image = generate_image_with_embeddings(prompt_embeddings, steps, seed, cfg_scale, width, height, progress)
316
 
317
  # 您可以在此处添加上传图片的代码
318
  result = {"status": "success", "message": "Image generated"}
@@ -334,7 +429,7 @@ with gr.Blocks(css=css) as demo:
334
  with gr.Row():
335
 
336
  with gr.Column():
337
-
338
  prompt_bg = gr.Text(label="Background Prompt", placeholder="Enter background/scene prompt", lines=2)
339
  character_prompts = gr.Text(label="Character Prompts (JSON List)", placeholder='["Character 1 prompt", "Character 2 prompt"]', lines=5)
340
  character_positions = gr.Text(label="Character Positions (JSON List)", placeholder='["Character 1 position", "Character 2 position"]', lines=5)
 
2
  import gradio as gr
3
  import numpy as np
4
  import random
5
+ import spaces
6
  import torch
7
  import json
8
  import logging
 
11
  import time
12
  from datetime import datetime
13
  from io import BytesIO
14
+ # from diffusers.models.attention_processor import AttentionProcessor
15
+ from diffusers.models.attention_processor import AttnProcessor2_0
16
+ import torch.nn.functional as F
17
+
18
  import re
19
  import json
20
  # 登录 Hugging Face Hub
21
  HF_TOKEN = os.environ.get("HF_TOKEN")
22
  login(token=HF_TOKEN)
23
+ import diffusers
24
+ print(diffusers.__version__)
25
 
26
  # 初始化
27
  dtype = torch.float16 # 您可以根据需要调整数据类型
 
151
  return mask_flat
152
 
153
  # 自定义注意力处理器
154
+
155
+ class CustomCrossAttentionProcessor(AttnProcessor2_0):
156
+ def __init__(self, masks, adapter_names):
157
  super().__init__()
158
+ self.masks = masks # 列表,包含每个角色的掩码 (shape: [key_length])
 
159
  self.adapter_names = adapter_names # 列表,包含每个角色的 LoRA 适配器名称
160
 
161
+ def __call__(
162
+ self,
163
+ attn,
164
+ hidden_states,
165
+ encoder_hidden_states=None,
166
+ attention_mask=None,
167
+ temb=None,
168
+ **kwargs,
169
+ ):
170
+ """
171
+ 自定义的注意力处理器,用于在注意力计算中应用角色掩码。
172
+
173
+ 参数:
174
+ attn: 注意力模块实例。
175
+ hidden_states: 输入的隐藏状态 (query)。
176
+ encoder_hidden_states: 编码器的隐藏状态 (key/value)。
177
+ attention_mask: 注意力掩码。
178
+ temb: 时间嵌入(可能不需要)。
179
+ **kwargs: 其他参数。
180
+
181
+ 返回:
182
+ 处理后的隐藏状态。
183
+ """
184
  # 获取当前的 adapter_name
185
  adapter_name = getattr(attn, 'adapter_name', None)
186
  if adapter_name is None or adapter_name not in self.adapter_names:
187
+ # 如果没有 adapter_name,或者不在我们的列表中,直接执行父类的 __call__ 方法
188
+ return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, temb, **kwargs)
189
 
190
  # 查找 adapter_name 对应的索引
191
  idx = self.adapter_names.index(adapter_name)
192
+ mask = self.masks[idx] # 获取对应的掩码 (shape: [key_length])
193
+
194
+ # 以下是 AttnProcessor2_0 的实现,我们在适当的位置加入自定义的掩码逻辑
195
+
196
+ residual = hidden_states
197
+ if attn.spatial_norm is not None:
198
+ hidden_states = attn.spatial_norm(hidden_states, temb)
199
+
200
+ input_ndim = hidden_states.ndim
201
+
202
+ if input_ndim == 4:
203
+ batch_size, channel, height, width = hidden_states.shape
204
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
205
+ else:
206
+ batch_size, sequence_length, _ = hidden_states.shape
207
+
208
+ if encoder_hidden_states is None:
209
+ encoder_hidden_states = hidden_states
210
+ else:
211
+ # 如果有 encoder_hidden_states,获取其形状
212
+ encoder_batch_size, key_length, _ = encoder_hidden_states.shape
213
+
214
+ if attention_mask is not None:
215
+ # 处理 attention_mask,如果需要的话
216
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_length, batch_size)
217
+ # attention_mask 的形状应为 (batch_size, attn.heads, query_length, key_length)
218
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
219
+ else:
220
+ # 如果没有 attention_mask,我们创建一个全 0 的掩码
221
+ attention_mask = torch.zeros(
222
+ batch_size, attn.heads, 1, key_length, device=hidden_states.device, dtype=hidden_states.dtype
223
+ )
224
+
225
+ if attn.group_norm is not None:
226
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
227
 
228
  query = attn.to_q(hidden_states)
229
+
230
+ if attn.norm_cross:
231
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
232
+
233
  key = attn.to_k(encoder_hidden_states)
234
  value = attn.to_v(encoder_hidden_states)
235
 
236
+ inner_dim = key.shape[-1]
237
+ head_dim = inner_dim // attn.heads
238
+
239
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
240
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
241
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
242
+
243
+ if attn.norm_q is not None:
244
+ query = attn.norm_q(query)
245
+ if attn.norm_k is not None:
246
+ key = attn.norm_k(key)
247
+
248
+ # 计算原始的注意力得分
249
+ # 我们需要在计算注意力得分前应用掩码
250
+ # 但由于 PyTorch 的 scaled_dot_product_attention 接受 attention_mask 参数,我们需要调整我们的掩码
251
+
252
+ # 创建自定义的 attention_mask
253
+ # mask 的形状为 [key_length],需要调整为 (batch_size, attn.heads, 1, key_length)
254
+ custom_attention_mask = mask.view(1, 1, 1, -1).to(hidden_states.device, dtype=hidden_states.dtype)
255
+ # 将有效位置设为 0,被掩蔽的位置设为 -1e9(对于 float16,使用 -65504)
256
+ mask_value = -65504.0 if hidden_states.dtype == torch.float16 else -1e9
257
+ custom_attention_mask = (1.0 - custom_attention_mask) * mask_value # 有效位置为 0,无效位置为 -1e9
258
 
259
+ # 将自定义掩码添加到 attention_mask
260
+ attention_mask = attention_mask + custom_attention_mask
261
 
262
+ # 计算注意力
263
+ hidden_states = F.scaled_dot_product_attention(
264
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
265
+ )
 
 
266
 
267
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
268
+ hidden_states = hidden_states.to(query.dtype)
269
 
270
+ # linear proj
271
+ hidden_states = attn.to_out[0](hidden_states)
272
+ # dropout
273
+ hidden_states = attn.to_out[1](hidden_states)
274
 
275
+ if input_ndim == 4:
276
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
277
+
278
+ if attn.residual_connection:
279
+ hidden_states = hidden_states + residual
280
+
281
+ hidden_states = hidden_states / attn.rescale_output_factor
282
 
 
 
283
  return hidden_states
284
 
285
+
286
  # 替换注意力处理器的函数
287
+ def replace_attention_processors(pipe, masks, adapter_names):
288
+ custom_processor = CustomCrossAttentionProcessor(masks, adapter_names)
289
+ for name, module in pipe.transformer.named_modules():
290
+ if hasattr(module, 'attn'):
291
+ module.attn.adapter_name = getattr(module, 'adapter_name', None)
292
+ module.attn.processor = custom_processor
293
+ if hasattr(module, 'cross_attn'):
294
+ module.cross_attn.adapter_name = getattr(module, 'adapter_name', None)
295
+ module.cross_attn.processor = custom_processor
296
 
297
  # 生成图像的函数
298
+
299
+ def generate_image_with_embeddings(prompt_embeds, pooled_prompt_embeds, steps, seed, cfg_scale, width, height, progress):
300
+ pipe.to(device)
301
+ generator = torch.Generator(device=device).manual_seed(seed)
 
302
 
303
  with calculateDuration("Generating image"):
304
  # Generate image
305
  generated_image = pipe(
306
+ prompt_embeds=prompt_embeds,
307
+ pooled_prompt_embeds=pooled_prompt_embeds,
308
  num_inference_steps=steps,
309
  guidance_scale=cfg_scale,
310
  width=width,
 
316
  return generated_image
317
 
318
  # 主函数
319
+ @spaces.GPU
320
+ @torch.inference_mode()
321
  def run_lora(prompt_bg, character_prompts_json, character_positions_json, lora_strings_json, prompt_details, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
322
 
323
  # 解析角色提示词、位置和 LoRA 字符串
 
348
  pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
349
  adapter_names.append(adapter_name)
350
  # 将 adapter_name 设置为模型的属性
351
+ setattr(pipe.transformer, 'adapter_name', adapter_name)
352
+
353
  else:
354
  raise ValueError("Invalid LoRA string format. Each item must have 'repo', 'weights', and 'adapter_name' keys.")
355
  adapter_weights = [lora_scale] * len(adapter_names)
 
368
  # 编码提示词
369
  with calculateDuration("Encoding prompts"):
370
  # 编码背景提示词
371
+ bg_text_input = pipe.tokenizer(prompt_bg, return_tensors="pt").to(device)
372
+ bg_prompt_embeds = pipe.text_encoder_2(bg_text_input.input_ids.to(device))[0]
373
+ bg_pooled_embeds = pipe.text_encoder(bg_text_input.input_ids.to(device)).pooler_output
374
+
375
  # 编码角色提示词
376
+ character_prompt_embeds = []
377
+ character_pooled_embeds = []
378
  for prompt in character_prompts:
379
+ char_text_input = pipe.tokenizer(prompt, return_tensors="pt").to(device)
380
+ char_prompt_embeds = pipe.text_encoder_2(char_text_input.input_ids.to(device))[0]
381
+ char_pooled_embeds = pipe.text_encoder(char_text_input.input_ids.to(device)).pooler_output
382
+ character_prompt_embeds.append(char_prompt_embeds)
383
+ character_pooled_embeds.append(char_pooled_embeds)
384
+
385
  # 编码互动细节提示词
386
+ details_text_input = pipe.tokenizer(prompt_details, return_tensors="pt").to(device)
387
+ details_prompt_embeds = pipe.text_encoder_2(details_text_input.input_ids.to(device))[0]
388
+ details_pooled_embeds = pipe.text_encoder(details_text_input.input_ids.to(device)).pooler_output
389
+
390
  # 合并背景和互动细节的嵌入
391
+ prompt_embeds = torch.cat([bg_prompt_embeds, details_prompt_embeds], dim=1)
392
+ pooled_prompt_embeds = torch.cat([bg_pooled_embeds, details_pooled_embeds], dim=1)
393
 
394
  # 解析角色位置
395
  character_infos = []
 
404
  masks.append(mask)
405
 
406
  # 替换注意力处理器
407
+ replace_attention_processors(pipe, masks, adapter_names)
408
 
409
  # Generate image
410
+ final_image = generate_image_with_embeddings(prompt_embeddings, pooled_prompt_embeds, steps, seed, cfg_scale, width, height, progress)
411
 
412
  # 您可以在此处添加上传图片的代码
413
  result = {"status": "success", "message": "Image generated"}
 
429
  with gr.Row():
430
 
431
  with gr.Column():
432
+
433
  prompt_bg = gr.Text(label="Background Prompt", placeholder="Enter background/scene prompt", lines=2)
434
  character_prompts = gr.Text(label="Character Prompts (JSON List)", placeholder='["Character 1 prompt", "Character 2 prompt"]', lines=5)
435
  character_positions = gr.Text(label="Character Positions (JSON List)", placeholder='["Character 1 position", "Character 2 position"]', lines=5)