WLMotionCtrl_SVD / sgm /motionctrl /modified_svd.py
wzhouxiff's picture
init
2890711
raw
history blame
7.69 kB
from functools import partial
from typing import List, Optional, Union
import torch
from einops import rearrange, repeat
from sgm.modules.attention import checkpoint, exists
from sgm.modules.diffusionmodules.util import timestep_embedding
### VideoUnet #####
def forward_VideoUnet(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
time_context: Optional[torch.Tensor] = None,
num_video_frames: Optional[int] = None,
image_only_indicator: Optional[torch.Tensor] = None,
RT: Optional[torch.Tensor] = None
):
if RT is not None:
context = {'RT': RT, 'context': context}
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
## tbd: check the role of "image_only_indicator"
num_video_frames = self.num_frames
image_only_indicator = torch.zeros(
x.shape[0]//num_video_frames, num_video_frames
).to(x.device) if image_only_indicator is None else image_only_indicator
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x
for module in self.input_blocks:
h = module(
h,
emb,
context=context,
image_only_indicator=image_only_indicator,
time_context=time_context,
num_video_frames=num_video_frames
)
hs.append(h)
h = self.middle_block(
h,
emb,
context=context,
image_only_indicator=image_only_indicator,
time_context=time_context,
num_video_frames=num_video_frames
)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(
h,
emb,
context=context,
image_only_indicator=image_only_indicator,
time_context=time_context,
num_video_frames=num_video_frames
)
h = h.type(x.dtype)
return self.out(h)
### VideoTransformerBlock #####
def forward_VideoTransformerBlock(self, x, context, timesteps):
if self.checkpoint:
return checkpoint(self._forward, x, context, timesteps)
else:
return self._forward(x, context, timesteps=timesteps)
def _forward_VideoTransformerBlock_attan2(self, x, context=None, timesteps=None):
assert self.timesteps or timesteps
assert not (self.timesteps and timesteps) or self.timesteps == timesteps
timesteps = self.timesteps or timesteps
B, S, C = x.shape
x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps)
if isinstance(context, dict):
RT = context['RT'] # (b, t, 12)
context = context['context']
else:
RT = None
if self.ff_in:
x_skip = x
x = self.ff_in(self.norm_in(x))
if self.is_res:
x += x_skip
if self.disable_self_attn:
x = self.attn1(self.norm1(x), context=context) + x
else:
x = self.attn1(self.norm1(x)) + x
if RT is not None:
# import pdb; pdb.set_trace()
RT = RT.repeat_interleave(repeats=S, dim=0) # (b*s, t, 12)
x = torch.cat([x, RT], dim=-1)
x = self.cc_projection(x)
if self.attn2 is not None:
if self.switch_temporal_ca_to_sa:
x = self.attn2(self.norm2(x)) + x
else:
x = self.attn2(self.norm2(x), context=context) + x
x_skip = x
x = self.ff(self.norm3(x))
if self.is_res:
x += x_skip
x = rearrange(
x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
)
return x
#### BasicTransformerBlock #####
def _forward_BasicTransformerBlock(
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
):
if isinstance(context, dict):
context = context['context']
x = (
self.attn1(
self.norm1(x),
context=context if self.disable_self_attn else None,
additional_tokens=additional_tokens,
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
if not self.disable_self_attn
else 0,
)
+ x
)
x = (
self.attn2(
self.norm2(x), context=context, additional_tokens=additional_tokens
)
+ x
)
x = self.ff(self.norm3(x)) + x
return x
#### SpatialVideoTransformer #####
def forward_SpatialVideoTransformer(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
time_context: Optional[torch.Tensor] = None,
timesteps: Optional[int] = None,
image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.Tensor:
_, _, h, w = x.shape
x_in = x
if isinstance(context, dict):
RT = context['RT']
context = context['context']
else:
RT = None
spatial_context = None
if exists(context):
spatial_context = context
if self.use_spatial_context:
assert (
context.ndim == 3
), f"n dims of spatial context should be 3 but are {context.ndim}"
time_context = context
time_context_first_timestep = time_context[::timesteps]
time_context = repeat(
time_context_first_timestep, "b ... -> (b n) ...", n=h * w
)
elif time_context is not None and not self.use_spatial_context:
time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
if time_context.ndim == 2:
time_context = rearrange(time_context, "b c -> b 1 c")
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c")
if self.use_linear:
x = self.proj_in(x)
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(
num_frames,
self.in_channels,
repeat_only=False,
max_period=self.max_time_embed_period,
)
emb = self.time_pos_embed(t_emb)
emb = emb[:, None, :]
for it_, (block, mix_block) in enumerate(
zip(self.transformer_blocks, self.time_stack)
):
x = block(
x,
context=spatial_context,
)
x_mix = x
x_mix = x_mix + emb
if RT is not None:
x_mix = mix_block(x_mix, context={'context': time_context, 'RT': RT}, timesteps=timesteps)
else:
x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps)
x = self.time_mixer(
x_spatial=x,
x_temporal=x_mix,
image_only_indicator=image_only_indicator,
)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
if not self.use_linear:
x = self.proj_out(x)
out = x + x_in
return out