jiuface commited on
Commit
53d0f2f
·
1 Parent(s): 5f10541

try to custom attention control

Browse files
Files changed (1) hide show
  1. app.py +49 -4
app.py CHANGED
@@ -17,6 +17,7 @@ import time
17
  import boto3
18
  from io import BytesIO
19
  from datetime import datetime
 
20
 
21
  from diffusers import UNet2DConditionModel
22
 
@@ -37,7 +38,8 @@ base_model = "black-forest-labs/FLUX.1-dev"
37
  # use_safetensors=True,
38
  # variant="fp16",
39
  # subfolder="unet",
40
- # ).to("cuda")
 
41
 
42
 
43
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
@@ -85,9 +87,37 @@ def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
85
  return image_file
86
 
87
 
88
-
89
  def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
90
  pipe.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  generator = torch.Generator(device="cuda").manual_seed(seed)
92
  with calculateDuration("Generating image"):
93
  # Generate image
@@ -98,13 +128,26 @@ def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
98
  width=width,
99
  height=height,
100
  generator=generator,
101
- joint_attention_kwargs={"scale": 1}
102
  ).images[0]
103
 
104
  progress(99, "Generate success!")
105
  return generate_image
106
 
107
- @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def run_lora(prompt, cfg_scale, steps, lora_strings, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
109
 
110
 
@@ -127,6 +170,8 @@ def run_lora(prompt, cfg_scale, steps, lora_strings, randomize_seed, seed, width
127
  adapter_weights = [lora_scale] * len(adapter_names)
128
  # 调用 pipeline.set_adapters 方法设置 adapter 和对应权重
129
  pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
 
 
130
 
131
  # Set random seed for reproducibility
132
  if randomize_seed:
 
17
  import boto3
18
  from io import BytesIO
19
  from datetime import datetime
20
+ from transformers import AutoTokenizer
21
 
22
  from diffusers import UNet2DConditionModel
23
 
 
38
  # use_safetensors=True,
39
  # variant="fp16",
40
  # subfolder="unet",
41
+ # # ).to("cuda")
42
+ # tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
43
 
44
 
45
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
 
87
  return image_file
88
 
89
 
90
+ @spaces.GPU
91
  def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
92
  pipe.to("cuda")
93
+
94
+ text_inputs = pipe.tokenizer(prompt, return_tensors="pt").to("cuda")
95
+ input_ids = text_inputs.input_ids[0]
96
+
97
+ # 获取每个主体对应的令牌 ID
98
+ boy_token_id = pipe.tokenizer.convert_tokens_to_ids("boy_asia_05")
99
+ print(boy_token_id)
100
+ girl_token_id = pipe.tokenizer.convert_tokens_to_ids("girl_asia_04")
101
+ print(girl_token_id)
102
+ # 找到每个主体在输入中的索引位置
103
+ boy_indices = (input_ids == boy_token_id).nonzero(as_tuple=True)[0]
104
+ girl_indices = (input_ids == girl_token_id).nonzero(as_tuple=True)[0]
105
+
106
+ # 准备 cross_attention_kwargs
107
+ def attention_control(attention_probs, adapter_name):
108
+ # 根据 adapter_name 和令牌索引控制注意力
109
+ print("attention_control", adapter_name)
110
+ if adapter_name == "boy_asia_05":
111
+ # 对女孩的令牌注意力设为零
112
+ attention_probs[:, :, :, girl_indices] = 0
113
+ elif adapter_name == "girl_asia_04":
114
+ # 对男孩的令牌注意力设为零
115
+ attention_probs[:, :, :, boy_indices] = 0
116
+ return attention_probs
117
+
118
+ joint_attention_kwargs = {"attention_control": attention_control}
119
+
120
+
121
  generator = torch.Generator(device="cuda").manual_seed(seed)
122
  with calculateDuration("Generating image"):
123
  # Generate image
 
128
  width=width,
129
  height=height,
130
  generator=generator,
131
+ joint_attention_kwargs=joint_attention_kwargs
132
  ).images[0]
133
 
134
  progress(99, "Generate success!")
135
  return generate_image
136
 
137
+ # 在 Transformer 中,自定义注意力处理器
138
+ class CustomAttentionProcessor(torch.nn.Module):
139
+ def __init__(self, attention_control, adapter_name):
140
+ super().__init__()
141
+ self.attention_control = attention_control
142
+ self.adapter_name = adapter_name
143
+
144
+ def forward(self, attention_probs):
145
+ # 调用自定义的注意力控制函数
146
+ attention_probs = self.attention_control(attention_probs, self.adapter_name)
147
+ return attention_probs
148
+
149
+
150
+
151
  def run_lora(prompt, cfg_scale, steps, lora_strings, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
152
 
153
 
 
170
  adapter_weights = [lora_scale] * len(adapter_names)
171
  # 调用 pipeline.set_adapters 方法设置 adapter 和对应权重
172
  pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
173
+
174
+
175
 
176
  # Set random seed for reproducibility
177
  if randomize_seed: