|
|
|
"""
|
|
This module contains the implementation of mutual self-attention,
|
|
which is a type of attention mechanism used in deep learning models.
|
|
The module includes several classes and functions related to attention mechanisms,
|
|
such as BasicTransformerBlock and TemporalBasicTransformerBlock.
|
|
The main purpose of this module is to provide a comprehensive attention mechanism for various tasks in deep learning,
|
|
such as image and video processing, natural language processing, and so on.
|
|
"""
|
|
|
|
from typing import Any, Dict, Optional
|
|
|
|
import torch
|
|
from einops import rearrange
|
|
|
|
from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock
|
|
|
|
|
|
def torch_dfs(model: torch.nn.Module):
|
|
"""
|
|
Perform a depth-first search (DFS) traversal on a PyTorch model's neural network architecture.
|
|
|
|
This function recursively traverses all the children modules of a given PyTorch model and returns a list
|
|
containing all the modules in the model's architecture. The DFS approach starts with the input model and
|
|
explores its children modules depth-wise before backtracking and exploring other branches.
|
|
|
|
Args:
|
|
model (torch.nn.Module): The root module of the neural network to traverse.
|
|
|
|
Returns:
|
|
list: A list of all the modules in the model's architecture.
|
|
"""
|
|
result = [model]
|
|
for child in model.children():
|
|
result += torch_dfs(child)
|
|
return result
|
|
|
|
|
|
class ReferenceAttentionControl:
|
|
"""
|
|
This class is used to control the reference attention mechanism in a neural network model.
|
|
It is responsible for managing the guidance and fusion blocks, and modifying the self-attention
|
|
and group normalization mechanisms. The class also provides methods for registering reference hooks
|
|
and updating/clearing the internal state of the attention control object.
|
|
|
|
Attributes:
|
|
unet: The UNet model associated with this attention control object.
|
|
mode: The operating mode of the attention control object, either 'write' or 'read'.
|
|
do_classifier_free_guidance: Whether to use classifier-free guidance in the attention mechanism.
|
|
attention_auto_machine_weight: The weight assigned to the attention auto-machine.
|
|
gn_auto_machine_weight: The weight assigned to the group normalization auto-machine.
|
|
style_fidelity: The style fidelity parameter for the attention mechanism.
|
|
reference_attn: Whether to use reference attention in the model.
|
|
reference_adain: Whether to use reference AdaIN in the model.
|
|
fusion_blocks: The type of fusion blocks to use in the model ('midup', 'late', or 'nofusion').
|
|
batch_size: The batch size used for processing video frames.
|
|
|
|
Methods:
|
|
register_reference_hooks: Registers the reference hooks for the attention control object.
|
|
hacked_basic_transformer_inner_forward: The modified inner forward method for the basic transformer block.
|
|
update: Updates the internal state of the attention control object using the provided writer and dtype.
|
|
clear: Clears the internal state of the attention control object.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
unet,
|
|
mode="write",
|
|
do_classifier_free_guidance=False,
|
|
attention_auto_machine_weight=float("inf"),
|
|
gn_auto_machine_weight=1.0,
|
|
style_fidelity=1.0,
|
|
reference_attn=True,
|
|
reference_adain=False,
|
|
fusion_blocks="midup",
|
|
batch_size=1,
|
|
) -> None:
|
|
"""
|
|
Initializes the ReferenceAttentionControl class.
|
|
|
|
Args:
|
|
unet (torch.nn.Module): The UNet model.
|
|
mode (str, optional): The mode of operation. Defaults to "write".
|
|
do_classifier_free_guidance (bool, optional): Whether to do classifier-free guidance. Defaults to False.
|
|
attention_auto_machine_weight (float, optional): The weight for attention auto-machine. Defaults to infinity.
|
|
gn_auto_machine_weight (float, optional): The weight for group-norm auto-machine. Defaults to 1.0.
|
|
style_fidelity (float, optional): The style fidelity. Defaults to 1.0.
|
|
reference_attn (bool, optional): Whether to use reference attention. Defaults to True.
|
|
reference_adain (bool, optional): Whether to use reference AdaIN. Defaults to False.
|
|
fusion_blocks (str, optional): The fusion blocks to use. Defaults to "midup".
|
|
batch_size (int, optional): The batch size. Defaults to 1.
|
|
|
|
Raises:
|
|
ValueError: If the mode is not recognized.
|
|
ValueError: If the fusion blocks are not recognized.
|
|
"""
|
|
|
|
self.unet = unet
|
|
assert mode in ["read", "write"]
|
|
assert fusion_blocks in ["midup", "full"]
|
|
self.reference_attn = reference_attn
|
|
self.reference_adain = reference_adain
|
|
self.fusion_blocks = fusion_blocks
|
|
self.register_reference_hooks(
|
|
mode,
|
|
do_classifier_free_guidance,
|
|
attention_auto_machine_weight,
|
|
gn_auto_machine_weight,
|
|
style_fidelity,
|
|
reference_attn,
|
|
reference_adain,
|
|
fusion_blocks,
|
|
batch_size=batch_size,
|
|
)
|
|
|
|
def register_reference_hooks(
|
|
self,
|
|
mode,
|
|
do_classifier_free_guidance,
|
|
_attention_auto_machine_weight,
|
|
_gn_auto_machine_weight,
|
|
_style_fidelity,
|
|
_reference_attn,
|
|
_reference_adain,
|
|
_dtype=torch.float16,
|
|
batch_size=1,
|
|
num_images_per_prompt=1,
|
|
device=torch.device("cpu"),
|
|
_fusion_blocks="midup",
|
|
):
|
|
"""
|
|
Registers reference hooks for the model.
|
|
|
|
This function is responsible for registering reference hooks in the model,
|
|
which are used to modify the attention mechanism and group normalization layers.
|
|
It takes various parameters as input, such as mode,
|
|
do_classifier_free_guidance, _attention_auto_machine_weight, _gn_auto_machine_weight, _style_fidelity,
|
|
_reference_attn, _reference_adain, _dtype, batch_size, num_images_per_prompt, device, and _fusion_blocks.
|
|
|
|
Args:
|
|
self: Reference to the instance of the class.
|
|
mode: The mode of operation for the reference hooks.
|
|
do_classifier_free_guidance: A boolean flag indicating whether to use classifier-free guidance.
|
|
_attention_auto_machine_weight: The weight for the attention auto-machine.
|
|
_gn_auto_machine_weight: The weight for the group normalization auto-machine.
|
|
_style_fidelity: The style fidelity for the reference hooks.
|
|
_reference_attn: A boolean flag indicating whether to use reference attention.
|
|
_reference_adain: A boolean flag indicating whether to use reference AdaIN.
|
|
_dtype: The data type for the reference hooks.
|
|
batch_size: The batch size for the reference hooks.
|
|
num_images_per_prompt: The number of images per prompt for the reference hooks.
|
|
device: The device for the reference hooks.
|
|
_fusion_blocks: The fusion blocks for the reference hooks.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
MODE = mode
|
|
if do_classifier_free_guidance:
|
|
uc_mask = (
|
|
torch.Tensor(
|
|
[1] * batch_size * num_images_per_prompt * 16
|
|
+ [0] * batch_size * num_images_per_prompt * 16
|
|
)
|
|
.to(device)
|
|
.bool()
|
|
)
|
|
else:
|
|
uc_mask = (
|
|
torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
|
|
.to(device)
|
|
.bool()
|
|
)
|
|
|
|
def hacked_basic_transformer_inner_forward(
|
|
self,
|
|
hidden_states: torch.FloatTensor,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
timestep: Optional[torch.LongTensor] = None,
|
|
cross_attention_kwargs: Dict[str, Any] = None,
|
|
class_labels: Optional[torch.LongTensor] = None,
|
|
video_length=None,
|
|
):
|
|
gate_msa = None
|
|
shift_mlp = None
|
|
scale_mlp = None
|
|
gate_mlp = None
|
|
|
|
if self.use_ada_layer_norm:
|
|
norm_hidden_states = self.norm1(hidden_states, timestep)
|
|
elif self.use_ada_layer_norm_zero:
|
|
(
|
|
norm_hidden_states,
|
|
gate_msa,
|
|
shift_mlp,
|
|
scale_mlp,
|
|
gate_mlp,
|
|
) = self.norm1(
|
|
hidden_states,
|
|
timestep,
|
|
class_labels,
|
|
hidden_dtype=hidden_states.dtype,
|
|
)
|
|
else:
|
|
norm_hidden_states = self.norm1(hidden_states)
|
|
|
|
|
|
|
|
cross_attention_kwargs = (
|
|
cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
|
)
|
|
if self.only_cross_attention:
|
|
attn_output = self.attn1(
|
|
norm_hidden_states,
|
|
encoder_hidden_states=(
|
|
encoder_hidden_states if self.only_cross_attention else None
|
|
),
|
|
attention_mask=attention_mask,
|
|
**cross_attention_kwargs,
|
|
)
|
|
else:
|
|
if MODE == "write":
|
|
self.bank.append(norm_hidden_states.clone())
|
|
attn_output = self.attn1(
|
|
norm_hidden_states,
|
|
encoder_hidden_states=(
|
|
encoder_hidden_states if self.only_cross_attention else None
|
|
),
|
|
attention_mask=attention_mask,
|
|
**cross_attention_kwargs,
|
|
)
|
|
if MODE == "read":
|
|
|
|
bank_fea = [
|
|
rearrange(
|
|
rearrange(
|
|
d,
|
|
"(b s) l c -> b s l c",
|
|
b=norm_hidden_states.shape[0] // video_length,
|
|
)[:, 0, :, :]
|
|
|
|
.repeat(1, video_length, 1, 1),
|
|
"b t l c -> (b t) l c",
|
|
)
|
|
for d in self.bank
|
|
]
|
|
motion_frames_fea = [rearrange(
|
|
d,
|
|
"(b s) l c -> b s l c",
|
|
b=norm_hidden_states.shape[0] // video_length,
|
|
)[:, 1:, :, :] for d in self.bank]
|
|
modify_norm_hidden_states = torch.cat(
|
|
[norm_hidden_states] + bank_fea, dim=1
|
|
)
|
|
hidden_states_uc = (
|
|
self.attn1(
|
|
norm_hidden_states,
|
|
encoder_hidden_states=modify_norm_hidden_states,
|
|
attention_mask=attention_mask,
|
|
)
|
|
+ hidden_states
|
|
)
|
|
if do_classifier_free_guidance:
|
|
hidden_states_c = hidden_states_uc.clone()
|
|
_uc_mask = uc_mask.clone()
|
|
if hidden_states.shape[0] != _uc_mask.shape[0]:
|
|
_uc_mask = (
|
|
torch.Tensor(
|
|
[1] * (hidden_states.shape[0] // 2)
|
|
+ [0] * (hidden_states.shape[0] // 2)
|
|
)
|
|
.to(device)
|
|
.bool()
|
|
)
|
|
hidden_states_c[_uc_mask] = (
|
|
self.attn1(
|
|
norm_hidden_states[_uc_mask],
|
|
encoder_hidden_states=norm_hidden_states[_uc_mask],
|
|
attention_mask=attention_mask,
|
|
)
|
|
+ hidden_states[_uc_mask]
|
|
)
|
|
hidden_states = hidden_states_c.clone()
|
|
else:
|
|
hidden_states = hidden_states_uc
|
|
|
|
|
|
if self.attn2 is not None:
|
|
|
|
norm_hidden_states = (
|
|
self.norm2(hidden_states, timestep)
|
|
if self.use_ada_layer_norm
|
|
else self.norm2(hidden_states)
|
|
)
|
|
hidden_states = (
|
|
self.attn2(
|
|
norm_hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
attention_mask=attention_mask,
|
|
)
|
|
+ hidden_states
|
|
)
|
|
|
|
|
|
hidden_states = self.ff(self.norm3(
|
|
hidden_states)) + hidden_states
|
|
|
|
|
|
if self.unet_use_temporal_attention:
|
|
d = hidden_states.shape[1]
|
|
hidden_states = rearrange(
|
|
hidden_states, "(b f) d c -> (b d) f c", f=video_length
|
|
)
|
|
norm_hidden_states = (
|
|
self.norm_temp(hidden_states, timestep)
|
|
if self.use_ada_layer_norm
|
|
else self.norm_temp(hidden_states)
|
|
)
|
|
hidden_states = (
|
|
self.attn_temp(norm_hidden_states) + hidden_states
|
|
)
|
|
hidden_states = rearrange(
|
|
hidden_states, "(b d) f c -> (b f) d c", d=d
|
|
)
|
|
|
|
return hidden_states, motion_frames_fea
|
|
|
|
if self.use_ada_layer_norm_zero:
|
|
attn_output = gate_msa.unsqueeze(1) * attn_output
|
|
hidden_states = attn_output + hidden_states
|
|
|
|
if self.attn2 is not None:
|
|
norm_hidden_states = (
|
|
self.norm2(hidden_states, timestep)
|
|
if self.use_ada_layer_norm
|
|
else self.norm2(hidden_states)
|
|
)
|
|
|
|
|
|
tmp = norm_hidden_states.shape[0] // encoder_hidden_states.shape[0]
|
|
attn_output = self.attn2(
|
|
norm_hidden_states,
|
|
|
|
encoder_hidden_states=encoder_hidden_states.repeat(
|
|
tmp, 1, 1),
|
|
attention_mask=encoder_attention_mask,
|
|
**cross_attention_kwargs,
|
|
)
|
|
hidden_states = attn_output + hidden_states
|
|
|
|
|
|
norm_hidden_states = self.norm3(hidden_states)
|
|
|
|
if self.use_ada_layer_norm_zero:
|
|
norm_hidden_states = (
|
|
norm_hidden_states *
|
|
(1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
|
)
|
|
|
|
ff_output = self.ff(norm_hidden_states)
|
|
|
|
if self.use_ada_layer_norm_zero:
|
|
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
|
|
|
hidden_states = ff_output + hidden_states
|
|
|
|
return hidden_states
|
|
|
|
if self.reference_attn:
|
|
if self.fusion_blocks == "midup":
|
|
attn_modules = [
|
|
module
|
|
for module in (
|
|
torch_dfs(self.unet.mid_block) +
|
|
torch_dfs(self.unet.up_blocks)
|
|
)
|
|
if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock))
|
|
]
|
|
elif self.fusion_blocks == "full":
|
|
attn_modules = [
|
|
module
|
|
for module in torch_dfs(self.unet)
|
|
if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock))
|
|
]
|
|
attn_modules = sorted(
|
|
attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
|
|
)
|
|
|
|
for i, module in enumerate(attn_modules):
|
|
module._original_inner_forward = module.forward
|
|
if isinstance(module, BasicTransformerBlock):
|
|
module.forward = hacked_basic_transformer_inner_forward.__get__(
|
|
module,
|
|
BasicTransformerBlock)
|
|
if isinstance(module, TemporalBasicTransformerBlock):
|
|
module.forward = hacked_basic_transformer_inner_forward.__get__(
|
|
module,
|
|
TemporalBasicTransformerBlock)
|
|
|
|
module.bank = []
|
|
module.attn_weight = float(i) / float(len(attn_modules))
|
|
|
|
def update(self, writer, dtype=torch.float16):
|
|
"""
|
|
Update the model's parameters.
|
|
|
|
Args:
|
|
writer (torch.nn.Module): The model's writer object.
|
|
dtype (torch.dtype, optional): The data type to be used for the update. Defaults to torch.float16.
|
|
|
|
Returns:
|
|
None.
|
|
"""
|
|
if self.reference_attn:
|
|
if self.fusion_blocks == "midup":
|
|
reader_attn_modules = [
|
|
module
|
|
for module in (
|
|
torch_dfs(self.unet.mid_block) +
|
|
torch_dfs(self.unet.up_blocks)
|
|
)
|
|
if isinstance(module, TemporalBasicTransformerBlock)
|
|
]
|
|
writer_attn_modules = [
|
|
module
|
|
for module in (
|
|
torch_dfs(writer.unet.mid_block)
|
|
+ torch_dfs(writer.unet.up_blocks)
|
|
)
|
|
if isinstance(module, BasicTransformerBlock)
|
|
]
|
|
elif self.fusion_blocks == "full":
|
|
reader_attn_modules = [
|
|
module
|
|
for module in torch_dfs(self.unet)
|
|
if isinstance(module, TemporalBasicTransformerBlock)
|
|
]
|
|
writer_attn_modules = [
|
|
module
|
|
for module in torch_dfs(writer.unet)
|
|
if isinstance(module, BasicTransformerBlock)
|
|
]
|
|
|
|
assert len(reader_attn_modules) == len(writer_attn_modules)
|
|
reader_attn_modules = sorted(
|
|
reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
|
|
)
|
|
writer_attn_modules = sorted(
|
|
writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
|
|
)
|
|
for r, w in zip(reader_attn_modules, writer_attn_modules):
|
|
r.bank = [v.clone().to(dtype) for v in w.bank]
|
|
|
|
|
|
def clear(self):
|
|
"""
|
|
Clears the attention bank of all reader attention modules.
|
|
|
|
This method is used when the `reference_attn` attribute is set to `True`.
|
|
It clears the attention bank of all reader attention modules inside the UNet
|
|
model based on the selected `fusion_blocks` mode.
|
|
|
|
If `fusion_blocks` is set to "midup", it searches for reader attention modules
|
|
in both the mid block and up blocks of the UNet model. If `fusion_blocks` is set
|
|
to "full", it searches for reader attention modules in the entire UNet model.
|
|
|
|
It sorts the reader attention modules by the number of neurons in their
|
|
`norm1.normalized_shape[0]` attribute in descending order. This sorting ensures
|
|
that the modules with more neurons are cleared first.
|
|
|
|
Finally, it iterates through the sorted list of reader attention modules and
|
|
calls the `clear()` method on each module's `bank` attribute to clear the
|
|
attention bank.
|
|
"""
|
|
if self.reference_attn:
|
|
if self.fusion_blocks == "midup":
|
|
reader_attn_modules = [
|
|
module
|
|
for module in (
|
|
torch_dfs(self.unet.mid_block) +
|
|
torch_dfs(self.unet.up_blocks)
|
|
)
|
|
if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock))
|
|
]
|
|
elif self.fusion_blocks == "full":
|
|
reader_attn_modules = [
|
|
module
|
|
for module in torch_dfs(self.unet)
|
|
if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock))
|
|
]
|
|
reader_attn_modules = sorted(
|
|
reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
|
|
)
|
|
for r in reader_attn_modules:
|
|
r.bank.clear()
|
|
|