Spaces:
Running
on
Zero
Running
on
Zero
try to custom attention control
Browse files
app.py
CHANGED
@@ -17,6 +17,7 @@ import time
|
|
17 |
import boto3
|
18 |
from io import BytesIO
|
19 |
from datetime import datetime
|
|
|
20 |
|
21 |
from diffusers import UNet2DConditionModel
|
22 |
|
@@ -37,7 +38,8 @@ base_model = "black-forest-labs/FLUX.1-dev"
|
|
37 |
# use_safetensors=True,
|
38 |
# variant="fp16",
|
39 |
# subfolder="unet",
|
40 |
-
# ).to("cuda")
|
|
|
41 |
|
42 |
|
43 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
|
@@ -85,9 +87,37 @@ def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
|
|
85 |
return image_file
|
86 |
|
87 |
|
88 |
-
|
89 |
def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
|
90 |
pipe.to("cuda")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|
92 |
with calculateDuration("Generating image"):
|
93 |
# Generate image
|
@@ -98,13 +128,26 @@ def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
|
|
98 |
width=width,
|
99 |
height=height,
|
100 |
generator=generator,
|
101 |
-
joint_attention_kwargs=
|
102 |
).images[0]
|
103 |
|
104 |
progress(99, "Generate success!")
|
105 |
return generate_image
|
106 |
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
def run_lora(prompt, cfg_scale, steps, lora_strings, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
|
109 |
|
110 |
|
@@ -127,6 +170,8 @@ def run_lora(prompt, cfg_scale, steps, lora_strings, randomize_seed, seed, width
|
|
127 |
adapter_weights = [lora_scale] * len(adapter_names)
|
128 |
# 调用 pipeline.set_adapters 方法设置 adapter 和对应权重
|
129 |
pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
|
|
|
|
|
130 |
|
131 |
# Set random seed for reproducibility
|
132 |
if randomize_seed:
|
|
|
17 |
import boto3
|
18 |
from io import BytesIO
|
19 |
from datetime import datetime
|
20 |
+
from transformers import AutoTokenizer
|
21 |
|
22 |
from diffusers import UNet2DConditionModel
|
23 |
|
|
|
38 |
# use_safetensors=True,
|
39 |
# variant="fp16",
|
40 |
# subfolder="unet",
|
41 |
+
# # ).to("cuda")
|
42 |
+
# tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
43 |
|
44 |
|
45 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
|
|
|
87 |
return image_file
|
88 |
|
89 |
|
90 |
+
@spaces.GPU
|
91 |
def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
|
92 |
pipe.to("cuda")
|
93 |
+
|
94 |
+
text_inputs = pipe.tokenizer(prompt, return_tensors="pt").to("cuda")
|
95 |
+
input_ids = text_inputs.input_ids[0]
|
96 |
+
|
97 |
+
# 获取每个主体对应的令牌 ID
|
98 |
+
boy_token_id = pipe.tokenizer.convert_tokens_to_ids("boy_asia_05")
|
99 |
+
print(boy_token_id)
|
100 |
+
girl_token_id = pipe.tokenizer.convert_tokens_to_ids("girl_asia_04")
|
101 |
+
print(girl_token_id)
|
102 |
+
# 找到每个主体在输入中的索引位置
|
103 |
+
boy_indices = (input_ids == boy_token_id).nonzero(as_tuple=True)[0]
|
104 |
+
girl_indices = (input_ids == girl_token_id).nonzero(as_tuple=True)[0]
|
105 |
+
|
106 |
+
# 准备 cross_attention_kwargs
|
107 |
+
def attention_control(attention_probs, adapter_name):
|
108 |
+
# 根据 adapter_name 和令牌索引控制注意力
|
109 |
+
print("attention_control", adapter_name)
|
110 |
+
if adapter_name == "boy_asia_05":
|
111 |
+
# 对女孩的令牌注意力设为零
|
112 |
+
attention_probs[:, :, :, girl_indices] = 0
|
113 |
+
elif adapter_name == "girl_asia_04":
|
114 |
+
# 对男孩的令牌注意力设为零
|
115 |
+
attention_probs[:, :, :, boy_indices] = 0
|
116 |
+
return attention_probs
|
117 |
+
|
118 |
+
joint_attention_kwargs = {"attention_control": attention_control}
|
119 |
+
|
120 |
+
|
121 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|
122 |
with calculateDuration("Generating image"):
|
123 |
# Generate image
|
|
|
128 |
width=width,
|
129 |
height=height,
|
130 |
generator=generator,
|
131 |
+
joint_attention_kwargs=joint_attention_kwargs
|
132 |
).images[0]
|
133 |
|
134 |
progress(99, "Generate success!")
|
135 |
return generate_image
|
136 |
|
137 |
+
# 在 Transformer 中,自定义注意力处理器
|
138 |
+
class CustomAttentionProcessor(torch.nn.Module):
|
139 |
+
def __init__(self, attention_control, adapter_name):
|
140 |
+
super().__init__()
|
141 |
+
self.attention_control = attention_control
|
142 |
+
self.adapter_name = adapter_name
|
143 |
+
|
144 |
+
def forward(self, attention_probs):
|
145 |
+
# 调用自定义的注意力控制函数
|
146 |
+
attention_probs = self.attention_control(attention_probs, self.adapter_name)
|
147 |
+
return attention_probs
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
def run_lora(prompt, cfg_scale, steps, lora_strings, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
|
152 |
|
153 |
|
|
|
170 |
adapter_weights = [lora_scale] * len(adapter_names)
|
171 |
# 调用 pipeline.set_adapters 方法设置 adapter 和对应权重
|
172 |
pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
|
173 |
+
|
174 |
+
|
175 |
|
176 |
# Set random seed for reproducibility
|
177 |
if randomize_seed:
|