Spaces:
Running
on
Zero
Running
on
Zero
update
Browse files- app.py +1 -7
- attention_processor_faceid.py +426 -0
- helper.py +236 -0
- ipown.py +470 -0
- loader.py +95 -0
- requirements.txt +14 -0
- resampler.py +158 -0
- utils.py +170 -0
app.py
CHANGED
@@ -1,7 +1 @@
|
|
1 |
-
import
|
2 |
-
|
3 |
-
def greet(name):
|
4 |
-
return "Hello " + name + "!!"
|
5 |
-
|
6 |
-
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
7 |
-
demo.launch()
|
|
|
1 |
+
import loader
|
|
|
|
|
|
|
|
|
|
|
|
attention_processor_faceid.py
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from diffusers.models.lora import LoRALinearLayer
|
7 |
+
|
8 |
+
|
9 |
+
class LoRAAttnProcessor(nn.Module):
|
10 |
+
r"""
|
11 |
+
Default processor for performing attention-related computations.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
hidden_size=None,
|
17 |
+
cross_attention_dim=None,
|
18 |
+
rank=4,
|
19 |
+
network_alpha=None,
|
20 |
+
lora_scale=1.0,
|
21 |
+
):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
self.rank = rank
|
25 |
+
self.lora_scale = lora_scale
|
26 |
+
|
27 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
28 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
29 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
30 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
31 |
+
|
32 |
+
def __call__(
|
33 |
+
self,
|
34 |
+
attn,
|
35 |
+
hidden_states,
|
36 |
+
encoder_hidden_states=None,
|
37 |
+
attention_mask=None,
|
38 |
+
temb=None,
|
39 |
+
):
|
40 |
+
residual = hidden_states
|
41 |
+
|
42 |
+
if attn.spatial_norm is not None:
|
43 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
44 |
+
|
45 |
+
input_ndim = hidden_states.ndim
|
46 |
+
|
47 |
+
if input_ndim == 4:
|
48 |
+
batch_size, channel, height, width = hidden_states.shape
|
49 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
50 |
+
|
51 |
+
batch_size, sequence_length, _ = (
|
52 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
53 |
+
)
|
54 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
55 |
+
|
56 |
+
if attn.group_norm is not None:
|
57 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
58 |
+
|
59 |
+
query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
|
60 |
+
|
61 |
+
if encoder_hidden_states is None:
|
62 |
+
encoder_hidden_states = hidden_states
|
63 |
+
elif attn.norm_cross:
|
64 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
65 |
+
|
66 |
+
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
|
67 |
+
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
|
68 |
+
|
69 |
+
query = attn.head_to_batch_dim(query)
|
70 |
+
key = attn.head_to_batch_dim(key)
|
71 |
+
value = attn.head_to_batch_dim(value)
|
72 |
+
|
73 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
74 |
+
hidden_states = torch.bmm(attention_probs, value)
|
75 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
76 |
+
|
77 |
+
# linear proj
|
78 |
+
hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
|
79 |
+
# dropout
|
80 |
+
hidden_states = attn.to_out[1](hidden_states)
|
81 |
+
|
82 |
+
if input_ndim == 4:
|
83 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
84 |
+
|
85 |
+
if attn.residual_connection:
|
86 |
+
hidden_states = hidden_states + residual
|
87 |
+
|
88 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
89 |
+
|
90 |
+
return hidden_states
|
91 |
+
|
92 |
+
|
93 |
+
class LoRAIPAttnProcessor(nn.Module):
|
94 |
+
r"""
|
95 |
+
Attention processor for IP-Adapater.
|
96 |
+
Args:
|
97 |
+
hidden_size (`int`):
|
98 |
+
The hidden size of the attention layer.
|
99 |
+
cross_attention_dim (`int`):
|
100 |
+
The number of channels in the `encoder_hidden_states`.
|
101 |
+
scale (`float`, defaults to 1.0):
|
102 |
+
the weight scale of image prompt.
|
103 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
104 |
+
The context length of the image features.
|
105 |
+
"""
|
106 |
+
|
107 |
+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4):
|
108 |
+
super().__init__()
|
109 |
+
|
110 |
+
self.rank = rank
|
111 |
+
self.lora_scale = lora_scale
|
112 |
+
|
113 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
114 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
115 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
116 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
117 |
+
|
118 |
+
self.hidden_size = hidden_size
|
119 |
+
self.cross_attention_dim = cross_attention_dim
|
120 |
+
self.scale = scale
|
121 |
+
self.num_tokens = num_tokens
|
122 |
+
|
123 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
124 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
125 |
+
|
126 |
+
def __call__(
|
127 |
+
self,
|
128 |
+
attn,
|
129 |
+
hidden_states,
|
130 |
+
encoder_hidden_states=None,
|
131 |
+
attention_mask=None,
|
132 |
+
temb=None,
|
133 |
+
):
|
134 |
+
residual = hidden_states
|
135 |
+
|
136 |
+
if attn.spatial_norm is not None:
|
137 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
138 |
+
|
139 |
+
input_ndim = hidden_states.ndim
|
140 |
+
|
141 |
+
if input_ndim == 4:
|
142 |
+
batch_size, channel, height, width = hidden_states.shape
|
143 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
144 |
+
|
145 |
+
batch_size, sequence_length, _ = (
|
146 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
147 |
+
)
|
148 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
149 |
+
|
150 |
+
if attn.group_norm is not None:
|
151 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
152 |
+
|
153 |
+
query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
|
154 |
+
|
155 |
+
if encoder_hidden_states is None:
|
156 |
+
encoder_hidden_states = hidden_states
|
157 |
+
else:
|
158 |
+
# get encoder_hidden_states, ip_hidden_states
|
159 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
160 |
+
encoder_hidden_states, ip_hidden_states = (
|
161 |
+
encoder_hidden_states[:, :end_pos, :],
|
162 |
+
encoder_hidden_states[:, end_pos:, :],
|
163 |
+
)
|
164 |
+
if attn.norm_cross:
|
165 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
166 |
+
|
167 |
+
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
|
168 |
+
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
|
169 |
+
|
170 |
+
query = attn.head_to_batch_dim(query)
|
171 |
+
key = attn.head_to_batch_dim(key)
|
172 |
+
value = attn.head_to_batch_dim(value)
|
173 |
+
|
174 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
175 |
+
hidden_states = torch.bmm(attention_probs, value)
|
176 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
177 |
+
|
178 |
+
# for ip-adapter
|
179 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
180 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
181 |
+
|
182 |
+
ip_key = attn.head_to_batch_dim(ip_key)
|
183 |
+
ip_value = attn.head_to_batch_dim(ip_value)
|
184 |
+
|
185 |
+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
186 |
+
self.attn_map = ip_attention_probs
|
187 |
+
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
188 |
+
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
189 |
+
|
190 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
191 |
+
|
192 |
+
# linear proj
|
193 |
+
hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
|
194 |
+
# dropout
|
195 |
+
hidden_states = attn.to_out[1](hidden_states)
|
196 |
+
|
197 |
+
if input_ndim == 4:
|
198 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
199 |
+
|
200 |
+
if attn.residual_connection:
|
201 |
+
hidden_states = hidden_states + residual
|
202 |
+
|
203 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
204 |
+
|
205 |
+
return hidden_states
|
206 |
+
|
207 |
+
|
208 |
+
class LoRAAttnProcessor2_0(nn.Module):
|
209 |
+
|
210 |
+
r"""
|
211 |
+
Default processor for performing attention-related computations.
|
212 |
+
"""
|
213 |
+
|
214 |
+
def __init__(
|
215 |
+
self,
|
216 |
+
hidden_size=None,
|
217 |
+
cross_attention_dim=None,
|
218 |
+
rank=4,
|
219 |
+
network_alpha=None,
|
220 |
+
lora_scale=1.0,
|
221 |
+
):
|
222 |
+
super().__init__()
|
223 |
+
|
224 |
+
self.rank = rank
|
225 |
+
self.lora_scale = lora_scale
|
226 |
+
|
227 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
228 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
229 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
230 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
231 |
+
|
232 |
+
def __call__(
|
233 |
+
self,
|
234 |
+
attn,
|
235 |
+
hidden_states,
|
236 |
+
encoder_hidden_states=None,
|
237 |
+
attention_mask=None,
|
238 |
+
temb=None,
|
239 |
+
):
|
240 |
+
residual = hidden_states
|
241 |
+
|
242 |
+
if attn.spatial_norm is not None:
|
243 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
244 |
+
|
245 |
+
input_ndim = hidden_states.ndim
|
246 |
+
|
247 |
+
if input_ndim == 4:
|
248 |
+
batch_size, channel, height, width = hidden_states.shape
|
249 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
250 |
+
|
251 |
+
batch_size, sequence_length, _ = (
|
252 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
253 |
+
)
|
254 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
255 |
+
|
256 |
+
if attn.group_norm is not None:
|
257 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
258 |
+
|
259 |
+
query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
|
260 |
+
|
261 |
+
if encoder_hidden_states is None:
|
262 |
+
encoder_hidden_states = hidden_states
|
263 |
+
elif attn.norm_cross:
|
264 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
265 |
+
|
266 |
+
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
|
267 |
+
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
|
268 |
+
|
269 |
+
inner_dim = key.shape[-1]
|
270 |
+
head_dim = inner_dim // attn.heads
|
271 |
+
|
272 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
273 |
+
|
274 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
275 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
276 |
+
|
277 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
278 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
279 |
+
hidden_states = F.scaled_dot_product_attention(
|
280 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
281 |
+
)
|
282 |
+
|
283 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
284 |
+
hidden_states = hidden_states.to(query.dtype)
|
285 |
+
|
286 |
+
# linear proj
|
287 |
+
hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
|
288 |
+
# dropout
|
289 |
+
hidden_states = attn.to_out[1](hidden_states)
|
290 |
+
|
291 |
+
if input_ndim == 4:
|
292 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
293 |
+
|
294 |
+
if attn.residual_connection:
|
295 |
+
hidden_states = hidden_states + residual
|
296 |
+
|
297 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
298 |
+
|
299 |
+
return hidden_states
|
300 |
+
|
301 |
+
|
302 |
+
class LoRAIPAttnProcessor2_0(nn.Module):
|
303 |
+
r"""
|
304 |
+
Processor for implementing the LoRA attention mechanism.
|
305 |
+
Args:
|
306 |
+
hidden_size (`int`, *optional*):
|
307 |
+
The hidden size of the attention layer.
|
308 |
+
cross_attention_dim (`int`, *optional*):
|
309 |
+
The number of channels in the `encoder_hidden_states`.
|
310 |
+
rank (`int`, defaults to 4):
|
311 |
+
The dimension of the LoRA update matrices.
|
312 |
+
network_alpha (`int`, *optional*):
|
313 |
+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
314 |
+
"""
|
315 |
+
|
316 |
+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4):
|
317 |
+
super().__init__()
|
318 |
+
|
319 |
+
self.rank = rank
|
320 |
+
self.lora_scale = lora_scale
|
321 |
+
self.num_tokens = num_tokens
|
322 |
+
|
323 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
324 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
325 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
326 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
327 |
+
|
328 |
+
|
329 |
+
self.hidden_size = hidden_size
|
330 |
+
self.cross_attention_dim = cross_attention_dim
|
331 |
+
self.scale = scale
|
332 |
+
|
333 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
334 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
335 |
+
|
336 |
+
def __call__(
|
337 |
+
self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
|
338 |
+
):
|
339 |
+
residual = hidden_states
|
340 |
+
|
341 |
+
if attn.spatial_norm is not None:
|
342 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
343 |
+
|
344 |
+
input_ndim = hidden_states.ndim
|
345 |
+
|
346 |
+
if input_ndim == 4:
|
347 |
+
batch_size, channel, height, width = hidden_states.shape
|
348 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
349 |
+
|
350 |
+
batch_size, sequence_length, _ = (
|
351 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
352 |
+
)
|
353 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
354 |
+
|
355 |
+
if attn.group_norm is not None:
|
356 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
357 |
+
|
358 |
+
query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
|
359 |
+
#query = attn.head_to_batch_dim(query)
|
360 |
+
|
361 |
+
if encoder_hidden_states is None:
|
362 |
+
encoder_hidden_states = hidden_states
|
363 |
+
else:
|
364 |
+
# get encoder_hidden_states, ip_hidden_states
|
365 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
366 |
+
encoder_hidden_states, ip_hidden_states = (
|
367 |
+
encoder_hidden_states[:, :end_pos, :],
|
368 |
+
encoder_hidden_states[:, end_pos:, :],
|
369 |
+
)
|
370 |
+
if attn.norm_cross:
|
371 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
372 |
+
|
373 |
+
# for text
|
374 |
+
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
|
375 |
+
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
|
376 |
+
|
377 |
+
inner_dim = key.shape[-1]
|
378 |
+
head_dim = inner_dim // attn.heads
|
379 |
+
|
380 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
381 |
+
|
382 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
383 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
384 |
+
|
385 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
386 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
387 |
+
hidden_states = F.scaled_dot_product_attention(
|
388 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
389 |
+
)
|
390 |
+
|
391 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
392 |
+
hidden_states = hidden_states.to(query.dtype)
|
393 |
+
|
394 |
+
# for ip
|
395 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
396 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
397 |
+
|
398 |
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
399 |
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
400 |
+
|
401 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
402 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
403 |
+
ip_hidden_states = F.scaled_dot_product_attention(
|
404 |
+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
405 |
+
)
|
406 |
+
|
407 |
+
|
408 |
+
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
409 |
+
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
410 |
+
|
411 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
412 |
+
|
413 |
+
# linear proj
|
414 |
+
hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
|
415 |
+
# dropout
|
416 |
+
hidden_states = attn.to_out[1](hidden_states)
|
417 |
+
|
418 |
+
if input_ndim == 4:
|
419 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
420 |
+
|
421 |
+
if attn.residual_connection:
|
422 |
+
hidden_states = hidden_states + residual
|
423 |
+
|
424 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
425 |
+
|
426 |
+
return hidden_states
|
helper.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import uuid
|
4 |
+
import re
|
5 |
+
|
6 |
+
def parse_prompt_attention(text):
|
7 |
+
re_attention = re.compile(r"""
|
8 |
+
\\\(|
|
9 |
+
\\\)|
|
10 |
+
\\\[|
|
11 |
+
\\]|
|
12 |
+
\\\\|
|
13 |
+
\\|
|
14 |
+
\(|
|
15 |
+
\[|
|
16 |
+
:([+-]?[.\d]+)\)|
|
17 |
+
\)|
|
18 |
+
]|
|
19 |
+
[^\\()\[\]:]+|
|
20 |
+
:
|
21 |
+
""", re.X)
|
22 |
+
|
23 |
+
res = []
|
24 |
+
round_brackets = []
|
25 |
+
square_brackets = []
|
26 |
+
|
27 |
+
round_bracket_multiplier = 1.1
|
28 |
+
square_bracket_multiplier = 1 / 1.1
|
29 |
+
|
30 |
+
def multiply_range(start_position, multiplier):
|
31 |
+
for p in range(start_position, len(res)):
|
32 |
+
res[p][1] *= multiplier
|
33 |
+
|
34 |
+
for m in re_attention.finditer(text):
|
35 |
+
text = m.group(0)
|
36 |
+
weight = m.group(1)
|
37 |
+
|
38 |
+
if text.startswith('\\'):
|
39 |
+
res.append([text[1:], 1.0])
|
40 |
+
elif text == '(':
|
41 |
+
round_brackets.append(len(res))
|
42 |
+
elif text == '[':
|
43 |
+
square_brackets.append(len(res))
|
44 |
+
elif weight is not None and len(round_brackets) > 0:
|
45 |
+
multiply_range(round_brackets.pop(), float(weight))
|
46 |
+
elif text == ')' and len(round_brackets) > 0:
|
47 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
48 |
+
elif text == ']' and len(square_brackets) > 0:
|
49 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
50 |
+
else:
|
51 |
+
parts = re.split(re.compile(r"\s*\bBREAK\b\s*", re.S), text)
|
52 |
+
for i, part in enumerate(parts):
|
53 |
+
if i > 0:
|
54 |
+
res.append(["BREAK", -1])
|
55 |
+
res.append([part, 1.0])
|
56 |
+
|
57 |
+
for pos in round_brackets:
|
58 |
+
multiply_range(pos, round_bracket_multiplier)
|
59 |
+
|
60 |
+
for pos in square_brackets:
|
61 |
+
multiply_range(pos, square_bracket_multiplier)
|
62 |
+
|
63 |
+
if len(res) == 0:
|
64 |
+
res = [["", 1.0]]
|
65 |
+
|
66 |
+
# merge runs of identical weights
|
67 |
+
i = 0
|
68 |
+
while i + 1 < len(res):
|
69 |
+
if res[i][1] == res[i + 1][1]:
|
70 |
+
res[i][0] += res[i + 1][0]
|
71 |
+
res.pop(i + 1)
|
72 |
+
else:
|
73 |
+
i += 1
|
74 |
+
|
75 |
+
return res
|
76 |
+
|
77 |
+
def prompt_attention_to_invoke_prompt(attention):
|
78 |
+
tokens = []
|
79 |
+
for text, weight in attention:
|
80 |
+
# Round weight to 2 decimal places
|
81 |
+
weight = round(weight, 2)
|
82 |
+
if weight == 1.0:
|
83 |
+
tokens.append(text)
|
84 |
+
elif weight < 1.0:
|
85 |
+
if weight < 0.8:
|
86 |
+
tokens.append(f"({text}){weight}")
|
87 |
+
else:
|
88 |
+
tokens.append(f"({text})-" + "-" * int((1.0 - weight) * 10))
|
89 |
+
else:
|
90 |
+
if weight < 1.3:
|
91 |
+
tokens.append(f"({text})" + "+" * int((weight - 1.0) * 10))
|
92 |
+
else:
|
93 |
+
tokens.append(f"({text}){weight}")
|
94 |
+
return "".join(tokens)
|
95 |
+
|
96 |
+
def concat_tensor(t):
|
97 |
+
t_list = torch.split(t, 1, dim=0)
|
98 |
+
t = torch.cat(t_list, dim=1)
|
99 |
+
return t
|
100 |
+
|
101 |
+
def merge_embeds(prompt_chanks, compel):
|
102 |
+
num_chanks = len(prompt_chanks)
|
103 |
+
if num_chanks != 0:
|
104 |
+
power_prompt = 1/(num_chanks*(num_chanks+1)//2)
|
105 |
+
prompt_embs = compel(prompt_chanks)
|
106 |
+
t_list = list(torch.split(prompt_embs, 1, dim=0))
|
107 |
+
for i in range(num_chanks):
|
108 |
+
t_list[-(i+1)] = t_list[-(i+1)] * ((i+1)*power_prompt)
|
109 |
+
prompt_emb = torch.stack(t_list, dim=0).sum(dim=0)
|
110 |
+
else:
|
111 |
+
prompt_emb = compel('')
|
112 |
+
return prompt_emb
|
113 |
+
|
114 |
+
def detokenize(chunk, actual_prompt):
|
115 |
+
chunk[-1] = chunk[-1].replace('</w>', '')
|
116 |
+
chanked_prompt = ''.join(chunk).strip()
|
117 |
+
while '</w>' in chanked_prompt:
|
118 |
+
if actual_prompt[chanked_prompt.find('</w>')] == ' ':
|
119 |
+
chanked_prompt = chanked_prompt.replace('</w>', ' ', 1)
|
120 |
+
else:
|
121 |
+
chanked_prompt = chanked_prompt.replace('</w>', '', 1)
|
122 |
+
actual_prompt = actual_prompt.replace(chanked_prompt,'')
|
123 |
+
return chanked_prompt.strip(), actual_prompt.strip()
|
124 |
+
|
125 |
+
def tokenize_line(line, tokenizer): # split into chunks
|
126 |
+
actual_prompt = line.lower().strip()
|
127 |
+
actual_tokens = tokenizer.tokenize(actual_prompt)
|
128 |
+
max_tokens = tokenizer.model_max_length - 2
|
129 |
+
comma_token = tokenizer.tokenize(',')[0]
|
130 |
+
|
131 |
+
chunks = []
|
132 |
+
chunk = []
|
133 |
+
for item in actual_tokens:
|
134 |
+
chunk.append(item)
|
135 |
+
if len(chunk) == max_tokens:
|
136 |
+
if chunk[-1] != comma_token:
|
137 |
+
for i in range(max_tokens-1, -1, -1):
|
138 |
+
if chunk[i] == comma_token:
|
139 |
+
actual_chunk, actual_prompt = detokenize(chunk[:i+1], actual_prompt)
|
140 |
+
chunks.append(actual_chunk)
|
141 |
+
chunk = chunk[i+1:]
|
142 |
+
break
|
143 |
+
else:
|
144 |
+
actual_chunk, actual_prompt = detokenize(chunk, actual_prompt)
|
145 |
+
chunks.append(actual_chunk)
|
146 |
+
chunk = []
|
147 |
+
else:
|
148 |
+
actual_chunk, actual_prompt = detokenize(chunk, actual_prompt)
|
149 |
+
chunks.append(actual_chunk)
|
150 |
+
chunk = []
|
151 |
+
if chunk:
|
152 |
+
actual_chunk, _ = detokenize(chunk, actual_prompt)
|
153 |
+
chunks.append(actual_chunk)
|
154 |
+
|
155 |
+
return chunks
|
156 |
+
|
157 |
+
def get_embed_new(prompt, pipeline, compel, only_convert_string=False, compel_process_sd=False):
|
158 |
+
|
159 |
+
if compel_process_sd:
|
160 |
+
return merge_embeds(tokenize_line(prompt, pipeline.tokenizer), compel)
|
161 |
+
else:
|
162 |
+
# fix bug weights conversion excessive emphasis
|
163 |
+
prompt = prompt.replace("((", "(").replace("))", ")").replace("\\", "\\\\\\")
|
164 |
+
|
165 |
+
# Convert to Compel
|
166 |
+
attention = parse_prompt_attention(prompt)
|
167 |
+
global_attention_chanks = []
|
168 |
+
|
169 |
+
for att in attention:
|
170 |
+
for chank in att[0].split(','):
|
171 |
+
temp_prompt_chanks = tokenize_line(chank, pipeline.tokenizer)
|
172 |
+
for small_chank in temp_prompt_chanks:
|
173 |
+
temp_dict = {
|
174 |
+
"weight": round(att[1], 2),
|
175 |
+
"lenght": len(pipeline.tokenizer.tokenize(f'{small_chank},')),
|
176 |
+
"prompt": f'{small_chank},'
|
177 |
+
}
|
178 |
+
global_attention_chanks.append(temp_dict)
|
179 |
+
|
180 |
+
max_tokens = pipeline.tokenizer.model_max_length - 2
|
181 |
+
global_prompt_chanks = []
|
182 |
+
current_list = []
|
183 |
+
current_length = 0
|
184 |
+
for item in global_attention_chanks:
|
185 |
+
if current_length + item['lenght'] > max_tokens:
|
186 |
+
global_prompt_chanks.append(current_list)
|
187 |
+
current_list = [[item['prompt'], item['weight']]]
|
188 |
+
current_length = item['lenght']
|
189 |
+
else:
|
190 |
+
if not current_list:
|
191 |
+
current_list.append([item['prompt'], item['weight']])
|
192 |
+
else:
|
193 |
+
if item['weight'] != current_list[-1][1]:
|
194 |
+
current_list.append([item['prompt'], item['weight']])
|
195 |
+
else:
|
196 |
+
current_list[-1][0] += f" {item['prompt']}"
|
197 |
+
current_length += item['lenght']
|
198 |
+
if current_list:
|
199 |
+
global_prompt_chanks.append(current_list)
|
200 |
+
|
201 |
+
if only_convert_string:
|
202 |
+
return ' '.join([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chanks])
|
203 |
+
|
204 |
+
return merge_embeds([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chanks], compel)
|
205 |
+
|
206 |
+
def add_comma_after_pattern_ti(text):
|
207 |
+
pattern = re.compile(r'\b\w+_\d+\b')
|
208 |
+
modified_text = pattern.sub(lambda x: x.group() + ',', text)
|
209 |
+
return modified_text
|
210 |
+
|
211 |
+
def save_image(img):
|
212 |
+
path = "./tmp/"
|
213 |
+
|
214 |
+
# Check if the input is a string (file path) and load the image if it is
|
215 |
+
if isinstance(img, str):
|
216 |
+
img = Image.open(img) # Load the image from the file path
|
217 |
+
|
218 |
+
# Ensure the Hugging Face path exists locally
|
219 |
+
if not os.path.exists(path):
|
220 |
+
os.makedirs(path)
|
221 |
+
|
222 |
+
# Generate a unique filename
|
223 |
+
unique_name = str(uuid.uuid4()) + ".webp"
|
224 |
+
unique_name = os.path.join(path, unique_name)
|
225 |
+
|
226 |
+
# Convert the image to WebP format
|
227 |
+
webp_img = img.convert("RGB") # Ensure the image is in RGB mode
|
228 |
+
|
229 |
+
# Save the image in WebP format with high quality
|
230 |
+
webp_img.save(unique_name, "WEBP", quality=90)
|
231 |
+
|
232 |
+
# Open the saved WebP file and return it as a PIL Image object
|
233 |
+
with Image.open(unique_name) as webp_file:
|
234 |
+
webp_image = webp_file.copy()
|
235 |
+
|
236 |
+
return unique_name
|
ipown.py
ADDED
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from diffusers import StableDiffusionPipeline
|
6 |
+
from diffusers.pipelines.controlnet import MultiControlNetModel
|
7 |
+
from PIL import Image
|
8 |
+
from safetensors import safe_open
|
9 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
10 |
+
|
11 |
+
from attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
|
12 |
+
from utils import is_torch2_available
|
13 |
+
|
14 |
+
USE_DAFAULT_ATTN = False # should be True for visualization_attnmap
|
15 |
+
if is_torch2_available() and (not USE_DAFAULT_ATTN):
|
16 |
+
from attention_processor_faceid import (
|
17 |
+
LoRAAttnProcessor2_0 as LoRAAttnProcessor,
|
18 |
+
)
|
19 |
+
from attention_processor_faceid import (
|
20 |
+
LoRAIPAttnProcessor2_0 as LoRAIPAttnProcessor,
|
21 |
+
)
|
22 |
+
else:
|
23 |
+
from attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
|
24 |
+
from resampler import PerceiverAttention, FeedForward
|
25 |
+
|
26 |
+
|
27 |
+
class FacePerceiverResampler(torch.nn.Module):
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
*,
|
31 |
+
dim=768,
|
32 |
+
depth=4,
|
33 |
+
dim_head=64,
|
34 |
+
heads=16,
|
35 |
+
embedding_dim=1280,
|
36 |
+
output_dim=768,
|
37 |
+
ff_mult=4,
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.proj_in = torch.nn.Linear(embedding_dim, dim)
|
42 |
+
self.proj_out = torch.nn.Linear(dim, output_dim)
|
43 |
+
self.norm_out = torch.nn.LayerNorm(output_dim)
|
44 |
+
self.layers = torch.nn.ModuleList([])
|
45 |
+
for _ in range(depth):
|
46 |
+
self.layers.append(
|
47 |
+
torch.nn.ModuleList(
|
48 |
+
[
|
49 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
50 |
+
FeedForward(dim=dim, mult=ff_mult),
|
51 |
+
]
|
52 |
+
)
|
53 |
+
)
|
54 |
+
|
55 |
+
def forward(self, latents, x):
|
56 |
+
x = self.proj_in(x)
|
57 |
+
for attn, ff in self.layers:
|
58 |
+
latents = attn(x, latents) + latents
|
59 |
+
latents = ff(latents) + latents
|
60 |
+
latents = self.proj_out(latents)
|
61 |
+
return self.norm_out(latents)
|
62 |
+
|
63 |
+
|
64 |
+
class MLPProjModel(torch.nn.Module):
|
65 |
+
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
66 |
+
super().__init__()
|
67 |
+
|
68 |
+
self.cross_attention_dim = cross_attention_dim
|
69 |
+
self.num_tokens = num_tokens
|
70 |
+
|
71 |
+
self.proj = torch.nn.Sequential(
|
72 |
+
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
|
73 |
+
torch.nn.GELU(),
|
74 |
+
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
|
75 |
+
)
|
76 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
77 |
+
|
78 |
+
def forward(self, id_embeds):
|
79 |
+
x = self.proj(id_embeds)
|
80 |
+
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
81 |
+
x = self.norm(x)
|
82 |
+
return x
|
83 |
+
|
84 |
+
|
85 |
+
class ProjPlusModel(torch.nn.Module):
|
86 |
+
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4):
|
87 |
+
super().__init__()
|
88 |
+
|
89 |
+
self.cross_attention_dim = cross_attention_dim
|
90 |
+
self.num_tokens = num_tokens
|
91 |
+
|
92 |
+
self.proj = torch.nn.Sequential(
|
93 |
+
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
|
94 |
+
torch.nn.GELU(),
|
95 |
+
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
|
96 |
+
)
|
97 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
98 |
+
|
99 |
+
self.perceiver_resampler = FacePerceiverResampler(
|
100 |
+
dim=cross_attention_dim,
|
101 |
+
depth=4,
|
102 |
+
dim_head=64,
|
103 |
+
heads=cross_attention_dim // 64,
|
104 |
+
embedding_dim=clip_embeddings_dim,
|
105 |
+
output_dim=cross_attention_dim,
|
106 |
+
ff_mult=4,
|
107 |
+
)
|
108 |
+
|
109 |
+
def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0):
|
110 |
+
|
111 |
+
x = self.proj(id_embeds)
|
112 |
+
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
113 |
+
x = self.norm(x)
|
114 |
+
out = self.perceiver_resampler(x, clip_embeds)
|
115 |
+
if shortcut:
|
116 |
+
out = x + scale * out
|
117 |
+
return out
|
118 |
+
|
119 |
+
|
120 |
+
class IPAdapterFaceID:
|
121 |
+
def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16):
|
122 |
+
self.device = device
|
123 |
+
self.ip_ckpt = ip_ckpt
|
124 |
+
self.lora_rank = lora_rank
|
125 |
+
self.num_tokens = num_tokens
|
126 |
+
self.torch_dtype = torch_dtype
|
127 |
+
|
128 |
+
self.pipe = sd_pipe.to(self.device)
|
129 |
+
self.set_ip_adapter()
|
130 |
+
|
131 |
+
# image proj model
|
132 |
+
self.image_proj_model = self.init_proj()
|
133 |
+
|
134 |
+
self.load_ip_adapter()
|
135 |
+
|
136 |
+
def init_proj(self):
|
137 |
+
image_proj_model = MLPProjModel(
|
138 |
+
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
139 |
+
id_embeddings_dim=512,
|
140 |
+
num_tokens=self.num_tokens,
|
141 |
+
).to(self.device, dtype=self.torch_dtype)
|
142 |
+
return image_proj_model
|
143 |
+
|
144 |
+
def set_ip_adapter(self):
|
145 |
+
unet = self.pipe.unet
|
146 |
+
attn_procs = {}
|
147 |
+
for name in unet.attn_processors.keys():
|
148 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
149 |
+
if name.startswith("mid_block"):
|
150 |
+
hidden_size = unet.config.block_out_channels[-1]
|
151 |
+
elif name.startswith("up_blocks"):
|
152 |
+
block_id = int(name[len("up_blocks.")])
|
153 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
154 |
+
elif name.startswith("down_blocks"):
|
155 |
+
block_id = int(name[len("down_blocks.")])
|
156 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
157 |
+
if cross_attention_dim is None:
|
158 |
+
attn_procs[name] = LoRAAttnProcessor(
|
159 |
+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
|
160 |
+
).to(self.device, dtype=self.torch_dtype)
|
161 |
+
else:
|
162 |
+
attn_procs[name] = LoRAIPAttnProcessor(
|
163 |
+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
|
164 |
+
).to(self.device, dtype=self.torch_dtype)
|
165 |
+
unet.set_attn_processor(attn_procs)
|
166 |
+
|
167 |
+
def load_ip_adapter(self):
|
168 |
+
if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
|
169 |
+
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
170 |
+
with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
|
171 |
+
for key in f.keys():
|
172 |
+
if key.startswith("image_proj."):
|
173 |
+
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
174 |
+
elif key.startswith("ip_adapter."):
|
175 |
+
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
176 |
+
else:
|
177 |
+
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
|
178 |
+
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
179 |
+
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
180 |
+
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
181 |
+
|
182 |
+
@torch.inference_mode()
|
183 |
+
def get_image_embeds(self, faceid_embeds):
|
184 |
+
|
185 |
+
faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
|
186 |
+
print(faceid_embeds.device)
|
187 |
+
print(next(self.image_proj_model.parameters()).device)
|
188 |
+
image_prompt_embeds = self.image_proj_model(faceid_embeds)
|
189 |
+
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds))
|
190 |
+
return image_prompt_embeds, uncond_image_prompt_embeds
|
191 |
+
|
192 |
+
def set_scale(self, scale):
|
193 |
+
for attn_processor in self.pipe.unet.attn_processors.values():
|
194 |
+
if isinstance(attn_processor, LoRAIPAttnProcessor):
|
195 |
+
attn_processor.scale = scale
|
196 |
+
|
197 |
+
def generate(
|
198 |
+
self,
|
199 |
+
faceid_embeds=None,
|
200 |
+
prompt=None,
|
201 |
+
negative_prompt=None,
|
202 |
+
scale=1.0,
|
203 |
+
num_samples=4,
|
204 |
+
seed=None,
|
205 |
+
guidance_scale=7.5,
|
206 |
+
num_inference_steps=30,
|
207 |
+
**kwargs,
|
208 |
+
):
|
209 |
+
self.set_scale(scale)
|
210 |
+
|
211 |
+
|
212 |
+
num_prompts = faceid_embeds.size(0)
|
213 |
+
|
214 |
+
if prompt is None:
|
215 |
+
prompt = "best quality, high quality"
|
216 |
+
if negative_prompt is None:
|
217 |
+
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
218 |
+
|
219 |
+
if not isinstance(prompt, List):
|
220 |
+
prompt = [prompt] * num_prompts
|
221 |
+
if not isinstance(negative_prompt, List):
|
222 |
+
negative_prompt = [negative_prompt] * num_prompts
|
223 |
+
|
224 |
+
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds)
|
225 |
+
|
226 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
227 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
228 |
+
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
229 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
230 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
231 |
+
|
232 |
+
with torch.inference_mode():
|
233 |
+
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
|
234 |
+
prompt,
|
235 |
+
device=self.device,
|
236 |
+
num_images_per_prompt=num_samples,
|
237 |
+
do_classifier_free_guidance=True,
|
238 |
+
negative_prompt=negative_prompt,
|
239 |
+
)
|
240 |
+
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
241 |
+
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
242 |
+
|
243 |
+
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
244 |
+
images = self.pipe(
|
245 |
+
prompt_embeds=prompt_embeds,
|
246 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
247 |
+
guidance_scale=guidance_scale,
|
248 |
+
num_inference_steps=num_inference_steps,
|
249 |
+
generator=generator,
|
250 |
+
**kwargs,
|
251 |
+
).images
|
252 |
+
|
253 |
+
return images
|
254 |
+
|
255 |
+
|
256 |
+
class IPAdapterFaceIDPlus:
|
257 |
+
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16):
|
258 |
+
self.device = device
|
259 |
+
self.image_encoder_path = image_encoder_path
|
260 |
+
self.ip_ckpt = ip_ckpt
|
261 |
+
self.lora_rank = lora_rank
|
262 |
+
self.num_tokens = num_tokens
|
263 |
+
self.torch_dtype = torch_dtype
|
264 |
+
|
265 |
+
self.pipe = sd_pipe.to(self.device)
|
266 |
+
self.set_ip_adapter()
|
267 |
+
|
268 |
+
# load image encoder
|
269 |
+
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
|
270 |
+
self.device, dtype=self.torch_dtype
|
271 |
+
)
|
272 |
+
self.clip_image_processor = CLIPImageProcessor()
|
273 |
+
# image proj model
|
274 |
+
self.image_proj_model = self.init_proj()
|
275 |
+
|
276 |
+
self.load_ip_adapter()
|
277 |
+
|
278 |
+
def init_proj(self):
|
279 |
+
image_proj_model = ProjPlusModel(
|
280 |
+
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
281 |
+
id_embeddings_dim=512,
|
282 |
+
clip_embeddings_dim=self.image_encoder.config.hidden_size,
|
283 |
+
num_tokens=self.num_tokens,
|
284 |
+
).to(self.device, dtype=self.torch_dtype)
|
285 |
+
return image_proj_model
|
286 |
+
|
287 |
+
def set_ip_adapter(self):
|
288 |
+
unet = self.pipe.unet
|
289 |
+
attn_procs = {}
|
290 |
+
for name in unet.attn_processors.keys():
|
291 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
292 |
+
if name.startswith("mid_block"):
|
293 |
+
hidden_size = unet.config.block_out_channels[-1]
|
294 |
+
elif name.startswith("up_blocks"):
|
295 |
+
block_id = int(name[len("up_blocks.")])
|
296 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
297 |
+
elif name.startswith("down_blocks"):
|
298 |
+
block_id = int(name[len("down_blocks.")])
|
299 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
300 |
+
if cross_attention_dim is None:
|
301 |
+
attn_procs[name] = LoRAAttnProcessor(
|
302 |
+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
|
303 |
+
).to(self.device, dtype=self.torch_dtype)
|
304 |
+
else:
|
305 |
+
attn_procs[name] = LoRAIPAttnProcessor(
|
306 |
+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
|
307 |
+
).to(self.device, dtype=self.torch_dtype)
|
308 |
+
unet.set_attn_processor(attn_procs)
|
309 |
+
|
310 |
+
def load_ip_adapter(self):
|
311 |
+
if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
|
312 |
+
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
313 |
+
with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
|
314 |
+
for key in f.keys():
|
315 |
+
if key.startswith("image_proj."):
|
316 |
+
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
317 |
+
elif key.startswith("ip_adapter."):
|
318 |
+
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
319 |
+
else:
|
320 |
+
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
|
321 |
+
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
322 |
+
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
323 |
+
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
324 |
+
|
325 |
+
@torch.inference_mode()
|
326 |
+
def get_image_embeds(self, faceid_embeds, face_image, s_scale, shortcut):
|
327 |
+
if isinstance(face_image, Image.Image):
|
328 |
+
pil_image = [face_image]
|
329 |
+
clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values
|
330 |
+
clip_image = clip_image.to(self.device, dtype=self.torch_dtype)
|
331 |
+
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
332 |
+
uncond_clip_image_embeds = self.image_encoder(
|
333 |
+
torch.zeros_like(clip_image), output_hidden_states=True
|
334 |
+
).hidden_states[-2]
|
335 |
+
|
336 |
+
faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
|
337 |
+
image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale)
|
338 |
+
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale)
|
339 |
+
return image_prompt_embeds, uncond_image_prompt_embeds
|
340 |
+
|
341 |
+
def set_scale(self, scale):
|
342 |
+
for attn_processor in self.pipe.unet.attn_processors.values():
|
343 |
+
if isinstance(attn_processor, LoRAIPAttnProcessor):
|
344 |
+
attn_processor.scale = scale
|
345 |
+
|
346 |
+
def generate(
|
347 |
+
self,
|
348 |
+
face_image=None,
|
349 |
+
faceid_embeds=None,
|
350 |
+
prompt=None,
|
351 |
+
negative_prompt=None,
|
352 |
+
scale=1.0,
|
353 |
+
num_samples=4,
|
354 |
+
seed=None,
|
355 |
+
guidance_scale=7.5,
|
356 |
+
num_inference_steps=30,
|
357 |
+
s_scale=1.0,
|
358 |
+
shortcut=False,
|
359 |
+
**kwargs,
|
360 |
+
):
|
361 |
+
self.set_scale(scale)
|
362 |
+
|
363 |
+
|
364 |
+
num_prompts = faceid_embeds.size(0)
|
365 |
+
|
366 |
+
if prompt is None:
|
367 |
+
prompt = "best quality, high quality"
|
368 |
+
if negative_prompt is None:
|
369 |
+
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
370 |
+
|
371 |
+
if not isinstance(prompt, List):
|
372 |
+
prompt = [prompt] * num_prompts
|
373 |
+
if not isinstance(negative_prompt, List):
|
374 |
+
negative_prompt = [negative_prompt] * num_prompts
|
375 |
+
|
376 |
+
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut)
|
377 |
+
|
378 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
379 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
380 |
+
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
381 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
382 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
383 |
+
|
384 |
+
with torch.inference_mode():
|
385 |
+
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
|
386 |
+
prompt,
|
387 |
+
device=self.device,
|
388 |
+
num_images_per_prompt=num_samples,
|
389 |
+
do_classifier_free_guidance=True,
|
390 |
+
negative_prompt=negative_prompt,
|
391 |
+
)
|
392 |
+
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
393 |
+
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
394 |
+
|
395 |
+
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
396 |
+
images = self.pipe(
|
397 |
+
prompt_embeds=prompt_embeds,
|
398 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
399 |
+
guidance_scale=guidance_scale,
|
400 |
+
num_inference_steps=num_inference_steps,
|
401 |
+
generator=generator,
|
402 |
+
**kwargs,
|
403 |
+
).images
|
404 |
+
|
405 |
+
return images
|
406 |
+
|
407 |
+
|
408 |
+
class IPAdapterFaceIDXL(IPAdapterFaceID):
|
409 |
+
"""SDXL"""
|
410 |
+
|
411 |
+
def generate(
|
412 |
+
self,
|
413 |
+
faceid_embeds=None,
|
414 |
+
prompt=None,
|
415 |
+
negative_prompt=None,
|
416 |
+
scale=1.0,
|
417 |
+
num_samples=4,
|
418 |
+
seed=None,
|
419 |
+
num_inference_steps=30,
|
420 |
+
**kwargs,
|
421 |
+
):
|
422 |
+
self.set_scale(scale)
|
423 |
+
|
424 |
+
num_prompts = faceid_embeds.size(0)
|
425 |
+
|
426 |
+
if prompt is None:
|
427 |
+
prompt = "best quality, high quality"
|
428 |
+
if negative_prompt is None:
|
429 |
+
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
430 |
+
|
431 |
+
if not isinstance(prompt, List):
|
432 |
+
prompt = [prompt] * num_prompts
|
433 |
+
if not isinstance(negative_prompt, List):
|
434 |
+
negative_prompt = [negative_prompt] * num_prompts
|
435 |
+
|
436 |
+
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds)
|
437 |
+
|
438 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
439 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
440 |
+
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
441 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
442 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
443 |
+
|
444 |
+
with torch.inference_mode():
|
445 |
+
(
|
446 |
+
prompt_embeds,
|
447 |
+
negative_prompt_embeds,
|
448 |
+
pooled_prompt_embeds,
|
449 |
+
negative_pooled_prompt_embeds,
|
450 |
+
) = self.pipe.encode_prompt(
|
451 |
+
prompt,
|
452 |
+
num_images_per_prompt=num_samples,
|
453 |
+
do_classifier_free_guidance=True,
|
454 |
+
negative_prompt=negative_prompt,
|
455 |
+
)
|
456 |
+
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
457 |
+
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
458 |
+
|
459 |
+
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
460 |
+
images = self.pipe(
|
461 |
+
prompt_embeds=prompt_embeds,
|
462 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
463 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
464 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
465 |
+
num_inference_steps=num_inference_steps,
|
466 |
+
generator=generator,
|
467 |
+
**kwargs,
|
468 |
+
).images
|
469 |
+
|
470 |
+
return images
|
loader.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from huggingface_hub import hf_hub_download
|
3 |
+
|
4 |
+
def load_script(file_str: str):
|
5 |
+
"""
|
6 |
+
file_str: something like 'myorg/myrepo/mysubfolder/myscript.py'
|
7 |
+
This function downloads the file from the Hugging Face Hub into ./ (current directory).
|
8 |
+
"""
|
9 |
+
try:
|
10 |
+
# Split the path by "/"
|
11 |
+
parts = file_str.split("/")
|
12 |
+
|
13 |
+
if len(parts) < 3:
|
14 |
+
raise ValueError(
|
15 |
+
f"Invalid file specification '{file_str}'. "
|
16 |
+
f"Expected format: 'repo_id/[subfolder]/filename'"
|
17 |
+
)
|
18 |
+
|
19 |
+
# First two parts form the repo_id (e.g. 'myorg/myrepo')
|
20 |
+
repo_id = "/".join(parts[:2])
|
21 |
+
|
22 |
+
# Last part is the actual filename (e.g. 'myscript.py')
|
23 |
+
filename = parts[-1]
|
24 |
+
|
25 |
+
# Anything between the second and last parts is a subfolder path
|
26 |
+
subfolder = None
|
27 |
+
if len(parts) > 3:
|
28 |
+
subfolder = "/".join(parts[2:-1])
|
29 |
+
|
30 |
+
# Retrieve HF token from environment
|
31 |
+
hf_token = os.getenv("HF_TOKEN", None)
|
32 |
+
|
33 |
+
# Download the file into current directory "."
|
34 |
+
file_path = hf_hub_download(
|
35 |
+
repo_id=repo_id,
|
36 |
+
filename=filename,
|
37 |
+
subfolder=subfolder,
|
38 |
+
token=hf_token,
|
39 |
+
local_dir="." # Download into current directory
|
40 |
+
)
|
41 |
+
|
42 |
+
print(f"Downloaded {filename} from {repo_id} to {file_path}")
|
43 |
+
return file_path
|
44 |
+
|
45 |
+
except Exception as e:
|
46 |
+
print(f"Error downloading the script '{file_str}': {e}")
|
47 |
+
return None
|
48 |
+
|
49 |
+
|
50 |
+
def load_scripts():
|
51 |
+
"""
|
52 |
+
1. Get the path of the 'FILE_LIST' file from the environment variable FILE_LIST.
|
53 |
+
2. Download that file list using load_script().
|
54 |
+
3. Read its lines, and each line is another file to be downloaded using load_script().
|
55 |
+
4. After all lines are downloaded, execute the last file.
|
56 |
+
"""
|
57 |
+
file_list = os.getenv("FILE_LIST", "").strip()
|
58 |
+
if not file_list:
|
59 |
+
print("No FILE_LIST environment variable set. Nothing to download.")
|
60 |
+
return
|
61 |
+
|
62 |
+
# Step 1: Download the file list itself
|
63 |
+
file_list_path = load_script(file_list)
|
64 |
+
if not file_list_path or not os.path.exists(file_list_path):
|
65 |
+
print(f"Could not download or find file list: {file_list_path}")
|
66 |
+
return
|
67 |
+
|
68 |
+
# Step 2: Read each line in the downloaded file list
|
69 |
+
try:
|
70 |
+
with open(file_list_path, 'r') as f:
|
71 |
+
lines = [line.strip() for line in f if line.strip()]
|
72 |
+
except Exception as e:
|
73 |
+
print(f"Error reading file list: {e}")
|
74 |
+
return
|
75 |
+
|
76 |
+
# Step 3: Download each file from the lines
|
77 |
+
downloaded_files = []
|
78 |
+
for file_str in lines:
|
79 |
+
file_path = load_script(file_str)
|
80 |
+
if file_path:
|
81 |
+
downloaded_files.append(file_path)
|
82 |
+
|
83 |
+
# Step 4: Execute the last downloaded file
|
84 |
+
if downloaded_files:
|
85 |
+
last_file_path = downloaded_files[-1]
|
86 |
+
print(f"Executing the last downloaded script: {last_file_path}")
|
87 |
+
try:
|
88 |
+
with open(last_file_path, 'r') as f:
|
89 |
+
exec(f.read(), globals())
|
90 |
+
except Exception as e:
|
91 |
+
print(f"Error executing the last downloaded script: {e}")
|
92 |
+
|
93 |
+
|
94 |
+
# Run the load_scripts function
|
95 |
+
load_scripts()
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
insightface==0.7.3
|
2 |
+
diffusers
|
3 |
+
transformers
|
4 |
+
accelerate
|
5 |
+
safetensors
|
6 |
+
einops
|
7 |
+
onnxruntime-gpu
|
8 |
+
spaces==0.19.4
|
9 |
+
opencv-python
|
10 |
+
pyjwt
|
11 |
+
torchsde
|
12 |
+
compel
|
13 |
+
hidiffusion
|
14 |
+
git+https://github.com/tencent-ailab/IP-Adapter.git
|
resampler.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
2 |
+
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
|
3 |
+
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from einops import rearrange
|
9 |
+
from einops.layers.torch import Rearrange
|
10 |
+
|
11 |
+
|
12 |
+
# FFN
|
13 |
+
def FeedForward(dim, mult=4):
|
14 |
+
inner_dim = int(dim * mult)
|
15 |
+
return nn.Sequential(
|
16 |
+
nn.LayerNorm(dim),
|
17 |
+
nn.Linear(dim, inner_dim, bias=False),
|
18 |
+
nn.GELU(),
|
19 |
+
nn.Linear(inner_dim, dim, bias=False),
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
def reshape_tensor(x, heads):
|
24 |
+
bs, length, width = x.shape
|
25 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
26 |
+
x = x.view(bs, length, heads, -1)
|
27 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
28 |
+
x = x.transpose(1, 2)
|
29 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
30 |
+
x = x.reshape(bs, heads, length, -1)
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
class PerceiverAttention(nn.Module):
|
35 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
36 |
+
super().__init__()
|
37 |
+
self.scale = dim_head**-0.5
|
38 |
+
self.dim_head = dim_head
|
39 |
+
self.heads = heads
|
40 |
+
inner_dim = dim_head * heads
|
41 |
+
|
42 |
+
self.norm1 = nn.LayerNorm(dim)
|
43 |
+
self.norm2 = nn.LayerNorm(dim)
|
44 |
+
|
45 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
46 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
47 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
48 |
+
|
49 |
+
def forward(self, x, latents):
|
50 |
+
"""
|
51 |
+
Args:
|
52 |
+
x (torch.Tensor): image features
|
53 |
+
shape (b, n1, D)
|
54 |
+
latent (torch.Tensor): latent features
|
55 |
+
shape (b, n2, D)
|
56 |
+
"""
|
57 |
+
x = self.norm1(x)
|
58 |
+
latents = self.norm2(latents)
|
59 |
+
|
60 |
+
b, l, _ = latents.shape
|
61 |
+
|
62 |
+
q = self.to_q(latents)
|
63 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
64 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
65 |
+
|
66 |
+
q = reshape_tensor(q, self.heads)
|
67 |
+
k = reshape_tensor(k, self.heads)
|
68 |
+
v = reshape_tensor(v, self.heads)
|
69 |
+
|
70 |
+
# attention
|
71 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
72 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
73 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
74 |
+
out = weight @ v
|
75 |
+
|
76 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
77 |
+
|
78 |
+
return self.to_out(out)
|
79 |
+
|
80 |
+
|
81 |
+
class Resampler(nn.Module):
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
dim=1024,
|
85 |
+
depth=8,
|
86 |
+
dim_head=64,
|
87 |
+
heads=16,
|
88 |
+
num_queries=8,
|
89 |
+
embedding_dim=768,
|
90 |
+
output_dim=1024,
|
91 |
+
ff_mult=4,
|
92 |
+
max_seq_len: int = 257, # CLIP tokens + CLS token
|
93 |
+
apply_pos_emb: bool = False,
|
94 |
+
num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
|
95 |
+
):
|
96 |
+
super().__init__()
|
97 |
+
self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
|
98 |
+
|
99 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
100 |
+
|
101 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
102 |
+
|
103 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
104 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
105 |
+
|
106 |
+
self.to_latents_from_mean_pooled_seq = (
|
107 |
+
nn.Sequential(
|
108 |
+
nn.LayerNorm(dim),
|
109 |
+
nn.Linear(dim, dim * num_latents_mean_pooled),
|
110 |
+
Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
|
111 |
+
)
|
112 |
+
if num_latents_mean_pooled > 0
|
113 |
+
else None
|
114 |
+
)
|
115 |
+
|
116 |
+
self.layers = nn.ModuleList([])
|
117 |
+
for _ in range(depth):
|
118 |
+
self.layers.append(
|
119 |
+
nn.ModuleList(
|
120 |
+
[
|
121 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
122 |
+
FeedForward(dim=dim, mult=ff_mult),
|
123 |
+
]
|
124 |
+
)
|
125 |
+
)
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
if self.pos_emb is not None:
|
129 |
+
n, device = x.shape[1], x.device
|
130 |
+
pos_emb = self.pos_emb(torch.arange(n, device=device))
|
131 |
+
x = x + pos_emb
|
132 |
+
|
133 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
134 |
+
|
135 |
+
x = self.proj_in(x)
|
136 |
+
|
137 |
+
if self.to_latents_from_mean_pooled_seq:
|
138 |
+
meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
|
139 |
+
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
|
140 |
+
latents = torch.cat((meanpooled_latents, latents), dim=-2)
|
141 |
+
|
142 |
+
for attn, ff in self.layers:
|
143 |
+
latents = attn(x, latents) + latents
|
144 |
+
latents = ff(latents) + latents
|
145 |
+
|
146 |
+
latents = self.proj_out(latents)
|
147 |
+
return self.norm_out(latents)
|
148 |
+
|
149 |
+
|
150 |
+
def masked_mean(t, *, dim, mask=None):
|
151 |
+
if mask is None:
|
152 |
+
return t.mean(dim=dim)
|
153 |
+
|
154 |
+
denom = mask.sum(dim=dim, keepdim=True)
|
155 |
+
mask = rearrange(mask, "b n -> b n 1")
|
156 |
+
masked_t = t.masked_fill(~mask, 0.0)
|
157 |
+
|
158 |
+
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
|
utils.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import requests
|
6 |
+
from datetime import datetime,timedelta
|
7 |
+
import re
|
8 |
+
|
9 |
+
attn_maps = {}
|
10 |
+
def hook_fn(name):
|
11 |
+
def forward_hook(module, input, output):
|
12 |
+
if hasattr(module.processor, "attn_map"):
|
13 |
+
attn_maps[name] = module.processor.attn_map
|
14 |
+
del module.processor.attn_map
|
15 |
+
|
16 |
+
return forward_hook
|
17 |
+
|
18 |
+
def register_cross_attention_hook(unet):
|
19 |
+
for name, module in unet.named_modules():
|
20 |
+
if name.split('.')[-1].startswith('attn2'):
|
21 |
+
module.register_forward_hook(hook_fn(name))
|
22 |
+
|
23 |
+
return unet
|
24 |
+
|
25 |
+
def upscale(attn_map, target_size):
|
26 |
+
attn_map = torch.mean(attn_map, dim=0)
|
27 |
+
attn_map = attn_map.permute(1,0)
|
28 |
+
temp_size = None
|
29 |
+
|
30 |
+
for i in range(0,5):
|
31 |
+
scale = 2 ** i
|
32 |
+
if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
|
33 |
+
temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
|
34 |
+
break
|
35 |
+
|
36 |
+
assert temp_size is not None, "temp_size cannot is None"
|
37 |
+
|
38 |
+
attn_map = attn_map.view(attn_map.shape[0], *temp_size)
|
39 |
+
|
40 |
+
attn_map = F.interpolate(
|
41 |
+
attn_map.unsqueeze(0).to(dtype=torch.float32),
|
42 |
+
size=target_size,
|
43 |
+
mode='bilinear',
|
44 |
+
align_corners=False
|
45 |
+
)[0]
|
46 |
+
|
47 |
+
attn_map = torch.softmax(attn_map, dim=0)
|
48 |
+
return attn_map
|
49 |
+
def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
|
50 |
+
|
51 |
+
idx = 0 if instance_or_negative else 1
|
52 |
+
net_attn_maps = []
|
53 |
+
|
54 |
+
for name, attn_map in attn_maps.items():
|
55 |
+
attn_map = attn_map.cpu() if detach else attn_map
|
56 |
+
attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
|
57 |
+
attn_map = upscale(attn_map, image_size)
|
58 |
+
net_attn_maps.append(attn_map)
|
59 |
+
|
60 |
+
net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
|
61 |
+
|
62 |
+
return net_attn_maps
|
63 |
+
|
64 |
+
def attnmaps2images(net_attn_maps):
|
65 |
+
|
66 |
+
#total_attn_scores = 0
|
67 |
+
images = []
|
68 |
+
|
69 |
+
for attn_map in net_attn_maps:
|
70 |
+
attn_map = attn_map.cpu().numpy()
|
71 |
+
#total_attn_scores += attn_map.mean().item()
|
72 |
+
|
73 |
+
normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
|
74 |
+
normalized_attn_map = normalized_attn_map.astype(np.uint8)
|
75 |
+
#print("norm: ", normalized_attn_map.shape)
|
76 |
+
image = Image.fromarray(normalized_attn_map)
|
77 |
+
|
78 |
+
#image = fix_save_attn_map(attn_map)
|
79 |
+
images.append(image)
|
80 |
+
|
81 |
+
#print(total_attn_scores)
|
82 |
+
return images
|
83 |
+
def is_torch2_available():
|
84 |
+
return hasattr(F, "scaled_dot_product_attention")
|
85 |
+
|
86 |
+
|
87 |
+
class RemoteJson:
|
88 |
+
def __init__(self, url, refresh_gap_seconds=3600, processor=None):
|
89 |
+
"""
|
90 |
+
Initialize the RemoteJsonManager.
|
91 |
+
:param url: The URL of the remote JSON file.
|
92 |
+
:param refresh_gap_seconds: Time in seconds after which the JSON should be refreshed.
|
93 |
+
:param processor: Optional callback function to process the JSON after it's loaded successfully.
|
94 |
+
"""
|
95 |
+
self.url = url
|
96 |
+
self.refresh_gap_seconds = refresh_gap_seconds
|
97 |
+
self.processor = processor
|
98 |
+
self.json_data = None
|
99 |
+
self.last_updated = None
|
100 |
+
|
101 |
+
def _load_json(self):
|
102 |
+
"""
|
103 |
+
Load JSON from the remote URL. If loading fails, return None.
|
104 |
+
"""
|
105 |
+
try:
|
106 |
+
response = requests.get(self.url)
|
107 |
+
response.raise_for_status()
|
108 |
+
return response.json()
|
109 |
+
except requests.RequestException as e:
|
110 |
+
print(f"Failed to fetch JSON: {e}")
|
111 |
+
return None
|
112 |
+
|
113 |
+
def _should_refresh(self):
|
114 |
+
"""
|
115 |
+
Check whether the JSON should be refreshed based on the time gap.
|
116 |
+
"""
|
117 |
+
if not self.last_updated:
|
118 |
+
return True # If no last update, always refresh
|
119 |
+
return datetime.now() - self.last_updated > timedelta(seconds=self.refresh_gap_seconds)
|
120 |
+
|
121 |
+
def _update_json(self):
|
122 |
+
"""
|
123 |
+
Fetch and load the JSON from the remote URL. If it fails, keep the previous data.
|
124 |
+
"""
|
125 |
+
new_json = self._load_json()
|
126 |
+
if new_json:
|
127 |
+
self.json_data = new_json
|
128 |
+
self.last_updated = datetime.now()
|
129 |
+
print("JSON updated successfully.")
|
130 |
+
if self.processor:
|
131 |
+
self.json_data = self.processor(self.json_data)
|
132 |
+
else:
|
133 |
+
print("Failed to update JSON. Keeping the previous version.")
|
134 |
+
|
135 |
+
def get(self):
|
136 |
+
"""
|
137 |
+
Get the JSON, checking whether it needs to be refreshed.
|
138 |
+
If refresh is required, it fetches the new data and applies the processor.
|
139 |
+
"""
|
140 |
+
if self._should_refresh():
|
141 |
+
print("Refreshing JSON...")
|
142 |
+
self._update_json()
|
143 |
+
else:
|
144 |
+
print("Using cached JSON.")
|
145 |
+
|
146 |
+
return self.json_data
|
147 |
+
|
148 |
+
def extract_key_value_pairs(input_string):
|
149 |
+
# Define the regular expression to match [xxx:yyy] where yyy can have special characters
|
150 |
+
pattern = r"\[([^\]]+):([^\]]+)\]"
|
151 |
+
|
152 |
+
# Find all matches in the input string with the original matching string
|
153 |
+
matches = re.finditer(pattern, input_string)
|
154 |
+
|
155 |
+
# Convert matches to a list of dictionaries including the raw matching string
|
156 |
+
result = [{"key": match.group(1), "value": match.group(2), "raw": match.group(0)} for match in matches]
|
157 |
+
|
158 |
+
return result
|
159 |
+
|
160 |
+
def extract_characters(prefix, input_string):
|
161 |
+
# Define the regular expression to match placeholders starting with "@" and ending with space or comma
|
162 |
+
pattern = rf"{prefix}([^\s,$]+)(?=\s|,|$)"
|
163 |
+
|
164 |
+
# Find all matches in the input string
|
165 |
+
matches = re.findall(pattern, input_string)
|
166 |
+
|
167 |
+
# Return a list of dictionaries with the extracted placeholders
|
168 |
+
result = [{"raw": f"{prefix}{match}", "key": match} for match in matches]
|
169 |
+
|
170 |
+
return result
|