nsfwalex commited on
Commit
82635c8
·
1 Parent(s): 1924ecf
Files changed (8) hide show
  1. app.py +1 -7
  2. attention_processor_faceid.py +426 -0
  3. helper.py +236 -0
  4. ipown.py +470 -0
  5. loader.py +95 -0
  6. requirements.txt +14 -0
  7. resampler.py +158 -0
  8. utils.py +170 -0
app.py CHANGED
@@ -1,7 +1 @@
1
- import gradio as gr
2
-
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
-
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ import loader
 
 
 
 
 
 
attention_processor_faceid.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from diffusers.models.lora import LoRALinearLayer
7
+
8
+
9
+ class LoRAAttnProcessor(nn.Module):
10
+ r"""
11
+ Default processor for performing attention-related computations.
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ hidden_size=None,
17
+ cross_attention_dim=None,
18
+ rank=4,
19
+ network_alpha=None,
20
+ lora_scale=1.0,
21
+ ):
22
+ super().__init__()
23
+
24
+ self.rank = rank
25
+ self.lora_scale = lora_scale
26
+
27
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
28
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
29
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
30
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
31
+
32
+ def __call__(
33
+ self,
34
+ attn,
35
+ hidden_states,
36
+ encoder_hidden_states=None,
37
+ attention_mask=None,
38
+ temb=None,
39
+ ):
40
+ residual = hidden_states
41
+
42
+ if attn.spatial_norm is not None:
43
+ hidden_states = attn.spatial_norm(hidden_states, temb)
44
+
45
+ input_ndim = hidden_states.ndim
46
+
47
+ if input_ndim == 4:
48
+ batch_size, channel, height, width = hidden_states.shape
49
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
50
+
51
+ batch_size, sequence_length, _ = (
52
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
53
+ )
54
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
55
+
56
+ if attn.group_norm is not None:
57
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
58
+
59
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
60
+
61
+ if encoder_hidden_states is None:
62
+ encoder_hidden_states = hidden_states
63
+ elif attn.norm_cross:
64
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
65
+
66
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
67
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
68
+
69
+ query = attn.head_to_batch_dim(query)
70
+ key = attn.head_to_batch_dim(key)
71
+ value = attn.head_to_batch_dim(value)
72
+
73
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
74
+ hidden_states = torch.bmm(attention_probs, value)
75
+ hidden_states = attn.batch_to_head_dim(hidden_states)
76
+
77
+ # linear proj
78
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
79
+ # dropout
80
+ hidden_states = attn.to_out[1](hidden_states)
81
+
82
+ if input_ndim == 4:
83
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
84
+
85
+ if attn.residual_connection:
86
+ hidden_states = hidden_states + residual
87
+
88
+ hidden_states = hidden_states / attn.rescale_output_factor
89
+
90
+ return hidden_states
91
+
92
+
93
+ class LoRAIPAttnProcessor(nn.Module):
94
+ r"""
95
+ Attention processor for IP-Adapater.
96
+ Args:
97
+ hidden_size (`int`):
98
+ The hidden size of the attention layer.
99
+ cross_attention_dim (`int`):
100
+ The number of channels in the `encoder_hidden_states`.
101
+ scale (`float`, defaults to 1.0):
102
+ the weight scale of image prompt.
103
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
104
+ The context length of the image features.
105
+ """
106
+
107
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4):
108
+ super().__init__()
109
+
110
+ self.rank = rank
111
+ self.lora_scale = lora_scale
112
+
113
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
114
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
115
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
116
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
117
+
118
+ self.hidden_size = hidden_size
119
+ self.cross_attention_dim = cross_attention_dim
120
+ self.scale = scale
121
+ self.num_tokens = num_tokens
122
+
123
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
124
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
125
+
126
+ def __call__(
127
+ self,
128
+ attn,
129
+ hidden_states,
130
+ encoder_hidden_states=None,
131
+ attention_mask=None,
132
+ temb=None,
133
+ ):
134
+ residual = hidden_states
135
+
136
+ if attn.spatial_norm is not None:
137
+ hidden_states = attn.spatial_norm(hidden_states, temb)
138
+
139
+ input_ndim = hidden_states.ndim
140
+
141
+ if input_ndim == 4:
142
+ batch_size, channel, height, width = hidden_states.shape
143
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
144
+
145
+ batch_size, sequence_length, _ = (
146
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
147
+ )
148
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
149
+
150
+ if attn.group_norm is not None:
151
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
152
+
153
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
154
+
155
+ if encoder_hidden_states is None:
156
+ encoder_hidden_states = hidden_states
157
+ else:
158
+ # get encoder_hidden_states, ip_hidden_states
159
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
160
+ encoder_hidden_states, ip_hidden_states = (
161
+ encoder_hidden_states[:, :end_pos, :],
162
+ encoder_hidden_states[:, end_pos:, :],
163
+ )
164
+ if attn.norm_cross:
165
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
166
+
167
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
168
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
169
+
170
+ query = attn.head_to_batch_dim(query)
171
+ key = attn.head_to_batch_dim(key)
172
+ value = attn.head_to_batch_dim(value)
173
+
174
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
175
+ hidden_states = torch.bmm(attention_probs, value)
176
+ hidden_states = attn.batch_to_head_dim(hidden_states)
177
+
178
+ # for ip-adapter
179
+ ip_key = self.to_k_ip(ip_hidden_states)
180
+ ip_value = self.to_v_ip(ip_hidden_states)
181
+
182
+ ip_key = attn.head_to_batch_dim(ip_key)
183
+ ip_value = attn.head_to_batch_dim(ip_value)
184
+
185
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
186
+ self.attn_map = ip_attention_probs
187
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
188
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
189
+
190
+ hidden_states = hidden_states + self.scale * ip_hidden_states
191
+
192
+ # linear proj
193
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
194
+ # dropout
195
+ hidden_states = attn.to_out[1](hidden_states)
196
+
197
+ if input_ndim == 4:
198
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
199
+
200
+ if attn.residual_connection:
201
+ hidden_states = hidden_states + residual
202
+
203
+ hidden_states = hidden_states / attn.rescale_output_factor
204
+
205
+ return hidden_states
206
+
207
+
208
+ class LoRAAttnProcessor2_0(nn.Module):
209
+
210
+ r"""
211
+ Default processor for performing attention-related computations.
212
+ """
213
+
214
+ def __init__(
215
+ self,
216
+ hidden_size=None,
217
+ cross_attention_dim=None,
218
+ rank=4,
219
+ network_alpha=None,
220
+ lora_scale=1.0,
221
+ ):
222
+ super().__init__()
223
+
224
+ self.rank = rank
225
+ self.lora_scale = lora_scale
226
+
227
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
228
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
229
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
230
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
231
+
232
+ def __call__(
233
+ self,
234
+ attn,
235
+ hidden_states,
236
+ encoder_hidden_states=None,
237
+ attention_mask=None,
238
+ temb=None,
239
+ ):
240
+ residual = hidden_states
241
+
242
+ if attn.spatial_norm is not None:
243
+ hidden_states = attn.spatial_norm(hidden_states, temb)
244
+
245
+ input_ndim = hidden_states.ndim
246
+
247
+ if input_ndim == 4:
248
+ batch_size, channel, height, width = hidden_states.shape
249
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
250
+
251
+ batch_size, sequence_length, _ = (
252
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
253
+ )
254
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
255
+
256
+ if attn.group_norm is not None:
257
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
258
+
259
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
260
+
261
+ if encoder_hidden_states is None:
262
+ encoder_hidden_states = hidden_states
263
+ elif attn.norm_cross:
264
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
265
+
266
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
267
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
268
+
269
+ inner_dim = key.shape[-1]
270
+ head_dim = inner_dim // attn.heads
271
+
272
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
273
+
274
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
275
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
276
+
277
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
278
+ # TODO: add support for attn.scale when we move to Torch 2.1
279
+ hidden_states = F.scaled_dot_product_attention(
280
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
281
+ )
282
+
283
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
284
+ hidden_states = hidden_states.to(query.dtype)
285
+
286
+ # linear proj
287
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
288
+ # dropout
289
+ hidden_states = attn.to_out[1](hidden_states)
290
+
291
+ if input_ndim == 4:
292
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
293
+
294
+ if attn.residual_connection:
295
+ hidden_states = hidden_states + residual
296
+
297
+ hidden_states = hidden_states / attn.rescale_output_factor
298
+
299
+ return hidden_states
300
+
301
+
302
+ class LoRAIPAttnProcessor2_0(nn.Module):
303
+ r"""
304
+ Processor for implementing the LoRA attention mechanism.
305
+ Args:
306
+ hidden_size (`int`, *optional*):
307
+ The hidden size of the attention layer.
308
+ cross_attention_dim (`int`, *optional*):
309
+ The number of channels in the `encoder_hidden_states`.
310
+ rank (`int`, defaults to 4):
311
+ The dimension of the LoRA update matrices.
312
+ network_alpha (`int`, *optional*):
313
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
314
+ """
315
+
316
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4):
317
+ super().__init__()
318
+
319
+ self.rank = rank
320
+ self.lora_scale = lora_scale
321
+ self.num_tokens = num_tokens
322
+
323
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
324
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
325
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
326
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
327
+
328
+
329
+ self.hidden_size = hidden_size
330
+ self.cross_attention_dim = cross_attention_dim
331
+ self.scale = scale
332
+
333
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
334
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
335
+
336
+ def __call__(
337
+ self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
338
+ ):
339
+ residual = hidden_states
340
+
341
+ if attn.spatial_norm is not None:
342
+ hidden_states = attn.spatial_norm(hidden_states, temb)
343
+
344
+ input_ndim = hidden_states.ndim
345
+
346
+ if input_ndim == 4:
347
+ batch_size, channel, height, width = hidden_states.shape
348
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
349
+
350
+ batch_size, sequence_length, _ = (
351
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
352
+ )
353
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
354
+
355
+ if attn.group_norm is not None:
356
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
357
+
358
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
359
+ #query = attn.head_to_batch_dim(query)
360
+
361
+ if encoder_hidden_states is None:
362
+ encoder_hidden_states = hidden_states
363
+ else:
364
+ # get encoder_hidden_states, ip_hidden_states
365
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
366
+ encoder_hidden_states, ip_hidden_states = (
367
+ encoder_hidden_states[:, :end_pos, :],
368
+ encoder_hidden_states[:, end_pos:, :],
369
+ )
370
+ if attn.norm_cross:
371
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
372
+
373
+ # for text
374
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
375
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
376
+
377
+ inner_dim = key.shape[-1]
378
+ head_dim = inner_dim // attn.heads
379
+
380
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
381
+
382
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
383
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
384
+
385
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
386
+ # TODO: add support for attn.scale when we move to Torch 2.1
387
+ hidden_states = F.scaled_dot_product_attention(
388
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
389
+ )
390
+
391
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
392
+ hidden_states = hidden_states.to(query.dtype)
393
+
394
+ # for ip
395
+ ip_key = self.to_k_ip(ip_hidden_states)
396
+ ip_value = self.to_v_ip(ip_hidden_states)
397
+
398
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
399
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
400
+
401
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
402
+ # TODO: add support for attn.scale when we move to Torch 2.1
403
+ ip_hidden_states = F.scaled_dot_product_attention(
404
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
405
+ )
406
+
407
+
408
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
409
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
410
+
411
+ hidden_states = hidden_states + self.scale * ip_hidden_states
412
+
413
+ # linear proj
414
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
415
+ # dropout
416
+ hidden_states = attn.to_out[1](hidden_states)
417
+
418
+ if input_ndim == 4:
419
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
420
+
421
+ if attn.residual_connection:
422
+ hidden_states = hidden_states + residual
423
+
424
+ hidden_states = hidden_states / attn.rescale_output_factor
425
+
426
+ return hidden_states
helper.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import uuid
4
+ import re
5
+
6
+ def parse_prompt_attention(text):
7
+ re_attention = re.compile(r"""
8
+ \\\(|
9
+ \\\)|
10
+ \\\[|
11
+ \\]|
12
+ \\\\|
13
+ \\|
14
+ \(|
15
+ \[|
16
+ :([+-]?[.\d]+)\)|
17
+ \)|
18
+ ]|
19
+ [^\\()\[\]:]+|
20
+ :
21
+ """, re.X)
22
+
23
+ res = []
24
+ round_brackets = []
25
+ square_brackets = []
26
+
27
+ round_bracket_multiplier = 1.1
28
+ square_bracket_multiplier = 1 / 1.1
29
+
30
+ def multiply_range(start_position, multiplier):
31
+ for p in range(start_position, len(res)):
32
+ res[p][1] *= multiplier
33
+
34
+ for m in re_attention.finditer(text):
35
+ text = m.group(0)
36
+ weight = m.group(1)
37
+
38
+ if text.startswith('\\'):
39
+ res.append([text[1:], 1.0])
40
+ elif text == '(':
41
+ round_brackets.append(len(res))
42
+ elif text == '[':
43
+ square_brackets.append(len(res))
44
+ elif weight is not None and len(round_brackets) > 0:
45
+ multiply_range(round_brackets.pop(), float(weight))
46
+ elif text == ')' and len(round_brackets) > 0:
47
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
48
+ elif text == ']' and len(square_brackets) > 0:
49
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
50
+ else:
51
+ parts = re.split(re.compile(r"\s*\bBREAK\b\s*", re.S), text)
52
+ for i, part in enumerate(parts):
53
+ if i > 0:
54
+ res.append(["BREAK", -1])
55
+ res.append([part, 1.0])
56
+
57
+ for pos in round_brackets:
58
+ multiply_range(pos, round_bracket_multiplier)
59
+
60
+ for pos in square_brackets:
61
+ multiply_range(pos, square_bracket_multiplier)
62
+
63
+ if len(res) == 0:
64
+ res = [["", 1.0]]
65
+
66
+ # merge runs of identical weights
67
+ i = 0
68
+ while i + 1 < len(res):
69
+ if res[i][1] == res[i + 1][1]:
70
+ res[i][0] += res[i + 1][0]
71
+ res.pop(i + 1)
72
+ else:
73
+ i += 1
74
+
75
+ return res
76
+
77
+ def prompt_attention_to_invoke_prompt(attention):
78
+ tokens = []
79
+ for text, weight in attention:
80
+ # Round weight to 2 decimal places
81
+ weight = round(weight, 2)
82
+ if weight == 1.0:
83
+ tokens.append(text)
84
+ elif weight < 1.0:
85
+ if weight < 0.8:
86
+ tokens.append(f"({text}){weight}")
87
+ else:
88
+ tokens.append(f"({text})-" + "-" * int((1.0 - weight) * 10))
89
+ else:
90
+ if weight < 1.3:
91
+ tokens.append(f"({text})" + "+" * int((weight - 1.0) * 10))
92
+ else:
93
+ tokens.append(f"({text}){weight}")
94
+ return "".join(tokens)
95
+
96
+ def concat_tensor(t):
97
+ t_list = torch.split(t, 1, dim=0)
98
+ t = torch.cat(t_list, dim=1)
99
+ return t
100
+
101
+ def merge_embeds(prompt_chanks, compel):
102
+ num_chanks = len(prompt_chanks)
103
+ if num_chanks != 0:
104
+ power_prompt = 1/(num_chanks*(num_chanks+1)//2)
105
+ prompt_embs = compel(prompt_chanks)
106
+ t_list = list(torch.split(prompt_embs, 1, dim=0))
107
+ for i in range(num_chanks):
108
+ t_list[-(i+1)] = t_list[-(i+1)] * ((i+1)*power_prompt)
109
+ prompt_emb = torch.stack(t_list, dim=0).sum(dim=0)
110
+ else:
111
+ prompt_emb = compel('')
112
+ return prompt_emb
113
+
114
+ def detokenize(chunk, actual_prompt):
115
+ chunk[-1] = chunk[-1].replace('</w>', '')
116
+ chanked_prompt = ''.join(chunk).strip()
117
+ while '</w>' in chanked_prompt:
118
+ if actual_prompt[chanked_prompt.find('</w>')] == ' ':
119
+ chanked_prompt = chanked_prompt.replace('</w>', ' ', 1)
120
+ else:
121
+ chanked_prompt = chanked_prompt.replace('</w>', '', 1)
122
+ actual_prompt = actual_prompt.replace(chanked_prompt,'')
123
+ return chanked_prompt.strip(), actual_prompt.strip()
124
+
125
+ def tokenize_line(line, tokenizer): # split into chunks
126
+ actual_prompt = line.lower().strip()
127
+ actual_tokens = tokenizer.tokenize(actual_prompt)
128
+ max_tokens = tokenizer.model_max_length - 2
129
+ comma_token = tokenizer.tokenize(',')[0]
130
+
131
+ chunks = []
132
+ chunk = []
133
+ for item in actual_tokens:
134
+ chunk.append(item)
135
+ if len(chunk) == max_tokens:
136
+ if chunk[-1] != comma_token:
137
+ for i in range(max_tokens-1, -1, -1):
138
+ if chunk[i] == comma_token:
139
+ actual_chunk, actual_prompt = detokenize(chunk[:i+1], actual_prompt)
140
+ chunks.append(actual_chunk)
141
+ chunk = chunk[i+1:]
142
+ break
143
+ else:
144
+ actual_chunk, actual_prompt = detokenize(chunk, actual_prompt)
145
+ chunks.append(actual_chunk)
146
+ chunk = []
147
+ else:
148
+ actual_chunk, actual_prompt = detokenize(chunk, actual_prompt)
149
+ chunks.append(actual_chunk)
150
+ chunk = []
151
+ if chunk:
152
+ actual_chunk, _ = detokenize(chunk, actual_prompt)
153
+ chunks.append(actual_chunk)
154
+
155
+ return chunks
156
+
157
+ def get_embed_new(prompt, pipeline, compel, only_convert_string=False, compel_process_sd=False):
158
+
159
+ if compel_process_sd:
160
+ return merge_embeds(tokenize_line(prompt, pipeline.tokenizer), compel)
161
+ else:
162
+ # fix bug weights conversion excessive emphasis
163
+ prompt = prompt.replace("((", "(").replace("))", ")").replace("\\", "\\\\\\")
164
+
165
+ # Convert to Compel
166
+ attention = parse_prompt_attention(prompt)
167
+ global_attention_chanks = []
168
+
169
+ for att in attention:
170
+ for chank in att[0].split(','):
171
+ temp_prompt_chanks = tokenize_line(chank, pipeline.tokenizer)
172
+ for small_chank in temp_prompt_chanks:
173
+ temp_dict = {
174
+ "weight": round(att[1], 2),
175
+ "lenght": len(pipeline.tokenizer.tokenize(f'{small_chank},')),
176
+ "prompt": f'{small_chank},'
177
+ }
178
+ global_attention_chanks.append(temp_dict)
179
+
180
+ max_tokens = pipeline.tokenizer.model_max_length - 2
181
+ global_prompt_chanks = []
182
+ current_list = []
183
+ current_length = 0
184
+ for item in global_attention_chanks:
185
+ if current_length + item['lenght'] > max_tokens:
186
+ global_prompt_chanks.append(current_list)
187
+ current_list = [[item['prompt'], item['weight']]]
188
+ current_length = item['lenght']
189
+ else:
190
+ if not current_list:
191
+ current_list.append([item['prompt'], item['weight']])
192
+ else:
193
+ if item['weight'] != current_list[-1][1]:
194
+ current_list.append([item['prompt'], item['weight']])
195
+ else:
196
+ current_list[-1][0] += f" {item['prompt']}"
197
+ current_length += item['lenght']
198
+ if current_list:
199
+ global_prompt_chanks.append(current_list)
200
+
201
+ if only_convert_string:
202
+ return ' '.join([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chanks])
203
+
204
+ return merge_embeds([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chanks], compel)
205
+
206
+ def add_comma_after_pattern_ti(text):
207
+ pattern = re.compile(r'\b\w+_\d+\b')
208
+ modified_text = pattern.sub(lambda x: x.group() + ',', text)
209
+ return modified_text
210
+
211
+ def save_image(img):
212
+ path = "./tmp/"
213
+
214
+ # Check if the input is a string (file path) and load the image if it is
215
+ if isinstance(img, str):
216
+ img = Image.open(img) # Load the image from the file path
217
+
218
+ # Ensure the Hugging Face path exists locally
219
+ if not os.path.exists(path):
220
+ os.makedirs(path)
221
+
222
+ # Generate a unique filename
223
+ unique_name = str(uuid.uuid4()) + ".webp"
224
+ unique_name = os.path.join(path, unique_name)
225
+
226
+ # Convert the image to WebP format
227
+ webp_img = img.convert("RGB") # Ensure the image is in RGB mode
228
+
229
+ # Save the image in WebP format with high quality
230
+ webp_img.save(unique_name, "WEBP", quality=90)
231
+
232
+ # Open the saved WebP file and return it as a PIL Image object
233
+ with Image.open(unique_name) as webp_file:
234
+ webp_image = webp_file.copy()
235
+
236
+ return unique_name
ipown.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import torch
5
+ from diffusers import StableDiffusionPipeline
6
+ from diffusers.pipelines.controlnet import MultiControlNetModel
7
+ from PIL import Image
8
+ from safetensors import safe_open
9
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
10
+
11
+ from attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
12
+ from utils import is_torch2_available
13
+
14
+ USE_DAFAULT_ATTN = False # should be True for visualization_attnmap
15
+ if is_torch2_available() and (not USE_DAFAULT_ATTN):
16
+ from attention_processor_faceid import (
17
+ LoRAAttnProcessor2_0 as LoRAAttnProcessor,
18
+ )
19
+ from attention_processor_faceid import (
20
+ LoRAIPAttnProcessor2_0 as LoRAIPAttnProcessor,
21
+ )
22
+ else:
23
+ from attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
24
+ from resampler import PerceiverAttention, FeedForward
25
+
26
+
27
+ class FacePerceiverResampler(torch.nn.Module):
28
+ def __init__(
29
+ self,
30
+ *,
31
+ dim=768,
32
+ depth=4,
33
+ dim_head=64,
34
+ heads=16,
35
+ embedding_dim=1280,
36
+ output_dim=768,
37
+ ff_mult=4,
38
+ ):
39
+ super().__init__()
40
+
41
+ self.proj_in = torch.nn.Linear(embedding_dim, dim)
42
+ self.proj_out = torch.nn.Linear(dim, output_dim)
43
+ self.norm_out = torch.nn.LayerNorm(output_dim)
44
+ self.layers = torch.nn.ModuleList([])
45
+ for _ in range(depth):
46
+ self.layers.append(
47
+ torch.nn.ModuleList(
48
+ [
49
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
50
+ FeedForward(dim=dim, mult=ff_mult),
51
+ ]
52
+ )
53
+ )
54
+
55
+ def forward(self, latents, x):
56
+ x = self.proj_in(x)
57
+ for attn, ff in self.layers:
58
+ latents = attn(x, latents) + latents
59
+ latents = ff(latents) + latents
60
+ latents = self.proj_out(latents)
61
+ return self.norm_out(latents)
62
+
63
+
64
+ class MLPProjModel(torch.nn.Module):
65
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
66
+ super().__init__()
67
+
68
+ self.cross_attention_dim = cross_attention_dim
69
+ self.num_tokens = num_tokens
70
+
71
+ self.proj = torch.nn.Sequential(
72
+ torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
73
+ torch.nn.GELU(),
74
+ torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
75
+ )
76
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
77
+
78
+ def forward(self, id_embeds):
79
+ x = self.proj(id_embeds)
80
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
81
+ x = self.norm(x)
82
+ return x
83
+
84
+
85
+ class ProjPlusModel(torch.nn.Module):
86
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4):
87
+ super().__init__()
88
+
89
+ self.cross_attention_dim = cross_attention_dim
90
+ self.num_tokens = num_tokens
91
+
92
+ self.proj = torch.nn.Sequential(
93
+ torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
94
+ torch.nn.GELU(),
95
+ torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
96
+ )
97
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
98
+
99
+ self.perceiver_resampler = FacePerceiverResampler(
100
+ dim=cross_attention_dim,
101
+ depth=4,
102
+ dim_head=64,
103
+ heads=cross_attention_dim // 64,
104
+ embedding_dim=clip_embeddings_dim,
105
+ output_dim=cross_attention_dim,
106
+ ff_mult=4,
107
+ )
108
+
109
+ def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0):
110
+
111
+ x = self.proj(id_embeds)
112
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
113
+ x = self.norm(x)
114
+ out = self.perceiver_resampler(x, clip_embeds)
115
+ if shortcut:
116
+ out = x + scale * out
117
+ return out
118
+
119
+
120
+ class IPAdapterFaceID:
121
+ def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16):
122
+ self.device = device
123
+ self.ip_ckpt = ip_ckpt
124
+ self.lora_rank = lora_rank
125
+ self.num_tokens = num_tokens
126
+ self.torch_dtype = torch_dtype
127
+
128
+ self.pipe = sd_pipe.to(self.device)
129
+ self.set_ip_adapter()
130
+
131
+ # image proj model
132
+ self.image_proj_model = self.init_proj()
133
+
134
+ self.load_ip_adapter()
135
+
136
+ def init_proj(self):
137
+ image_proj_model = MLPProjModel(
138
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
139
+ id_embeddings_dim=512,
140
+ num_tokens=self.num_tokens,
141
+ ).to(self.device, dtype=self.torch_dtype)
142
+ return image_proj_model
143
+
144
+ def set_ip_adapter(self):
145
+ unet = self.pipe.unet
146
+ attn_procs = {}
147
+ for name in unet.attn_processors.keys():
148
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
149
+ if name.startswith("mid_block"):
150
+ hidden_size = unet.config.block_out_channels[-1]
151
+ elif name.startswith("up_blocks"):
152
+ block_id = int(name[len("up_blocks.")])
153
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
154
+ elif name.startswith("down_blocks"):
155
+ block_id = int(name[len("down_blocks.")])
156
+ hidden_size = unet.config.block_out_channels[block_id]
157
+ if cross_attention_dim is None:
158
+ attn_procs[name] = LoRAAttnProcessor(
159
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
160
+ ).to(self.device, dtype=self.torch_dtype)
161
+ else:
162
+ attn_procs[name] = LoRAIPAttnProcessor(
163
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
164
+ ).to(self.device, dtype=self.torch_dtype)
165
+ unet.set_attn_processor(attn_procs)
166
+
167
+ def load_ip_adapter(self):
168
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
169
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
170
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
171
+ for key in f.keys():
172
+ if key.startswith("image_proj."):
173
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
174
+ elif key.startswith("ip_adapter."):
175
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
176
+ else:
177
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
178
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
179
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
180
+ ip_layers.load_state_dict(state_dict["ip_adapter"])
181
+
182
+ @torch.inference_mode()
183
+ def get_image_embeds(self, faceid_embeds):
184
+
185
+ faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
186
+ print(faceid_embeds.device)
187
+ print(next(self.image_proj_model.parameters()).device)
188
+ image_prompt_embeds = self.image_proj_model(faceid_embeds)
189
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds))
190
+ return image_prompt_embeds, uncond_image_prompt_embeds
191
+
192
+ def set_scale(self, scale):
193
+ for attn_processor in self.pipe.unet.attn_processors.values():
194
+ if isinstance(attn_processor, LoRAIPAttnProcessor):
195
+ attn_processor.scale = scale
196
+
197
+ def generate(
198
+ self,
199
+ faceid_embeds=None,
200
+ prompt=None,
201
+ negative_prompt=None,
202
+ scale=1.0,
203
+ num_samples=4,
204
+ seed=None,
205
+ guidance_scale=7.5,
206
+ num_inference_steps=30,
207
+ **kwargs,
208
+ ):
209
+ self.set_scale(scale)
210
+
211
+
212
+ num_prompts = faceid_embeds.size(0)
213
+
214
+ if prompt is None:
215
+ prompt = "best quality, high quality"
216
+ if negative_prompt is None:
217
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
218
+
219
+ if not isinstance(prompt, List):
220
+ prompt = [prompt] * num_prompts
221
+ if not isinstance(negative_prompt, List):
222
+ negative_prompt = [negative_prompt] * num_prompts
223
+
224
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds)
225
+
226
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
227
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
228
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
229
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
230
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
231
+
232
+ with torch.inference_mode():
233
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
234
+ prompt,
235
+ device=self.device,
236
+ num_images_per_prompt=num_samples,
237
+ do_classifier_free_guidance=True,
238
+ negative_prompt=negative_prompt,
239
+ )
240
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
241
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
242
+
243
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
244
+ images = self.pipe(
245
+ prompt_embeds=prompt_embeds,
246
+ negative_prompt_embeds=negative_prompt_embeds,
247
+ guidance_scale=guidance_scale,
248
+ num_inference_steps=num_inference_steps,
249
+ generator=generator,
250
+ **kwargs,
251
+ ).images
252
+
253
+ return images
254
+
255
+
256
+ class IPAdapterFaceIDPlus:
257
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16):
258
+ self.device = device
259
+ self.image_encoder_path = image_encoder_path
260
+ self.ip_ckpt = ip_ckpt
261
+ self.lora_rank = lora_rank
262
+ self.num_tokens = num_tokens
263
+ self.torch_dtype = torch_dtype
264
+
265
+ self.pipe = sd_pipe.to(self.device)
266
+ self.set_ip_adapter()
267
+
268
+ # load image encoder
269
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
270
+ self.device, dtype=self.torch_dtype
271
+ )
272
+ self.clip_image_processor = CLIPImageProcessor()
273
+ # image proj model
274
+ self.image_proj_model = self.init_proj()
275
+
276
+ self.load_ip_adapter()
277
+
278
+ def init_proj(self):
279
+ image_proj_model = ProjPlusModel(
280
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
281
+ id_embeddings_dim=512,
282
+ clip_embeddings_dim=self.image_encoder.config.hidden_size,
283
+ num_tokens=self.num_tokens,
284
+ ).to(self.device, dtype=self.torch_dtype)
285
+ return image_proj_model
286
+
287
+ def set_ip_adapter(self):
288
+ unet = self.pipe.unet
289
+ attn_procs = {}
290
+ for name in unet.attn_processors.keys():
291
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
292
+ if name.startswith("mid_block"):
293
+ hidden_size = unet.config.block_out_channels[-1]
294
+ elif name.startswith("up_blocks"):
295
+ block_id = int(name[len("up_blocks.")])
296
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
297
+ elif name.startswith("down_blocks"):
298
+ block_id = int(name[len("down_blocks.")])
299
+ hidden_size = unet.config.block_out_channels[block_id]
300
+ if cross_attention_dim is None:
301
+ attn_procs[name] = LoRAAttnProcessor(
302
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
303
+ ).to(self.device, dtype=self.torch_dtype)
304
+ else:
305
+ attn_procs[name] = LoRAIPAttnProcessor(
306
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
307
+ ).to(self.device, dtype=self.torch_dtype)
308
+ unet.set_attn_processor(attn_procs)
309
+
310
+ def load_ip_adapter(self):
311
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
312
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
313
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
314
+ for key in f.keys():
315
+ if key.startswith("image_proj."):
316
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
317
+ elif key.startswith("ip_adapter."):
318
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
319
+ else:
320
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
321
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
322
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
323
+ ip_layers.load_state_dict(state_dict["ip_adapter"])
324
+
325
+ @torch.inference_mode()
326
+ def get_image_embeds(self, faceid_embeds, face_image, s_scale, shortcut):
327
+ if isinstance(face_image, Image.Image):
328
+ pil_image = [face_image]
329
+ clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values
330
+ clip_image = clip_image.to(self.device, dtype=self.torch_dtype)
331
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
332
+ uncond_clip_image_embeds = self.image_encoder(
333
+ torch.zeros_like(clip_image), output_hidden_states=True
334
+ ).hidden_states[-2]
335
+
336
+ faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
337
+ image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale)
338
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale)
339
+ return image_prompt_embeds, uncond_image_prompt_embeds
340
+
341
+ def set_scale(self, scale):
342
+ for attn_processor in self.pipe.unet.attn_processors.values():
343
+ if isinstance(attn_processor, LoRAIPAttnProcessor):
344
+ attn_processor.scale = scale
345
+
346
+ def generate(
347
+ self,
348
+ face_image=None,
349
+ faceid_embeds=None,
350
+ prompt=None,
351
+ negative_prompt=None,
352
+ scale=1.0,
353
+ num_samples=4,
354
+ seed=None,
355
+ guidance_scale=7.5,
356
+ num_inference_steps=30,
357
+ s_scale=1.0,
358
+ shortcut=False,
359
+ **kwargs,
360
+ ):
361
+ self.set_scale(scale)
362
+
363
+
364
+ num_prompts = faceid_embeds.size(0)
365
+
366
+ if prompt is None:
367
+ prompt = "best quality, high quality"
368
+ if negative_prompt is None:
369
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
370
+
371
+ if not isinstance(prompt, List):
372
+ prompt = [prompt] * num_prompts
373
+ if not isinstance(negative_prompt, List):
374
+ negative_prompt = [negative_prompt] * num_prompts
375
+
376
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut)
377
+
378
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
379
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
380
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
381
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
382
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
383
+
384
+ with torch.inference_mode():
385
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
386
+ prompt,
387
+ device=self.device,
388
+ num_images_per_prompt=num_samples,
389
+ do_classifier_free_guidance=True,
390
+ negative_prompt=negative_prompt,
391
+ )
392
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
393
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
394
+
395
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
396
+ images = self.pipe(
397
+ prompt_embeds=prompt_embeds,
398
+ negative_prompt_embeds=negative_prompt_embeds,
399
+ guidance_scale=guidance_scale,
400
+ num_inference_steps=num_inference_steps,
401
+ generator=generator,
402
+ **kwargs,
403
+ ).images
404
+
405
+ return images
406
+
407
+
408
+ class IPAdapterFaceIDXL(IPAdapterFaceID):
409
+ """SDXL"""
410
+
411
+ def generate(
412
+ self,
413
+ faceid_embeds=None,
414
+ prompt=None,
415
+ negative_prompt=None,
416
+ scale=1.0,
417
+ num_samples=4,
418
+ seed=None,
419
+ num_inference_steps=30,
420
+ **kwargs,
421
+ ):
422
+ self.set_scale(scale)
423
+
424
+ num_prompts = faceid_embeds.size(0)
425
+
426
+ if prompt is None:
427
+ prompt = "best quality, high quality"
428
+ if negative_prompt is None:
429
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
430
+
431
+ if not isinstance(prompt, List):
432
+ prompt = [prompt] * num_prompts
433
+ if not isinstance(negative_prompt, List):
434
+ negative_prompt = [negative_prompt] * num_prompts
435
+
436
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds)
437
+
438
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
439
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
440
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
441
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
442
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
443
+
444
+ with torch.inference_mode():
445
+ (
446
+ prompt_embeds,
447
+ negative_prompt_embeds,
448
+ pooled_prompt_embeds,
449
+ negative_pooled_prompt_embeds,
450
+ ) = self.pipe.encode_prompt(
451
+ prompt,
452
+ num_images_per_prompt=num_samples,
453
+ do_classifier_free_guidance=True,
454
+ negative_prompt=negative_prompt,
455
+ )
456
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
457
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
458
+
459
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
460
+ images = self.pipe(
461
+ prompt_embeds=prompt_embeds,
462
+ negative_prompt_embeds=negative_prompt_embeds,
463
+ pooled_prompt_embeds=pooled_prompt_embeds,
464
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
465
+ num_inference_steps=num_inference_steps,
466
+ generator=generator,
467
+ **kwargs,
468
+ ).images
469
+
470
+ return images
loader.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from huggingface_hub import hf_hub_download
3
+
4
+ def load_script(file_str: str):
5
+ """
6
+ file_str: something like 'myorg/myrepo/mysubfolder/myscript.py'
7
+ This function downloads the file from the Hugging Face Hub into ./ (current directory).
8
+ """
9
+ try:
10
+ # Split the path by "/"
11
+ parts = file_str.split("/")
12
+
13
+ if len(parts) < 3:
14
+ raise ValueError(
15
+ f"Invalid file specification '{file_str}'. "
16
+ f"Expected format: 'repo_id/[subfolder]/filename'"
17
+ )
18
+
19
+ # First two parts form the repo_id (e.g. 'myorg/myrepo')
20
+ repo_id = "/".join(parts[:2])
21
+
22
+ # Last part is the actual filename (e.g. 'myscript.py')
23
+ filename = parts[-1]
24
+
25
+ # Anything between the second and last parts is a subfolder path
26
+ subfolder = None
27
+ if len(parts) > 3:
28
+ subfolder = "/".join(parts[2:-1])
29
+
30
+ # Retrieve HF token from environment
31
+ hf_token = os.getenv("HF_TOKEN", None)
32
+
33
+ # Download the file into current directory "."
34
+ file_path = hf_hub_download(
35
+ repo_id=repo_id,
36
+ filename=filename,
37
+ subfolder=subfolder,
38
+ token=hf_token,
39
+ local_dir="." # Download into current directory
40
+ )
41
+
42
+ print(f"Downloaded {filename} from {repo_id} to {file_path}")
43
+ return file_path
44
+
45
+ except Exception as e:
46
+ print(f"Error downloading the script '{file_str}': {e}")
47
+ return None
48
+
49
+
50
+ def load_scripts():
51
+ """
52
+ 1. Get the path of the 'FILE_LIST' file from the environment variable FILE_LIST.
53
+ 2. Download that file list using load_script().
54
+ 3. Read its lines, and each line is another file to be downloaded using load_script().
55
+ 4. After all lines are downloaded, execute the last file.
56
+ """
57
+ file_list = os.getenv("FILE_LIST", "").strip()
58
+ if not file_list:
59
+ print("No FILE_LIST environment variable set. Nothing to download.")
60
+ return
61
+
62
+ # Step 1: Download the file list itself
63
+ file_list_path = load_script(file_list)
64
+ if not file_list_path or not os.path.exists(file_list_path):
65
+ print(f"Could not download or find file list: {file_list_path}")
66
+ return
67
+
68
+ # Step 2: Read each line in the downloaded file list
69
+ try:
70
+ with open(file_list_path, 'r') as f:
71
+ lines = [line.strip() for line in f if line.strip()]
72
+ except Exception as e:
73
+ print(f"Error reading file list: {e}")
74
+ return
75
+
76
+ # Step 3: Download each file from the lines
77
+ downloaded_files = []
78
+ for file_str in lines:
79
+ file_path = load_script(file_str)
80
+ if file_path:
81
+ downloaded_files.append(file_path)
82
+
83
+ # Step 4: Execute the last downloaded file
84
+ if downloaded_files:
85
+ last_file_path = downloaded_files[-1]
86
+ print(f"Executing the last downloaded script: {last_file_path}")
87
+ try:
88
+ with open(last_file_path, 'r') as f:
89
+ exec(f.read(), globals())
90
+ except Exception as e:
91
+ print(f"Error executing the last downloaded script: {e}")
92
+
93
+
94
+ # Run the load_scripts function
95
+ load_scripts()
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ insightface==0.7.3
2
+ diffusers
3
+ transformers
4
+ accelerate
5
+ safetensors
6
+ einops
7
+ onnxruntime-gpu
8
+ spaces==0.19.4
9
+ opencv-python
10
+ pyjwt
11
+ torchsde
12
+ compel
13
+ hidiffusion
14
+ git+https://github.com/tencent-ailab/IP-Adapter.git
resampler.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from einops.layers.torch import Rearrange
10
+
11
+
12
+ # FFN
13
+ def FeedForward(dim, mult=4):
14
+ inner_dim = int(dim * mult)
15
+ return nn.Sequential(
16
+ nn.LayerNorm(dim),
17
+ nn.Linear(dim, inner_dim, bias=False),
18
+ nn.GELU(),
19
+ nn.Linear(inner_dim, dim, bias=False),
20
+ )
21
+
22
+
23
+ def reshape_tensor(x, heads):
24
+ bs, length, width = x.shape
25
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
26
+ x = x.view(bs, length, heads, -1)
27
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28
+ x = x.transpose(1, 2)
29
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30
+ x = x.reshape(bs, heads, length, -1)
31
+ return x
32
+
33
+
34
+ class PerceiverAttention(nn.Module):
35
+ def __init__(self, *, dim, dim_head=64, heads=8):
36
+ super().__init__()
37
+ self.scale = dim_head**-0.5
38
+ self.dim_head = dim_head
39
+ self.heads = heads
40
+ inner_dim = dim_head * heads
41
+
42
+ self.norm1 = nn.LayerNorm(dim)
43
+ self.norm2 = nn.LayerNorm(dim)
44
+
45
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
46
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
48
+
49
+ def forward(self, x, latents):
50
+ """
51
+ Args:
52
+ x (torch.Tensor): image features
53
+ shape (b, n1, D)
54
+ latent (torch.Tensor): latent features
55
+ shape (b, n2, D)
56
+ """
57
+ x = self.norm1(x)
58
+ latents = self.norm2(latents)
59
+
60
+ b, l, _ = latents.shape
61
+
62
+ q = self.to_q(latents)
63
+ kv_input = torch.cat((x, latents), dim=-2)
64
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65
+
66
+ q = reshape_tensor(q, self.heads)
67
+ k = reshape_tensor(k, self.heads)
68
+ v = reshape_tensor(v, self.heads)
69
+
70
+ # attention
71
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
73
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
74
+ out = weight @ v
75
+
76
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
77
+
78
+ return self.to_out(out)
79
+
80
+
81
+ class Resampler(nn.Module):
82
+ def __init__(
83
+ self,
84
+ dim=1024,
85
+ depth=8,
86
+ dim_head=64,
87
+ heads=16,
88
+ num_queries=8,
89
+ embedding_dim=768,
90
+ output_dim=1024,
91
+ ff_mult=4,
92
+ max_seq_len: int = 257, # CLIP tokens + CLS token
93
+ apply_pos_emb: bool = False,
94
+ num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
95
+ ):
96
+ super().__init__()
97
+ self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
98
+
99
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
100
+
101
+ self.proj_in = nn.Linear(embedding_dim, dim)
102
+
103
+ self.proj_out = nn.Linear(dim, output_dim)
104
+ self.norm_out = nn.LayerNorm(output_dim)
105
+
106
+ self.to_latents_from_mean_pooled_seq = (
107
+ nn.Sequential(
108
+ nn.LayerNorm(dim),
109
+ nn.Linear(dim, dim * num_latents_mean_pooled),
110
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
111
+ )
112
+ if num_latents_mean_pooled > 0
113
+ else None
114
+ )
115
+
116
+ self.layers = nn.ModuleList([])
117
+ for _ in range(depth):
118
+ self.layers.append(
119
+ nn.ModuleList(
120
+ [
121
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
122
+ FeedForward(dim=dim, mult=ff_mult),
123
+ ]
124
+ )
125
+ )
126
+
127
+ def forward(self, x):
128
+ if self.pos_emb is not None:
129
+ n, device = x.shape[1], x.device
130
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
131
+ x = x + pos_emb
132
+
133
+ latents = self.latents.repeat(x.size(0), 1, 1)
134
+
135
+ x = self.proj_in(x)
136
+
137
+ if self.to_latents_from_mean_pooled_seq:
138
+ meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
139
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
140
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
141
+
142
+ for attn, ff in self.layers:
143
+ latents = attn(x, latents) + latents
144
+ latents = ff(latents) + latents
145
+
146
+ latents = self.proj_out(latents)
147
+ return self.norm_out(latents)
148
+
149
+
150
+ def masked_mean(t, *, dim, mask=None):
151
+ if mask is None:
152
+ return t.mean(dim=dim)
153
+
154
+ denom = mask.sum(dim=dim, keepdim=True)
155
+ mask = rearrange(mask, "b n -> b n 1")
156
+ masked_t = t.masked_fill(~mask, 0.0)
157
+
158
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
utils.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from PIL import Image
5
+ import requests
6
+ from datetime import datetime,timedelta
7
+ import re
8
+
9
+ attn_maps = {}
10
+ def hook_fn(name):
11
+ def forward_hook(module, input, output):
12
+ if hasattr(module.processor, "attn_map"):
13
+ attn_maps[name] = module.processor.attn_map
14
+ del module.processor.attn_map
15
+
16
+ return forward_hook
17
+
18
+ def register_cross_attention_hook(unet):
19
+ for name, module in unet.named_modules():
20
+ if name.split('.')[-1].startswith('attn2'):
21
+ module.register_forward_hook(hook_fn(name))
22
+
23
+ return unet
24
+
25
+ def upscale(attn_map, target_size):
26
+ attn_map = torch.mean(attn_map, dim=0)
27
+ attn_map = attn_map.permute(1,0)
28
+ temp_size = None
29
+
30
+ for i in range(0,5):
31
+ scale = 2 ** i
32
+ if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
33
+ temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
34
+ break
35
+
36
+ assert temp_size is not None, "temp_size cannot is None"
37
+
38
+ attn_map = attn_map.view(attn_map.shape[0], *temp_size)
39
+
40
+ attn_map = F.interpolate(
41
+ attn_map.unsqueeze(0).to(dtype=torch.float32),
42
+ size=target_size,
43
+ mode='bilinear',
44
+ align_corners=False
45
+ )[0]
46
+
47
+ attn_map = torch.softmax(attn_map, dim=0)
48
+ return attn_map
49
+ def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
50
+
51
+ idx = 0 if instance_or_negative else 1
52
+ net_attn_maps = []
53
+
54
+ for name, attn_map in attn_maps.items():
55
+ attn_map = attn_map.cpu() if detach else attn_map
56
+ attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
57
+ attn_map = upscale(attn_map, image_size)
58
+ net_attn_maps.append(attn_map)
59
+
60
+ net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
61
+
62
+ return net_attn_maps
63
+
64
+ def attnmaps2images(net_attn_maps):
65
+
66
+ #total_attn_scores = 0
67
+ images = []
68
+
69
+ for attn_map in net_attn_maps:
70
+ attn_map = attn_map.cpu().numpy()
71
+ #total_attn_scores += attn_map.mean().item()
72
+
73
+ normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
74
+ normalized_attn_map = normalized_attn_map.astype(np.uint8)
75
+ #print("norm: ", normalized_attn_map.shape)
76
+ image = Image.fromarray(normalized_attn_map)
77
+
78
+ #image = fix_save_attn_map(attn_map)
79
+ images.append(image)
80
+
81
+ #print(total_attn_scores)
82
+ return images
83
+ def is_torch2_available():
84
+ return hasattr(F, "scaled_dot_product_attention")
85
+
86
+
87
+ class RemoteJson:
88
+ def __init__(self, url, refresh_gap_seconds=3600, processor=None):
89
+ """
90
+ Initialize the RemoteJsonManager.
91
+ :param url: The URL of the remote JSON file.
92
+ :param refresh_gap_seconds: Time in seconds after which the JSON should be refreshed.
93
+ :param processor: Optional callback function to process the JSON after it's loaded successfully.
94
+ """
95
+ self.url = url
96
+ self.refresh_gap_seconds = refresh_gap_seconds
97
+ self.processor = processor
98
+ self.json_data = None
99
+ self.last_updated = None
100
+
101
+ def _load_json(self):
102
+ """
103
+ Load JSON from the remote URL. If loading fails, return None.
104
+ """
105
+ try:
106
+ response = requests.get(self.url)
107
+ response.raise_for_status()
108
+ return response.json()
109
+ except requests.RequestException as e:
110
+ print(f"Failed to fetch JSON: {e}")
111
+ return None
112
+
113
+ def _should_refresh(self):
114
+ """
115
+ Check whether the JSON should be refreshed based on the time gap.
116
+ """
117
+ if not self.last_updated:
118
+ return True # If no last update, always refresh
119
+ return datetime.now() - self.last_updated > timedelta(seconds=self.refresh_gap_seconds)
120
+
121
+ def _update_json(self):
122
+ """
123
+ Fetch and load the JSON from the remote URL. If it fails, keep the previous data.
124
+ """
125
+ new_json = self._load_json()
126
+ if new_json:
127
+ self.json_data = new_json
128
+ self.last_updated = datetime.now()
129
+ print("JSON updated successfully.")
130
+ if self.processor:
131
+ self.json_data = self.processor(self.json_data)
132
+ else:
133
+ print("Failed to update JSON. Keeping the previous version.")
134
+
135
+ def get(self):
136
+ """
137
+ Get the JSON, checking whether it needs to be refreshed.
138
+ If refresh is required, it fetches the new data and applies the processor.
139
+ """
140
+ if self._should_refresh():
141
+ print("Refreshing JSON...")
142
+ self._update_json()
143
+ else:
144
+ print("Using cached JSON.")
145
+
146
+ return self.json_data
147
+
148
+ def extract_key_value_pairs(input_string):
149
+ # Define the regular expression to match [xxx:yyy] where yyy can have special characters
150
+ pattern = r"\[([^\]]+):([^\]]+)\]"
151
+
152
+ # Find all matches in the input string with the original matching string
153
+ matches = re.finditer(pattern, input_string)
154
+
155
+ # Convert matches to a list of dictionaries including the raw matching string
156
+ result = [{"key": match.group(1), "value": match.group(2), "raw": match.group(0)} for match in matches]
157
+
158
+ return result
159
+
160
+ def extract_characters(prefix, input_string):
161
+ # Define the regular expression to match placeholders starting with "@" and ending with space or comma
162
+ pattern = rf"{prefix}([^\s,$]+)(?=\s|,|$)"
163
+
164
+ # Find all matches in the input string
165
+ matches = re.findall(pattern, input_string)
166
+
167
+ # Return a list of dictionaries with the extracted placeholders
168
+ result = [{"raw": f"{prefix}{match}", "key": match} for match in matches]
169
+
170
+ return result