MotionClone-Text-to-Video / motionclone /utils /motionclone_functions.py
svjack's picture
Upload folder using huggingface_hub
ce68674 verified
from dataclasses import dataclass
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from typing import Callable, List, Optional, Union
from diffusers.utils import deprecate, logging, BaseOutput
from .xformer_attention import *
from .conv_layer import *
from .util import *
from diffusers.utils.torch_utils import randn_tensor
from typing import List, Optional, Tuple, Union
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
from motionclone.utils.util import video_preprocess
import einops
import torchvision.transforms as transforms
def add_noise(self, timestep, x_0, noise_pred):
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
beta_prod_t = 1 - alpha_prod_t
latents_noise = alpha_prod_t ** 0.5 * x_0 + beta_prod_t ** 0.5 * noise_pred
return latents_noise
@torch.no_grad()
def obtain_motion_representation(self, generator=None, motion_representation_path: str = None,
duration=None,use_controlnet=False,):
video_data = video_preprocess(self.input_config.video_path, self.input_config.height,
self.input_config.width, self.input_config.video_length,duration=duration)
video_latents = self.vae.encode(video_data.to(self.vae.dtype).to(self.vae.device)).latent_dist.mode()
video_latents = self.vae.config.scaling_factor * video_latents
video_latents = video_latents.unsqueeze(0)
video_latents = einops.rearrange(video_latents, "b f c h w -> b c f h w")
uncond_input = self.tokenizer(
[""], padding="max_length", max_length=self.tokenizer.model_max_length,
return_tensors="pt"
)
step_t = int(self.input_config.add_noise_step)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
noise_sampled = randn_tensor(video_latents.shape, generator=generator, device=video_latents.device, dtype=video_latents.dtype)
noisy_latents = self.add_noise(step_t, video_latents, noise_sampled)
down_block_additional_residuals = mid_block_additional_residual = None
if use_controlnet:
controlnet_image_index = self.input_config.image_index
if self.controlnet.use_simplified_condition_embedding:
controlnet_images = video_latents[:,:,controlnet_image_index,:,:]
else:
controlnet_images = (einops.rearrange(video_data.unsqueeze(0).to(self.vae.dtype).to(self.vae.device), "b f c h w -> b c f h w")+1)/2
controlnet_images = controlnet_images[:,:,controlnet_image_index,:,:]
controlnet_cond_shape = list(controlnet_images.shape)
controlnet_cond_shape[2] = noisy_latents.shape[2]
controlnet_cond = torch.zeros(controlnet_cond_shape).to(noisy_latents.device).to(noisy_latents.dtype)
controlnet_conditioning_mask_shape = list(controlnet_cond.shape)
controlnet_conditioning_mask_shape[1] = 1
controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(noisy_latents.device).to(noisy_latents.dtype)
controlnet_cond[:,:,controlnet_image_index] = controlnet_images
controlnet_conditioning_mask[:,:,controlnet_image_index] = 1
down_block_additional_residuals, mid_block_additional_residual = self.controlnet(
noisy_latents, step_t,
encoder_hidden_states=uncond_embeddings,
controlnet_cond=controlnet_cond,
conditioning_mask=controlnet_conditioning_mask,
conditioning_scale=self.input_config.controlnet_scale,
guess_mode=False, return_dict=False,
)
_ = self.unet(noisy_latents, step_t, encoder_hidden_states=uncond_embeddings, return_dict=False, only_motion_feature=True,
down_block_additional_residuals = down_block_additional_residuals,
mid_block_additional_residual = mid_block_additional_residual,)
temp_attn_prob_control = self.get_temp_attn_prob()
motion_representation = { key: [max_value, max_index.to(torch.uint8)] for key, tensor in temp_attn_prob_control.items() for max_value, max_index in [torch.topk(tensor, k=1, dim=-1)]}
torch.save(motion_representation, motion_representation_path)
self.motion_representation_path = motion_representation_path
def compute_temp_loss(self, temp_attn_prob_control_dict):
temp_attn_prob_loss = []
for name in temp_attn_prob_control_dict.keys():
current_temp_attn_prob = temp_attn_prob_control_dict[name]
reference_representation_dict = self.motion_representation_dict[name]
max_index = reference_representation_dict[1].to(torch.int64).to(current_temp_attn_prob.device)
current_motion_representation = torch.gather(current_temp_attn_prob, index = max_index, dim=-1)
reference_motion_representation = reference_representation_dict[0].to(dtype = current_motion_representation.dtype, device = current_motion_representation.device)
module_attn_loss = F.mse_loss(current_motion_representation, reference_motion_representation.detach())
temp_attn_prob_loss.append(module_attn_loss)
loss_temp = torch.stack(temp_attn_prob_loss)
return loss_temp.sum()
def sample_video(
self,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
noisy_latents: Optional[torch.FloatTensor] = None,
add_controlnet: bool = False,
):
# Determine if use controlnet, i.e., conditional image2video
self.add_controlnet = add_controlnet
if self.add_controlnet:
image_transforms = transforms.Compose([
transforms.Resize((self.input_config.height, self.input_config.width)),
transforms.ToTensor(),
])
controlnet_images = [image_transforms(Image.open(path).convert("RGB")) for path in self.input_config.condition_image_path_list]
controlnet_images = torch.stack(controlnet_images).unsqueeze(0).to(dtype=self.vae.dtype,device=self.vae.device)
controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w")
with torch.no_grad():
if self.controlnet.use_simplified_condition_embedding:
num_controlnet_images = controlnet_images.shape[2]
controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w")
controlnet_images = self.vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * self.vae.config.scaling_factor
self.controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images)
else:
self.controlnet_images = controlnet_images
# Define call parameters
# perform classifier_free_guidance in default
batch_size = 1
do_classifier_free_guidance = True
device = self._execution_device
# Encode input prompt
self.text_embeddings = self._encode_prompt(self.input_config.new_prompt, device, 1, do_classifier_free_guidance, self.input_config.negative_prompt)
# [uncond_embeddings, text_embeddings] [2, 77, 768]
# Prepare latent variables
noisy_latents = self.prepare_latents(
batch_size,
self.unet.config.in_channels,
self.input_config.video_length,
self.input_config.height,
self.input_config.width,
self.text_embeddings.dtype,
device,
generator,
noisy_latents,
)
self.motion_representation_dict = torch.load(self.motion_representation_path)
self.motion_scale = self.input_config.motion_guidance_weight
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# save GPU memory
# self.vae.to(device = "cpu")
# self.text_encoder.to(device = "cpu")
# torch.cuda.empty_cache()
with self.progress_bar(total=self.input_config.inference_steps) as progress_bar:
for step_index, step_t in enumerate(self.scheduler.timesteps):
noisy_latents = self.single_step_video(noisy_latents, step_index, step_t, extra_step_kwargs)
progress_bar.update()
# decode latents for videos
video = self.decode_latents(noisy_latents)
return video
def single_step_video(self, noisy_latents, step_index, step_t, extra_step_kwargs):
down_block_additional_residuals = mid_block_additional_residual = None
if self.add_controlnet:
with torch.no_grad():
controlnet_cond_shape = list(self.controlnet_images.shape)
controlnet_cond_shape[2] = noisy_latents.shape[2]
controlnet_cond = torch.zeros(controlnet_cond_shape).to(noisy_latents.device).to(noisy_latents.dtype)
controlnet_conditioning_mask_shape = list(controlnet_cond.shape)
controlnet_conditioning_mask_shape[1] = 1
controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(noisy_latents.device).to(noisy_latents.dtype)
controlnet_image_index = self.input_config.image_index
controlnet_cond[:,:,controlnet_image_index] = self.controlnet_images
controlnet_conditioning_mask[:,:,controlnet_image_index] = 1
down_block_additional_residuals, mid_block_additional_residual = self.controlnet(
noisy_latents.expand(2,-1,-1,-1,-1), step_t,
encoder_hidden_states=self.text_embeddings,
controlnet_cond=controlnet_cond,
conditioning_mask=controlnet_conditioning_mask,
conditioning_scale=self.input_config.controlnet_scale,
guess_mode=False, return_dict=False,
)
# Only require grad when need to compute the gradient for guidance
if step_index < self.input_config.guidance_steps:
down_block_additional_residuals_uncond = down_block_additional_residuals_cond = None
mid_block_additional_residual_uncond = mid_block_additional_residual_cond = None
if self.add_controlnet:
down_block_additional_residuals_uncond = [tensor[[0],...].detach() for tensor in down_block_additional_residuals]
down_block_additional_residuals_cond = [tensor[[1],...].detach() for tensor in down_block_additional_residuals]
mid_block_additional_residual_uncond = mid_block_additional_residual[[0],...].detach()
mid_block_additional_residual_cond = mid_block_additional_residual[[1],...].detach()
control_latents = noisy_latents.clone().detach()
control_latents.requires_grad = True
control_latents = self.scheduler.scale_model_input(control_latents, step_t)
noisy_latents = self.scheduler.scale_model_input(noisy_latents, step_t)
with torch.no_grad():
noise_pred_uncondition = self.unet(noisy_latents, step_t, encoder_hidden_states=self.text_embeddings[[0]],
down_block_additional_residuals = down_block_additional_residuals_uncond,
mid_block_additional_residual = mid_block_additional_residual_uncond,).sample.to(dtype=noisy_latents.dtype)
noise_pred_condition = self.unet(control_latents, step_t, encoder_hidden_states=self.text_embeddings[[1]],
down_block_additional_residuals = down_block_additional_residuals_cond,
mid_block_additional_residual = mid_block_additional_residual_cond,).sample.to(dtype=noisy_latents.dtype)
temp_attn_prob_control = self.get_temp_attn_prob()
loss_motion = self.motion_scale * self.compute_temp_loss(temp_attn_prob_control,)
if step_index < self.input_config.warm_up_steps:
scale = (step_index+1)/self.input_config.warm_up_steps
loss_motion = scale*loss_motion
if step_index > self.input_config.guidance_steps-self.input_config.cool_up_steps:
scale = (self.input_config.guidance_steps-step_index)/self.input_config.cool_up_steps
loss_motion = scale*loss_motion
gradient = torch.autograd.grad(loss_motion, control_latents, allow_unused=True)[0] # [1, 4, 16, 64, 64],
assert gradient is not None, f"Step {step_index}: grad is None"
noise_pred = noise_pred_condition + self.input_config.cfg_scale * (noise_pred_condition - noise_pred_uncondition)
control_latents = self.scheduler.customized_step(noise_pred, step_index, control_latents, score=gradient.detach(),
**extra_step_kwargs, return_dict=False)[0] # [1, 4, 16, 64, 64]
return control_latents.detach()
else:
with torch.no_grad():
noisy_latents = self.scheduler.scale_model_input(noisy_latents, step_t)
noise_pred_group = self.unet(
noisy_latents.expand(2,-1,-1,-1,-1), step_t,
encoder_hidden_states=self.text_embeddings,
down_block_additional_residuals = down_block_additional_residuals,
mid_block_additional_residual = mid_block_additional_residual,
).sample.to(dtype=noisy_latents.dtype)
noise_pred = noise_pred_group[[1]] + self.input_config.cfg_scale * (noise_pred_group[[1]] - noise_pred_group[[0]])
noisy_latents = self.scheduler.customized_step(noise_pred, step_index, noisy_latents, score=None, **extra_step_kwargs, return_dict=False)[0] # [1, 4, 16, 64, 64]
return noisy_latents.detach()
def get_temp_attn_prob(self,index_select=None):
attn_prob_dic = {}
for name, module in self.unet.named_modules():
module_name = type(module).__name__
if "VersatileAttention" in module_name and classify_blocks(self.input_config.motion_guidance_blocks, name):
key = module.processor.key
if index_select is not None:
get_index = torch.repeat_interleave(torch.tensor(index_select), repeats=key.shape[0]//len(index_select))
index_all = torch.arange(key.shape[0])
index_picked = index_all[get_index.bool()]
key = key[index_picked]
key = module.reshape_heads_to_batch_dim(key).contiguous()
query = module.processor.query
if index_select is not None:
query = query[index_picked]
query = module.reshape_heads_to_batch_dim(query).contiguous()
attention_probs = module.get_attention_scores(query, key, None)
attention_probs = attention_probs.reshape(-1, module.heads,attention_probs.shape[1], attention_probs.shape[2])
attn_prob_dic[name] = attention_probs
return attn_prob_dic
@torch.no_grad()
def schedule_customized_step(
self,
model_output: torch.FloatTensor,
step_index: int,
sample: torch.FloatTensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
variance_noise: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
# Guidance parameters
score=None,
guidance_scale=1.0,
indices=None, # [0]
return_middle = False,
):
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_sample -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"
# Support IF models
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
else:
predicted_variance = None
timestep = self.timesteps[step_index]
# 1. get previous step value (=t-1)
# prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
prev_timestep = self.timesteps[step_index+1] if step_index +1 <len(self.timesteps) else -1
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output
pred_epsilon = (alpha_prod_t ** 0.5) * model_output + (beta_prod_t ** 0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)
# 4. Clip or threshold "predicted x_0"
if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
elif self.config.clip_sample:
pred_original_sample = pred_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range
)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance = self._get_variance(timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
if use_clipped_model_output:
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # [2, 4, 64, 64]
if score is not None and return_middle:
return pred_epsilon, alpha_prod_t, alpha_prod_t_prev, pred_original_sample
# 6. apply guidance following the formula (14) from https://arxiv.org/pdf/2105.05233.pdf
if score is not None and guidance_scale > 0.0:
if indices is not None:
# import pdb; pdb.set_trace()
assert pred_epsilon[indices].shape == score.shape, "pred_epsilon[indices].shape != score.shape"
pred_epsilon[indices] = pred_epsilon[indices] - guidance_scale * (1 - alpha_prod_t) ** (0.5) * score
else:
assert pred_epsilon.shape == score.shape
pred_epsilon = pred_epsilon - guidance_scale * (1 - alpha_prod_t) ** (0.5) * score
#
# 7. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2) ** (0.5) * pred_epsilon
# 8. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if eta > 0:
if variance_noise is not None and generator is not None:
raise ValueError(
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
" `variance_noise` stays `None`."
)
if variance_noise is None:
variance_noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
)
variance = std_dev_t * variance_noise
prev_sample = prev_sample + variance
if not return_dict:
return (prev_sample,)
return prev_sample, pred_original_sample, alpha_prod_t_prev
def schedule_set_timesteps(self, num_inference_steps: int, guidance_steps: int = 0, guiduance_scale: float = 0.0, device: Union[str, torch.device] = None,timestep_spacing_type= "uneven"):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
"""
if num_inference_steps > self.config.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
self.num_inference_steps = num_inference_steps
# assign more steps in early denoising stage for motion guidance
if timestep_spacing_type == "uneven":
timesteps_guidance = (
np.linspace(int((1-guiduance_scale)*self.config.num_train_timesteps), self.config.num_train_timesteps - 1, guidance_steps)
.round()[::-1]
.copy()
.astype(np.int64)
)
timesteps_vanilla = (
np.linspace(0, int((1-guiduance_scale)*self.config.num_train_timesteps) - 1, num_inference_steps-guidance_steps)
.round()[::-1]
.copy()
.astype(np.int64)
)
timesteps = np.concatenate((timesteps_guidance, timesteps_vanilla))
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
elif timestep_spacing_type == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
.round()[::-1]
.copy()
.astype(np.int64)
)
elif timestep_spacing_type == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
timesteps += self.config.steps_offset
elif timestep_spacing_type == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
timesteps -= 1
else:
raise ValueError(
f"{timestep_spacing_type} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
)
self.timesteps = torch.from_numpy(timesteps).to(device)
@dataclass
class UNet3DConditionOutput(BaseOutput):
sample: torch.FloatTensor
def unet_customized_forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
# support controlnet
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
return_dict: bool = True,
only_motion_feature: bool = False,
) -> Union[UNet3DConditionOutput, Tuple]:
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# time
timesteps = timestep
if not torch.is_tensor(timesteps):
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# pre-process
sample = self.conv_in(sample)
# down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
down_block_res_samples += res_samples
# support controlnet
down_block_res_samples = list(down_block_res_samples)
if down_block_additional_residuals is not None:
for i, down_block_additional_residual in enumerate(down_block_additional_residuals):
if down_block_additional_residual.dim() == 4: # boardcast
down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
# mid
sample = self.mid_block(
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
)
# support controlnet
if mid_block_additional_residual is not None:
if mid_block_additional_residual.dim() == 4: # boardcast
mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
sample = sample + mid_block_additional_residual
# up
for i, upsample_block in enumerate(self.up_blocks):
if i<= int(self.input_config.motion_guidance_blocks[-1].split(".")[-1]):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
attention_mask=attention_mask,
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
)
else:
if only_motion_feature:
return 0
with torch.no_grad():
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
attention_mask=attention_mask,
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
return UNet3DConditionOutput(sample=sample)