|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod |
|
import copy |
|
import math |
|
import numpy as np |
|
from typing import Iterable, Union |
|
import torch |
|
import torch as th |
|
import torch.nn as nn |
|
from torch import Tensor |
|
from einops import rearrange, repeat |
|
|
|
from comfy.ldm.modules.diffusionmodules.util import ( |
|
zero_module, |
|
timestep_embedding, |
|
) |
|
|
|
from comfy.cli_args import args |
|
from comfy.cldm.cldm import ControlNet as ControlNetCLDM |
|
from comfy.ldm.modules.attention import SpatialTransformer |
|
from comfy.ldm.modules.attention import attention_basic, attention_pytorch, attention_split, attention_sub_quad, default |
|
from comfy.ldm.modules.attention import FeedForward, SpatialTransformer |
|
from comfy.ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential, ResBlock, Downsample |
|
from comfy.model_patcher import ModelPatcher |
|
import comfy.ops |
|
import comfy.model_management |
|
|
|
from .logger import logger |
|
from .utils import (BIGMAX, TimestepKeyframeGroup, disable_weight_init_clean_groupnorm, |
|
prepare_mask_batch, broadcast_image_to_extend, extend_to_batch_size) |
|
|
|
|
|
|
|
|
|
|
|
optimized_attention_mm = attention_basic |
|
fallback_attention_mm = attention_basic |
|
if comfy.model_management.xformers_enabled(): |
|
pass |
|
|
|
if comfy.model_management.pytorch_attention_enabled(): |
|
optimized_attention_mm = attention_pytorch |
|
if args.use_split_cross_attention: |
|
fallback_attention_mm = attention_split |
|
else: |
|
fallback_attention_mm = attention_sub_quad |
|
else: |
|
if args.use_split_cross_attention: |
|
optimized_attention_mm = attention_split |
|
else: |
|
optimized_attention_mm = attention_sub_quad |
|
|
|
|
|
class SparseConst: |
|
HINT_MULT = "sparse_hint_mult" |
|
NONHINT_MULT = "sparse_nonhint_mult" |
|
MASK_MULT = "sparse_mask_mult" |
|
|
|
|
|
class SparseControlNet(ControlNetCLDM): |
|
def __init__(self, *args,**kwargs): |
|
super().__init__(*args, **kwargs) |
|
hint_channels = kwargs.get("hint_channels") |
|
operations: disable_weight_init_clean_groupnorm = kwargs.get("operations", disable_weight_init_clean_groupnorm) |
|
device = kwargs.get("device", None) |
|
self.use_simplified_conditioning_embedding = kwargs.get("use_simplified_conditioning_embedding", False) |
|
if self.use_simplified_conditioning_embedding: |
|
self.input_hint_block = TimestepEmbedSequential( |
|
zero_module(operations.conv_nd(self.dims, hint_channels, self.model_channels, 3, padding=1, dtype=self.dtype, device=device)), |
|
) |
|
self.motion_wrapper: SparseCtrlMotionWrapper = None |
|
|
|
def set_actual_length(self, actual_length: int, full_length: int): |
|
if self.motion_wrapper is not None: |
|
self.motion_wrapper.set_video_length(video_length=actual_length, full_length=full_length) |
|
|
|
def forward(self, x: Tensor, hint: Tensor, timesteps, context, y=None, **kwargs): |
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) |
|
emb = self.time_embed(t_emb) |
|
|
|
|
|
x = torch.zeros_like(x) |
|
guided_hint = self.input_hint_block(hint, emb, context) |
|
|
|
outs = [] |
|
|
|
hs = [] |
|
if self.num_classes is not None: |
|
assert y.shape[0] == x.shape[0] |
|
emb = emb + self.label_emb(y) |
|
|
|
h = x |
|
for module, zero_conv in zip(self.input_blocks, self.zero_convs): |
|
if guided_hint is not None: |
|
h = module(h, emb, context) |
|
h += guided_hint |
|
guided_hint = None |
|
else: |
|
h = module(h, emb, context) |
|
outs.append(zero_conv(h, emb, context)) |
|
|
|
h = self.middle_block(h, emb, context) |
|
outs.append(self.middle_block_out(h, emb, context)) |
|
|
|
return outs |
|
|
|
|
|
class SparseModelPatcher(ModelPatcher): |
|
def __init__(self, *args, **kwargs): |
|
self.model: SparseControlNet |
|
super().__init__(*args, **kwargs) |
|
|
|
def patch_model_lowvram(self, device_to=None, *args, **kwargs): |
|
patched_model = super().patch_model_lowvram(device_to, *args, **kwargs) |
|
|
|
if self.model.motion_wrapper is not None: |
|
|
|
remaining_tensors = list(self.model.motion_wrapper.state_dict().keys()) |
|
named_modules = [] |
|
for n, _ in self.model.motion_wrapper.named_modules(): |
|
named_modules.append(n) |
|
named_modules.append(f"{n}.weight") |
|
named_modules.append(f"{n}.bias") |
|
for name in named_modules: |
|
if name in remaining_tensors: |
|
remaining_tensors.remove(name) |
|
|
|
for key in remaining_tensors: |
|
self.patch_weight_to_device(key, device_to) |
|
if device_to is not None: |
|
comfy.utils.set_attr(self.model.motion_wrapper, key, comfy.utils.get_attr(self.model.motion_wrapper, key).to(device_to)) |
|
|
|
return patched_model |
|
|
|
def clone(self): |
|
|
|
n = SparseModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) |
|
n.patches = {} |
|
for k in self.patches: |
|
n.patches[k] = self.patches[k][:] |
|
if hasattr(n, "patches_uuid"): |
|
self.patches_uuid = n.patches_uuid |
|
|
|
n.object_patches = self.object_patches.copy() |
|
n.model_options = copy.deepcopy(self.model_options) |
|
if hasattr(n, "model_keys"): |
|
n.model_keys = self.model_keys |
|
if hasattr(n, "backup"): |
|
self.backup = n.backup |
|
if hasattr(n, "object_patches_backup"): |
|
self.object_patches_backup = n.object_patches_backup |
|
|
|
|
|
class PreprocSparseRGBWrapper: |
|
error_msg = "Invalid use of RGB SparseCtrl output. The output of RGB SparseCtrl preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply ControlNet node (advanced or otherwise). It cannot be used for anything else that accepts IMAGE input." |
|
def __init__(self, condhint: Tensor): |
|
self.condhint = condhint |
|
|
|
def movedim(self, *args, **kwargs): |
|
return self |
|
|
|
def __getattr__(self, *args, **kwargs): |
|
raise AttributeError(self.error_msg) |
|
|
|
def __setattr__(self, name, value): |
|
if name != "condhint": |
|
raise AttributeError(self.error_msg) |
|
super().__setattr__(name, value) |
|
|
|
def __iter__(self, *args, **kwargs): |
|
raise AttributeError(self.error_msg) |
|
|
|
def __next__(self, *args, **kwargs): |
|
raise AttributeError(self.error_msg) |
|
|
|
def __len__(self, *args, **kwargs): |
|
raise AttributeError(self.error_msg) |
|
|
|
def __getitem__(self, *args, **kwargs): |
|
raise AttributeError(self.error_msg) |
|
|
|
def __setitem__(self, *args, **kwargs): |
|
raise AttributeError(self.error_msg) |
|
|
|
|
|
class SparseContextAware: |
|
NEAREST_HINT = "nearest_hint" |
|
OFF = "off" |
|
|
|
LIST = [NEAREST_HINT, OFF] |
|
|
|
|
|
class SparseSettings: |
|
def __init__(self, sparse_method: 'SparseMethod', use_motion: bool=True, motion_strength=1.0, motion_scale=1.0, merged=False, |
|
sparse_mask_mult=1.0, sparse_hint_mult=1.0, sparse_nonhint_mult=1.0, context_aware=SparseContextAware.NEAREST_HINT): |
|
|
|
|
|
if type(sparse_method) == str: |
|
logger.warn("Outdated Steerable-Motion workflow detected; attempting to auto-convert indexes input. If you experience an error here, consult Steerable-Motion github, NOT Advanced-ControlNet.") |
|
sparse_method = SparseIndexMethod(get_idx_list_from_str(sparse_method)) |
|
self.sparse_method = sparse_method |
|
self.use_motion = use_motion |
|
self.motion_strength = motion_strength |
|
self.motion_scale = motion_scale |
|
self.merged = merged |
|
self.sparse_mask_mult = float(sparse_mask_mult) |
|
self.sparse_hint_mult = float(sparse_hint_mult) |
|
self.sparse_nonhint_mult = float(sparse_nonhint_mult) |
|
self.context_aware = context_aware |
|
|
|
def is_context_aware(self): |
|
return self.context_aware != SparseContextAware.OFF |
|
|
|
@classmethod |
|
def default(cls): |
|
return SparseSettings(sparse_method=SparseSpreadMethod(), use_motion=True) |
|
|
|
|
|
class SparseMethod(ABC): |
|
SPREAD = "spread" |
|
INDEX = "index" |
|
def __init__(self, method: str): |
|
self.method = method |
|
|
|
@abstractmethod |
|
def _get_indexes(self, hint_length: int, full_length: int) -> list[int]: |
|
pass |
|
|
|
def get_indexes(self, hint_length: int, full_length: int, sub_idxs: list[int]=None) -> tuple[list[int], list[int]]: |
|
returned_idxs = self._get_indexes(hint_length, full_length) |
|
if sub_idxs is None: |
|
return returned_idxs, None |
|
|
|
index_mapping = {} |
|
for i, value in enumerate(returned_idxs): |
|
index_mapping[value] = i |
|
def get_mapped_idxs(idxs: list[int]): |
|
return [index_mapping[idx] for idx in idxs] |
|
|
|
fitting_idxs = [] |
|
for sub_idx in sub_idxs: |
|
if sub_idx in returned_idxs: |
|
fitting_idxs.append(sub_idx) |
|
|
|
if len(fitting_idxs) > 0: |
|
return fitting_idxs, get_mapped_idxs(fitting_idxs) |
|
|
|
|
|
def get_closest_idx(target_idx: int, idxs: list[int]): |
|
min_idx = -1 |
|
min_dist = BIGMAX |
|
for idx in idxs: |
|
new_dist = abs(idx-target_idx) |
|
if new_dist < min_dist: |
|
min_idx = idx |
|
min_dist = new_dist |
|
if min_dist == 1: |
|
return min_idx, min_dist |
|
return min_idx, min_dist |
|
start_closest_idx, start_dist = get_closest_idx(sub_idxs[0], returned_idxs) |
|
end_closest_idx, end_dist = get_closest_idx(sub_idxs[-1], returned_idxs) |
|
|
|
if hint_length == 1: |
|
|
|
if start_dist == end_dist: |
|
|
|
center_idx = sub_idxs[np.linspace(0, len(sub_idxs)-1, 3, endpoint=True, dtype=int)[1]] |
|
return [center_idx], get_mapped_idxs([start_closest_idx]) |
|
|
|
if start_dist < end_dist: |
|
return [sub_idxs[0]], get_mapped_idxs([start_closest_idx]) |
|
return [sub_idxs[-1]], get_mapped_idxs([end_closest_idx]) |
|
|
|
|
|
if start_dist == end_dist: |
|
return [sub_idxs[0], sub_idxs[-1]], get_mapped_idxs([start_closest_idx, end_closest_idx]) |
|
|
|
if start_dist < end_dist: |
|
return [sub_idxs[0]], get_mapped_idxs([start_closest_idx]) |
|
return [sub_idxs[-1]], get_mapped_idxs([end_closest_idx]) |
|
|
|
|
|
class SparseSpreadMethod(SparseMethod): |
|
UNIFORM = "uniform" |
|
STARTING = "starting" |
|
ENDING = "ending" |
|
CENTER = "center" |
|
|
|
LIST = [UNIFORM, STARTING, ENDING, CENTER] |
|
|
|
def __init__(self, spread=UNIFORM): |
|
super().__init__(self.SPREAD) |
|
self.spread = spread |
|
|
|
def _get_indexes(self, hint_length: int, full_length: int) -> list[int]: |
|
|
|
if hint_length >= full_length: |
|
return list(range(full_length)) |
|
|
|
if hint_length == 1: |
|
if self.spread in [self.UNIFORM, self.STARTING]: |
|
return [0] |
|
elif self.spread == self.ENDING: |
|
return [full_length-1] |
|
elif self.spread == self.CENTER: |
|
|
|
return [np.linspace(0, full_length-1, 3, endpoint=True, dtype=int)[1]] |
|
else: |
|
raise ValueError(f"Unrecognized spread: {self.spread}") |
|
|
|
if self.spread == self.UNIFORM: |
|
return list(np.linspace(0, full_length-1, hint_length, endpoint=True, dtype=int)) |
|
elif self.spread == self.STARTING: |
|
|
|
return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[:-1] |
|
elif self.spread == self.ENDING: |
|
|
|
return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[1:] |
|
elif self.spread == self.CENTER: |
|
|
|
if full_length-hint_length < 3: |
|
return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[:-1] |
|
|
|
return list(np.linspace(0, full_length-1, hint_length+2, endpoint=True, dtype=int))[1:-1] |
|
return ValueError(f"Unrecognized spread: {self.spread}") |
|
|
|
|
|
class SparseIndexMethod(SparseMethod): |
|
def __init__(self, idxs: list[int]): |
|
super().__init__(self.INDEX) |
|
self.idxs = idxs |
|
|
|
def _get_indexes(self, hint_length: int, full_length: int) -> list[int]: |
|
orig_hint_length = hint_length |
|
if hint_length > full_length: |
|
hint_length = full_length |
|
|
|
if len(self.idxs) < hint_length: |
|
err_msg = f"There are not enough indexes ({len(self.idxs)}) provided to fit the usable {hint_length} input images." |
|
if orig_hint_length != hint_length: |
|
err_msg = f"{err_msg} (original input images: {orig_hint_length})" |
|
raise ValueError(err_msg) |
|
|
|
idxs = self.idxs[:hint_length] |
|
new_idxs = [] |
|
real_idxs = set() |
|
for idx in idxs: |
|
if idx < 0: |
|
real_idx = full_length+idx |
|
if real_idx in real_idxs: |
|
raise ValueError(f"Index '{idx}' maps to '{real_idx}' and is duplicate - indexes in Sparse Index Method must be unique.") |
|
else: |
|
real_idx = idx |
|
if real_idx in real_idxs: |
|
raise ValueError(f"Index '{idx}' is duplicate (or a negative index is equivalent) - indexes in Sparse Index Method must be unique.") |
|
real_idxs.add(real_idx) |
|
new_idxs.append(real_idx) |
|
return new_idxs |
|
|
|
|
|
def get_idx_list_from_str(indexes: str) -> list[int]: |
|
idxs = [] |
|
unique_idxs = set() |
|
|
|
str_idxs = [x.strip() for x in indexes.strip().split(",")] |
|
for str_idx in str_idxs: |
|
try: |
|
idx = int(str_idx) |
|
if idx in unique_idxs: |
|
raise ValueError(f"'{idx}' is duplicated; indexes must be unique.") |
|
idxs.append(idx) |
|
unique_idxs.add(idx) |
|
except ValueError: |
|
raise ValueError(f"'{str_idx}' is not a valid integer index.") |
|
if len(idxs) == 0: |
|
raise ValueError(f"No indexes were listed in Sparse Index Method.") |
|
return idxs |
|
|
|
|
|
|
|
|
|
class BlockType: |
|
UP = "up" |
|
DOWN = "down" |
|
MID = "mid" |
|
|
|
def get_down_block_max(mm_state_dict: dict[str, Tensor]) -> int: |
|
return get_block_max(mm_state_dict, "down_blocks") |
|
|
|
def get_up_block_max(mm_state_dict: dict[str, Tensor]) -> int: |
|
return get_block_max(mm_state_dict, "up_blocks") |
|
|
|
def get_block_max(mm_state_dict: dict[str, Tensor], block_name: str) -> int: |
|
|
|
biggest_block = -1 |
|
for key in mm_state_dict.keys(): |
|
if block_name in key: |
|
try: |
|
block_int = key.split(".")[1] |
|
block_num = int(block_int) |
|
if block_num > biggest_block: |
|
biggest_block = block_num |
|
except ValueError: |
|
pass |
|
return biggest_block |
|
|
|
def has_mid_block(mm_state_dict: dict[str, Tensor]): |
|
|
|
for key in mm_state_dict.keys(): |
|
if key.startswith("mid_block."): |
|
return True |
|
return False |
|
|
|
def get_position_encoding_max_len(mm_state_dict: dict[str, Tensor], mm_name: str=None) -> int: |
|
|
|
for key in mm_state_dict.keys(): |
|
if key.endswith("pos_encoder.pe"): |
|
return mm_state_dict[key].size(1) |
|
raise ValueError(f"No pos_encoder.pe found in SparseCtrl state_dict - {mm_name} is not a valid SparseCtrl model!") |
|
|
|
|
|
class SparseCtrlMotionWrapper(nn.Module): |
|
def __init__(self, mm_state_dict: dict[str, Tensor], ops=disable_weight_init_clean_groupnorm): |
|
super().__init__() |
|
self.down_blocks: Iterable[MotionModule] = None |
|
self.up_blocks: Iterable[MotionModule] = None |
|
self.mid_block: MotionModule = None |
|
self.encoding_max_len = get_position_encoding_max_len(mm_state_dict, "") |
|
layer_channels = (320, 640, 1280, 1280) |
|
if get_down_block_max(mm_state_dict) > -1: |
|
self.down_blocks = nn.ModuleList([]) |
|
for c in layer_channels: |
|
self.down_blocks.append(MotionModule(c, temporal_position_encoding_max_len=self.encoding_max_len, block_type=BlockType.DOWN, ops=ops)) |
|
if get_up_block_max(mm_state_dict) > -1: |
|
self.up_blocks = nn.ModuleList([]) |
|
for c in reversed(layer_channels): |
|
self.up_blocks.append(MotionModule(c, temporal_position_encoding_max_len=self.encoding_max_len, block_type=BlockType.UP, ops=ops)) |
|
if has_mid_block(mm_state_dict): |
|
self.mid_block = MotionModule(1280, temporal_position_encoding_max_len=self.encoding_max_len, block_type=BlockType.MID, ops=ops) |
|
|
|
def inject(self, unet: SparseControlNet): |
|
|
|
self._inject(unet.input_blocks, self.down_blocks) |
|
|
|
if self.mid_block is not None: |
|
self._inject([unet.middle_block], [self.mid_block]) |
|
unet.motion_wrapper = self |
|
|
|
def _inject(self, unet_blocks: nn.ModuleList, mm_blocks: nn.ModuleList): |
|
|
|
|
|
|
|
|
|
|
|
injection_count = 0 |
|
unet_idx = 0 |
|
|
|
per_block = len(mm_blocks[0].motion_modules) |
|
injection_goal = len(mm_blocks) * per_block |
|
|
|
while injection_count < injection_goal: |
|
|
|
mm_blk_idx, mm_vtm_idx = injection_count // per_block, injection_count % per_block |
|
|
|
st_idx = -1 |
|
res_idx = -1 |
|
|
|
for idx, component in enumerate(unet_blocks[unet_idx]): |
|
if type(component) == SpatialTransformer: |
|
st_idx = idx |
|
elif type(component).__name__ == "ResBlock" and res_idx < 0: |
|
res_idx = idx |
|
|
|
if st_idx >= 0: |
|
unet_blocks[unet_idx].insert(st_idx+1, mm_blocks[mm_blk_idx].motion_modules[mm_vtm_idx]) |
|
injection_count += 1 |
|
|
|
elif res_idx >= 0: |
|
unet_blocks[unet_idx].insert(res_idx+1, mm_blocks[mm_blk_idx].motion_modules[mm_vtm_idx]) |
|
injection_count += 1 |
|
|
|
unet_idx += 1 |
|
|
|
def eject(self, unet: SparseControlNet): |
|
|
|
self._eject(unet.input_blocks) |
|
|
|
self._eject([unet.middle_block]) |
|
del unet.motion_wrapper |
|
unet.motion_wrapper = None |
|
|
|
def _eject(self, unet_blocks: nn.ModuleList): |
|
|
|
for block in unet_blocks: |
|
idx_to_pop = [] |
|
for idx, component in enumerate(block): |
|
if type(component) == VanillaTemporalModule: |
|
idx_to_pop.append(idx) |
|
|
|
for idx in sorted(idx_to_pop, reverse=True): |
|
block.pop(idx) |
|
|
|
def set_video_length(self, video_length: int, full_length: int): |
|
self.AD_video_length = video_length |
|
if self.down_blocks is not None: |
|
for block in self.down_blocks: |
|
block.set_video_length(video_length, full_length) |
|
if self.up_blocks is not None: |
|
for block in self.up_blocks: |
|
block.set_video_length(video_length, full_length) |
|
if self.mid_block is not None: |
|
self.mid_block.set_video_length(video_length, full_length) |
|
|
|
def set_scale_multiplier(self, multiplier: Union[float, None]): |
|
if self.down_blocks is not None: |
|
for block in self.down_blocks: |
|
block.set_scale_multiplier(multiplier) |
|
if self.up_blocks is not None: |
|
for block in self.up_blocks: |
|
block.set_scale_multiplier(multiplier) |
|
if self.mid_block is not None: |
|
self.mid_block.set_scale_multiplier(multiplier) |
|
|
|
def set_strength(self, strength: float): |
|
if self.down_blocks is not None: |
|
for block in self.down_blocks: |
|
block.set_strength(strength) |
|
if self.up_blocks is not None: |
|
for block in self.up_blocks: |
|
block.set_strength(strength) |
|
if self.mid_block is not None: |
|
self.mid_block.set_strength(strength) |
|
|
|
def reset_temp_vars(self): |
|
if self.down_blocks is not None: |
|
for block in self.down_blocks: |
|
block.reset_temp_vars() |
|
if self.up_blocks is not None: |
|
for block in self.up_blocks: |
|
block.reset_temp_vars() |
|
if self.mid_block is not None: |
|
self.mid_block.reset_temp_vars() |
|
|
|
def reset_scale_multiplier(self): |
|
self.set_scale_multiplier(None) |
|
|
|
def reset(self): |
|
self.reset_scale_multiplier() |
|
self.reset_temp_vars() |
|
|
|
|
|
class MotionModule(nn.Module): |
|
def __init__(self, in_channels, temporal_position_encoding_max_len=24, block_type: str=BlockType.DOWN, ops=disable_weight_init_clean_groupnorm): |
|
super().__init__() |
|
if block_type == BlockType.MID: |
|
|
|
self.motion_modules: Iterable[VanillaTemporalModule] = nn.ModuleList([get_motion_module(in_channels, temporal_position_encoding_max_len, ops=ops)]) |
|
else: |
|
|
|
self.motion_modules: Iterable[VanillaTemporalModule] = nn.ModuleList( |
|
[ |
|
get_motion_module(in_channels, temporal_position_encoding_max_len, ops=ops), |
|
get_motion_module(in_channels, temporal_position_encoding_max_len, ops=ops) |
|
] |
|
) |
|
|
|
if block_type == BlockType.UP: |
|
self.motion_modules.append(get_motion_module(in_channels, temporal_position_encoding_max_len, ops=ops)) |
|
|
|
def set_video_length(self, video_length: int, full_length: int): |
|
for motion_module in self.motion_modules: |
|
motion_module.set_video_length(video_length, full_length) |
|
|
|
def set_scale_multiplier(self, multiplier: Union[float, None]): |
|
for motion_module in self.motion_modules: |
|
motion_module.set_scale_multiplier(multiplier) |
|
|
|
def set_masks(self, masks: Tensor, min_val: float, max_val: float): |
|
for motion_module in self.motion_modules: |
|
motion_module.set_masks(masks, min_val, max_val) |
|
|
|
def set_sub_idxs(self, sub_idxs: list[int]): |
|
for motion_module in self.motion_modules: |
|
motion_module.set_sub_idxs(sub_idxs) |
|
|
|
def set_strength(self, strength: float): |
|
for motion_module in self.motion_modules: |
|
motion_module.set_strength(strength) |
|
|
|
def reset_temp_vars(self): |
|
for motion_module in self.motion_modules: |
|
motion_module.reset_temp_vars() |
|
|
|
|
|
def get_motion_module(in_channels, temporal_position_encoding_max_len, ops=disable_weight_init_clean_groupnorm): |
|
|
|
return VanillaTemporalModule(in_channels=in_channels, attention_block_types=("Temporal_Self",), temporal_position_encoding_max_len=temporal_position_encoding_max_len, ops=ops) |
|
|
|
|
|
class VanillaTemporalModule(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
num_attention_heads=8, |
|
num_transformer_block=1, |
|
attention_block_types=("Temporal_Self", "Temporal_Self"), |
|
cross_frame_attention_mode=None, |
|
temporal_position_encoding=True, |
|
temporal_position_encoding_max_len=24, |
|
temporal_attention_dim_div=1, |
|
zero_initialize=True, |
|
ops=disable_weight_init_clean_groupnorm, |
|
): |
|
super().__init__() |
|
self.strength = 1.0 |
|
self.temporal_transformer = TemporalTransformer3DModel( |
|
in_channels=in_channels, |
|
num_attention_heads=num_attention_heads, |
|
attention_head_dim=in_channels |
|
// num_attention_heads |
|
// temporal_attention_dim_div, |
|
num_layers=num_transformer_block, |
|
attention_block_types=attention_block_types, |
|
cross_frame_attention_mode=cross_frame_attention_mode, |
|
temporal_position_encoding=temporal_position_encoding, |
|
temporal_position_encoding_max_len=temporal_position_encoding_max_len, |
|
ops=ops, |
|
) |
|
|
|
if zero_initialize: |
|
self.temporal_transformer.proj_out = zero_module( |
|
self.temporal_transformer.proj_out |
|
) |
|
|
|
def set_video_length(self, video_length: int, full_length: int): |
|
self.temporal_transformer.set_video_length(video_length, full_length) |
|
|
|
def set_scale_multiplier(self, multiplier: Union[float, None]): |
|
self.temporal_transformer.set_scale_multiplier(multiplier) |
|
|
|
def set_masks(self, masks: Tensor, min_val: float, max_val: float): |
|
self.temporal_transformer.set_masks(masks, min_val, max_val) |
|
|
|
def set_sub_idxs(self, sub_idxs: list[int]): |
|
self.temporal_transformer.set_sub_idxs(sub_idxs) |
|
|
|
def set_strength(self, strength: float): |
|
self.strength = strength |
|
|
|
def reset_temp_vars(self): |
|
self.set_strength(1.0) |
|
self.temporal_transformer.reset_temp_vars() |
|
|
|
def forward(self, input_tensor, encoder_hidden_states=None, attention_mask=None): |
|
if math.isclose(self.strength, 1.0): |
|
return self.temporal_transformer(input_tensor, encoder_hidden_states, attention_mask) |
|
elif math.isclose(self.strength, 0.0): |
|
return input_tensor |
|
|
|
|
|
else: |
|
return self.temporal_transformer(input_tensor, encoder_hidden_states, attention_mask)*self.strength + input_tensor*(1.0-self.strength) |
|
|
|
|
|
class TemporalTransformer3DModel(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
num_attention_heads, |
|
attention_head_dim, |
|
num_layers, |
|
attention_block_types=( |
|
"Temporal_Self", |
|
"Temporal_Self", |
|
), |
|
dropout=0.0, |
|
norm_num_groups=32, |
|
cross_attention_dim=768, |
|
activation_fn="geglu", |
|
attention_bias=False, |
|
upcast_attention=False, |
|
cross_frame_attention_mode=None, |
|
temporal_position_encoding=False, |
|
temporal_position_encoding_max_len=24, |
|
ops=disable_weight_init_clean_groupnorm, |
|
): |
|
super().__init__() |
|
self.video_length = 16 |
|
self.full_length = 16 |
|
self.scale_min = 1.0 |
|
self.scale_max = 1.0 |
|
self.raw_scale_mask: Union[Tensor, None] = None |
|
self.temp_scale_mask: Union[Tensor, None] = None |
|
self.sub_idxs: Union[list[int], None] = None |
|
self.prev_hidden_states_batch = 0 |
|
|
|
|
|
inner_dim = num_attention_heads * attention_head_dim |
|
|
|
self.norm = ops.GroupNorm( |
|
num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True |
|
) |
|
self.proj_in = ops.Linear(in_channels, inner_dim) |
|
|
|
self.transformer_blocks: Iterable[TemporalTransformerBlock] = nn.ModuleList( |
|
[ |
|
TemporalTransformerBlock( |
|
dim=inner_dim, |
|
num_attention_heads=num_attention_heads, |
|
attention_head_dim=attention_head_dim, |
|
attention_block_types=attention_block_types, |
|
dropout=dropout, |
|
norm_num_groups=norm_num_groups, |
|
cross_attention_dim=cross_attention_dim, |
|
activation_fn=activation_fn, |
|
attention_bias=attention_bias, |
|
upcast_attention=upcast_attention, |
|
cross_frame_attention_mode=cross_frame_attention_mode, |
|
temporal_position_encoding=temporal_position_encoding, |
|
temporal_position_encoding_max_len=temporal_position_encoding_max_len, |
|
ops=ops, |
|
) |
|
for d in range(num_layers) |
|
] |
|
) |
|
self.proj_out = ops.Linear(inner_dim, in_channels) |
|
|
|
def set_video_length(self, video_length: int, full_length: int): |
|
self.video_length = video_length |
|
self.full_length = full_length |
|
|
|
def set_scale_multiplier(self, multiplier: Union[float, None]): |
|
for block in self.transformer_blocks: |
|
block.set_scale_multiplier(multiplier) |
|
|
|
def set_masks(self, masks: Tensor, min_val: float, max_val: float): |
|
self.scale_min = min_val |
|
self.scale_max = max_val |
|
self.raw_scale_mask = masks |
|
|
|
def set_sub_idxs(self, sub_idxs: list[int]): |
|
self.sub_idxs = sub_idxs |
|
for block in self.transformer_blocks: |
|
block.set_sub_idxs(sub_idxs) |
|
|
|
def reset_temp_vars(self): |
|
del self.temp_scale_mask |
|
self.temp_scale_mask = None |
|
self.prev_hidden_states_batch = 0 |
|
for block in self.transformer_blocks: |
|
block.reset_temp_vars() |
|
|
|
def get_scale_mask(self, hidden_states: Tensor) -> Union[Tensor, None]: |
|
|
|
if self.raw_scale_mask is None: |
|
return None |
|
shape = hidden_states.shape |
|
batch, channel, height, width = shape |
|
|
|
if self.temp_scale_mask != None: |
|
|
|
if batch == self.prev_hidden_states_batch: |
|
if self.sub_idxs is not None: |
|
return self.temp_scale_mask[:, self.sub_idxs, :] |
|
return self.temp_scale_mask |
|
|
|
del self.temp_scale_mask |
|
self.temp_scale_mask = None |
|
|
|
self.prev_hidden_states_batch = batch |
|
mask = prepare_mask_batch(self.raw_scale_mask, shape=(self.full_length, 1, height, width)) |
|
mask = extend_to_batch_size(mask, self.full_length) |
|
|
|
if self.full_length != mask.shape[0]: |
|
mask = broadcast_image_to_extend(mask, self.full_length, 1) |
|
|
|
batch, channel, height, width = mask.shape |
|
|
|
|
|
mask = mask.permute(0, 2, 3, 1).reshape(batch, height*width, channel) |
|
|
|
mask = mask.permute(1, 0, 2) |
|
|
|
batched_number = shape[0] // self.video_length |
|
if batched_number > 1: |
|
mask = torch.cat([mask] * batched_number, dim=0) |
|
|
|
self.temp_scale_mask = mask |
|
|
|
self.temp_scale_mask = self.temp_scale_mask.to(dtype=hidden_states.dtype, device=hidden_states.device) |
|
|
|
if self.sub_idxs is not None: |
|
return self.temp_scale_mask[:, self.sub_idxs, :] |
|
return self.temp_scale_mask |
|
|
|
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): |
|
batch, channel, height, width = hidden_states.shape |
|
residual = hidden_states |
|
scale_mask = self.get_scale_mask(hidden_states) |
|
|
|
hidden_states = self.norm(hidden_states).to(hidden_states.dtype) |
|
inner_dim = hidden_states.shape[1] |
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( |
|
batch, height * width, inner_dim |
|
) |
|
hidden_states = self.proj_in(hidden_states).to(hidden_states.dtype) |
|
|
|
|
|
for block in self.transformer_blocks: |
|
hidden_states = block( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=attention_mask, |
|
video_length=self.video_length, |
|
scale_mask=scale_mask |
|
) |
|
|
|
|
|
hidden_states = self.proj_out(hidden_states) |
|
hidden_states = ( |
|
hidden_states.reshape(batch, height, width, inner_dim) |
|
.permute(0, 3, 1, 2) |
|
.contiguous() |
|
) |
|
|
|
output = hidden_states + residual |
|
|
|
return output |
|
|
|
|
|
class TemporalTransformerBlock(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
num_attention_heads, |
|
attention_head_dim, |
|
attention_block_types=( |
|
"Temporal_Self", |
|
"Temporal_Self", |
|
), |
|
dropout=0.0, |
|
norm_num_groups=32, |
|
cross_attention_dim=768, |
|
activation_fn="geglu", |
|
attention_bias=False, |
|
upcast_attention=False, |
|
cross_frame_attention_mode=None, |
|
temporal_position_encoding=False, |
|
temporal_position_encoding_max_len=24, |
|
ops=disable_weight_init_clean_groupnorm, |
|
): |
|
super().__init__() |
|
|
|
attention_blocks = [] |
|
norms = [] |
|
|
|
for block_name in attention_block_types: |
|
attention_blocks.append( |
|
VersatileAttention( |
|
attention_mode=block_name.split("_")[0], |
|
context_dim=cross_attention_dim |
|
if block_name.endswith("_Cross") |
|
else None, |
|
query_dim=dim, |
|
heads=num_attention_heads, |
|
dim_head=attention_head_dim, |
|
dropout=dropout, |
|
|
|
|
|
cross_frame_attention_mode=cross_frame_attention_mode, |
|
temporal_position_encoding=temporal_position_encoding, |
|
temporal_position_encoding_max_len=temporal_position_encoding_max_len, |
|
ops=ops, |
|
) |
|
) |
|
norms.append(ops.LayerNorm(dim)) |
|
|
|
self.attention_blocks: Iterable[VersatileAttention] = nn.ModuleList(attention_blocks) |
|
self.norms = nn.ModuleList(norms) |
|
|
|
self.ff = FeedForward(dim, dropout=dropout, glu=(activation_fn == "geglu"), operations=ops) |
|
self.ff_norm = ops.LayerNorm(dim) |
|
|
|
def set_scale_multiplier(self, multiplier: Union[float, None]): |
|
for block in self.attention_blocks: |
|
block.set_scale_multiplier(multiplier) |
|
|
|
def set_sub_idxs(self, sub_idxs: list[int]): |
|
for block in self.attention_blocks: |
|
block.set_sub_idxs(sub_idxs) |
|
|
|
def reset_temp_vars(self): |
|
for block in self.attention_blocks: |
|
block.reset_temp_vars() |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
video_length=None, |
|
scale_mask=None |
|
): |
|
for attention_block, norm in zip(self.attention_blocks, self.norms): |
|
norm_hidden_states = norm(hidden_states).to(hidden_states.dtype) |
|
hidden_states = ( |
|
attention_block( |
|
norm_hidden_states, |
|
encoder_hidden_states=encoder_hidden_states |
|
if attention_block.is_cross_attention |
|
else None, |
|
attention_mask=attention_mask, |
|
video_length=video_length, |
|
scale_mask=scale_mask |
|
) |
|
+ hidden_states |
|
) |
|
|
|
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states |
|
|
|
output = hidden_states |
|
return output |
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
def __init__(self, d_model, dropout=0.0, max_len=24): |
|
super().__init__() |
|
self.dropout = nn.Dropout(p=dropout) |
|
position = torch.arange(max_len).unsqueeze(1) |
|
div_term = torch.exp( |
|
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) |
|
) |
|
pe = torch.zeros(1, max_len, d_model) |
|
pe[0, :, 0::2] = torch.sin(position * div_term) |
|
pe[0, :, 1::2] = torch.cos(position * div_term) |
|
self.register_buffer("pe", pe) |
|
self.sub_idxs = None |
|
|
|
def set_sub_idxs(self, sub_idxs: list[int]): |
|
self.sub_idxs = sub_idxs |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
x = x + self.pe[:, : x.size(1)] |
|
return self.dropout(x) |
|
|
|
|
|
class CrossAttentionMMSparse(nn.Module): |
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, |
|
operations=disable_weight_init_clean_groupnorm): |
|
super().__init__() |
|
inner_dim = dim_head * heads |
|
context_dim = default(context_dim, query_dim) |
|
|
|
self.actual_attention = optimized_attention_mm |
|
self.heads = heads |
|
self.dim_head = dim_head |
|
self.scale = None |
|
|
|
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) |
|
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) |
|
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) |
|
|
|
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) |
|
|
|
def reset_attention_type(self): |
|
self.actual_attention = optimized_attention_mm |
|
|
|
def forward(self, x, context=None, value=None, mask=None, scale_mask=None): |
|
q = self.to_q(x) |
|
context = default(context, x) |
|
k: Tensor = self.to_k(context) |
|
if value is not None: |
|
v = self.to_v(value) |
|
del value |
|
else: |
|
v = self.to_v(context) |
|
|
|
|
|
if self.scale is not None: |
|
k *= self.scale |
|
|
|
|
|
if scale_mask is not None: |
|
k *= scale_mask |
|
|
|
try: |
|
out = self.actual_attention(q, k, v, self.heads, mask) |
|
except RuntimeError as e: |
|
if str(e).startswith("CUDA error: invalid configuration argument"): |
|
self.actual_attention = fallback_attention_mm |
|
out = self.actual_attention(q, k, v, self.heads, mask) |
|
else: |
|
raise |
|
return self.to_out(out) |
|
|
|
|
|
class VersatileAttention(CrossAttentionMMSparse): |
|
def __init__( |
|
self, |
|
attention_mode=None, |
|
cross_frame_attention_mode=None, |
|
temporal_position_encoding=False, |
|
temporal_position_encoding_max_len=24, |
|
ops=disable_weight_init_clean_groupnorm, |
|
*args, |
|
**kwargs, |
|
): |
|
super().__init__(operations=ops, *args, **kwargs) |
|
assert attention_mode == "Temporal" |
|
|
|
self.attention_mode = attention_mode |
|
self.is_cross_attention = kwargs["context_dim"] is not None |
|
|
|
self.pos_encoder = ( |
|
PositionalEncoding( |
|
kwargs["query_dim"], |
|
dropout=0.0, |
|
max_len=temporal_position_encoding_max_len, |
|
) |
|
if (temporal_position_encoding and attention_mode == "Temporal") |
|
else None |
|
) |
|
|
|
def extra_repr(self): |
|
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" |
|
|
|
def set_scale_multiplier(self, multiplier: Union[float, None]): |
|
if multiplier is None or math.isclose(multiplier, 1.0): |
|
self.scale = None |
|
else: |
|
self.scale = multiplier |
|
|
|
def set_sub_idxs(self, sub_idxs: list[int]): |
|
if self.pos_encoder != None: |
|
self.pos_encoder.set_sub_idxs(sub_idxs) |
|
|
|
def reset_temp_vars(self): |
|
self.reset_attention_type() |
|
|
|
def forward( |
|
self, |
|
hidden_states: Tensor, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
video_length=None, |
|
scale_mask=None, |
|
): |
|
if self.attention_mode != "Temporal": |
|
raise NotImplementedError |
|
|
|
d = hidden_states.shape[1] |
|
hidden_states = rearrange( |
|
hidden_states, "(b f) d c -> (b d) f c", f=video_length |
|
) |
|
|
|
if self.pos_encoder is not None: |
|
hidden_states = self.pos_encoder(hidden_states).to(hidden_states.dtype) |
|
|
|
encoder_hidden_states = ( |
|
repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) |
|
if encoder_hidden_states is not None |
|
else encoder_hidden_states |
|
) |
|
|
|
hidden_states = super().forward( |
|
hidden_states, |
|
encoder_hidden_states, |
|
value=None, |
|
mask=attention_mask, |
|
scale_mask=scale_mask, |
|
) |
|
|
|
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) |
|
|
|
return hidden_states |
|
|