AIMS168 commited on
Commit
fb09aef
·
verified ·
1 Parent(s): 89866f3

Delete ip_adapter

Browse files
ip_adapter/__init__.py DELETED
@@ -1,10 +0,0 @@
1
- from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull,IPAdapterXL_CS,IPAdapter_CS
2
- from .ip_adapter import CSGO
3
- __all__ = [
4
- "IPAdapter",
5
- "IPAdapterPlus",
6
- "IPAdapterPlusXL",
7
- "IPAdapterXL",
8
- "CSGO"
9
- "IPAdapterFull",
10
- ]
 
 
 
 
 
 
 
 
 
 
 
ip_adapter/attention_processor.py DELETED
@@ -1,754 +0,0 @@
1
- # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
-
6
-
7
- class AttnProcessor(nn.Module):
8
- r"""
9
- Default processor for performing attention-related computations.
10
- """
11
-
12
- def __init__(
13
- self,
14
- hidden_size=None,
15
- cross_attention_dim=None,
16
- save_in_unet='down',
17
- atten_control=None,
18
- ):
19
- super().__init__()
20
- self.atten_control = atten_control
21
- self.save_in_unet = save_in_unet
22
-
23
- def __call__(
24
- self,
25
- attn,
26
- hidden_states,
27
- encoder_hidden_states=None,
28
- attention_mask=None,
29
- temb=None,
30
- ):
31
- residual = hidden_states
32
-
33
- if attn.spatial_norm is not None:
34
- hidden_states = attn.spatial_norm(hidden_states, temb)
35
-
36
- input_ndim = hidden_states.ndim
37
-
38
- if input_ndim == 4:
39
- batch_size, channel, height, width = hidden_states.shape
40
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
41
-
42
- batch_size, sequence_length, _ = (
43
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
44
- )
45
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
46
-
47
- if attn.group_norm is not None:
48
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
49
-
50
- query = attn.to_q(hidden_states)
51
-
52
- if encoder_hidden_states is None:
53
- encoder_hidden_states = hidden_states
54
- elif attn.norm_cross:
55
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
56
-
57
- key = attn.to_k(encoder_hidden_states)
58
- value = attn.to_v(encoder_hidden_states)
59
-
60
- query = attn.head_to_batch_dim(query)
61
- key = attn.head_to_batch_dim(key)
62
- value = attn.head_to_batch_dim(value)
63
-
64
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
65
- hidden_states = torch.bmm(attention_probs, value)
66
- hidden_states = attn.batch_to_head_dim(hidden_states)
67
-
68
- # linear proj
69
- hidden_states = attn.to_out[0](hidden_states)
70
- # dropout
71
- hidden_states = attn.to_out[1](hidden_states)
72
-
73
- if input_ndim == 4:
74
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
75
-
76
- if attn.residual_connection:
77
- hidden_states = hidden_states + residual
78
-
79
- hidden_states = hidden_states / attn.rescale_output_factor
80
-
81
- return hidden_states
82
-
83
-
84
- class IPAttnProcessor(nn.Module):
85
- r"""
86
- Attention processor for IP-Adapater.
87
- Args:
88
- hidden_size (`int`):
89
- The hidden size of the attention layer.
90
- cross_attention_dim (`int`):
91
- The number of channels in the `encoder_hidden_states`.
92
- scale (`float`, defaults to 1.0):
93
- the weight scale of image prompt.
94
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
95
- The context length of the image features.
96
- """
97
-
98
- def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False,save_in_unet='down', atten_control=None):
99
- super().__init__()
100
-
101
- self.hidden_size = hidden_size
102
- self.cross_attention_dim = cross_attention_dim
103
- self.scale = scale
104
- self.num_tokens = num_tokens
105
- self.skip = skip
106
-
107
- self.atten_control = atten_control
108
- self.save_in_unet = save_in_unet
109
-
110
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
111
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
112
-
113
- def __call__(
114
- self,
115
- attn,
116
- hidden_states,
117
- encoder_hidden_states=None,
118
- attention_mask=None,
119
- temb=None,
120
- ):
121
- residual = hidden_states
122
-
123
- if attn.spatial_norm is not None:
124
- hidden_states = attn.spatial_norm(hidden_states, temb)
125
-
126
- input_ndim = hidden_states.ndim
127
-
128
- if input_ndim == 4:
129
- batch_size, channel, height, width = hidden_states.shape
130
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
131
-
132
- batch_size, sequence_length, _ = (
133
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
134
- )
135
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
136
-
137
- if attn.group_norm is not None:
138
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
139
-
140
- query = attn.to_q(hidden_states)
141
-
142
- if encoder_hidden_states is None:
143
- encoder_hidden_states = hidden_states
144
- else:
145
- # get encoder_hidden_states, ip_hidden_states
146
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
147
- encoder_hidden_states, ip_hidden_states = (
148
- encoder_hidden_states[:, :end_pos, :],
149
- encoder_hidden_states[:, end_pos:, :],
150
- )
151
- if attn.norm_cross:
152
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
153
-
154
- key = attn.to_k(encoder_hidden_states)
155
- value = attn.to_v(encoder_hidden_states)
156
-
157
- query = attn.head_to_batch_dim(query)
158
- key = attn.head_to_batch_dim(key)
159
- value = attn.head_to_batch_dim(value)
160
-
161
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
162
- hidden_states = torch.bmm(attention_probs, value)
163
- hidden_states = attn.batch_to_head_dim(hidden_states)
164
-
165
- if not self.skip:
166
- # for ip-adapter
167
- ip_key = self.to_k_ip(ip_hidden_states)
168
- ip_value = self.to_v_ip(ip_hidden_states)
169
-
170
- ip_key = attn.head_to_batch_dim(ip_key)
171
- ip_value = attn.head_to_batch_dim(ip_value)
172
-
173
- ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
174
- self.attn_map = ip_attention_probs
175
- ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
176
- ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
177
-
178
- hidden_states = hidden_states + self.scale * ip_hidden_states
179
-
180
- # linear proj
181
- hidden_states = attn.to_out[0](hidden_states)
182
- # dropout
183
- hidden_states = attn.to_out[1](hidden_states)
184
-
185
- if input_ndim == 4:
186
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
187
-
188
- if attn.residual_connection:
189
- hidden_states = hidden_states + residual
190
-
191
- hidden_states = hidden_states / attn.rescale_output_factor
192
-
193
- return hidden_states
194
-
195
-
196
- class AttnProcessor2_0(torch.nn.Module):
197
- r"""
198
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
199
- """
200
-
201
- def __init__(
202
- self,
203
- hidden_size=None,
204
- cross_attention_dim=None,
205
- save_in_unet='down',
206
- atten_control=None,
207
- ):
208
- super().__init__()
209
- if not hasattr(F, "scaled_dot_product_attention"):
210
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
211
- self.atten_control = atten_control
212
- self.save_in_unet = save_in_unet
213
-
214
- def __call__(
215
- self,
216
- attn,
217
- hidden_states,
218
- encoder_hidden_states=None,
219
- attention_mask=None,
220
- temb=None,
221
- ):
222
- residual = hidden_states
223
-
224
- if attn.spatial_norm is not None:
225
- hidden_states = attn.spatial_norm(hidden_states, temb)
226
-
227
- input_ndim = hidden_states.ndim
228
-
229
- if input_ndim == 4:
230
- batch_size, channel, height, width = hidden_states.shape
231
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
232
-
233
- batch_size, sequence_length, _ = (
234
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
235
- )
236
-
237
- if attention_mask is not None:
238
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
239
- # scaled_dot_product_attention expects attention_mask shape to be
240
- # (batch, heads, source_length, target_length)
241
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
242
-
243
- if attn.group_norm is not None:
244
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
245
-
246
- query = attn.to_q(hidden_states)
247
-
248
- if encoder_hidden_states is None:
249
- encoder_hidden_states = hidden_states
250
- elif attn.norm_cross:
251
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
252
-
253
- key = attn.to_k(encoder_hidden_states)
254
- value = attn.to_v(encoder_hidden_states)
255
-
256
- inner_dim = key.shape[-1]
257
- head_dim = inner_dim // attn.heads
258
-
259
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
260
-
261
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
262
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
263
-
264
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
265
- # TODO: add support for attn.scale when we move to Torch 2.1
266
- hidden_states = F.scaled_dot_product_attention(
267
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
268
- )
269
-
270
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
271
- hidden_states = hidden_states.to(query.dtype)
272
-
273
- # linear proj
274
- hidden_states = attn.to_out[0](hidden_states)
275
- # dropout
276
- hidden_states = attn.to_out[1](hidden_states)
277
-
278
- if input_ndim == 4:
279
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
280
-
281
- if attn.residual_connection:
282
- hidden_states = hidden_states + residual
283
-
284
- hidden_states = hidden_states / attn.rescale_output_factor
285
-
286
- return hidden_states
287
-
288
-
289
- class IPAttnProcessor2_0(torch.nn.Module):
290
- r"""
291
- Attention processor for IP-Adapater for PyTorch 2.0.
292
- Args:
293
- hidden_size (`int`):
294
- The hidden size of the attention layer.
295
- cross_attention_dim (`int`):
296
- The number of channels in the `encoder_hidden_states`.
297
- scale (`float`, defaults to 1.0):
298
- the weight scale of image prompt.
299
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
300
- The context length of the image features.
301
- """
302
-
303
- def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False,save_in_unet='down', atten_control=None):
304
- super().__init__()
305
-
306
- if not hasattr(F, "scaled_dot_product_attention"):
307
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
308
-
309
- self.hidden_size = hidden_size
310
- self.cross_attention_dim = cross_attention_dim
311
- self.scale = scale
312
- self.num_tokens = num_tokens
313
- self.skip = skip
314
-
315
- self.atten_control = atten_control
316
- self.save_in_unet = save_in_unet
317
-
318
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
319
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
320
-
321
- def __call__(
322
- self,
323
- attn,
324
- hidden_states,
325
- encoder_hidden_states=None,
326
- attention_mask=None,
327
- temb=None,
328
- ):
329
- residual = hidden_states
330
-
331
- if attn.spatial_norm is not None:
332
- hidden_states = attn.spatial_norm(hidden_states, temb)
333
-
334
- input_ndim = hidden_states.ndim
335
-
336
- if input_ndim == 4:
337
- batch_size, channel, height, width = hidden_states.shape
338
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
339
-
340
- batch_size, sequence_length, _ = (
341
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
342
- )
343
-
344
- if attention_mask is not None:
345
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
346
- # scaled_dot_product_attention expects attention_mask shape to be
347
- # (batch, heads, source_length, target_length)
348
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
349
-
350
- if attn.group_norm is not None:
351
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
352
-
353
- query = attn.to_q(hidden_states)
354
-
355
- if encoder_hidden_states is None:
356
- encoder_hidden_states = hidden_states
357
- else:
358
- # get encoder_hidden_states, ip_hidden_states
359
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
360
- encoder_hidden_states, ip_hidden_states = (
361
- encoder_hidden_states[:, :end_pos, :],
362
- encoder_hidden_states[:, end_pos:, :],
363
- )
364
- if attn.norm_cross:
365
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
366
-
367
- key = attn.to_k(encoder_hidden_states)
368
- value = attn.to_v(encoder_hidden_states)
369
-
370
- inner_dim = key.shape[-1]
371
- head_dim = inner_dim // attn.heads
372
-
373
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
374
-
375
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
376
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
377
-
378
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
379
- # TODO: add support for attn.scale when we move to Torch 2.1
380
- hidden_states = F.scaled_dot_product_attention(
381
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
382
- )
383
-
384
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
385
- hidden_states = hidden_states.to(query.dtype)
386
-
387
- if not self.skip:
388
- # for ip-adapter
389
- ip_key = self.to_k_ip(ip_hidden_states)
390
- ip_value = self.to_v_ip(ip_hidden_states)
391
-
392
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
393
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
394
-
395
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
396
- # TODO: add support for attn.scale when we move to Torch 2.1
397
- ip_hidden_states = F.scaled_dot_product_attention(
398
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
399
- )
400
- with torch.no_grad():
401
- self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
402
- #print(self.attn_map.shape)
403
-
404
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
405
- ip_hidden_states = ip_hidden_states.to(query.dtype)
406
-
407
- hidden_states = hidden_states + self.scale * ip_hidden_states
408
-
409
- # linear proj
410
- hidden_states = attn.to_out[0](hidden_states)
411
- # dropout
412
- hidden_states = attn.to_out[1](hidden_states)
413
-
414
- if input_ndim == 4:
415
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
416
-
417
- if attn.residual_connection:
418
- hidden_states = hidden_states + residual
419
-
420
- hidden_states = hidden_states / attn.rescale_output_factor
421
-
422
- return hidden_states
423
-
424
-
425
- class IP_CS_AttnProcessor2_0(torch.nn.Module):
426
- r"""
427
- Attention processor for IP-Adapater for PyTorch 2.0.
428
- Args:
429
- hidden_size (`int`):
430
- The hidden size of the attention layer.
431
- cross_attention_dim (`int`):
432
- The number of channels in the `encoder_hidden_states`.
433
- scale (`float`, defaults to 1.0):
434
- the weight scale of image prompt.
435
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
436
- The context length of the image features.
437
- """
438
-
439
- def __init__(self, hidden_size, cross_attention_dim=None, content_scale=1.0,style_scale=1.0, num_content_tokens=4,num_style_tokens=4,
440
- skip=False,content=False, style=False):
441
- super().__init__()
442
-
443
- if not hasattr(F, "scaled_dot_product_attention"):
444
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
445
-
446
- self.hidden_size = hidden_size
447
- self.cross_attention_dim = cross_attention_dim
448
- self.content_scale = content_scale
449
- self.style_scale = style_scale
450
- self.num_content_tokens = num_content_tokens
451
- self.num_style_tokens = num_style_tokens
452
- self.skip = skip
453
-
454
- self.content = content
455
- self.style = style
456
-
457
- if self.content or self.style:
458
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
459
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
460
- self.to_k_ip_content =None
461
- self.to_v_ip_content =None
462
-
463
- def set_content_ipa(self,content_scale=1.0):
464
-
465
- self.to_k_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False)
466
- self.to_v_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False)
467
- self.content_scale=content_scale
468
- self.content =True
469
-
470
- def __call__(
471
- self,
472
- attn,
473
- hidden_states,
474
- encoder_hidden_states=None,
475
- attention_mask=None,
476
- temb=None,
477
- ):
478
- residual = hidden_states
479
-
480
- if attn.spatial_norm is not None:
481
- hidden_states = attn.spatial_norm(hidden_states, temb)
482
-
483
- input_ndim = hidden_states.ndim
484
-
485
- if input_ndim == 4:
486
- batch_size, channel, height, width = hidden_states.shape
487
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
488
-
489
- batch_size, sequence_length, _ = (
490
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
491
- )
492
-
493
- if attention_mask is not None:
494
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
495
- # scaled_dot_product_attention expects attention_mask shape to be
496
- # (batch, heads, source_length, target_length)
497
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
498
-
499
- if attn.group_norm is not None:
500
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
501
-
502
- query = attn.to_q(hidden_states)
503
-
504
- if encoder_hidden_states is None:
505
- encoder_hidden_states = hidden_states
506
- else:
507
- # get encoder_hidden_states, ip_hidden_states
508
- end_pos = encoder_hidden_states.shape[1] - self.num_content_tokens-self.num_style_tokens
509
- encoder_hidden_states, ip_content_hidden_states,ip_style_hidden_states = (
510
- encoder_hidden_states[:, :end_pos, :],
511
- encoder_hidden_states[:, end_pos:end_pos + self.num_content_tokens, :],
512
- encoder_hidden_states[:, end_pos + self.num_content_tokens:, :],
513
- )
514
- if attn.norm_cross:
515
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
516
-
517
- key = attn.to_k(encoder_hidden_states)
518
- value = attn.to_v(encoder_hidden_states)
519
-
520
- inner_dim = key.shape[-1]
521
- head_dim = inner_dim // attn.heads
522
-
523
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
524
-
525
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
526
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
527
-
528
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
529
- # TODO: add support for attn.scale when we move to Torch 2.1
530
- hidden_states = F.scaled_dot_product_attention(
531
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
532
- )
533
-
534
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
535
- hidden_states = hidden_states.to(query.dtype)
536
-
537
- if not self.skip and self.content is True:
538
- # print('content#####################################################')
539
- # for ip-content-adapter
540
- if self.to_k_ip_content is None:
541
-
542
- ip_content_key = self.to_k_ip(ip_content_hidden_states)
543
- ip_content_value = self.to_v_ip(ip_content_hidden_states)
544
- else:
545
- ip_content_key = self.to_k_ip_content(ip_content_hidden_states)
546
- ip_content_value = self.to_v_ip_content(ip_content_hidden_states)
547
-
548
- ip_content_key = ip_content_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
549
- ip_content_value = ip_content_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
550
-
551
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
552
- # TODO: add support for attn.scale when we move to Torch 2.1
553
- ip_content_hidden_states = F.scaled_dot_product_attention(
554
- query, ip_content_key, ip_content_value, attn_mask=None, dropout_p=0.0, is_causal=False
555
- )
556
-
557
-
558
- ip_content_hidden_states = ip_content_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
559
- ip_content_hidden_states = ip_content_hidden_states.to(query.dtype)
560
-
561
-
562
- hidden_states = hidden_states + self.content_scale * ip_content_hidden_states
563
-
564
- if not self.skip and self.style is True:
565
- # for ip-style-adapter
566
- ip_style_key = self.to_k_ip(ip_style_hidden_states)
567
- ip_style_value = self.to_v_ip(ip_style_hidden_states)
568
-
569
- ip_style_key = ip_style_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
570
- ip_style_value = ip_style_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
571
-
572
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
573
- # TODO: add support for attn.scale when we move to Torch 2.1
574
- ip_style_hidden_states = F.scaled_dot_product_attention(
575
- query, ip_style_key, ip_style_value, attn_mask=None, dropout_p=0.0, is_causal=False
576
- )
577
-
578
- ip_style_hidden_states = ip_style_hidden_states.transpose(1, 2).reshape(batch_size, -1,
579
- attn.heads * head_dim)
580
- ip_style_hidden_states = ip_style_hidden_states.to(query.dtype)
581
-
582
- hidden_states = hidden_states + self.style_scale * ip_style_hidden_states
583
-
584
- # linear proj
585
- hidden_states = attn.to_out[0](hidden_states)
586
- # dropout
587
- hidden_states = attn.to_out[1](hidden_states)
588
-
589
- if input_ndim == 4:
590
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
591
-
592
- if attn.residual_connection:
593
- hidden_states = hidden_states + residual
594
-
595
- hidden_states = hidden_states / attn.rescale_output_factor
596
-
597
- return hidden_states
598
-
599
- ## for controlnet
600
- class CNAttnProcessor:
601
- r"""
602
- Default processor for performing attention-related computations.
603
- """
604
-
605
- def __init__(self, num_tokens=4,save_in_unet='down',atten_control=None):
606
- self.num_tokens = num_tokens
607
- self.atten_control = atten_control
608
- self.save_in_unet = save_in_unet
609
-
610
- def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
611
- residual = hidden_states
612
-
613
- if attn.spatial_norm is not None:
614
- hidden_states = attn.spatial_norm(hidden_states, temb)
615
-
616
- input_ndim = hidden_states.ndim
617
-
618
- if input_ndim == 4:
619
- batch_size, channel, height, width = hidden_states.shape
620
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
621
-
622
- batch_size, sequence_length, _ = (
623
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
624
- )
625
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
626
-
627
- if attn.group_norm is not None:
628
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
629
-
630
- query = attn.to_q(hidden_states)
631
-
632
- if encoder_hidden_states is None:
633
- encoder_hidden_states = hidden_states
634
- else:
635
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
636
- encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
637
- if attn.norm_cross:
638
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
639
-
640
- key = attn.to_k(encoder_hidden_states)
641
- value = attn.to_v(encoder_hidden_states)
642
-
643
- query = attn.head_to_batch_dim(query)
644
- key = attn.head_to_batch_dim(key)
645
- value = attn.head_to_batch_dim(value)
646
-
647
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
648
- hidden_states = torch.bmm(attention_probs, value)
649
- hidden_states = attn.batch_to_head_dim(hidden_states)
650
-
651
- # linear proj
652
- hidden_states = attn.to_out[0](hidden_states)
653
- # dropout
654
- hidden_states = attn.to_out[1](hidden_states)
655
-
656
- if input_ndim == 4:
657
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
658
-
659
- if attn.residual_connection:
660
- hidden_states = hidden_states + residual
661
-
662
- hidden_states = hidden_states / attn.rescale_output_factor
663
-
664
- return hidden_states
665
-
666
-
667
- class CNAttnProcessor2_0:
668
- r"""
669
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
670
- """
671
-
672
- def __init__(self, num_tokens=4, save_in_unet='down', atten_control=None):
673
- if not hasattr(F, "scaled_dot_product_attention"):
674
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
675
- self.num_tokens = num_tokens
676
- self.atten_control = atten_control
677
- self.save_in_unet = save_in_unet
678
-
679
- def __call__(
680
- self,
681
- attn,
682
- hidden_states,
683
- encoder_hidden_states=None,
684
- attention_mask=None,
685
- temb=None,
686
- ):
687
- residual = hidden_states
688
-
689
- if attn.spatial_norm is not None:
690
- hidden_states = attn.spatial_norm(hidden_states, temb)
691
-
692
- input_ndim = hidden_states.ndim
693
-
694
- if input_ndim == 4:
695
- batch_size, channel, height, width = hidden_states.shape
696
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
697
-
698
- batch_size, sequence_length, _ = (
699
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
700
- )
701
-
702
- if attention_mask is not None:
703
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
704
- # scaled_dot_product_attention expects attention_mask shape to be
705
- # (batch, heads, source_length, target_length)
706
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
707
-
708
- if attn.group_norm is not None:
709
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
710
-
711
- query = attn.to_q(hidden_states)
712
-
713
- if encoder_hidden_states is None:
714
- encoder_hidden_states = hidden_states
715
- else:
716
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
717
- encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
718
- if attn.norm_cross:
719
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
720
-
721
- key = attn.to_k(encoder_hidden_states)
722
- value = attn.to_v(encoder_hidden_states)
723
-
724
- inner_dim = key.shape[-1]
725
- head_dim = inner_dim // attn.heads
726
-
727
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
728
-
729
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
730
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
731
-
732
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
733
- # TODO: add support for attn.scale when we move to Torch 2.1
734
- hidden_states = F.scaled_dot_product_attention(
735
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
736
- )
737
-
738
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
739
- hidden_states = hidden_states.to(query.dtype)
740
-
741
- # linear proj
742
- hidden_states = attn.to_out[0](hidden_states)
743
- # dropout
744
- hidden_states = attn.to_out[1](hidden_states)
745
-
746
- if input_ndim == 4:
747
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
748
-
749
- if attn.residual_connection:
750
- hidden_states = hidden_states + residual
751
-
752
- hidden_states = hidden_states / attn.rescale_output_factor
753
-
754
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ip_adapter/ip_adapter.py DELETED
@@ -1,1078 +0,0 @@
1
- import os
2
- from typing import List
3
-
4
- import torch
5
- from diffusers import StableDiffusionPipeline
6
- from diffusers.pipelines.controlnet import MultiControlNetModel
7
- from PIL import Image
8
- from safetensors import safe_open
9
- from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
10
- from torchvision import transforms
11
- from .utils import is_torch2_available, get_generator
12
-
13
- # import torchvision.transforms.functional as Func
14
-
15
- # from .clip_style_models import CSD_CLIP, convert_state_dict
16
-
17
- if is_torch2_available():
18
- from .attention_processor import (
19
- AttnProcessor2_0 as AttnProcessor,
20
- )
21
- from .attention_processor import (
22
- CNAttnProcessor2_0 as CNAttnProcessor,
23
- )
24
- from .attention_processor import (
25
- IPAttnProcessor2_0 as IPAttnProcessor,
26
- )
27
- from .attention_processor import IP_CS_AttnProcessor2_0 as IP_CS_AttnProcessor
28
- else:
29
- from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
30
- from .resampler import Resampler
31
-
32
- from transformers import AutoImageProcessor, AutoModel
33
-
34
-
35
- class ImageProjModel(torch.nn.Module):
36
- """Projection Model"""
37
-
38
- def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
39
- super().__init__()
40
-
41
- self.generator = None
42
- self.cross_attention_dim = cross_attention_dim
43
- self.clip_extra_context_tokens = clip_extra_context_tokens
44
- # print(clip_embeddings_dim, self.clip_extra_context_tokens, cross_attention_dim)
45
- self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
46
- self.norm = torch.nn.LayerNorm(cross_attention_dim)
47
-
48
- def forward(self, image_embeds):
49
- embeds = image_embeds
50
- clip_extra_context_tokens = self.proj(embeds).reshape(
51
- -1, self.clip_extra_context_tokens, self.cross_attention_dim
52
- )
53
- clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
54
- return clip_extra_context_tokens
55
-
56
-
57
- class MLPProjModel(torch.nn.Module):
58
- """SD model with image prompt"""
59
-
60
- def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
61
- super().__init__()
62
-
63
- self.proj = torch.nn.Sequential(
64
- torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
65
- torch.nn.GELU(),
66
- torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
67
- torch.nn.LayerNorm(cross_attention_dim)
68
- )
69
-
70
- def forward(self, image_embeds):
71
- clip_extra_context_tokens = self.proj(image_embeds)
72
- return clip_extra_context_tokens
73
-
74
-
75
- class IPAdapter:
76
- def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, target_blocks=["block"]):
77
- self.device = device
78
- self.image_encoder_path = image_encoder_path
79
- self.ip_ckpt = ip_ckpt
80
- self.num_tokens = num_tokens
81
- self.target_blocks = target_blocks
82
-
83
- self.pipe = sd_pipe.to(self.device)
84
- self.set_ip_adapter()
85
-
86
- # load image encoder
87
- self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
88
- self.device, dtype=torch.float16
89
- )
90
- self.clip_image_processor = CLIPImageProcessor()
91
- # image proj model
92
- self.image_proj_model = self.init_proj()
93
-
94
- self.load_ip_adapter()
95
-
96
- def init_proj(self):
97
- image_proj_model = ImageProjModel(
98
- cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
99
- clip_embeddings_dim=self.image_encoder.config.projection_dim,
100
- clip_extra_context_tokens=self.num_tokens,
101
- ).to(self.device, dtype=torch.float16)
102
- return image_proj_model
103
-
104
- def set_ip_adapter(self):
105
- unet = self.pipe.unet
106
- attn_procs = {}
107
- for name in unet.attn_processors.keys():
108
- cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
109
- if name.startswith("mid_block"):
110
- hidden_size = unet.config.block_out_channels[-1]
111
- elif name.startswith("up_blocks"):
112
- block_id = int(name[len("up_blocks.")])
113
- hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
114
- elif name.startswith("down_blocks"):
115
- block_id = int(name[len("down_blocks.")])
116
- hidden_size = unet.config.block_out_channels[block_id]
117
- if cross_attention_dim is None:
118
- attn_procs[name] = AttnProcessor()
119
- else:
120
- selected = False
121
- for block_name in self.target_blocks:
122
- if block_name in name:
123
- selected = True
124
- break
125
- if selected:
126
- attn_procs[name] = IPAttnProcessor(
127
- hidden_size=hidden_size,
128
- cross_attention_dim=cross_attention_dim,
129
- scale=1.0,
130
- num_tokens=self.num_tokens,
131
- ).to(self.device, dtype=torch.float16)
132
- else:
133
- attn_procs[name] = IPAttnProcessor(
134
- hidden_size=hidden_size,
135
- cross_attention_dim=cross_attention_dim,
136
- scale=1.0,
137
- num_tokens=self.num_tokens,
138
- skip=True
139
- ).to(self.device, dtype=torch.float16)
140
- unet.set_attn_processor(attn_procs)
141
- if hasattr(self.pipe, "controlnet"):
142
- if isinstance(self.pipe.controlnet, MultiControlNetModel):
143
- for controlnet in self.pipe.controlnet.nets:
144
- controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
145
- else:
146
- self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
147
-
148
- def load_ip_adapter(self):
149
- if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
150
- state_dict = {"image_proj": {}, "ip_adapter": {}}
151
- with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
152
- for key in f.keys():
153
- if key.startswith("image_proj."):
154
- state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
155
- elif key.startswith("ip_adapter."):
156
- state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
157
- else:
158
- state_dict = torch.load(self.ip_ckpt, map_location="cpu")
159
- self.image_proj_model.load_state_dict(state_dict["image_proj"])
160
- ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
161
- ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
162
-
163
- @torch.inference_mode()
164
- def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None):
165
- if pil_image is not None:
166
- if isinstance(pil_image, Image.Image):
167
- pil_image = [pil_image]
168
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
169
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
170
- else:
171
- clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
172
-
173
- if content_prompt_embeds is not None:
174
- clip_image_embeds = clip_image_embeds - content_prompt_embeds
175
-
176
- image_prompt_embeds = self.image_proj_model(clip_image_embeds)
177
- uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
178
- return image_prompt_embeds, uncond_image_prompt_embeds
179
-
180
- def set_scale(self, scale):
181
- for attn_processor in self.pipe.unet.attn_processors.values():
182
- if isinstance(attn_processor, IPAttnProcessor):
183
- attn_processor.scale = scale
184
-
185
- def generate(
186
- self,
187
- pil_image=None,
188
- clip_image_embeds=None,
189
- prompt=None,
190
- negative_prompt=None,
191
- scale=1.0,
192
- num_samples=4,
193
- seed=None,
194
- guidance_scale=7.5,
195
- num_inference_steps=30,
196
- neg_content_emb=None,
197
- **kwargs,
198
- ):
199
- self.set_scale(scale)
200
-
201
- if pil_image is not None:
202
- num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
203
- else:
204
- num_prompts = clip_image_embeds.size(0)
205
-
206
- if prompt is None:
207
- prompt = "best quality, high quality"
208
- if negative_prompt is None:
209
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
210
-
211
- if not isinstance(prompt, List):
212
- prompt = [prompt] * num_prompts
213
- if not isinstance(negative_prompt, List):
214
- negative_prompt = [negative_prompt] * num_prompts
215
-
216
- image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
217
- pil_image=pil_image, clip_image_embeds=clip_image_embeds, content_prompt_embeds=neg_content_emb
218
- )
219
- bs_embed, seq_len, _ = image_prompt_embeds.shape
220
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
221
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
222
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
223
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
224
-
225
- with torch.inference_mode():
226
- prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
227
- prompt,
228
- device=self.device,
229
- num_images_per_prompt=num_samples,
230
- do_classifier_free_guidance=True,
231
- negative_prompt=negative_prompt,
232
- )
233
- prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
234
- negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
235
-
236
- generator = get_generator(seed, self.device)
237
-
238
- images = self.pipe(
239
- prompt_embeds=prompt_embeds,
240
- negative_prompt_embeds=negative_prompt_embeds,
241
- guidance_scale=guidance_scale,
242
- num_inference_steps=num_inference_steps,
243
- generator=generator,
244
- **kwargs,
245
- ).images
246
-
247
- return images
248
-
249
-
250
- class IPAdapter_CS:
251
- def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_content_tokens=4,
252
- num_style_tokens=4,
253
- target_content_blocks=["block"], target_style_blocks=["block"], content_image_encoder_path=None,
254
- controlnet_adapter=False,
255
- controlnet_target_content_blocks=None,
256
- controlnet_target_style_blocks=None,
257
- content_model_resampler=False,
258
- style_model_resampler=False,
259
- ):
260
- self.device = device
261
- self.image_encoder_path = image_encoder_path
262
- self.ip_ckpt = ip_ckpt
263
- self.num_content_tokens = num_content_tokens
264
- self.num_style_tokens = num_style_tokens
265
- self.content_target_blocks = target_content_blocks
266
- self.style_target_blocks = target_style_blocks
267
-
268
- self.content_model_resampler = content_model_resampler
269
- self.style_model_resampler = style_model_resampler
270
-
271
- self.controlnet_adapter = controlnet_adapter
272
- self.controlnet_target_content_blocks = controlnet_target_content_blocks
273
- self.controlnet_target_style_blocks = controlnet_target_style_blocks
274
-
275
- self.pipe = sd_pipe.to(self.device)
276
- self.set_ip_adapter()
277
- self.content_image_encoder_path = content_image_encoder_path
278
-
279
-
280
- # load image encoder
281
- if content_image_encoder_path is not None:
282
- self.content_image_encoder = AutoModel.from_pretrained(content_image_encoder_path).to(self.device,
283
- dtype=torch.float16)
284
- self.content_image_processor = AutoImageProcessor.from_pretrained(content_image_encoder_path)
285
- else:
286
- self.content_image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
287
- self.device, dtype=torch.float16
288
- )
289
- self.content_image_processor = CLIPImageProcessor()
290
- # model.requires_grad_(False)
291
-
292
- self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
293
- self.device, dtype=torch.float16
294
- )
295
- # if self.use_CSD is not None:
296
- # self.style_image_encoder = CSD_CLIP("vit_large", "default",self.use_CSD+"/ViT-L-14.pt")
297
- # model_path = self.use_CSD+"/checkpoint.pth"
298
- # checkpoint = torch.load(model_path, map_location="cpu")
299
- # state_dict = convert_state_dict(checkpoint['model_state_dict'])
300
- # self.style_image_encoder.load_state_dict(state_dict, strict=False)
301
- #
302
- # normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
303
- # self.style_preprocess = transforms.Compose([
304
- # transforms.Resize(size=224, interpolation=Func.InterpolationMode.BICUBIC),
305
- # transforms.CenterCrop(224),
306
- # transforms.ToTensor(),
307
- # normalize,
308
- # ])
309
-
310
- self.clip_image_processor = CLIPImageProcessor()
311
- # image proj model
312
- self.content_image_proj_model = self.init_proj(self.num_content_tokens, content_or_style_='content',
313
- model_resampler=self.content_model_resampler)
314
- self.style_image_proj_model = self.init_proj(self.num_style_tokens, content_or_style_='style',
315
- model_resampler=self.style_model_resampler)
316
-
317
- self.load_ip_adapter()
318
-
319
- def init_proj(self, num_tokens, content_or_style_='content', model_resampler=False):
320
-
321
- # print('@@@@',self.pipe.unet.config.cross_attention_dim,self.image_encoder.config.projection_dim)
322
- if content_or_style_ == 'content' and self.content_image_encoder_path is not None:
323
- image_proj_model = ImageProjModel(
324
- cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
325
- clip_embeddings_dim=self.content_image_encoder.config.projection_dim,
326
- clip_extra_context_tokens=num_tokens,
327
- ).to(self.device, dtype=torch.float16)
328
- return image_proj_model
329
-
330
- image_proj_model = ImageProjModel(
331
- cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
332
- clip_embeddings_dim=self.image_encoder.config.projection_dim,
333
- clip_extra_context_tokens=num_tokens,
334
- ).to(self.device, dtype=torch.float16)
335
- return image_proj_model
336
-
337
- def set_ip_adapter(self):
338
- unet = self.pipe.unet
339
- attn_procs = {}
340
- for name in unet.attn_processors.keys():
341
- cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
342
- if name.startswith("mid_block"):
343
- hidden_size = unet.config.block_out_channels[-1]
344
- elif name.startswith("up_blocks"):
345
- block_id = int(name[len("up_blocks.")])
346
- hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
347
- elif name.startswith("down_blocks"):
348
- block_id = int(name[len("down_blocks.")])
349
- hidden_size = unet.config.block_out_channels[block_id]
350
- if cross_attention_dim is None:
351
- attn_procs[name] = AttnProcessor()
352
- else:
353
- # layername_id += 1
354
- selected = False
355
- for block_name in self.style_target_blocks:
356
- if block_name in name:
357
- selected = True
358
- # print(name)
359
- attn_procs[name] = IP_CS_AttnProcessor(
360
- hidden_size=hidden_size,
361
- cross_attention_dim=cross_attention_dim,
362
- style_scale=1.0,
363
- style=True,
364
- num_content_tokens=self.num_content_tokens,
365
- num_style_tokens=self.num_style_tokens,
366
- )
367
- for block_name in self.content_target_blocks:
368
- if block_name in name:
369
- # selected = True
370
- if selected is False:
371
- attn_procs[name] = IP_CS_AttnProcessor(
372
- hidden_size=hidden_size,
373
- cross_attention_dim=cross_attention_dim,
374
- content_scale=1.0,
375
- content=True,
376
- num_content_tokens=self.num_content_tokens,
377
- num_style_tokens=self.num_style_tokens,
378
- )
379
- else:
380
- attn_procs[name].set_content_ipa(content_scale=1.0)
381
- # attn_procs[name].content=True
382
-
383
- if selected is False:
384
- attn_procs[name] = IP_CS_AttnProcessor(
385
- hidden_size=hidden_size,
386
- cross_attention_dim=cross_attention_dim,
387
- num_content_tokens=self.num_content_tokens,
388
- num_style_tokens=self.num_style_tokens,
389
- skip=True,
390
- )
391
-
392
- attn_procs[name].to(self.device, dtype=torch.float16)
393
- unet.set_attn_processor(attn_procs)
394
- if hasattr(self.pipe, "controlnet"):
395
- if self.controlnet_adapter is False:
396
- if isinstance(self.pipe.controlnet, MultiControlNetModel):
397
- for controlnet in self.pipe.controlnet.nets:
398
- controlnet.set_attn_processor(CNAttnProcessor(
399
- num_tokens=self.num_content_tokens + self.num_style_tokens))
400
- else:
401
- self.pipe.controlnet.set_attn_processor(CNAttnProcessor(
402
- num_tokens=self.num_content_tokens + self.num_style_tokens))
403
-
404
- else:
405
- controlnet_attn_procs = {}
406
- controlnet_style_target_blocks = self.controlnet_target_style_blocks
407
- controlnet_content_target_blocks = self.controlnet_target_content_blocks
408
- for name in self.pipe.controlnet.attn_processors.keys():
409
- # print(name)
410
- cross_attention_dim = None if name.endswith(
411
- "attn1.processor") else self.pipe.controlnet.config.cross_attention_dim
412
- if name.startswith("mid_block"):
413
- hidden_size = self.pipe.controlnet.config.block_out_channels[-1]
414
- elif name.startswith("up_blocks"):
415
- block_id = int(name[len("up_blocks.")])
416
- hidden_size = list(reversed(self.pipe.controlnet.config.block_out_channels))[block_id]
417
- elif name.startswith("down_blocks"):
418
- block_id = int(name[len("down_blocks.")])
419
- hidden_size = self.pipe.controlnet.config.block_out_channels[block_id]
420
- if cross_attention_dim is None:
421
- # layername_id += 1
422
- controlnet_attn_procs[name] = AttnProcessor()
423
-
424
- else:
425
- # layername_id += 1
426
- selected = False
427
- for block_name in controlnet_style_target_blocks:
428
- if block_name in name:
429
- selected = True
430
- # print(name)
431
- controlnet_attn_procs[name] = IP_CS_AttnProcessor(
432
- hidden_size=hidden_size,
433
- cross_attention_dim=cross_attention_dim,
434
- style_scale=1.0,
435
- style=True,
436
- num_content_tokens=self.num_content_tokens,
437
- num_style_tokens=self.num_style_tokens,
438
- )
439
-
440
- for block_name in controlnet_content_target_blocks:
441
- if block_name in name:
442
- if selected is False:
443
- controlnet_attn_procs[name] = IP_CS_AttnProcessor(
444
- hidden_size=hidden_size,
445
- cross_attention_dim=cross_attention_dim,
446
- content_scale=1.0,
447
- content=True,
448
- num_content_tokens=self.num_content_tokens,
449
- num_style_tokens=self.num_style_tokens,
450
- )
451
-
452
- selected = True
453
- elif selected is True:
454
- controlnet_attn_procs[name].set_content_ipa(content_scale=1.0)
455
-
456
- # if args.content_image_encoder_type !='dinov2':
457
- # weights = {
458
- # "to_k_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_k_ip.weight"],
459
- # "to_v_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_v_ip.weight"],
460
- # }
461
- # attn_procs[name].load_state_dict(weights)
462
- if selected is False:
463
- controlnet_attn_procs[name] = IP_CS_AttnProcessor(
464
- hidden_size=hidden_size,
465
- cross_attention_dim=cross_attention_dim,
466
- num_content_tokens=self.num_content_tokens,
467
- num_style_tokens=self.num_style_tokens,
468
- skip=True,
469
- )
470
- controlnet_attn_procs[name].to(self.device, dtype=torch.float16)
471
- # layer_name = name.split(".processor")[0]
472
- # # print(state_dict["ip_adapter"].keys())
473
- # weights = {
474
- # "to_k_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_k_ip.weight"],
475
- # "to_v_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_v_ip.weight"],
476
- # }
477
- # attn_procs[name].load_state_dict(weights)
478
- self.pipe.controlnet.set_attn_processor(controlnet_attn_procs)
479
-
480
- def load_ip_adapter(self):
481
- if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
482
- state_dict = {"content_image_proj": {}, "style_image_proj": {}, "ip_adapter": {}}
483
- with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
484
- for key in f.keys():
485
- if key.startswith("content_image_proj."):
486
- state_dict["content_image_proj"][key.replace("content_image_proj.", "")] = f.get_tensor(key)
487
- elif key.startswith("style_image_proj."):
488
- state_dict["style_image_proj"][key.replace("style_image_proj.", "")] = f.get_tensor(key)
489
- elif key.startswith("ip_adapter."):
490
- state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
491
- else:
492
- state_dict = torch.load(self.ip_ckpt, map_location="cpu")
493
- self.content_image_proj_model.load_state_dict(state_dict["content_image_proj"])
494
- self.style_image_proj_model.load_state_dict(state_dict["style_image_proj"])
495
-
496
- if 'conv_in_unet_sd' in state_dict.keys():
497
- self.pipe.unet.conv_in.load_state_dict(state_dict["conv_in_unet_sd"], strict=True)
498
- ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
499
- ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
500
-
501
- if self.controlnet_adapter is True:
502
- print('loading controlnet_adapter')
503
- self.pipe.controlnet.load_state_dict(state_dict["controlnet_adapter_modules"], strict=False)
504
-
505
- @torch.inference_mode()
506
- def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None,
507
- content_or_style_=''):
508
- # if pil_image is not None:
509
- # if isinstance(pil_image, Image.Image):
510
- # pil_image = [pil_image]
511
- # clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
512
- # clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
513
- # else:
514
- # clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
515
-
516
- # if content_prompt_embeds is not None:
517
- # clip_image_embeds = clip_image_embeds - content_prompt_embeds
518
-
519
- if content_or_style_ == 'content':
520
- if pil_image is not None:
521
- if isinstance(pil_image, Image.Image):
522
- pil_image = [pil_image]
523
- if self.content_image_proj_model is not None:
524
- clip_image = self.content_image_processor(images=pil_image, return_tensors="pt").pixel_values
525
- clip_image_embeds = self.content_image_encoder(
526
- clip_image.to(self.device, dtype=torch.float16)).image_embeds
527
- else:
528
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
529
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
530
- else:
531
- clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
532
-
533
- image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
534
- uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))
535
- return image_prompt_embeds, uncond_image_prompt_embeds
536
- if content_or_style_ == 'style':
537
- if pil_image is not None:
538
- if self.use_CSD is not None:
539
- clip_image = self.style_preprocess(pil_image).unsqueeze(0).to(self.device, dtype=torch.float32)
540
- clip_image_embeds = self.style_image_encoder(clip_image)
541
- else:
542
- if isinstance(pil_image, Image.Image):
543
- pil_image = [pil_image]
544
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
545
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
546
-
547
-
548
- else:
549
- clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
550
- image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)
551
- uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds))
552
- return image_prompt_embeds, uncond_image_prompt_embeds
553
-
554
- def set_scale(self, content_scale, style_scale):
555
- for attn_processor in self.pipe.unet.attn_processors.values():
556
- if isinstance(attn_processor, IP_CS_AttnProcessor):
557
- if attn_processor.content is True:
558
- attn_processor.content_scale = content_scale
559
-
560
- if attn_processor.style is True:
561
- attn_processor.style_scale = style_scale
562
- # print('style_scale:',style_scale)
563
- if self.controlnet_adapter is not None:
564
- for attn_processor in self.pipe.controlnet.attn_processors.values():
565
-
566
- if isinstance(attn_processor, IP_CS_AttnProcessor):
567
- if attn_processor.content is True:
568
- attn_processor.content_scale = content_scale
569
- # print(content_scale)
570
-
571
- if attn_processor.style is True:
572
- attn_processor.style_scale = style_scale
573
-
574
- def generate(
575
- self,
576
- pil_content_image=None,
577
- pil_style_image=None,
578
- clip_content_image_embeds=None,
579
- clip_style_image_embeds=None,
580
- prompt=None,
581
- negative_prompt=None,
582
- content_scale=1.0,
583
- style_scale=1.0,
584
- num_samples=4,
585
- seed=None,
586
- guidance_scale=7.5,
587
- num_inference_steps=30,
588
- neg_content_emb=None,
589
- **kwargs,
590
- ):
591
- self.set_scale(content_scale, style_scale)
592
-
593
- if pil_content_image is not None:
594
- num_prompts = 1 if isinstance(pil_content_image, Image.Image) else len(pil_content_image)
595
- else:
596
- num_prompts = clip_content_image_embeds.size(0)
597
-
598
- if prompt is None:
599
- prompt = "best quality, high quality"
600
- if negative_prompt is None:
601
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
602
-
603
- if not isinstance(prompt, List):
604
- prompt = [prompt] * num_prompts
605
- if not isinstance(negative_prompt, List):
606
- negative_prompt = [negative_prompt] * num_prompts
607
-
608
- content_image_prompt_embeds, uncond_content_image_prompt_embeds = self.get_image_embeds(
609
- pil_image=pil_content_image, clip_image_embeds=clip_content_image_embeds
610
- )
611
- style_image_prompt_embeds, uncond_style_image_prompt_embeds = self.get_image_embeds(
612
- pil_image=pil_style_image, clip_image_embeds=clip_style_image_embeds
613
- )
614
-
615
- bs_embed, seq_len, _ = content_image_prompt_embeds.shape
616
- content_image_prompt_embeds = content_image_prompt_embeds.repeat(1, num_samples, 1)
617
- content_image_prompt_embeds = content_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
618
- uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.repeat(1, num_samples, 1)
619
- uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.view(bs_embed * num_samples, seq_len,
620
- -1)
621
-
622
- bs_style_embed, seq_style_len, _ = content_image_prompt_embeds.shape
623
- style_image_prompt_embeds = style_image_prompt_embeds.repeat(1, num_samples, 1)
624
- style_image_prompt_embeds = style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len, -1)
625
- uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.repeat(1, num_samples, 1)
626
- uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len,
627
- -1)
628
-
629
- with torch.inference_mode():
630
- prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
631
- prompt,
632
- device=self.device,
633
- num_images_per_prompt=num_samples,
634
- do_classifier_free_guidance=True,
635
- negative_prompt=negative_prompt,
636
- )
637
- prompt_embeds = torch.cat([prompt_embeds_, content_image_prompt_embeds, style_image_prompt_embeds], dim=1)
638
- negative_prompt_embeds = torch.cat([negative_prompt_embeds_,
639
- uncond_content_image_prompt_embeds, uncond_style_image_prompt_embeds],
640
- dim=1)
641
-
642
- generator = get_generator(seed, self.device)
643
-
644
- images = self.pipe(
645
- prompt_embeds=prompt_embeds,
646
- negative_prompt_embeds=negative_prompt_embeds,
647
- guidance_scale=guidance_scale,
648
- num_inference_steps=num_inference_steps,
649
- generator=generator,
650
- **kwargs,
651
- ).images
652
-
653
- return images
654
-
655
-
656
- class IPAdapterXL_CS(IPAdapter_CS):
657
- """SDXL"""
658
-
659
- def generate(
660
- self,
661
- pil_content_image,
662
- pil_style_image,
663
- prompt=None,
664
- negative_prompt=None,
665
- content_scale=1.0,
666
- style_scale=1.0,
667
- num_samples=4,
668
- seed=None,
669
- content_image_embeds=None,
670
- style_image_embeds=None,
671
- num_inference_steps=30,
672
- neg_content_emb=None,
673
- neg_content_prompt=None,
674
- neg_content_scale=1.0,
675
- **kwargs,
676
- ):
677
- self.set_scale(content_scale, style_scale)
678
-
679
- num_prompts = 1 if isinstance(pil_content_image, Image.Image) else len(pil_content_image)
680
-
681
- if prompt is None:
682
- prompt = "best quality, high quality"
683
- if negative_prompt is None:
684
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
685
-
686
- if not isinstance(prompt, List):
687
- prompt = [prompt] * num_prompts
688
- if not isinstance(negative_prompt, List):
689
- negative_prompt = [negative_prompt] * num_prompts
690
-
691
- content_image_prompt_embeds, uncond_content_image_prompt_embeds = self.get_image_embeds(pil_content_image,
692
- content_image_embeds,
693
- content_or_style_='content')
694
-
695
-
696
-
697
- style_image_prompt_embeds, uncond_style_image_prompt_embeds = self.get_image_embeds(pil_style_image,
698
- style_image_embeds,
699
- content_or_style_='style')
700
-
701
- bs_embed, seq_len, _ = content_image_prompt_embeds.shape
702
-
703
- content_image_prompt_embeds = content_image_prompt_embeds.repeat(1, num_samples, 1)
704
- content_image_prompt_embeds = content_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
705
-
706
- uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.repeat(1, num_samples, 1)
707
- uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.view(bs_embed * num_samples, seq_len,
708
- -1)
709
- bs_style_embed, seq_style_len, _ = style_image_prompt_embeds.shape
710
- style_image_prompt_embeds = style_image_prompt_embeds.repeat(1, num_samples, 1)
711
- style_image_prompt_embeds = style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len, -1)
712
- uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.repeat(1, num_samples, 1)
713
- uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len,
714
- -1)
715
-
716
- with torch.inference_mode():
717
- (
718
- prompt_embeds,
719
- negative_prompt_embeds,
720
- pooled_prompt_embeds,
721
- negative_pooled_prompt_embeds,
722
- ) = self.pipe.encode_prompt(
723
- prompt,
724
- num_images_per_prompt=num_samples,
725
- do_classifier_free_guidance=True,
726
- negative_prompt=negative_prompt,
727
- )
728
- prompt_embeds = torch.cat([prompt_embeds, content_image_prompt_embeds, style_image_prompt_embeds], dim=1)
729
- negative_prompt_embeds = torch.cat([negative_prompt_embeds,
730
- uncond_content_image_prompt_embeds, uncond_style_image_prompt_embeds],
731
- dim=1)
732
-
733
- self.generator = get_generator(seed, self.device)
734
-
735
- images = self.pipe(
736
- prompt_embeds=prompt_embeds,
737
- negative_prompt_embeds=negative_prompt_embeds,
738
- pooled_prompt_embeds=pooled_prompt_embeds,
739
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
740
- num_inference_steps=num_inference_steps,
741
- generator=self.generator,
742
- **kwargs,
743
- ).images
744
- return images
745
-
746
-
747
- class CSGO(IPAdapterXL_CS):
748
- """SDXL"""
749
-
750
- def init_proj(self, num_tokens, content_or_style_='content', model_resampler=False):
751
- if content_or_style_ == 'content':
752
- if model_resampler:
753
- image_proj_model = Resampler(
754
- dim=self.pipe.unet.config.cross_attention_dim,
755
- depth=4,
756
- dim_head=64,
757
- heads=12,
758
- num_queries=num_tokens,
759
- embedding_dim=self.content_image_encoder.config.hidden_size,
760
- output_dim=self.pipe.unet.config.cross_attention_dim,
761
- ff_mult=4,
762
- ).to(self.device, dtype=torch.float16)
763
- else:
764
- image_proj_model = ImageProjModel(
765
- cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
766
- clip_embeddings_dim=self.image_encoder.config.projection_dim,
767
- clip_extra_context_tokens=num_tokens,
768
- ).to(self.device, dtype=torch.float16)
769
- if content_or_style_ == 'style':
770
- if model_resampler:
771
- image_proj_model = Resampler(
772
- dim=self.pipe.unet.config.cross_attention_dim,
773
- depth=4,
774
- dim_head=64,
775
- heads=12,
776
- num_queries=num_tokens,
777
- embedding_dim=self.content_image_encoder.config.hidden_size,
778
- output_dim=self.pipe.unet.config.cross_attention_dim,
779
- ff_mult=4,
780
- ).to(self.device, dtype=torch.float16)
781
- else:
782
- image_proj_model = ImageProjModel(
783
- cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
784
- clip_embeddings_dim=self.image_encoder.config.projection_dim,
785
- clip_extra_context_tokens=num_tokens,
786
- ).to(self.device, dtype=torch.float16)
787
- return image_proj_model
788
-
789
- @torch.inference_mode()
790
- def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_or_style_=''):
791
- if isinstance(pil_image, Image.Image):
792
- pil_image = [pil_image]
793
- if content_or_style_ == 'style':
794
-
795
- if self.style_model_resampler:
796
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
797
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16),
798
- output_hidden_states=True).hidden_states[-2]
799
- image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)
800
- uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds))
801
- else:
802
-
803
-
804
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
805
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
806
- image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)
807
- uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds))
808
- return image_prompt_embeds, uncond_image_prompt_embeds
809
-
810
-
811
- else:
812
-
813
- if self.content_image_encoder_path is not None:
814
- clip_image = self.content_image_processor(images=pil_image, return_tensors="pt").pixel_values
815
- outputs = self.content_image_encoder(clip_image.to(self.device, dtype=torch.float16),
816
- output_hidden_states=True)
817
- clip_image_embeds = outputs.last_hidden_state
818
- image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
819
-
820
- # uncond_clip_image_embeds = self.image_encoder(
821
- # torch.zeros_like(clip_image), output_hidden_states=True
822
- # ).last_hidden_state
823
- uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))
824
- return image_prompt_embeds, uncond_image_prompt_embeds
825
-
826
- else:
827
- if self.content_model_resampler:
828
-
829
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
830
-
831
- clip_image = clip_image.to(self.device, dtype=torch.float16)
832
- clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
833
- # clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
834
- image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
835
- # uncond_clip_image_embeds = self.image_encoder(
836
- # torch.zeros_like(clip_image), output_hidden_states=True
837
- # ).hidden_states[-2]
838
- uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))
839
- else:
840
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
841
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
842
- image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
843
- uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))
844
-
845
- return image_prompt_embeds, uncond_image_prompt_embeds
846
-
847
- # # clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
848
- # clip_image = clip_image.to(self.device, dtype=torch.float16)
849
- # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
850
- # image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
851
- # uncond_clip_image_embeds = self.image_encoder(
852
- # torch.zeros_like(clip_image), output_hidden_states=True
853
- # ).hidden_states[-2]
854
- # uncond_image_prompt_embeds = self.content_image_proj_model(uncond_clip_image_embeds)
855
- # return image_prompt_embeds, uncond_image_prompt_embeds
856
-
857
-
858
- class IPAdapterXL(IPAdapter):
859
- """SDXL"""
860
-
861
- def generate(
862
- self,
863
- pil_image,
864
- prompt=None,
865
- negative_prompt=None,
866
- scale=1.0,
867
- num_samples=4,
868
- seed=None,
869
- num_inference_steps=30,
870
- neg_content_emb=None,
871
- neg_content_prompt=None,
872
- neg_content_scale=1.0,
873
- **kwargs,
874
- ):
875
- self.set_scale(scale)
876
-
877
- num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
878
-
879
- if prompt is None:
880
- prompt = "best quality, high quality"
881
- if negative_prompt is None:
882
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
883
-
884
- if not isinstance(prompt, List):
885
- prompt = [prompt] * num_prompts
886
- if not isinstance(negative_prompt, List):
887
- negative_prompt = [negative_prompt] * num_prompts
888
-
889
- if neg_content_emb is None:
890
- if neg_content_prompt is not None:
891
- with torch.inference_mode():
892
- (
893
- prompt_embeds_, # torch.Size([1, 77, 2048])
894
- negative_prompt_embeds_,
895
- pooled_prompt_embeds_, # torch.Size([1, 1280])
896
- negative_pooled_prompt_embeds_,
897
- ) = self.pipe.encode_prompt(
898
- neg_content_prompt,
899
- num_images_per_prompt=num_samples,
900
- do_classifier_free_guidance=True,
901
- negative_prompt=negative_prompt,
902
- )
903
- pooled_prompt_embeds_ *= neg_content_scale
904
- else:
905
- pooled_prompt_embeds_ = neg_content_emb
906
- else:
907
- pooled_prompt_embeds_ = None
908
-
909
- image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image,
910
- content_prompt_embeds=pooled_prompt_embeds_)
911
- bs_embed, seq_len, _ = image_prompt_embeds.shape
912
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
913
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
914
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
915
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
916
-
917
- with torch.inference_mode():
918
- (
919
- prompt_embeds,
920
- negative_prompt_embeds,
921
- pooled_prompt_embeds,
922
- negative_pooled_prompt_embeds,
923
- ) = self.pipe.encode_prompt(
924
- prompt,
925
- num_images_per_prompt=num_samples,
926
- do_classifier_free_guidance=True,
927
- negative_prompt=negative_prompt,
928
- )
929
- prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
930
- negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
931
-
932
- self.generator = get_generator(seed, self.device)
933
-
934
- images = self.pipe(
935
- prompt_embeds=prompt_embeds,
936
- negative_prompt_embeds=negative_prompt_embeds,
937
- pooled_prompt_embeds=pooled_prompt_embeds,
938
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
939
- num_inference_steps=num_inference_steps,
940
- generator=self.generator,
941
- **kwargs,
942
- ).images
943
-
944
- return images
945
-
946
-
947
- class IPAdapterPlus(IPAdapter):
948
- """IP-Adapter with fine-grained features"""
949
-
950
- def init_proj(self):
951
- image_proj_model = Resampler(
952
- dim=self.pipe.unet.config.cross_attention_dim,
953
- depth=4,
954
- dim_head=64,
955
- heads=12,
956
- num_queries=self.num_tokens,
957
- embedding_dim=self.image_encoder.config.hidden_size,
958
- output_dim=self.pipe.unet.config.cross_attention_dim,
959
- ff_mult=4,
960
- ).to(self.device, dtype=torch.float16)
961
- return image_proj_model
962
-
963
- @torch.inference_mode()
964
- def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
965
- if isinstance(pil_image, Image.Image):
966
- pil_image = [pil_image]
967
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
968
- clip_image = clip_image.to(self.device, dtype=torch.float16)
969
- clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
970
- image_prompt_embeds = self.image_proj_model(clip_image_embeds)
971
- uncond_clip_image_embeds = self.image_encoder(
972
- torch.zeros_like(clip_image), output_hidden_states=True
973
- ).hidden_states[-2]
974
- uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
975
- return image_prompt_embeds, uncond_image_prompt_embeds
976
-
977
-
978
- class IPAdapterFull(IPAdapterPlus):
979
- """IP-Adapter with full features"""
980
-
981
- def init_proj(self):
982
- image_proj_model = MLPProjModel(
983
- cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
984
- clip_embeddings_dim=self.image_encoder.config.hidden_size,
985
- ).to(self.device, dtype=torch.float16)
986
- return image_proj_model
987
-
988
-
989
- class IPAdapterPlusXL(IPAdapter):
990
- """SDXL"""
991
-
992
- def init_proj(self):
993
- image_proj_model = Resampler(
994
- dim=1280,
995
- depth=4,
996
- dim_head=64,
997
- heads=20,
998
- num_queries=self.num_tokens,
999
- embedding_dim=self.image_encoder.config.hidden_size,
1000
- output_dim=self.pipe.unet.config.cross_attention_dim,
1001
- ff_mult=4,
1002
- ).to(self.device, dtype=torch.float16)
1003
- return image_proj_model
1004
-
1005
- @torch.inference_mode()
1006
- def get_image_embeds(self, pil_image):
1007
- if isinstance(pil_image, Image.Image):
1008
- pil_image = [pil_image]
1009
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
1010
- clip_image = clip_image.to(self.device, dtype=torch.float16)
1011
- clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
1012
- image_prompt_embeds = self.image_proj_model(clip_image_embeds)
1013
- uncond_clip_image_embeds = self.image_encoder(
1014
- torch.zeros_like(clip_image), output_hidden_states=True
1015
- ).hidden_states[-2]
1016
- uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
1017
- return image_prompt_embeds, uncond_image_prompt_embeds
1018
-
1019
- def generate(
1020
- self,
1021
- pil_image,
1022
- prompt=None,
1023
- negative_prompt=None,
1024
- scale=1.0,
1025
- num_samples=4,
1026
- seed=None,
1027
- num_inference_steps=30,
1028
- **kwargs,
1029
- ):
1030
- self.set_scale(scale)
1031
-
1032
- num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
1033
-
1034
- if prompt is None:
1035
- prompt = "best quality, high quality"
1036
- if negative_prompt is None:
1037
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
1038
-
1039
- if not isinstance(prompt, List):
1040
- prompt = [prompt] * num_prompts
1041
- if not isinstance(negative_prompt, List):
1042
- negative_prompt = [negative_prompt] * num_prompts
1043
-
1044
- image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
1045
- bs_embed, seq_len, _ = image_prompt_embeds.shape
1046
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
1047
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
1048
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
1049
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
1050
-
1051
- with torch.inference_mode():
1052
- (
1053
- prompt_embeds,
1054
- negative_prompt_embeds,
1055
- pooled_prompt_embeds,
1056
- negative_pooled_prompt_embeds,
1057
- ) = self.pipe.encode_prompt(
1058
- prompt,
1059
- num_images_per_prompt=num_samples,
1060
- do_classifier_free_guidance=True,
1061
- negative_prompt=negative_prompt,
1062
- )
1063
- prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
1064
- negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
1065
-
1066
- generator = get_generator(seed, self.device)
1067
-
1068
- images = self.pipe(
1069
- prompt_embeds=prompt_embeds,
1070
- negative_prompt_embeds=negative_prompt_embeds,
1071
- pooled_prompt_embeds=pooled_prompt_embeds,
1072
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1073
- num_inference_steps=num_inference_steps,
1074
- generator=generator,
1075
- **kwargs,
1076
- ).images
1077
-
1078
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ip_adapter/ip_adapter___init__.py DELETED
@@ -1,10 +0,0 @@
1
- from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull,IPAdapterXL_CS,IPAdapter_CS
2
- from .ip_adapter import CSGO
3
- __all__ = [
4
- "IPAdapter",
5
- "IPAdapterPlus",
6
- "IPAdapterPlusXL",
7
- "IPAdapterXL",
8
- "CSGO"
9
- "IPAdapterFull",
10
- ]
 
 
 
 
 
 
 
 
 
 
 
ip_adapter/ip_adapter_attention_processor.py DELETED
@@ -1,754 +0,0 @@
1
- # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
-
6
-
7
- class AttnProcessor(nn.Module):
8
- r"""
9
- Default processor for performing attention-related computations.
10
- """
11
-
12
- def __init__(
13
- self,
14
- hidden_size=None,
15
- cross_attention_dim=None,
16
- save_in_unet='down',
17
- atten_control=None,
18
- ):
19
- super().__init__()
20
- self.atten_control = atten_control
21
- self.save_in_unet = save_in_unet
22
-
23
- def __call__(
24
- self,
25
- attn,
26
- hidden_states,
27
- encoder_hidden_states=None,
28
- attention_mask=None,
29
- temb=None,
30
- ):
31
- residual = hidden_states
32
-
33
- if attn.spatial_norm is not None:
34
- hidden_states = attn.spatial_norm(hidden_states, temb)
35
-
36
- input_ndim = hidden_states.ndim
37
-
38
- if input_ndim == 4:
39
- batch_size, channel, height, width = hidden_states.shape
40
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
41
-
42
- batch_size, sequence_length, _ = (
43
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
44
- )
45
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
46
-
47
- if attn.group_norm is not None:
48
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
49
-
50
- query = attn.to_q(hidden_states)
51
-
52
- if encoder_hidden_states is None:
53
- encoder_hidden_states = hidden_states
54
- elif attn.norm_cross:
55
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
56
-
57
- key = attn.to_k(encoder_hidden_states)
58
- value = attn.to_v(encoder_hidden_states)
59
-
60
- query = attn.head_to_batch_dim(query)
61
- key = attn.head_to_batch_dim(key)
62
- value = attn.head_to_batch_dim(value)
63
-
64
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
65
- hidden_states = torch.bmm(attention_probs, value)
66
- hidden_states = attn.batch_to_head_dim(hidden_states)
67
-
68
- # linear proj
69
- hidden_states = attn.to_out[0](hidden_states)
70
- # dropout
71
- hidden_states = attn.to_out[1](hidden_states)
72
-
73
- if input_ndim == 4:
74
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
75
-
76
- if attn.residual_connection:
77
- hidden_states = hidden_states + residual
78
-
79
- hidden_states = hidden_states / attn.rescale_output_factor
80
-
81
- return hidden_states
82
-
83
-
84
- class IPAttnProcessor(nn.Module):
85
- r"""
86
- Attention processor for IP-Adapater.
87
- Args:
88
- hidden_size (`int`):
89
- The hidden size of the attention layer.
90
- cross_attention_dim (`int`):
91
- The number of channels in the `encoder_hidden_states`.
92
- scale (`float`, defaults to 1.0):
93
- the weight scale of image prompt.
94
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
95
- The context length of the image features.
96
- """
97
-
98
- def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False,save_in_unet='down', atten_control=None):
99
- super().__init__()
100
-
101
- self.hidden_size = hidden_size
102
- self.cross_attention_dim = cross_attention_dim
103
- self.scale = scale
104
- self.num_tokens = num_tokens
105
- self.skip = skip
106
-
107
- self.atten_control = atten_control
108
- self.save_in_unet = save_in_unet
109
-
110
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
111
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
112
-
113
- def __call__(
114
- self,
115
- attn,
116
- hidden_states,
117
- encoder_hidden_states=None,
118
- attention_mask=None,
119
- temb=None,
120
- ):
121
- residual = hidden_states
122
-
123
- if attn.spatial_norm is not None:
124
- hidden_states = attn.spatial_norm(hidden_states, temb)
125
-
126
- input_ndim = hidden_states.ndim
127
-
128
- if input_ndim == 4:
129
- batch_size, channel, height, width = hidden_states.shape
130
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
131
-
132
- batch_size, sequence_length, _ = (
133
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
134
- )
135
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
136
-
137
- if attn.group_norm is not None:
138
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
139
-
140
- query = attn.to_q(hidden_states)
141
-
142
- if encoder_hidden_states is None:
143
- encoder_hidden_states = hidden_states
144
- else:
145
- # get encoder_hidden_states, ip_hidden_states
146
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
147
- encoder_hidden_states, ip_hidden_states = (
148
- encoder_hidden_states[:, :end_pos, :],
149
- encoder_hidden_states[:, end_pos:, :],
150
- )
151
- if attn.norm_cross:
152
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
153
-
154
- key = attn.to_k(encoder_hidden_states)
155
- value = attn.to_v(encoder_hidden_states)
156
-
157
- query = attn.head_to_batch_dim(query)
158
- key = attn.head_to_batch_dim(key)
159
- value = attn.head_to_batch_dim(value)
160
-
161
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
162
- hidden_states = torch.bmm(attention_probs, value)
163
- hidden_states = attn.batch_to_head_dim(hidden_states)
164
-
165
- if not self.skip:
166
- # for ip-adapter
167
- ip_key = self.to_k_ip(ip_hidden_states)
168
- ip_value = self.to_v_ip(ip_hidden_states)
169
-
170
- ip_key = attn.head_to_batch_dim(ip_key)
171
- ip_value = attn.head_to_batch_dim(ip_value)
172
-
173
- ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
174
- self.attn_map = ip_attention_probs
175
- ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
176
- ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
177
-
178
- hidden_states = hidden_states + self.scale * ip_hidden_states
179
-
180
- # linear proj
181
- hidden_states = attn.to_out[0](hidden_states)
182
- # dropout
183
- hidden_states = attn.to_out[1](hidden_states)
184
-
185
- if input_ndim == 4:
186
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
187
-
188
- if attn.residual_connection:
189
- hidden_states = hidden_states + residual
190
-
191
- hidden_states = hidden_states / attn.rescale_output_factor
192
-
193
- return hidden_states
194
-
195
-
196
- class AttnProcessor2_0(torch.nn.Module):
197
- r"""
198
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
199
- """
200
-
201
- def __init__(
202
- self,
203
- hidden_size=None,
204
- cross_attention_dim=None,
205
- save_in_unet='down',
206
- atten_control=None,
207
- ):
208
- super().__init__()
209
- if not hasattr(F, "scaled_dot_product_attention"):
210
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
211
- self.atten_control = atten_control
212
- self.save_in_unet = save_in_unet
213
-
214
- def __call__(
215
- self,
216
- attn,
217
- hidden_states,
218
- encoder_hidden_states=None,
219
- attention_mask=None,
220
- temb=None,
221
- ):
222
- residual = hidden_states
223
-
224
- if attn.spatial_norm is not None:
225
- hidden_states = attn.spatial_norm(hidden_states, temb)
226
-
227
- input_ndim = hidden_states.ndim
228
-
229
- if input_ndim == 4:
230
- batch_size, channel, height, width = hidden_states.shape
231
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
232
-
233
- batch_size, sequence_length, _ = (
234
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
235
- )
236
-
237
- if attention_mask is not None:
238
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
239
- # scaled_dot_product_attention expects attention_mask shape to be
240
- # (batch, heads, source_length, target_length)
241
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
242
-
243
- if attn.group_norm is not None:
244
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
245
-
246
- query = attn.to_q(hidden_states)
247
-
248
- if encoder_hidden_states is None:
249
- encoder_hidden_states = hidden_states
250
- elif attn.norm_cross:
251
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
252
-
253
- key = attn.to_k(encoder_hidden_states)
254
- value = attn.to_v(encoder_hidden_states)
255
-
256
- inner_dim = key.shape[-1]
257
- head_dim = inner_dim // attn.heads
258
-
259
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
260
-
261
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
262
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
263
-
264
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
265
- # TODO: add support for attn.scale when we move to Torch 2.1
266
- hidden_states = F.scaled_dot_product_attention(
267
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
268
- )
269
-
270
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
271
- hidden_states = hidden_states.to(query.dtype)
272
-
273
- # linear proj
274
- hidden_states = attn.to_out[0](hidden_states)
275
- # dropout
276
- hidden_states = attn.to_out[1](hidden_states)
277
-
278
- if input_ndim == 4:
279
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
280
-
281
- if attn.residual_connection:
282
- hidden_states = hidden_states + residual
283
-
284
- hidden_states = hidden_states / attn.rescale_output_factor
285
-
286
- return hidden_states
287
-
288
-
289
- class IPAttnProcessor2_0(torch.nn.Module):
290
- r"""
291
- Attention processor for IP-Adapater for PyTorch 2.0.
292
- Args:
293
- hidden_size (`int`):
294
- The hidden size of the attention layer.
295
- cross_attention_dim (`int`):
296
- The number of channels in the `encoder_hidden_states`.
297
- scale (`float`, defaults to 1.0):
298
- the weight scale of image prompt.
299
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
300
- The context length of the image features.
301
- """
302
-
303
- def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False,save_in_unet='down', atten_control=None):
304
- super().__init__()
305
-
306
- if not hasattr(F, "scaled_dot_product_attention"):
307
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
308
-
309
- self.hidden_size = hidden_size
310
- self.cross_attention_dim = cross_attention_dim
311
- self.scale = scale
312
- self.num_tokens = num_tokens
313
- self.skip = skip
314
-
315
- self.atten_control = atten_control
316
- self.save_in_unet = save_in_unet
317
-
318
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
319
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
320
-
321
- def __call__(
322
- self,
323
- attn,
324
- hidden_states,
325
- encoder_hidden_states=None,
326
- attention_mask=None,
327
- temb=None,
328
- ):
329
- residual = hidden_states
330
-
331
- if attn.spatial_norm is not None:
332
- hidden_states = attn.spatial_norm(hidden_states, temb)
333
-
334
- input_ndim = hidden_states.ndim
335
-
336
- if input_ndim == 4:
337
- batch_size, channel, height, width = hidden_states.shape
338
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
339
-
340
- batch_size, sequence_length, _ = (
341
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
342
- )
343
-
344
- if attention_mask is not None:
345
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
346
- # scaled_dot_product_attention expects attention_mask shape to be
347
- # (batch, heads, source_length, target_length)
348
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
349
-
350
- if attn.group_norm is not None:
351
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
352
-
353
- query = attn.to_q(hidden_states)
354
-
355
- if encoder_hidden_states is None:
356
- encoder_hidden_states = hidden_states
357
- else:
358
- # get encoder_hidden_states, ip_hidden_states
359
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
360
- encoder_hidden_states, ip_hidden_states = (
361
- encoder_hidden_states[:, :end_pos, :],
362
- encoder_hidden_states[:, end_pos:, :],
363
- )
364
- if attn.norm_cross:
365
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
366
-
367
- key = attn.to_k(encoder_hidden_states)
368
- value = attn.to_v(encoder_hidden_states)
369
-
370
- inner_dim = key.shape[-1]
371
- head_dim = inner_dim // attn.heads
372
-
373
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
374
-
375
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
376
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
377
-
378
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
379
- # TODO: add support for attn.scale when we move to Torch 2.1
380
- hidden_states = F.scaled_dot_product_attention(
381
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
382
- )
383
-
384
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
385
- hidden_states = hidden_states.to(query.dtype)
386
-
387
- if not self.skip:
388
- # for ip-adapter
389
- ip_key = self.to_k_ip(ip_hidden_states)
390
- ip_value = self.to_v_ip(ip_hidden_states)
391
-
392
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
393
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
394
-
395
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
396
- # TODO: add support for attn.scale when we move to Torch 2.1
397
- ip_hidden_states = F.scaled_dot_product_attention(
398
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
399
- )
400
- with torch.no_grad():
401
- self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
402
- #print(self.attn_map.shape)
403
-
404
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
405
- ip_hidden_states = ip_hidden_states.to(query.dtype)
406
-
407
- hidden_states = hidden_states + self.scale * ip_hidden_states
408
-
409
- # linear proj
410
- hidden_states = attn.to_out[0](hidden_states)
411
- # dropout
412
- hidden_states = attn.to_out[1](hidden_states)
413
-
414
- if input_ndim == 4:
415
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
416
-
417
- if attn.residual_connection:
418
- hidden_states = hidden_states + residual
419
-
420
- hidden_states = hidden_states / attn.rescale_output_factor
421
-
422
- return hidden_states
423
-
424
-
425
- class IP_CS_AttnProcessor2_0(torch.nn.Module):
426
- r"""
427
- Attention processor for IP-Adapater for PyTorch 2.0.
428
- Args:
429
- hidden_size (`int`):
430
- The hidden size of the attention layer.
431
- cross_attention_dim (`int`):
432
- The number of channels in the `encoder_hidden_states`.
433
- scale (`float`, defaults to 1.0):
434
- the weight scale of image prompt.
435
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
436
- The context length of the image features.
437
- """
438
-
439
- def __init__(self, hidden_size, cross_attention_dim=None, content_scale=1.0,style_scale=1.0, num_content_tokens=4,num_style_tokens=4,
440
- skip=False,content=False, style=False):
441
- super().__init__()
442
-
443
- if not hasattr(F, "scaled_dot_product_attention"):
444
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
445
-
446
- self.hidden_size = hidden_size
447
- self.cross_attention_dim = cross_attention_dim
448
- self.content_scale = content_scale
449
- self.style_scale = style_scale
450
- self.num_content_tokens = num_content_tokens
451
- self.num_style_tokens = num_style_tokens
452
- self.skip = skip
453
-
454
- self.content = content
455
- self.style = style
456
-
457
- if self.content or self.style:
458
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
459
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
460
- self.to_k_ip_content =None
461
- self.to_v_ip_content =None
462
-
463
- def set_content_ipa(self,content_scale=1.0):
464
-
465
- self.to_k_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False)
466
- self.to_v_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False)
467
- self.content_scale=content_scale
468
- self.content =True
469
-
470
- def __call__(
471
- self,
472
- attn,
473
- hidden_states,
474
- encoder_hidden_states=None,
475
- attention_mask=None,
476
- temb=None,
477
- ):
478
- residual = hidden_states
479
-
480
- if attn.spatial_norm is not None:
481
- hidden_states = attn.spatial_norm(hidden_states, temb)
482
-
483
- input_ndim = hidden_states.ndim
484
-
485
- if input_ndim == 4:
486
- batch_size, channel, height, width = hidden_states.shape
487
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
488
-
489
- batch_size, sequence_length, _ = (
490
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
491
- )
492
-
493
- if attention_mask is not None:
494
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
495
- # scaled_dot_product_attention expects attention_mask shape to be
496
- # (batch, heads, source_length, target_length)
497
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
498
-
499
- if attn.group_norm is not None:
500
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
501
-
502
- query = attn.to_q(hidden_states)
503
-
504
- if encoder_hidden_states is None:
505
- encoder_hidden_states = hidden_states
506
- else:
507
- # get encoder_hidden_states, ip_hidden_states
508
- end_pos = encoder_hidden_states.shape[1] - self.num_content_tokens-self.num_style_tokens
509
- encoder_hidden_states, ip_content_hidden_states,ip_style_hidden_states = (
510
- encoder_hidden_states[:, :end_pos, :],
511
- encoder_hidden_states[:, end_pos:end_pos + self.num_content_tokens, :],
512
- encoder_hidden_states[:, end_pos + self.num_content_tokens:, :],
513
- )
514
- if attn.norm_cross:
515
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
516
-
517
- key = attn.to_k(encoder_hidden_states)
518
- value = attn.to_v(encoder_hidden_states)
519
-
520
- inner_dim = key.shape[-1]
521
- head_dim = inner_dim // attn.heads
522
-
523
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
524
-
525
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
526
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
527
-
528
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
529
- # TODO: add support for attn.scale when we move to Torch 2.1
530
- hidden_states = F.scaled_dot_product_attention(
531
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
532
- )
533
-
534
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
535
- hidden_states = hidden_states.to(query.dtype)
536
-
537
- if not self.skip and self.content is True:
538
- # print('content#####################################################')
539
- # for ip-content-adapter
540
- if self.to_k_ip_content is None:
541
-
542
- ip_content_key = self.to_k_ip(ip_content_hidden_states)
543
- ip_content_value = self.to_v_ip(ip_content_hidden_states)
544
- else:
545
- ip_content_key = self.to_k_ip_content(ip_content_hidden_states)
546
- ip_content_value = self.to_v_ip_content(ip_content_hidden_states)
547
-
548
- ip_content_key = ip_content_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
549
- ip_content_value = ip_content_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
550
-
551
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
552
- # TODO: add support for attn.scale when we move to Torch 2.1
553
- ip_content_hidden_states = F.scaled_dot_product_attention(
554
- query, ip_content_key, ip_content_value, attn_mask=None, dropout_p=0.0, is_causal=False
555
- )
556
-
557
-
558
- ip_content_hidden_states = ip_content_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
559
- ip_content_hidden_states = ip_content_hidden_states.to(query.dtype)
560
-
561
-
562
- hidden_states = hidden_states + self.content_scale * ip_content_hidden_states
563
-
564
- if not self.skip and self.style is True:
565
- # for ip-style-adapter
566
- ip_style_key = self.to_k_ip(ip_style_hidden_states)
567
- ip_style_value = self.to_v_ip(ip_style_hidden_states)
568
-
569
- ip_style_key = ip_style_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
570
- ip_style_value = ip_style_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
571
-
572
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
573
- # TODO: add support for attn.scale when we move to Torch 2.1
574
- ip_style_hidden_states = F.scaled_dot_product_attention(
575
- query, ip_style_key, ip_style_value, attn_mask=None, dropout_p=0.0, is_causal=False
576
- )
577
-
578
- ip_style_hidden_states = ip_style_hidden_states.transpose(1, 2).reshape(batch_size, -1,
579
- attn.heads * head_dim)
580
- ip_style_hidden_states = ip_style_hidden_states.to(query.dtype)
581
-
582
- hidden_states = hidden_states + self.style_scale * ip_style_hidden_states
583
-
584
- # linear proj
585
- hidden_states = attn.to_out[0](hidden_states)
586
- # dropout
587
- hidden_states = attn.to_out[1](hidden_states)
588
-
589
- if input_ndim == 4:
590
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
591
-
592
- if attn.residual_connection:
593
- hidden_states = hidden_states + residual
594
-
595
- hidden_states = hidden_states / attn.rescale_output_factor
596
-
597
- return hidden_states
598
-
599
- ## for controlnet
600
- class CNAttnProcessor:
601
- r"""
602
- Default processor for performing attention-related computations.
603
- """
604
-
605
- def __init__(self, num_tokens=4,save_in_unet='down',atten_control=None):
606
- self.num_tokens = num_tokens
607
- self.atten_control = atten_control
608
- self.save_in_unet = save_in_unet
609
-
610
- def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
611
- residual = hidden_states
612
-
613
- if attn.spatial_norm is not None:
614
- hidden_states = attn.spatial_norm(hidden_states, temb)
615
-
616
- input_ndim = hidden_states.ndim
617
-
618
- if input_ndim == 4:
619
- batch_size, channel, height, width = hidden_states.shape
620
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
621
-
622
- batch_size, sequence_length, _ = (
623
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
624
- )
625
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
626
-
627
- if attn.group_norm is not None:
628
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
629
-
630
- query = attn.to_q(hidden_states)
631
-
632
- if encoder_hidden_states is None:
633
- encoder_hidden_states = hidden_states
634
- else:
635
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
636
- encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
637
- if attn.norm_cross:
638
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
639
-
640
- key = attn.to_k(encoder_hidden_states)
641
- value = attn.to_v(encoder_hidden_states)
642
-
643
- query = attn.head_to_batch_dim(query)
644
- key = attn.head_to_batch_dim(key)
645
- value = attn.head_to_batch_dim(value)
646
-
647
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
648
- hidden_states = torch.bmm(attention_probs, value)
649
- hidden_states = attn.batch_to_head_dim(hidden_states)
650
-
651
- # linear proj
652
- hidden_states = attn.to_out[0](hidden_states)
653
- # dropout
654
- hidden_states = attn.to_out[1](hidden_states)
655
-
656
- if input_ndim == 4:
657
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
658
-
659
- if attn.residual_connection:
660
- hidden_states = hidden_states + residual
661
-
662
- hidden_states = hidden_states / attn.rescale_output_factor
663
-
664
- return hidden_states
665
-
666
-
667
- class CNAttnProcessor2_0:
668
- r"""
669
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
670
- """
671
-
672
- def __init__(self, num_tokens=4, save_in_unet='down', atten_control=None):
673
- if not hasattr(F, "scaled_dot_product_attention"):
674
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
675
- self.num_tokens = num_tokens
676
- self.atten_control = atten_control
677
- self.save_in_unet = save_in_unet
678
-
679
- def __call__(
680
- self,
681
- attn,
682
- hidden_states,
683
- encoder_hidden_states=None,
684
- attention_mask=None,
685
- temb=None,
686
- ):
687
- residual = hidden_states
688
-
689
- if attn.spatial_norm is not None:
690
- hidden_states = attn.spatial_norm(hidden_states, temb)
691
-
692
- input_ndim = hidden_states.ndim
693
-
694
- if input_ndim == 4:
695
- batch_size, channel, height, width = hidden_states.shape
696
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
697
-
698
- batch_size, sequence_length, _ = (
699
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
700
- )
701
-
702
- if attention_mask is not None:
703
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
704
- # scaled_dot_product_attention expects attention_mask shape to be
705
- # (batch, heads, source_length, target_length)
706
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
707
-
708
- if attn.group_norm is not None:
709
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
710
-
711
- query = attn.to_q(hidden_states)
712
-
713
- if encoder_hidden_states is None:
714
- encoder_hidden_states = hidden_states
715
- else:
716
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
717
- encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
718
- if attn.norm_cross:
719
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
720
-
721
- key = attn.to_k(encoder_hidden_states)
722
- value = attn.to_v(encoder_hidden_states)
723
-
724
- inner_dim = key.shape[-1]
725
- head_dim = inner_dim // attn.heads
726
-
727
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
728
-
729
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
730
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
731
-
732
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
733
- # TODO: add support for attn.scale when we move to Torch 2.1
734
- hidden_states = F.scaled_dot_product_attention(
735
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
736
- )
737
-
738
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
739
- hidden_states = hidden_states.to(query.dtype)
740
-
741
- # linear proj
742
- hidden_states = attn.to_out[0](hidden_states)
743
- # dropout
744
- hidden_states = attn.to_out[1](hidden_states)
745
-
746
- if input_ndim == 4:
747
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
748
-
749
- if attn.residual_connection:
750
- hidden_states = hidden_states + residual
751
-
752
- hidden_states = hidden_states / attn.rescale_output_factor
753
-
754
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ip_adapter/ip_adapter_ip_adapter.py DELETED
@@ -1,1078 +0,0 @@
1
- import os
2
- from typing import List
3
-
4
- import torch
5
- from diffusers import StableDiffusionPipeline
6
- from diffusers.pipelines.controlnet import MultiControlNetModel
7
- from PIL import Image
8
- from safetensors import safe_open
9
- from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
10
- from torchvision import transforms
11
- from .utils import is_torch2_available, get_generator
12
-
13
- # import torchvision.transforms.functional as Func
14
-
15
- # from .clip_style_models import CSD_CLIP, convert_state_dict
16
-
17
- if is_torch2_available():
18
- from .attention_processor import (
19
- AttnProcessor2_0 as AttnProcessor,
20
- )
21
- from .attention_processor import (
22
- CNAttnProcessor2_0 as CNAttnProcessor,
23
- )
24
- from .attention_processor import (
25
- IPAttnProcessor2_0 as IPAttnProcessor,
26
- )
27
- from .attention_processor import IP_CS_AttnProcessor2_0 as IP_CS_AttnProcessor
28
- else:
29
- from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
30
- from .resampler import Resampler
31
-
32
- from transformers import AutoImageProcessor, AutoModel
33
-
34
-
35
- class ImageProjModel(torch.nn.Module):
36
- """Projection Model"""
37
-
38
- def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
39
- super().__init__()
40
-
41
- self.generator = None
42
- self.cross_attention_dim = cross_attention_dim
43
- self.clip_extra_context_tokens = clip_extra_context_tokens
44
- # print(clip_embeddings_dim, self.clip_extra_context_tokens, cross_attention_dim)
45
- self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
46
- self.norm = torch.nn.LayerNorm(cross_attention_dim)
47
-
48
- def forward(self, image_embeds):
49
- embeds = image_embeds
50
- clip_extra_context_tokens = self.proj(embeds).reshape(
51
- -1, self.clip_extra_context_tokens, self.cross_attention_dim
52
- )
53
- clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
54
- return clip_extra_context_tokens
55
-
56
-
57
- class MLPProjModel(torch.nn.Module):
58
- """SD model with image prompt"""
59
-
60
- def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
61
- super().__init__()
62
-
63
- self.proj = torch.nn.Sequential(
64
- torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
65
- torch.nn.GELU(),
66
- torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
67
- torch.nn.LayerNorm(cross_attention_dim)
68
- )
69
-
70
- def forward(self, image_embeds):
71
- clip_extra_context_tokens = self.proj(image_embeds)
72
- return clip_extra_context_tokens
73
-
74
-
75
- class IPAdapter:
76
- def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, target_blocks=["block"]):
77
- self.device = device
78
- self.image_encoder_path = image_encoder_path
79
- self.ip_ckpt = ip_ckpt
80
- self.num_tokens = num_tokens
81
- self.target_blocks = target_blocks
82
-
83
- self.pipe = sd_pipe.to(self.device)
84
- self.set_ip_adapter()
85
-
86
- # load image encoder
87
- self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
88
- self.device, dtype=torch.float16
89
- )
90
- self.clip_image_processor = CLIPImageProcessor()
91
- # image proj model
92
- self.image_proj_model = self.init_proj()
93
-
94
- self.load_ip_adapter()
95
-
96
- def init_proj(self):
97
- image_proj_model = ImageProjModel(
98
- cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
99
- clip_embeddings_dim=self.image_encoder.config.projection_dim,
100
- clip_extra_context_tokens=self.num_tokens,
101
- ).to(self.device, dtype=torch.float16)
102
- return image_proj_model
103
-
104
- def set_ip_adapter(self):
105
- unet = self.pipe.unet
106
- attn_procs = {}
107
- for name in unet.attn_processors.keys():
108
- cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
109
- if name.startswith("mid_block"):
110
- hidden_size = unet.config.block_out_channels[-1]
111
- elif name.startswith("up_blocks"):
112
- block_id = int(name[len("up_blocks.")])
113
- hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
114
- elif name.startswith("down_blocks"):
115
- block_id = int(name[len("down_blocks.")])
116
- hidden_size = unet.config.block_out_channels[block_id]
117
- if cross_attention_dim is None:
118
- attn_procs[name] = AttnProcessor()
119
- else:
120
- selected = False
121
- for block_name in self.target_blocks:
122
- if block_name in name:
123
- selected = True
124
- break
125
- if selected:
126
- attn_procs[name] = IPAttnProcessor(
127
- hidden_size=hidden_size,
128
- cross_attention_dim=cross_attention_dim,
129
- scale=1.0,
130
- num_tokens=self.num_tokens,
131
- ).to(self.device, dtype=torch.float16)
132
- else:
133
- attn_procs[name] = IPAttnProcessor(
134
- hidden_size=hidden_size,
135
- cross_attention_dim=cross_attention_dim,
136
- scale=1.0,
137
- num_tokens=self.num_tokens,
138
- skip=True
139
- ).to(self.device, dtype=torch.float16)
140
- unet.set_attn_processor(attn_procs)
141
- if hasattr(self.pipe, "controlnet"):
142
- if isinstance(self.pipe.controlnet, MultiControlNetModel):
143
- for controlnet in self.pipe.controlnet.nets:
144
- controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
145
- else:
146
- self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
147
-
148
- def load_ip_adapter(self):
149
- if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
150
- state_dict = {"image_proj": {}, "ip_adapter": {}}
151
- with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
152
- for key in f.keys():
153
- if key.startswith("image_proj."):
154
- state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
155
- elif key.startswith("ip_adapter."):
156
- state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
157
- else:
158
- state_dict = torch.load(self.ip_ckpt, map_location="cpu")
159
- self.image_proj_model.load_state_dict(state_dict["image_proj"])
160
- ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
161
- ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
162
-
163
- @torch.inference_mode()
164
- def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None):
165
- if pil_image is not None:
166
- if isinstance(pil_image, Image.Image):
167
- pil_image = [pil_image]
168
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
169
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
170
- else:
171
- clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
172
-
173
- if content_prompt_embeds is not None:
174
- clip_image_embeds = clip_image_embeds - content_prompt_embeds
175
-
176
- image_prompt_embeds = self.image_proj_model(clip_image_embeds)
177
- uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
178
- return image_prompt_embeds, uncond_image_prompt_embeds
179
-
180
- def set_scale(self, scale):
181
- for attn_processor in self.pipe.unet.attn_processors.values():
182
- if isinstance(attn_processor, IPAttnProcessor):
183
- attn_processor.scale = scale
184
-
185
- def generate(
186
- self,
187
- pil_image=None,
188
- clip_image_embeds=None,
189
- prompt=None,
190
- negative_prompt=None,
191
- scale=1.0,
192
- num_samples=4,
193
- seed=None,
194
- guidance_scale=7.5,
195
- num_inference_steps=30,
196
- neg_content_emb=None,
197
- **kwargs,
198
- ):
199
- self.set_scale(scale)
200
-
201
- if pil_image is not None:
202
- num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
203
- else:
204
- num_prompts = clip_image_embeds.size(0)
205
-
206
- if prompt is None:
207
- prompt = "best quality, high quality"
208
- if negative_prompt is None:
209
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
210
-
211
- if not isinstance(prompt, List):
212
- prompt = [prompt] * num_prompts
213
- if not isinstance(negative_prompt, List):
214
- negative_prompt = [negative_prompt] * num_prompts
215
-
216
- image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
217
- pil_image=pil_image, clip_image_embeds=clip_image_embeds, content_prompt_embeds=neg_content_emb
218
- )
219
- bs_embed, seq_len, _ = image_prompt_embeds.shape
220
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
221
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
222
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
223
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
224
-
225
- with torch.inference_mode():
226
- prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
227
- prompt,
228
- device=self.device,
229
- num_images_per_prompt=num_samples,
230
- do_classifier_free_guidance=True,
231
- negative_prompt=negative_prompt,
232
- )
233
- prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
234
- negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
235
-
236
- generator = get_generator(seed, self.device)
237
-
238
- images = self.pipe(
239
- prompt_embeds=prompt_embeds,
240
- negative_prompt_embeds=negative_prompt_embeds,
241
- guidance_scale=guidance_scale,
242
- num_inference_steps=num_inference_steps,
243
- generator=generator,
244
- **kwargs,
245
- ).images
246
-
247
- return images
248
-
249
-
250
- class IPAdapter_CS:
251
- def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_content_tokens=4,
252
- num_style_tokens=4,
253
- target_content_blocks=["block"], target_style_blocks=["block"], content_image_encoder_path=None,
254
- controlnet_adapter=False,
255
- controlnet_target_content_blocks=None,
256
- controlnet_target_style_blocks=None,
257
- content_model_resampler=False,
258
- style_model_resampler=False,
259
- ):
260
- self.device = device
261
- self.image_encoder_path = image_encoder_path
262
- self.ip_ckpt = ip_ckpt
263
- self.num_content_tokens = num_content_tokens
264
- self.num_style_tokens = num_style_tokens
265
- self.content_target_blocks = target_content_blocks
266
- self.style_target_blocks = target_style_blocks
267
-
268
- self.content_model_resampler = content_model_resampler
269
- self.style_model_resampler = style_model_resampler
270
-
271
- self.controlnet_adapter = controlnet_adapter
272
- self.controlnet_target_content_blocks = controlnet_target_content_blocks
273
- self.controlnet_target_style_blocks = controlnet_target_style_blocks
274
-
275
- self.pipe = sd_pipe.to(self.device)
276
- self.set_ip_adapter()
277
- self.content_image_encoder_path = content_image_encoder_path
278
-
279
-
280
- # load image encoder
281
- if content_image_encoder_path is not None:
282
- self.content_image_encoder = AutoModel.from_pretrained(content_image_encoder_path).to(self.device,
283
- dtype=torch.float16)
284
- self.content_image_processor = AutoImageProcessor.from_pretrained(content_image_encoder_path)
285
- else:
286
- self.content_image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
287
- self.device, dtype=torch.float16
288
- )
289
- self.content_image_processor = CLIPImageProcessor()
290
- # model.requires_grad_(False)
291
-
292
- self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
293
- self.device, dtype=torch.float16
294
- )
295
- # if self.use_CSD is not None:
296
- # self.style_image_encoder = CSD_CLIP("vit_large", "default",self.use_CSD+"/ViT-L-14.pt")
297
- # model_path = self.use_CSD+"/checkpoint.pth"
298
- # checkpoint = torch.load(model_path, map_location="cpu")
299
- # state_dict = convert_state_dict(checkpoint['model_state_dict'])
300
- # self.style_image_encoder.load_state_dict(state_dict, strict=False)
301
- #
302
- # normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
303
- # self.style_preprocess = transforms.Compose([
304
- # transforms.Resize(size=224, interpolation=Func.InterpolationMode.BICUBIC),
305
- # transforms.CenterCrop(224),
306
- # transforms.ToTensor(),
307
- # normalize,
308
- # ])
309
-
310
- self.clip_image_processor = CLIPImageProcessor()
311
- # image proj model
312
- self.content_image_proj_model = self.init_proj(self.num_content_tokens, content_or_style_='content',
313
- model_resampler=self.content_model_resampler)
314
- self.style_image_proj_model = self.init_proj(self.num_style_tokens, content_or_style_='style',
315
- model_resampler=self.style_model_resampler)
316
-
317
- self.load_ip_adapter()
318
-
319
- def init_proj(self, num_tokens, content_or_style_='content', model_resampler=False):
320
-
321
- # print('@@@@',self.pipe.unet.config.cross_attention_dim,self.image_encoder.config.projection_dim)
322
- if content_or_style_ == 'content' and self.content_image_encoder_path is not None:
323
- image_proj_model = ImageProjModel(
324
- cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
325
- clip_embeddings_dim=self.content_image_encoder.config.projection_dim,
326
- clip_extra_context_tokens=num_tokens,
327
- ).to(self.device, dtype=torch.float16)
328
- return image_proj_model
329
-
330
- image_proj_model = ImageProjModel(
331
- cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
332
- clip_embeddings_dim=self.image_encoder.config.projection_dim,
333
- clip_extra_context_tokens=num_tokens,
334
- ).to(self.device, dtype=torch.float16)
335
- return image_proj_model
336
-
337
- def set_ip_adapter(self):
338
- unet = self.pipe.unet
339
- attn_procs = {}
340
- for name in unet.attn_processors.keys():
341
- cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
342
- if name.startswith("mid_block"):
343
- hidden_size = unet.config.block_out_channels[-1]
344
- elif name.startswith("up_blocks"):
345
- block_id = int(name[len("up_blocks.")])
346
- hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
347
- elif name.startswith("down_blocks"):
348
- block_id = int(name[len("down_blocks.")])
349
- hidden_size = unet.config.block_out_channels[block_id]
350
- if cross_attention_dim is None:
351
- attn_procs[name] = AttnProcessor()
352
- else:
353
- # layername_id += 1
354
- selected = False
355
- for block_name in self.style_target_blocks:
356
- if block_name in name:
357
- selected = True
358
- # print(name)
359
- attn_procs[name] = IP_CS_AttnProcessor(
360
- hidden_size=hidden_size,
361
- cross_attention_dim=cross_attention_dim,
362
- style_scale=1.0,
363
- style=True,
364
- num_content_tokens=self.num_content_tokens,
365
- num_style_tokens=self.num_style_tokens,
366
- )
367
- for block_name in self.content_target_blocks:
368
- if block_name in name:
369
- # selected = True
370
- if selected is False:
371
- attn_procs[name] = IP_CS_AttnProcessor(
372
- hidden_size=hidden_size,
373
- cross_attention_dim=cross_attention_dim,
374
- content_scale=1.0,
375
- content=True,
376
- num_content_tokens=self.num_content_tokens,
377
- num_style_tokens=self.num_style_tokens,
378
- )
379
- else:
380
- attn_procs[name].set_content_ipa(content_scale=1.0)
381
- # attn_procs[name].content=True
382
-
383
- if selected is False:
384
- attn_procs[name] = IP_CS_AttnProcessor(
385
- hidden_size=hidden_size,
386
- cross_attention_dim=cross_attention_dim,
387
- num_content_tokens=self.num_content_tokens,
388
- num_style_tokens=self.num_style_tokens,
389
- skip=True,
390
- )
391
-
392
- attn_procs[name].to(self.device, dtype=torch.float16)
393
- unet.set_attn_processor(attn_procs)
394
- if hasattr(self.pipe, "controlnet"):
395
- if self.controlnet_adapter is False:
396
- if isinstance(self.pipe.controlnet, MultiControlNetModel):
397
- for controlnet in self.pipe.controlnet.nets:
398
- controlnet.set_attn_processor(CNAttnProcessor(
399
- num_tokens=self.num_content_tokens + self.num_style_tokens))
400
- else:
401
- self.pipe.controlnet.set_attn_processor(CNAttnProcessor(
402
- num_tokens=self.num_content_tokens + self.num_style_tokens))
403
-
404
- else:
405
- controlnet_attn_procs = {}
406
- controlnet_style_target_blocks = self.controlnet_target_style_blocks
407
- controlnet_content_target_blocks = self.controlnet_target_content_blocks
408
- for name in self.pipe.controlnet.attn_processors.keys():
409
- # print(name)
410
- cross_attention_dim = None if name.endswith(
411
- "attn1.processor") else self.pipe.controlnet.config.cross_attention_dim
412
- if name.startswith("mid_block"):
413
- hidden_size = self.pipe.controlnet.config.block_out_channels[-1]
414
- elif name.startswith("up_blocks"):
415
- block_id = int(name[len("up_blocks.")])
416
- hidden_size = list(reversed(self.pipe.controlnet.config.block_out_channels))[block_id]
417
- elif name.startswith("down_blocks"):
418
- block_id = int(name[len("down_blocks.")])
419
- hidden_size = self.pipe.controlnet.config.block_out_channels[block_id]
420
- if cross_attention_dim is None:
421
- # layername_id += 1
422
- controlnet_attn_procs[name] = AttnProcessor()
423
-
424
- else:
425
- # layername_id += 1
426
- selected = False
427
- for block_name in controlnet_style_target_blocks:
428
- if block_name in name:
429
- selected = True
430
- # print(name)
431
- controlnet_attn_procs[name] = IP_CS_AttnProcessor(
432
- hidden_size=hidden_size,
433
- cross_attention_dim=cross_attention_dim,
434
- style_scale=1.0,
435
- style=True,
436
- num_content_tokens=self.num_content_tokens,
437
- num_style_tokens=self.num_style_tokens,
438
- )
439
-
440
- for block_name in controlnet_content_target_blocks:
441
- if block_name in name:
442
- if selected is False:
443
- controlnet_attn_procs[name] = IP_CS_AttnProcessor(
444
- hidden_size=hidden_size,
445
- cross_attention_dim=cross_attention_dim,
446
- content_scale=1.0,
447
- content=True,
448
- num_content_tokens=self.num_content_tokens,
449
- num_style_tokens=self.num_style_tokens,
450
- )
451
-
452
- selected = True
453
- elif selected is True:
454
- controlnet_attn_procs[name].set_content_ipa(content_scale=1.0)
455
-
456
- # if args.content_image_encoder_type !='dinov2':
457
- # weights = {
458
- # "to_k_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_k_ip.weight"],
459
- # "to_v_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_v_ip.weight"],
460
- # }
461
- # attn_procs[name].load_state_dict(weights)
462
- if selected is False:
463
- controlnet_attn_procs[name] = IP_CS_AttnProcessor(
464
- hidden_size=hidden_size,
465
- cross_attention_dim=cross_attention_dim,
466
- num_content_tokens=self.num_content_tokens,
467
- num_style_tokens=self.num_style_tokens,
468
- skip=True,
469
- )
470
- controlnet_attn_procs[name].to(self.device, dtype=torch.float16)
471
- # layer_name = name.split(".processor")[0]
472
- # # print(state_dict["ip_adapter"].keys())
473
- # weights = {
474
- # "to_k_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_k_ip.weight"],
475
- # "to_v_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_v_ip.weight"],
476
- # }
477
- # attn_procs[name].load_state_dict(weights)
478
- self.pipe.controlnet.set_attn_processor(controlnet_attn_procs)
479
-
480
- def load_ip_adapter(self):
481
- if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
482
- state_dict = {"content_image_proj": {}, "style_image_proj": {}, "ip_adapter": {}}
483
- with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
484
- for key in f.keys():
485
- if key.startswith("content_image_proj."):
486
- state_dict["content_image_proj"][key.replace("content_image_proj.", "")] = f.get_tensor(key)
487
- elif key.startswith("style_image_proj."):
488
- state_dict["style_image_proj"][key.replace("style_image_proj.", "")] = f.get_tensor(key)
489
- elif key.startswith("ip_adapter."):
490
- state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
491
- else:
492
- state_dict = torch.load(self.ip_ckpt, map_location="cpu")
493
- self.content_image_proj_model.load_state_dict(state_dict["content_image_proj"])
494
- self.style_image_proj_model.load_state_dict(state_dict["style_image_proj"])
495
-
496
- if 'conv_in_unet_sd' in state_dict.keys():
497
- self.pipe.unet.conv_in.load_state_dict(state_dict["conv_in_unet_sd"], strict=True)
498
- ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
499
- ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
500
-
501
- if self.controlnet_adapter is True:
502
- print('loading controlnet_adapter')
503
- self.pipe.controlnet.load_state_dict(state_dict["controlnet_adapter_modules"], strict=False)
504
-
505
- @torch.inference_mode()
506
- def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None,
507
- content_or_style_=''):
508
- # if pil_image is not None:
509
- # if isinstance(pil_image, Image.Image):
510
- # pil_image = [pil_image]
511
- # clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
512
- # clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
513
- # else:
514
- # clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
515
-
516
- # if content_prompt_embeds is not None:
517
- # clip_image_embeds = clip_image_embeds - content_prompt_embeds
518
-
519
- if content_or_style_ == 'content':
520
- if pil_image is not None:
521
- if isinstance(pil_image, Image.Image):
522
- pil_image = [pil_image]
523
- if self.content_image_proj_model is not None:
524
- clip_image = self.content_image_processor(images=pil_image, return_tensors="pt").pixel_values
525
- clip_image_embeds = self.content_image_encoder(
526
- clip_image.to(self.device, dtype=torch.float16)).image_embeds
527
- else:
528
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
529
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
530
- else:
531
- clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
532
-
533
- image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
534
- uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))
535
- return image_prompt_embeds, uncond_image_prompt_embeds
536
- if content_or_style_ == 'style':
537
- if pil_image is not None:
538
- if self.use_CSD is not None:
539
- clip_image = self.style_preprocess(pil_image).unsqueeze(0).to(self.device, dtype=torch.float32)
540
- clip_image_embeds = self.style_image_encoder(clip_image)
541
- else:
542
- if isinstance(pil_image, Image.Image):
543
- pil_image = [pil_image]
544
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
545
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
546
-
547
-
548
- else:
549
- clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
550
- image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)
551
- uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds))
552
- return image_prompt_embeds, uncond_image_prompt_embeds
553
-
554
- def set_scale(self, content_scale, style_scale):
555
- for attn_processor in self.pipe.unet.attn_processors.values():
556
- if isinstance(attn_processor, IP_CS_AttnProcessor):
557
- if attn_processor.content is True:
558
- attn_processor.content_scale = content_scale
559
-
560
- if attn_processor.style is True:
561
- attn_processor.style_scale = style_scale
562
- # print('style_scale:',style_scale)
563
- if self.controlnet_adapter is not None:
564
- for attn_processor in self.pipe.controlnet.attn_processors.values():
565
-
566
- if isinstance(attn_processor, IP_CS_AttnProcessor):
567
- if attn_processor.content is True:
568
- attn_processor.content_scale = content_scale
569
- # print(content_scale)
570
-
571
- if attn_processor.style is True:
572
- attn_processor.style_scale = style_scale
573
-
574
- def generate(
575
- self,
576
- pil_content_image=None,
577
- pil_style_image=None,
578
- clip_content_image_embeds=None,
579
- clip_style_image_embeds=None,
580
- prompt=None,
581
- negative_prompt=None,
582
- content_scale=1.0,
583
- style_scale=1.0,
584
- num_samples=4,
585
- seed=None,
586
- guidance_scale=7.5,
587
- num_inference_steps=30,
588
- neg_content_emb=None,
589
- **kwargs,
590
- ):
591
- self.set_scale(content_scale, style_scale)
592
-
593
- if pil_content_image is not None:
594
- num_prompts = 1 if isinstance(pil_content_image, Image.Image) else len(pil_content_image)
595
- else:
596
- num_prompts = clip_content_image_embeds.size(0)
597
-
598
- if prompt is None:
599
- prompt = "best quality, high quality"
600
- if negative_prompt is None:
601
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
602
-
603
- if not isinstance(prompt, List):
604
- prompt = [prompt] * num_prompts
605
- if not isinstance(negative_prompt, List):
606
- negative_prompt = [negative_prompt] * num_prompts
607
-
608
- content_image_prompt_embeds, uncond_content_image_prompt_embeds = self.get_image_embeds(
609
- pil_image=pil_content_image, clip_image_embeds=clip_content_image_embeds
610
- )
611
- style_image_prompt_embeds, uncond_style_image_prompt_embeds = self.get_image_embeds(
612
- pil_image=pil_style_image, clip_image_embeds=clip_style_image_embeds
613
- )
614
-
615
- bs_embed, seq_len, _ = content_image_prompt_embeds.shape
616
- content_image_prompt_embeds = content_image_prompt_embeds.repeat(1, num_samples, 1)
617
- content_image_prompt_embeds = content_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
618
- uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.repeat(1, num_samples, 1)
619
- uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.view(bs_embed * num_samples, seq_len,
620
- -1)
621
-
622
- bs_style_embed, seq_style_len, _ = content_image_prompt_embeds.shape
623
- style_image_prompt_embeds = style_image_prompt_embeds.repeat(1, num_samples, 1)
624
- style_image_prompt_embeds = style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len, -1)
625
- uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.repeat(1, num_samples, 1)
626
- uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len,
627
- -1)
628
-
629
- with torch.inference_mode():
630
- prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
631
- prompt,
632
- device=self.device,
633
- num_images_per_prompt=num_samples,
634
- do_classifier_free_guidance=True,
635
- negative_prompt=negative_prompt,
636
- )
637
- prompt_embeds = torch.cat([prompt_embeds_, content_image_prompt_embeds, style_image_prompt_embeds], dim=1)
638
- negative_prompt_embeds = torch.cat([negative_prompt_embeds_,
639
- uncond_content_image_prompt_embeds, uncond_style_image_prompt_embeds],
640
- dim=1)
641
-
642
- generator = get_generator(seed, self.device)
643
-
644
- images = self.pipe(
645
- prompt_embeds=prompt_embeds,
646
- negative_prompt_embeds=negative_prompt_embeds,
647
- guidance_scale=guidance_scale,
648
- num_inference_steps=num_inference_steps,
649
- generator=generator,
650
- **kwargs,
651
- ).images
652
-
653
- return images
654
-
655
-
656
- class IPAdapterXL_CS(IPAdapter_CS):
657
- """SDXL"""
658
-
659
- def generate(
660
- self,
661
- pil_content_image,
662
- pil_style_image,
663
- prompt=None,
664
- negative_prompt=None,
665
- content_scale=1.0,
666
- style_scale=1.0,
667
- num_samples=4,
668
- seed=None,
669
- content_image_embeds=None,
670
- style_image_embeds=None,
671
- num_inference_steps=30,
672
- neg_content_emb=None,
673
- neg_content_prompt=None,
674
- neg_content_scale=1.0,
675
- **kwargs,
676
- ):
677
- self.set_scale(content_scale, style_scale)
678
-
679
- num_prompts = 1 if isinstance(pil_content_image, Image.Image) else len(pil_content_image)
680
-
681
- if prompt is None:
682
- prompt = "best quality, high quality"
683
- if negative_prompt is None:
684
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
685
-
686
- if not isinstance(prompt, List):
687
- prompt = [prompt] * num_prompts
688
- if not isinstance(negative_prompt, List):
689
- negative_prompt = [negative_prompt] * num_prompts
690
-
691
- content_image_prompt_embeds, uncond_content_image_prompt_embeds = self.get_image_embeds(pil_content_image,
692
- content_image_embeds,
693
- content_or_style_='content')
694
-
695
-
696
-
697
- style_image_prompt_embeds, uncond_style_image_prompt_embeds = self.get_image_embeds(pil_style_image,
698
- style_image_embeds,
699
- content_or_style_='style')
700
-
701
- bs_embed, seq_len, _ = content_image_prompt_embeds.shape
702
-
703
- content_image_prompt_embeds = content_image_prompt_embeds.repeat(1, num_samples, 1)
704
- content_image_prompt_embeds = content_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
705
-
706
- uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.repeat(1, num_samples, 1)
707
- uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.view(bs_embed * num_samples, seq_len,
708
- -1)
709
- bs_style_embed, seq_style_len, _ = style_image_prompt_embeds.shape
710
- style_image_prompt_embeds = style_image_prompt_embeds.repeat(1, num_samples, 1)
711
- style_image_prompt_embeds = style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len, -1)
712
- uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.repeat(1, num_samples, 1)
713
- uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len,
714
- -1)
715
-
716
- with torch.inference_mode():
717
- (
718
- prompt_embeds,
719
- negative_prompt_embeds,
720
- pooled_prompt_embeds,
721
- negative_pooled_prompt_embeds,
722
- ) = self.pipe.encode_prompt(
723
- prompt,
724
- num_images_per_prompt=num_samples,
725
- do_classifier_free_guidance=True,
726
- negative_prompt=negative_prompt,
727
- )
728
- prompt_embeds = torch.cat([prompt_embeds, content_image_prompt_embeds, style_image_prompt_embeds], dim=1)
729
- negative_prompt_embeds = torch.cat([negative_prompt_embeds,
730
- uncond_content_image_prompt_embeds, uncond_style_image_prompt_embeds],
731
- dim=1)
732
-
733
- self.generator = get_generator(seed, self.device)
734
-
735
- images = self.pipe(
736
- prompt_embeds=prompt_embeds,
737
- negative_prompt_embeds=negative_prompt_embeds,
738
- pooled_prompt_embeds=pooled_prompt_embeds,
739
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
740
- num_inference_steps=num_inference_steps,
741
- generator=self.generator,
742
- **kwargs,
743
- ).images
744
- return images
745
-
746
-
747
- class CSGO(IPAdapterXL_CS):
748
- """SDXL"""
749
-
750
- def init_proj(self, num_tokens, content_or_style_='content', model_resampler=False):
751
- if content_or_style_ == 'content':
752
- if model_resampler:
753
- image_proj_model = Resampler(
754
- dim=self.pipe.unet.config.cross_attention_dim,
755
- depth=4,
756
- dim_head=64,
757
- heads=12,
758
- num_queries=num_tokens,
759
- embedding_dim=self.content_image_encoder.config.hidden_size,
760
- output_dim=self.pipe.unet.config.cross_attention_dim,
761
- ff_mult=4,
762
- ).to(self.device, dtype=torch.float16)
763
- else:
764
- image_proj_model = ImageProjModel(
765
- cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
766
- clip_embeddings_dim=self.image_encoder.config.projection_dim,
767
- clip_extra_context_tokens=num_tokens,
768
- ).to(self.device, dtype=torch.float16)
769
- if content_or_style_ == 'style':
770
- if model_resampler:
771
- image_proj_model = Resampler(
772
- dim=self.pipe.unet.config.cross_attention_dim,
773
- depth=4,
774
- dim_head=64,
775
- heads=12,
776
- num_queries=num_tokens,
777
- embedding_dim=self.content_image_encoder.config.hidden_size,
778
- output_dim=self.pipe.unet.config.cross_attention_dim,
779
- ff_mult=4,
780
- ).to(self.device, dtype=torch.float16)
781
- else:
782
- image_proj_model = ImageProjModel(
783
- cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
784
- clip_embeddings_dim=self.image_encoder.config.projection_dim,
785
- clip_extra_context_tokens=num_tokens,
786
- ).to(self.device, dtype=torch.float16)
787
- return image_proj_model
788
-
789
- @torch.inference_mode()
790
- def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_or_style_=''):
791
- if isinstance(pil_image, Image.Image):
792
- pil_image = [pil_image]
793
- if content_or_style_ == 'style':
794
-
795
- if self.style_model_resampler:
796
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
797
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16),
798
- output_hidden_states=True).hidden_states[-2]
799
- image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)
800
- uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds))
801
- else:
802
-
803
-
804
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
805
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
806
- image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)
807
- uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds))
808
- return image_prompt_embeds, uncond_image_prompt_embeds
809
-
810
-
811
- else:
812
-
813
- if self.content_image_encoder_path is not None:
814
- clip_image = self.content_image_processor(images=pil_image, return_tensors="pt").pixel_values
815
- outputs = self.content_image_encoder(clip_image.to(self.device, dtype=torch.float16),
816
- output_hidden_states=True)
817
- clip_image_embeds = outputs.last_hidden_state
818
- image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
819
-
820
- # uncond_clip_image_embeds = self.image_encoder(
821
- # torch.zeros_like(clip_image), output_hidden_states=True
822
- # ).last_hidden_state
823
- uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))
824
- return image_prompt_embeds, uncond_image_prompt_embeds
825
-
826
- else:
827
- if self.content_model_resampler:
828
-
829
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
830
-
831
- clip_image = clip_image.to(self.device, dtype=torch.float16)
832
- clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
833
- # clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
834
- image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
835
- # uncond_clip_image_embeds = self.image_encoder(
836
- # torch.zeros_like(clip_image), output_hidden_states=True
837
- # ).hidden_states[-2]
838
- uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))
839
- else:
840
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
841
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
842
- image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
843
- uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))
844
-
845
- return image_prompt_embeds, uncond_image_prompt_embeds
846
-
847
- # # clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
848
- # clip_image = clip_image.to(self.device, dtype=torch.float16)
849
- # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
850
- # image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
851
- # uncond_clip_image_embeds = self.image_encoder(
852
- # torch.zeros_like(clip_image), output_hidden_states=True
853
- # ).hidden_states[-2]
854
- # uncond_image_prompt_embeds = self.content_image_proj_model(uncond_clip_image_embeds)
855
- # return image_prompt_embeds, uncond_image_prompt_embeds
856
-
857
-
858
- class IPAdapterXL(IPAdapter):
859
- """SDXL"""
860
-
861
- def generate(
862
- self,
863
- pil_image,
864
- prompt=None,
865
- negative_prompt=None,
866
- scale=1.0,
867
- num_samples=4,
868
- seed=None,
869
- num_inference_steps=30,
870
- neg_content_emb=None,
871
- neg_content_prompt=None,
872
- neg_content_scale=1.0,
873
- **kwargs,
874
- ):
875
- self.set_scale(scale)
876
-
877
- num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
878
-
879
- if prompt is None:
880
- prompt = "best quality, high quality"
881
- if negative_prompt is None:
882
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
883
-
884
- if not isinstance(prompt, List):
885
- prompt = [prompt] * num_prompts
886
- if not isinstance(negative_prompt, List):
887
- negative_prompt = [negative_prompt] * num_prompts
888
-
889
- if neg_content_emb is None:
890
- if neg_content_prompt is not None:
891
- with torch.inference_mode():
892
- (
893
- prompt_embeds_, # torch.Size([1, 77, 2048])
894
- negative_prompt_embeds_,
895
- pooled_prompt_embeds_, # torch.Size([1, 1280])
896
- negative_pooled_prompt_embeds_,
897
- ) = self.pipe.encode_prompt(
898
- neg_content_prompt,
899
- num_images_per_prompt=num_samples,
900
- do_classifier_free_guidance=True,
901
- negative_prompt=negative_prompt,
902
- )
903
- pooled_prompt_embeds_ *= neg_content_scale
904
- else:
905
- pooled_prompt_embeds_ = neg_content_emb
906
- else:
907
- pooled_prompt_embeds_ = None
908
-
909
- image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image,
910
- content_prompt_embeds=pooled_prompt_embeds_)
911
- bs_embed, seq_len, _ = image_prompt_embeds.shape
912
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
913
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
914
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
915
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
916
-
917
- with torch.inference_mode():
918
- (
919
- prompt_embeds,
920
- negative_prompt_embeds,
921
- pooled_prompt_embeds,
922
- negative_pooled_prompt_embeds,
923
- ) = self.pipe.encode_prompt(
924
- prompt,
925
- num_images_per_prompt=num_samples,
926
- do_classifier_free_guidance=True,
927
- negative_prompt=negative_prompt,
928
- )
929
- prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
930
- negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
931
-
932
- self.generator = get_generator(seed, self.device)
933
-
934
- images = self.pipe(
935
- prompt_embeds=prompt_embeds,
936
- negative_prompt_embeds=negative_prompt_embeds,
937
- pooled_prompt_embeds=pooled_prompt_embeds,
938
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
939
- num_inference_steps=num_inference_steps,
940
- generator=self.generator,
941
- **kwargs,
942
- ).images
943
-
944
- return images
945
-
946
-
947
- class IPAdapterPlus(IPAdapter):
948
- """IP-Adapter with fine-grained features"""
949
-
950
- def init_proj(self):
951
- image_proj_model = Resampler(
952
- dim=self.pipe.unet.config.cross_attention_dim,
953
- depth=4,
954
- dim_head=64,
955
- heads=12,
956
- num_queries=self.num_tokens,
957
- embedding_dim=self.image_encoder.config.hidden_size,
958
- output_dim=self.pipe.unet.config.cross_attention_dim,
959
- ff_mult=4,
960
- ).to(self.device, dtype=torch.float16)
961
- return image_proj_model
962
-
963
- @torch.inference_mode()
964
- def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
965
- if isinstance(pil_image, Image.Image):
966
- pil_image = [pil_image]
967
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
968
- clip_image = clip_image.to(self.device, dtype=torch.float16)
969
- clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
970
- image_prompt_embeds = self.image_proj_model(clip_image_embeds)
971
- uncond_clip_image_embeds = self.image_encoder(
972
- torch.zeros_like(clip_image), output_hidden_states=True
973
- ).hidden_states[-2]
974
- uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
975
- return image_prompt_embeds, uncond_image_prompt_embeds
976
-
977
-
978
- class IPAdapterFull(IPAdapterPlus):
979
- """IP-Adapter with full features"""
980
-
981
- def init_proj(self):
982
- image_proj_model = MLPProjModel(
983
- cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
984
- clip_embeddings_dim=self.image_encoder.config.hidden_size,
985
- ).to(self.device, dtype=torch.float16)
986
- return image_proj_model
987
-
988
-
989
- class IPAdapterPlusXL(IPAdapter):
990
- """SDXL"""
991
-
992
- def init_proj(self):
993
- image_proj_model = Resampler(
994
- dim=1280,
995
- depth=4,
996
- dim_head=64,
997
- heads=20,
998
- num_queries=self.num_tokens,
999
- embedding_dim=self.image_encoder.config.hidden_size,
1000
- output_dim=self.pipe.unet.config.cross_attention_dim,
1001
- ff_mult=4,
1002
- ).to(self.device, dtype=torch.float16)
1003
- return image_proj_model
1004
-
1005
- @torch.inference_mode()
1006
- def get_image_embeds(self, pil_image):
1007
- if isinstance(pil_image, Image.Image):
1008
- pil_image = [pil_image]
1009
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
1010
- clip_image = clip_image.to(self.device, dtype=torch.float16)
1011
- clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
1012
- image_prompt_embeds = self.image_proj_model(clip_image_embeds)
1013
- uncond_clip_image_embeds = self.image_encoder(
1014
- torch.zeros_like(clip_image), output_hidden_states=True
1015
- ).hidden_states[-2]
1016
- uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
1017
- return image_prompt_embeds, uncond_image_prompt_embeds
1018
-
1019
- def generate(
1020
- self,
1021
- pil_image,
1022
- prompt=None,
1023
- negative_prompt=None,
1024
- scale=1.0,
1025
- num_samples=4,
1026
- seed=None,
1027
- num_inference_steps=30,
1028
- **kwargs,
1029
- ):
1030
- self.set_scale(scale)
1031
-
1032
- num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
1033
-
1034
- if prompt is None:
1035
- prompt = "best quality, high quality"
1036
- if negative_prompt is None:
1037
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
1038
-
1039
- if not isinstance(prompt, List):
1040
- prompt = [prompt] * num_prompts
1041
- if not isinstance(negative_prompt, List):
1042
- negative_prompt = [negative_prompt] * num_prompts
1043
-
1044
- image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
1045
- bs_embed, seq_len, _ = image_prompt_embeds.shape
1046
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
1047
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
1048
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
1049
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
1050
-
1051
- with torch.inference_mode():
1052
- (
1053
- prompt_embeds,
1054
- negative_prompt_embeds,
1055
- pooled_prompt_embeds,
1056
- negative_pooled_prompt_embeds,
1057
- ) = self.pipe.encode_prompt(
1058
- prompt,
1059
- num_images_per_prompt=num_samples,
1060
- do_classifier_free_guidance=True,
1061
- negative_prompt=negative_prompt,
1062
- )
1063
- prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
1064
- negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
1065
-
1066
- generator = get_generator(seed, self.device)
1067
-
1068
- images = self.pipe(
1069
- prompt_embeds=prompt_embeds,
1070
- negative_prompt_embeds=negative_prompt_embeds,
1071
- pooled_prompt_embeds=pooled_prompt_embeds,
1072
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1073
- num_inference_steps=num_inference_steps,
1074
- generator=generator,
1075
- **kwargs,
1076
- ).images
1077
-
1078
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ip_adapter/ip_adapter_resampler.py DELETED
@@ -1,158 +0,0 @@
1
- # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
- # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3
-
4
- import math
5
-
6
- import torch
7
- import torch.nn as nn
8
- from einops import rearrange
9
- from einops.layers.torch import Rearrange
10
-
11
-
12
- # FFN
13
- def FeedForward(dim, mult=4):
14
- inner_dim = int(dim * mult)
15
- return nn.Sequential(
16
- nn.LayerNorm(dim),
17
- nn.Linear(dim, inner_dim, bias=False),
18
- nn.GELU(),
19
- nn.Linear(inner_dim, dim, bias=False),
20
- )
21
-
22
-
23
- def reshape_tensor(x, heads):
24
- bs, length, width = x.shape
25
- # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
26
- x = x.view(bs, length, heads, -1)
27
- # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28
- x = x.transpose(1, 2)
29
- # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30
- x = x.reshape(bs, heads, length, -1)
31
- return x
32
-
33
-
34
- class PerceiverAttention(nn.Module):
35
- def __init__(self, *, dim, dim_head=64, heads=8):
36
- super().__init__()
37
- self.scale = dim_head**-0.5
38
- self.dim_head = dim_head
39
- self.heads = heads
40
- inner_dim = dim_head * heads
41
-
42
- self.norm1 = nn.LayerNorm(dim)
43
- self.norm2 = nn.LayerNorm(dim)
44
-
45
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
46
- self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
48
-
49
- def forward(self, x, latents):
50
- """
51
- Args:
52
- x (torch.Tensor): image features
53
- shape (b, n1, D)
54
- latent (torch.Tensor): latent features
55
- shape (b, n2, D)
56
- """
57
- x = self.norm1(x)
58
- latents = self.norm2(latents)
59
-
60
- b, l, _ = latents.shape
61
-
62
- q = self.to_q(latents)
63
- kv_input = torch.cat((x, latents), dim=-2)
64
- k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65
-
66
- q = reshape_tensor(q, self.heads)
67
- k = reshape_tensor(k, self.heads)
68
- v = reshape_tensor(v, self.heads)
69
-
70
- # attention
71
- scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72
- weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
73
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
74
- out = weight @ v
75
-
76
- out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
77
-
78
- return self.to_out(out)
79
-
80
-
81
- class Resampler(nn.Module):
82
- def __init__(
83
- self,
84
- dim=1024,
85
- depth=8,
86
- dim_head=64,
87
- heads=16,
88
- num_queries=8,
89
- embedding_dim=768,
90
- output_dim=1024,
91
- ff_mult=4,
92
- max_seq_len: int = 257, # CLIP tokens + CLS token
93
- apply_pos_emb: bool = False,
94
- num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
95
- ):
96
- super().__init__()
97
- self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
98
-
99
- self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
100
-
101
- self.proj_in = nn.Linear(embedding_dim, dim)
102
-
103
- self.proj_out = nn.Linear(dim, output_dim)
104
- self.norm_out = nn.LayerNorm(output_dim)
105
-
106
- self.to_latents_from_mean_pooled_seq = (
107
- nn.Sequential(
108
- nn.LayerNorm(dim),
109
- nn.Linear(dim, dim * num_latents_mean_pooled),
110
- Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
111
- )
112
- if num_latents_mean_pooled > 0
113
- else None
114
- )
115
-
116
- self.layers = nn.ModuleList([])
117
- for _ in range(depth):
118
- self.layers.append(
119
- nn.ModuleList(
120
- [
121
- PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
122
- FeedForward(dim=dim, mult=ff_mult),
123
- ]
124
- )
125
- )
126
-
127
- def forward(self, x):
128
- if self.pos_emb is not None:
129
- n, device = x.shape[1], x.device
130
- pos_emb = self.pos_emb(torch.arange(n, device=device))
131
- x = x + pos_emb
132
-
133
- latents = self.latents.repeat(x.size(0), 1, 1)
134
-
135
- x = self.proj_in(x)
136
-
137
- if self.to_latents_from_mean_pooled_seq:
138
- meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
139
- meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
140
- latents = torch.cat((meanpooled_latents, latents), dim=-2)
141
-
142
- for attn, ff in self.layers:
143
- latents = attn(x, latents) + latents
144
- latents = ff(latents) + latents
145
-
146
- latents = self.proj_out(latents)
147
- return self.norm_out(latents)
148
-
149
-
150
- def masked_mean(t, *, dim, mask=None):
151
- if mask is None:
152
- return t.mean(dim=dim)
153
-
154
- denom = mask.sum(dim=dim, keepdim=True)
155
- mask = rearrange(mask, "b n -> b n 1")
156
- masked_t = t.masked_fill(~mask, 0.0)
157
-
158
- return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ip_adapter/ip_adapter_utils.py DELETED
@@ -1,142 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import numpy as np
4
- from PIL import Image
5
-
6
- BLOCKS = {
7
- 'content': ['down_blocks'],
8
- 'style': ["up_blocks"],
9
-
10
- }
11
-
12
- controlnet_BLOCKS = {
13
- 'content': [],
14
- 'style': ["down_blocks"],
15
- }
16
-
17
-
18
- def resize_width_height(width, height, min_short_side=512, max_long_side=1024):
19
-
20
- if width < height:
21
-
22
- if width < min_short_side:
23
- scale_factor = min_short_side / width
24
- new_width = min_short_side
25
- new_height = int(height * scale_factor)
26
- else:
27
- new_width, new_height = width, height
28
- else:
29
-
30
- if height < min_short_side:
31
- scale_factor = min_short_side / height
32
- new_width = int(width * scale_factor)
33
- new_height = min_short_side
34
- else:
35
- new_width, new_height = width, height
36
-
37
- if max(new_width, new_height) > max_long_side:
38
- scale_factor = max_long_side / max(new_width, new_height)
39
- new_width = int(new_width * scale_factor)
40
- new_height = int(new_height * scale_factor)
41
- return new_width, new_height
42
-
43
- def resize_content(content_image):
44
- max_long_side = 1024
45
- min_short_side = 1024
46
-
47
- new_width, new_height = resize_width_height(content_image.size[0], content_image.size[1],
48
- min_short_side=min_short_side, max_long_side=max_long_side)
49
- height = new_height // 16 * 16
50
- width = new_width // 16 * 16
51
- content_image = content_image.resize((width, height))
52
-
53
- return width,height,content_image
54
-
55
- attn_maps = {}
56
- def hook_fn(name):
57
- def forward_hook(module, input, output):
58
- if hasattr(module.processor, "attn_map"):
59
- attn_maps[name] = module.processor.attn_map
60
- del module.processor.attn_map
61
-
62
- return forward_hook
63
-
64
- def register_cross_attention_hook(unet):
65
- for name, module in unet.named_modules():
66
- if name.split('.')[-1].startswith('attn2'):
67
- module.register_forward_hook(hook_fn(name))
68
-
69
- return unet
70
-
71
- def upscale(attn_map, target_size):
72
- attn_map = torch.mean(attn_map, dim=0)
73
- attn_map = attn_map.permute(1,0)
74
- temp_size = None
75
-
76
- for i in range(0,5):
77
- scale = 2 ** i
78
- if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
79
- temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
80
- break
81
-
82
- assert temp_size is not None, "temp_size cannot is None"
83
-
84
- attn_map = attn_map.view(attn_map.shape[0], *temp_size)
85
-
86
- attn_map = F.interpolate(
87
- attn_map.unsqueeze(0).to(dtype=torch.float32),
88
- size=target_size,
89
- mode='bilinear',
90
- align_corners=False
91
- )[0]
92
-
93
- attn_map = torch.softmax(attn_map, dim=0)
94
- return attn_map
95
- def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
96
-
97
- idx = 0 if instance_or_negative else 1
98
- net_attn_maps = []
99
-
100
- for name, attn_map in attn_maps.items():
101
- attn_map = attn_map.cpu() if detach else attn_map
102
- attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
103
- attn_map = upscale(attn_map, image_size)
104
- net_attn_maps.append(attn_map)
105
-
106
- net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
107
-
108
- return net_attn_maps
109
-
110
- def attnmaps2images(net_attn_maps):
111
-
112
- #total_attn_scores = 0
113
- images = []
114
-
115
- for attn_map in net_attn_maps:
116
- attn_map = attn_map.cpu().numpy()
117
- #total_attn_scores += attn_map.mean().item()
118
-
119
- normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
120
- normalized_attn_map = normalized_attn_map.astype(np.uint8)
121
- #print("norm: ", normalized_attn_map.shape)
122
- image = Image.fromarray(normalized_attn_map)
123
-
124
- #image = fix_save_attn_map(attn_map)
125
- images.append(image)
126
-
127
- #print(total_attn_scores)
128
- return images
129
- def is_torch2_available():
130
- return hasattr(F, "scaled_dot_product_attention")
131
-
132
- def get_generator(seed, device):
133
-
134
- if seed is not None:
135
- if isinstance(seed, list):
136
- generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
137
- else:
138
- generator = torch.Generator(device).manual_seed(seed)
139
- else:
140
- generator = None
141
-
142
- return generator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ip_adapter/resampler.py DELETED
@@ -1,158 +0,0 @@
1
- # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
- # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3
-
4
- import math
5
-
6
- import torch
7
- import torch.nn as nn
8
- from einops import rearrange
9
- from einops.layers.torch import Rearrange
10
-
11
-
12
- # FFN
13
- def FeedForward(dim, mult=4):
14
- inner_dim = int(dim * mult)
15
- return nn.Sequential(
16
- nn.LayerNorm(dim),
17
- nn.Linear(dim, inner_dim, bias=False),
18
- nn.GELU(),
19
- nn.Linear(inner_dim, dim, bias=False),
20
- )
21
-
22
-
23
- def reshape_tensor(x, heads):
24
- bs, length, width = x.shape
25
- # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
26
- x = x.view(bs, length, heads, -1)
27
- # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28
- x = x.transpose(1, 2)
29
- # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30
- x = x.reshape(bs, heads, length, -1)
31
- return x
32
-
33
-
34
- class PerceiverAttention(nn.Module):
35
- def __init__(self, *, dim, dim_head=64, heads=8):
36
- super().__init__()
37
- self.scale = dim_head**-0.5
38
- self.dim_head = dim_head
39
- self.heads = heads
40
- inner_dim = dim_head * heads
41
-
42
- self.norm1 = nn.LayerNorm(dim)
43
- self.norm2 = nn.LayerNorm(dim)
44
-
45
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
46
- self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
48
-
49
- def forward(self, x, latents):
50
- """
51
- Args:
52
- x (torch.Tensor): image features
53
- shape (b, n1, D)
54
- latent (torch.Tensor): latent features
55
- shape (b, n2, D)
56
- """
57
- x = self.norm1(x)
58
- latents = self.norm2(latents)
59
-
60
- b, l, _ = latents.shape
61
-
62
- q = self.to_q(latents)
63
- kv_input = torch.cat((x, latents), dim=-2)
64
- k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65
-
66
- q = reshape_tensor(q, self.heads)
67
- k = reshape_tensor(k, self.heads)
68
- v = reshape_tensor(v, self.heads)
69
-
70
- # attention
71
- scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72
- weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
73
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
74
- out = weight @ v
75
-
76
- out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
77
-
78
- return self.to_out(out)
79
-
80
-
81
- class Resampler(nn.Module):
82
- def __init__(
83
- self,
84
- dim=1024,
85
- depth=8,
86
- dim_head=64,
87
- heads=16,
88
- num_queries=8,
89
- embedding_dim=768,
90
- output_dim=1024,
91
- ff_mult=4,
92
- max_seq_len: int = 257, # CLIP tokens + CLS token
93
- apply_pos_emb: bool = False,
94
- num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
95
- ):
96
- super().__init__()
97
- self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
98
-
99
- self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
100
-
101
- self.proj_in = nn.Linear(embedding_dim, dim)
102
-
103
- self.proj_out = nn.Linear(dim, output_dim)
104
- self.norm_out = nn.LayerNorm(output_dim)
105
-
106
- self.to_latents_from_mean_pooled_seq = (
107
- nn.Sequential(
108
- nn.LayerNorm(dim),
109
- nn.Linear(dim, dim * num_latents_mean_pooled),
110
- Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
111
- )
112
- if num_latents_mean_pooled > 0
113
- else None
114
- )
115
-
116
- self.layers = nn.ModuleList([])
117
- for _ in range(depth):
118
- self.layers.append(
119
- nn.ModuleList(
120
- [
121
- PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
122
- FeedForward(dim=dim, mult=ff_mult),
123
- ]
124
- )
125
- )
126
-
127
- def forward(self, x):
128
- if self.pos_emb is not None:
129
- n, device = x.shape[1], x.device
130
- pos_emb = self.pos_emb(torch.arange(n, device=device))
131
- x = x + pos_emb
132
-
133
- latents = self.latents.repeat(x.size(0), 1, 1)
134
-
135
- x = self.proj_in(x)
136
-
137
- if self.to_latents_from_mean_pooled_seq:
138
- meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
139
- meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
140
- latents = torch.cat((meanpooled_latents, latents), dim=-2)
141
-
142
- for attn, ff in self.layers:
143
- latents = attn(x, latents) + latents
144
- latents = ff(latents) + latents
145
-
146
- latents = self.proj_out(latents)
147
- return self.norm_out(latents)
148
-
149
-
150
- def masked_mean(t, *, dim, mask=None):
151
- if mask is None:
152
- return t.mean(dim=dim)
153
-
154
- denom = mask.sum(dim=dim, keepdim=True)
155
- mask = rearrange(mask, "b n -> b n 1")
156
- masked_t = t.masked_fill(~mask, 0.0)
157
-
158
- return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ip_adapter/utils.py DELETED
@@ -1,142 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import numpy as np
4
- from PIL import Image
5
-
6
- BLOCKS = {
7
- 'content': ['down_blocks'],
8
- 'style': ["up_blocks"],
9
-
10
- }
11
-
12
- controlnet_BLOCKS = {
13
- 'content': [],
14
- 'style': ["down_blocks"],
15
- }
16
-
17
-
18
- def resize_width_height(width, height, min_short_side=512, max_long_side=1024):
19
-
20
- if width < height:
21
-
22
- if width < min_short_side:
23
- scale_factor = min_short_side / width
24
- new_width = min_short_side
25
- new_height = int(height * scale_factor)
26
- else:
27
- new_width, new_height = width, height
28
- else:
29
-
30
- if height < min_short_side:
31
- scale_factor = min_short_side / height
32
- new_width = int(width * scale_factor)
33
- new_height = min_short_side
34
- else:
35
- new_width, new_height = width, height
36
-
37
- if max(new_width, new_height) > max_long_side:
38
- scale_factor = max_long_side / max(new_width, new_height)
39
- new_width = int(new_width * scale_factor)
40
- new_height = int(new_height * scale_factor)
41
- return new_width, new_height
42
-
43
- def resize_content(content_image):
44
- max_long_side = 1024
45
- min_short_side = 1024
46
-
47
- new_width, new_height = resize_width_height(content_image.size[0], content_image.size[1],
48
- min_short_side=min_short_side, max_long_side=max_long_side)
49
- height = new_height // 16 * 16
50
- width = new_width // 16 * 16
51
- content_image = content_image.resize((width, height))
52
-
53
- return width,height,content_image
54
-
55
- attn_maps = {}
56
- def hook_fn(name):
57
- def forward_hook(module, input, output):
58
- if hasattr(module.processor, "attn_map"):
59
- attn_maps[name] = module.processor.attn_map
60
- del module.processor.attn_map
61
-
62
- return forward_hook
63
-
64
- def register_cross_attention_hook(unet):
65
- for name, module in unet.named_modules():
66
- if name.split('.')[-1].startswith('attn2'):
67
- module.register_forward_hook(hook_fn(name))
68
-
69
- return unet
70
-
71
- def upscale(attn_map, target_size):
72
- attn_map = torch.mean(attn_map, dim=0)
73
- attn_map = attn_map.permute(1,0)
74
- temp_size = None
75
-
76
- for i in range(0,5):
77
- scale = 2 ** i
78
- if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
79
- temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
80
- break
81
-
82
- assert temp_size is not None, "temp_size cannot is None"
83
-
84
- attn_map = attn_map.view(attn_map.shape[0], *temp_size)
85
-
86
- attn_map = F.interpolate(
87
- attn_map.unsqueeze(0).to(dtype=torch.float32),
88
- size=target_size,
89
- mode='bilinear',
90
- align_corners=False
91
- )[0]
92
-
93
- attn_map = torch.softmax(attn_map, dim=0)
94
- return attn_map
95
- def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
96
-
97
- idx = 0 if instance_or_negative else 1
98
- net_attn_maps = []
99
-
100
- for name, attn_map in attn_maps.items():
101
- attn_map = attn_map.cpu() if detach else attn_map
102
- attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
103
- attn_map = upscale(attn_map, image_size)
104
- net_attn_maps.append(attn_map)
105
-
106
- net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
107
-
108
- return net_attn_maps
109
-
110
- def attnmaps2images(net_attn_maps):
111
-
112
- #total_attn_scores = 0
113
- images = []
114
-
115
- for attn_map in net_attn_maps:
116
- attn_map = attn_map.cpu().numpy()
117
- #total_attn_scores += attn_map.mean().item()
118
-
119
- normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
120
- normalized_attn_map = normalized_attn_map.astype(np.uint8)
121
- #print("norm: ", normalized_attn_map.shape)
122
- image = Image.fromarray(normalized_attn_map)
123
-
124
- #image = fix_save_attn_map(attn_map)
125
- images.append(image)
126
-
127
- #print(total_attn_scores)
128
- return images
129
- def is_torch2_available():
130
- return hasattr(F, "scaled_dot_product_attention")
131
-
132
- def get_generator(seed, device):
133
-
134
- if seed is not None:
135
- if isinstance(seed, list):
136
- generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
137
- else:
138
- generator = torch.Generator(device).manual_seed(seed)
139
- else:
140
- generator = None
141
-
142
- return generator