Spaces:
No application file
No application file
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""该模型是自定义的attn_processor,实现特殊功能的 Attn功能。 | |
相对而言,开源代码经常会重新定义Attention 类, | |
This module implements special AttnProcessor function with custom attn_processor class. | |
While other open source code always modify Attention class. | |
""" | |
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py | |
from __future__ import annotations | |
import time | |
from typing import Any, Callable, Optional | |
import logging | |
from einops import rearrange, repeat | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import xformers | |
from diffusers.models.lora import LoRACompatibleLinear | |
from diffusers.utils.torch_utils import maybe_allow_in_graph | |
from diffusers.models.attention_processor import ( | |
Attention as DiffusersAttention, | |
AttnProcessor, | |
AttnProcessor2_0, | |
) | |
from ..data.data_util import ( | |
batch_concat_two_tensor_with_index, | |
batch_index_select, | |
align_repeat_tensor_single_dim, | |
batch_adain_conditioned_tensor, | |
) | |
from . import Model_Register | |
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | |
class IPAttention(DiffusersAttention): | |
r""" | |
Modified Attention class which has special layer, like ip_apadapter_to_k, ip_apadapter_to_v, | |
""" | |
def __init__( | |
self, | |
query_dim: int, | |
cross_attention_dim: int | None = None, | |
heads: int = 8, | |
dim_head: int = 64, | |
dropout: float = 0, | |
bias=False, | |
upcast_attention: bool = False, | |
upcast_softmax: bool = False, | |
cross_attention_norm: str | None = None, | |
cross_attention_norm_num_groups: int = 32, | |
added_kv_proj_dim: int | None = None, | |
norm_num_groups: int | None = None, | |
spatial_norm_dim: int | None = None, | |
out_bias: bool = True, | |
scale_qk: bool = True, | |
only_cross_attention: bool = False, | |
eps: float = 0.00001, | |
rescale_output_factor: float = 1, | |
residual_connection: bool = False, | |
_from_deprecated_attn_block=False, | |
processor: AttnProcessor | None = None, | |
cross_attn_temporal_cond: bool = False, | |
image_scale: float = 1.0, | |
ip_adapter_dim: int = None, | |
need_t2i_facein: bool = False, | |
facein_dim: int = None, | |
need_t2i_ip_adapter_face: bool = False, | |
ip_adapter_face_dim: int = None, | |
): | |
super().__init__( | |
query_dim, | |
cross_attention_dim, | |
heads, | |
dim_head, | |
dropout, | |
bias, | |
upcast_attention, | |
upcast_softmax, | |
cross_attention_norm, | |
cross_attention_norm_num_groups, | |
added_kv_proj_dim, | |
norm_num_groups, | |
spatial_norm_dim, | |
out_bias, | |
scale_qk, | |
only_cross_attention, | |
eps, | |
rescale_output_factor, | |
residual_connection, | |
_from_deprecated_attn_block, | |
processor, | |
) | |
self.cross_attn_temporal_cond = cross_attn_temporal_cond | |
self.image_scale = image_scale | |
# 面向首帧的 ip_adapter | |
# ip_apdater | |
if cross_attn_temporal_cond: | |
self.to_k_ip = LoRACompatibleLinear(ip_adapter_dim, query_dim, bias=False) | |
self.to_v_ip = LoRACompatibleLinear(ip_adapter_dim, query_dim, bias=False) | |
# facein | |
self.need_t2i_facein = need_t2i_facein | |
self.facein_dim = facein_dim | |
if need_t2i_facein: | |
raise NotImplementedError("facein") | |
# ip_adapter_face | |
self.need_t2i_ip_adapter_face = need_t2i_ip_adapter_face | |
self.ip_adapter_face_dim = ip_adapter_face_dim | |
if need_t2i_ip_adapter_face: | |
self.ip_adapter_face_to_k_ip = LoRACompatibleLinear( | |
ip_adapter_face_dim, query_dim, bias=False | |
) | |
self.ip_adapter_face_to_v_ip = LoRACompatibleLinear( | |
ip_adapter_face_dim, query_dim, bias=False | |
) | |
def set_use_memory_efficient_attention_xformers( | |
self, | |
use_memory_efficient_attention_xformers: bool, | |
attention_op: Callable[..., Any] | None = None, | |
): | |
if ( | |
"XFormers" in self.processor.__class__.__name__ | |
or "IP" in self.processor.__class__.__name__ | |
): | |
pass | |
else: | |
return super().set_use_memory_efficient_attention_xformers( | |
use_memory_efficient_attention_xformers, attention_op | |
) | |
class BaseIPAttnProcessor(nn.Module): | |
print_idx = 0 | |
def __init__(self, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
class T2IReferencenetIPAdapterXFormersAttnProcessor(BaseIPAttnProcessor): | |
r""" | |
面向 ref_image的 self_attn的 IPAdapter | |
""" | |
print_idx = 0 | |
def __init__( | |
self, | |
attention_op: Optional[Callable] = None, | |
): | |
super().__init__() | |
self.attention_op = attention_op | |
def __call__( | |
self, | |
attn: IPAttention, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
temb: Optional[torch.FloatTensor] = None, | |
scale: float = 1.0, | |
num_frames: int = None, | |
sample_index: torch.LongTensor = None, | |
vision_conditon_frames_sample_index: torch.LongTensor = None, | |
refer_emb: torch.Tensor = None, | |
vision_clip_emb: torch.Tensor = None, | |
ip_adapter_scale: float = 1.0, | |
face_emb: torch.Tensor = None, | |
facein_scale: float = 1.0, | |
ip_adapter_face_emb: torch.Tensor = None, | |
ip_adapter_face_scale: float = 1.0, | |
do_classifier_free_guidance: bool = False, | |
): | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view( | |
batch_size, channel, height * width | |
).transpose(1, 2) | |
batch_size, key_tokens, _ = ( | |
hidden_states.shape | |
if encoder_hidden_states is None | |
else encoder_hidden_states.shape | |
) | |
attention_mask = attn.prepare_attention_mask( | |
attention_mask, key_tokens, batch_size | |
) | |
if attention_mask is not None: | |
# expand our mask's singleton query_tokens dimension: | |
# [batch*heads, 1, key_tokens] -> | |
# [batch*heads, query_tokens, key_tokens] | |
# so that it can be added as a bias onto the attention scores that xformers computes: | |
# [batch*heads, query_tokens, key_tokens] | |
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us. | |
_, query_tokens, _ = hidden_states.shape | |
attention_mask = attention_mask.expand(-1, query_tokens, -1) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
1, 2 | |
) | |
query = attn.to_q(hidden_states, scale=scale) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states( | |
encoder_hidden_states | |
) | |
encoder_hidden_states = align_repeat_tensor_single_dim( | |
encoder_hidden_states, target_length=hidden_states.shape[0], dim=0 | |
) | |
key = attn.to_k(encoder_hidden_states, scale=scale) | |
value = attn.to_v(encoder_hidden_states, scale=scale) | |
# for facein | |
if self.print_idx == 0: | |
logger.debug( | |
f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(face_emb)={type(face_emb)}, facein_scale={facein_scale}" | |
) | |
if facein_scale > 0 and face_emb is not None: | |
raise NotImplementedError("facein") | |
query = attn.head_to_batch_dim(query).contiguous() | |
key = attn.head_to_batch_dim(key).contiguous() | |
value = attn.head_to_batch_dim(value).contiguous() | |
hidden_states = xformers.ops.memory_efficient_attention( | |
query, | |
key, | |
value, | |
attn_bias=attention_mask, | |
op=self.attention_op, | |
scale=attn.scale, | |
) | |
# ip-adapter start | |
if self.print_idx == 0: | |
logger.debug( | |
f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(vision_clip_emb)={type(vision_clip_emb)}" | |
) | |
if ip_adapter_scale > 0 and vision_clip_emb is not None: | |
if self.print_idx == 0: | |
logger.debug( | |
f"T2I cross_attn, ipadapter, vision_clip_emb={vision_clip_emb.shape}, hidden_states={hidden_states.shape}, batch_size={batch_size}" | |
) | |
ip_key = attn.to_k_ip(vision_clip_emb) | |
ip_value = attn.to_v_ip(vision_clip_emb) | |
ip_key = align_repeat_tensor_single_dim( | |
ip_key, target_length=batch_size, dim=0 | |
) | |
ip_value = align_repeat_tensor_single_dim( | |
ip_value, target_length=batch_size, dim=0 | |
) | |
ip_key = attn.head_to_batch_dim(ip_key).contiguous() | |
ip_value = attn.head_to_batch_dim(ip_value).contiguous() | |
if self.print_idx == 0: | |
logger.debug( | |
f"query={query.shape}, ip_key={ip_key.shape}, ip_value={ip_value.shape}" | |
) | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
hidden_states_from_ip = xformers.ops.memory_efficient_attention( | |
query, | |
ip_key, | |
ip_value, | |
attn_bias=attention_mask, | |
op=self.attention_op, | |
scale=attn.scale, | |
) | |
hidden_states = hidden_states + ip_adapter_scale * hidden_states_from_ip | |
# ip-adapter end | |
# ip-adapter face start | |
if self.print_idx == 0: | |
logger.debug( | |
f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(ip_adapter_face_emb)={type(ip_adapter_face_emb)}" | |
) | |
if ip_adapter_face_scale > 0 and ip_adapter_face_emb is not None: | |
if self.print_idx == 0: | |
logger.debug( | |
f"T2I cross_attn, ipadapter face, ip_adapter_face_emb={vision_clip_emb.shape}, hidden_states={hidden_states.shape}, batch_size={batch_size}" | |
) | |
ip_key = attn.ip_adapter_face_to_k_ip(ip_adapter_face_emb) | |
ip_value = attn.ip_adapter_face_to_v_ip(ip_adapter_face_emb) | |
ip_key = align_repeat_tensor_single_dim( | |
ip_key, target_length=batch_size, dim=0 | |
) | |
ip_value = align_repeat_tensor_single_dim( | |
ip_value, target_length=batch_size, dim=0 | |
) | |
ip_key = attn.head_to_batch_dim(ip_key).contiguous() | |
ip_value = attn.head_to_batch_dim(ip_value).contiguous() | |
if self.print_idx == 0: | |
logger.debug( | |
f"query={query.shape}, ip_key={ip_key.shape}, ip_value={ip_value.shape}" | |
) | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
hidden_states_from_ip = xformers.ops.memory_efficient_attention( | |
query, | |
ip_key, | |
ip_value, | |
attn_bias=attention_mask, | |
op=self.attention_op, | |
scale=attn.scale, | |
) | |
hidden_states = ( | |
hidden_states + ip_adapter_face_scale * hidden_states_from_ip | |
) | |
# ip-adapter face end | |
hidden_states = hidden_states.to(query.dtype) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states, scale=scale) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape( | |
batch_size, channel, height, width | |
) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
self.print_idx += 1 | |
return hidden_states | |
class NonParamT2ISelfReferenceXFormersAttnProcessor(BaseIPAttnProcessor): | |
r""" | |
面向首帧的 referenceonly attn,适用于 T2I的 self_attn | |
referenceonly with vis_cond as key, value, in t2i self_attn. | |
""" | |
print_idx = 0 | |
def __init__( | |
self, | |
attention_op: Optional[Callable] = None, | |
): | |
super().__init__() | |
self.attention_op = attention_op | |
def __call__( | |
self, | |
attn: IPAttention, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
temb: Optional[torch.FloatTensor] = None, | |
scale: float = 1.0, | |
num_frames: int = None, | |
sample_index: torch.LongTensor = None, | |
vision_conditon_frames_sample_index: torch.LongTensor = None, | |
refer_emb: torch.Tensor = None, | |
face_emb: torch.Tensor = None, | |
vision_clip_emb: torch.Tensor = None, | |
ip_adapter_scale: float = 1.0, | |
facein_scale: float = 1.0, | |
ip_adapter_face_emb: torch.Tensor = None, | |
ip_adapter_face_scale: float = 1.0, | |
do_classifier_free_guidance: bool = False, | |
): | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view( | |
batch_size, channel, height * width | |
).transpose(1, 2) | |
batch_size, key_tokens, _ = ( | |
hidden_states.shape | |
if encoder_hidden_states is None | |
else encoder_hidden_states.shape | |
) | |
attention_mask = attn.prepare_attention_mask( | |
attention_mask, key_tokens, batch_size | |
) | |
if attention_mask is not None: | |
# expand our mask's singleton query_tokens dimension: | |
# [batch*heads, 1, key_tokens] -> | |
# [batch*heads, query_tokens, key_tokens] | |
# so that it can be added as a bias onto the attention scores that xformers computes: | |
# [batch*heads, query_tokens, key_tokens] | |
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us. | |
_, query_tokens, _ = hidden_states.shape | |
attention_mask = attention_mask.expand(-1, query_tokens, -1) | |
# vision_cond in same unet attn start | |
if ( | |
vision_conditon_frames_sample_index is not None and num_frames > 1 | |
) or refer_emb is not None: | |
batchsize_timesize = hidden_states.shape[0] | |
if self.print_idx == 0: | |
logger.debug( | |
f"NonParamT2ISelfReferenceXFormersAttnProcessor 0, hidden_states={hidden_states.shape}, vision_conditon_frames_sample_index={vision_conditon_frames_sample_index}" | |
) | |
encoder_hidden_states = rearrange( | |
hidden_states, "(b t) hw c -> b t hw c", t=num_frames | |
) | |
# if False: | |
if vision_conditon_frames_sample_index is not None and num_frames > 1: | |
ip_hidden_states = batch_index_select( | |
encoder_hidden_states, | |
dim=1, | |
index=vision_conditon_frames_sample_index, | |
).contiguous() | |
if self.print_idx == 0: | |
logger.debug( | |
f"NonParamT2ISelfReferenceXFormersAttnProcessor 1, vis_cond referenceonly, encoder_hidden_states={encoder_hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}" | |
) | |
# | |
ip_hidden_states = rearrange( | |
ip_hidden_states, "b t hw c -> b 1 (t hw) c" | |
) | |
ip_hidden_states = align_repeat_tensor_single_dim( | |
ip_hidden_states, | |
dim=1, | |
target_length=num_frames, | |
) | |
# b t hw c -> b t hw + hw c | |
if self.print_idx == 0: | |
logger.debug( | |
f"NonParamT2ISelfReferenceXFormersAttnProcessor 2, vis_cond referenceonly, encoder_hidden_states={encoder_hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}" | |
) | |
encoder_hidden_states = torch.concat( | |
[encoder_hidden_states, ip_hidden_states], dim=2 | |
) | |
if self.print_idx == 0: | |
logger.debug( | |
f"NonParamT2ISelfReferenceXFormersAttnProcessor 3, hidden_states={hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}" | |
) | |
# if False: | |
if refer_emb is not None: # and num_frames > 1: | |
refer_emb = rearrange(refer_emb, "b c t h w->b 1 (t h w) c") | |
refer_emb = align_repeat_tensor_single_dim( | |
refer_emb, target_length=num_frames, dim=1 | |
) | |
if self.print_idx == 0: | |
logger.debug( | |
f"NonParamT2ISelfReferenceXFormersAttnProcessor4, referencenet, encoder_hidden_states={encoder_hidden_states.shape}, refer_emb={refer_emb.shape}" | |
) | |
encoder_hidden_states = torch.concat( | |
[encoder_hidden_states, refer_emb], dim=2 | |
) | |
if self.print_idx == 0: | |
logger.debug( | |
f"NonParamT2ISelfReferenceXFormersAttnProcessor5, referencenet, encoder_hidden_states={encoder_hidden_states.shape}, refer_emb={refer_emb.shape}" | |
) | |
encoder_hidden_states = rearrange( | |
encoder_hidden_states, "b t hw c -> (b t) hw c" | |
) | |
# vision_cond in same unet attn end | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
1, 2 | |
) | |
query = attn.to_q(hidden_states, scale=scale) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states( | |
encoder_hidden_states | |
) | |
encoder_hidden_states = align_repeat_tensor_single_dim( | |
encoder_hidden_states, target_length=hidden_states.shape[0], dim=0 | |
) | |
key = attn.to_k(encoder_hidden_states, scale=scale) | |
value = attn.to_v(encoder_hidden_states, scale=scale) | |
query = attn.head_to_batch_dim(query).contiguous() | |
key = attn.head_to_batch_dim(key).contiguous() | |
value = attn.head_to_batch_dim(value).contiguous() | |
hidden_states = xformers.ops.memory_efficient_attention( | |
query, | |
key, | |
value, | |
attn_bias=attention_mask, | |
op=self.attention_op, | |
scale=attn.scale, | |
) | |
hidden_states = hidden_states.to(query.dtype) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states, scale=scale) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape( | |
batch_size, channel, height, width | |
) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
self.print_idx += 1 | |
return hidden_states | |
class NonParamReferenceIPXFormersAttnProcessor( | |
NonParamT2ISelfReferenceXFormersAttnProcessor | |
): | |
def __init__(self, attention_op: Callable[..., Any] | None = None): | |
super().__init__(attention_op) | |
class ReferEmbFuseAttention(IPAttention): | |
"""使用 attention 融合 refernet 中的 emb 到 unet 对应的 latens 中 | |
# TODO: 目前只支持 bt hw c 的融合,后续考虑增加对 视频 bhw t c、b thw c的融合 | |
residual_connection: bool = True, 默认, 从不产生影响开始学习 | |
use attention to fuse referencenet emb into unet latents | |
# TODO: by now, only support bt hw c, later consider to support bhw t c, b thw c | |
residual_connection: bool = True, default, start from no effect | |
Args: | |
IPAttention (_type_): _description_ | |
""" | |
print_idx = 0 | |
def __init__( | |
self, | |
query_dim: int, | |
cross_attention_dim: int | None = None, | |
heads: int = 8, | |
dim_head: int = 64, | |
dropout: float = 0, | |
bias=False, | |
upcast_attention: bool = False, | |
upcast_softmax: bool = False, | |
cross_attention_norm: str | None = None, | |
cross_attention_norm_num_groups: int = 32, | |
added_kv_proj_dim: int | None = None, | |
norm_num_groups: int | None = None, | |
spatial_norm_dim: int | None = None, | |
out_bias: bool = True, | |
scale_qk: bool = True, | |
only_cross_attention: bool = False, | |
eps: float = 0.00001, | |
rescale_output_factor: float = 1, | |
residual_connection: bool = True, | |
_from_deprecated_attn_block=False, | |
processor: AttnProcessor | None = None, | |
cross_attn_temporal_cond: bool = False, | |
image_scale: float = 1, | |
): | |
super().__init__( | |
query_dim, | |
cross_attention_dim, | |
heads, | |
dim_head, | |
dropout, | |
bias, | |
upcast_attention, | |
upcast_softmax, | |
cross_attention_norm, | |
cross_attention_norm_num_groups, | |
added_kv_proj_dim, | |
norm_num_groups, | |
spatial_norm_dim, | |
out_bias, | |
scale_qk, | |
only_cross_attention, | |
eps, | |
rescale_output_factor, | |
residual_connection, | |
_from_deprecated_attn_block, | |
processor, | |
cross_attn_temporal_cond, | |
image_scale, | |
) | |
self.processor = None | |
# 配合residual,使一开始不影响之前结果 | |
nn.init.zeros_(self.to_out[0].weight) | |
nn.init.zeros_(self.to_out[0].bias) | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
temb: Optional[torch.FloatTensor] = None, | |
scale: float = 1.0, | |
num_frames: int = None, | |
) -> torch.Tensor: | |
"""fuse referencenet emb b c t2 h2 w2 into unet latents b c t1 h1 w1 with attn | |
refer to musev/models/attention_processor.py::NonParamT2ISelfReferenceXFormersAttnProcessor | |
Args: | |
hidden_states (torch.FloatTensor): unet latents, (b t1) c h1 w1 | |
encoder_hidden_states (Optional[torch.FloatTensor], optional): referencenet emb b c2 t2 h2 w2. Defaults to None. | |
attention_mask (Optional[torch.FloatTensor], optional): _description_. Defaults to None. | |
temb (Optional[torch.FloatTensor], optional): _description_. Defaults to None. | |
scale (float, optional): _description_. Defaults to 1.0. | |
num_frames (int, optional): _description_. Defaults to None. | |
Returns: | |
torch.Tensor: _description_ | |
""" | |
residual = hidden_states | |
# start | |
hidden_states = rearrange( | |
hidden_states, "(b t) c h w -> b c t h w", t=num_frames | |
) | |
batch_size, channel, t1, height, width = hidden_states.shape | |
if self.print_idx == 0: | |
logger.debug( | |
f"hidden_states={hidden_states.shape},encoder_hidden_states={encoder_hidden_states.shape}" | |
) | |
# concat with hidden_states b c t1 h1 w1 in hw channel into bt (t2 + 1)hw c | |
encoder_hidden_states = rearrange( | |
encoder_hidden_states, " b c t2 h w-> b (t2 h w) c" | |
) | |
encoder_hidden_states = repeat( | |
encoder_hidden_states, " b t2hw c -> (b t) t2hw c", t=t1 | |
) | |
hidden_states = rearrange(hidden_states, " b c t h w-> (b t) (h w) c") | |
# bt (t2+1)hw d | |
encoder_hidden_states = torch.concat( | |
[encoder_hidden_states, hidden_states], dim=1 | |
) | |
# encoder_hidden_states = align_repeat_tensor_single_dim( | |
# encoder_hidden_states, target_length=hidden_states.shape[0], dim=0 | |
# ) | |
# end | |
if self.spatial_norm is not None: | |
hidden_states = self.spatial_norm(hidden_states, temb) | |
_, key_tokens, _ = ( | |
hidden_states.shape | |
if encoder_hidden_states is None | |
else encoder_hidden_states.shape | |
) | |
attention_mask = self.prepare_attention_mask( | |
attention_mask, key_tokens, batch_size | |
) | |
if attention_mask is not None: | |
# expand our mask's singleton query_tokens dimension: | |
# [batch*heads, 1, key_tokens] -> | |
# [batch*heads, query_tokens, key_tokens] | |
# so that it can be added as a bias onto the attention scores that xformers computes: | |
# [batch*heads, query_tokens, key_tokens] | |
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us. | |
_, query_tokens, _ = hidden_states.shape | |
attention_mask = attention_mask.expand(-1, query_tokens, -1) | |
if self.group_norm is not None: | |
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose( | |
1, 2 | |
) | |
query = self.to_q(hidden_states, scale=scale) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif self.norm_cross: | |
encoder_hidden_states = self.norm_encoder_hidden_states( | |
encoder_hidden_states | |
) | |
key = self.to_k(encoder_hidden_states, scale=scale) | |
value = self.to_v(encoder_hidden_states, scale=scale) | |
query = self.head_to_batch_dim(query).contiguous() | |
key = self.head_to_batch_dim(key).contiguous() | |
value = self.head_to_batch_dim(value).contiguous() | |
# query: b t hw d | |
# key/value: bt (t1+1)hw d | |
hidden_states = xformers.ops.memory_efficient_attention( | |
query, | |
key, | |
value, | |
attn_bias=attention_mask, | |
scale=self.scale, | |
) | |
hidden_states = hidden_states.to(query.dtype) | |
hidden_states = self.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = self.to_out[0](hidden_states, scale=scale) | |
# dropout | |
hidden_states = self.to_out[1](hidden_states) | |
hidden_states = rearrange( | |
hidden_states, | |
"bt (h w) c-> bt c h w", | |
h=height, | |
w=width, | |
) | |
if self.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / self.rescale_output_factor | |
self.print_idx += 1 | |
return hidden_states | |