Spaces:
Runtime error
Runtime error
from dataclasses import dataclass | |
import os | |
import numpy as np | |
import torch | |
import matplotlib.pyplot as plt | |
import matplotlib.colors as mcolors | |
from typing import Callable, List, Optional, Union | |
from diffusers.utils import deprecate, logging, BaseOutput | |
from .xformer_attention import * | |
from .conv_layer import * | |
from .util import * | |
from diffusers.utils.torch_utils import randn_tensor | |
from typing import List, Optional, Tuple, Union | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
from motionclone.utils.util import video_preprocess | |
import einops | |
import torchvision.transforms as transforms | |
def add_noise(self, timestep, x_0, noise_pred): | |
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] | |
beta_prod_t = 1 - alpha_prod_t | |
latents_noise = alpha_prod_t ** 0.5 * x_0 + beta_prod_t ** 0.5 * noise_pred | |
return latents_noise | |
def obtain_motion_representation(self, generator=None, motion_representation_path: str = None, | |
duration=None,use_controlnet=False,): | |
video_data = video_preprocess(self.input_config.video_path, self.input_config.height, | |
self.input_config.width, self.input_config.video_length,duration=duration) | |
video_latents = self.vae.encode(video_data.to(self.vae.dtype).to(self.vae.device)).latent_dist.mode() | |
video_latents = self.vae.config.scaling_factor * video_latents | |
video_latents = video_latents.unsqueeze(0) | |
video_latents = einops.rearrange(video_latents, "b f c h w -> b c f h w") | |
uncond_input = self.tokenizer( | |
[""], padding="max_length", max_length=self.tokenizer.model_max_length, | |
return_tensors="pt" | |
) | |
step_t = int(self.input_config.add_noise_step) | |
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | |
noise_sampled = randn_tensor(video_latents.shape, generator=generator, device=video_latents.device, dtype=video_latents.dtype) | |
noisy_latents = self.add_noise(step_t, video_latents, noise_sampled) | |
down_block_additional_residuals = mid_block_additional_residual = None | |
if use_controlnet: | |
controlnet_image_index = self.input_config.image_index | |
if self.controlnet.use_simplified_condition_embedding: | |
controlnet_images = video_latents[:,:,controlnet_image_index,:,:] | |
else: | |
controlnet_images = (einops.rearrange(video_data.unsqueeze(0).to(self.vae.dtype).to(self.vae.device), "b f c h w -> b c f h w")+1)/2 | |
controlnet_images = controlnet_images[:,:,controlnet_image_index,:,:] | |
controlnet_cond_shape = list(controlnet_images.shape) | |
controlnet_cond_shape[2] = noisy_latents.shape[2] | |
controlnet_cond = torch.zeros(controlnet_cond_shape).to(noisy_latents.device).to(noisy_latents.dtype) | |
controlnet_conditioning_mask_shape = list(controlnet_cond.shape) | |
controlnet_conditioning_mask_shape[1] = 1 | |
controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(noisy_latents.device).to(noisy_latents.dtype) | |
controlnet_cond[:,:,controlnet_image_index] = controlnet_images | |
controlnet_conditioning_mask[:,:,controlnet_image_index] = 1 | |
down_block_additional_residuals, mid_block_additional_residual = self.controlnet( | |
noisy_latents, step_t, | |
encoder_hidden_states=uncond_embeddings, | |
controlnet_cond=controlnet_cond, | |
conditioning_mask=controlnet_conditioning_mask, | |
conditioning_scale=self.input_config.controlnet_scale, | |
guess_mode=False, return_dict=False, | |
) | |
_ = self.unet(noisy_latents, step_t, encoder_hidden_states=uncond_embeddings, return_dict=False, only_motion_feature=True, | |
down_block_additional_residuals = down_block_additional_residuals, | |
mid_block_additional_residual = mid_block_additional_residual,) | |
temp_attn_prob_control = self.get_temp_attn_prob() | |
motion_representation = { key: [max_value, max_index.to(torch.uint8)] for key, tensor in temp_attn_prob_control.items() for max_value, max_index in [torch.topk(tensor, k=1, dim=-1)]} | |
torch.save(motion_representation, motion_representation_path) | |
self.motion_representation_path = motion_representation_path | |
def compute_temp_loss(self, temp_attn_prob_control_dict): | |
temp_attn_prob_loss = [] | |
for name in temp_attn_prob_control_dict.keys(): | |
current_temp_attn_prob = temp_attn_prob_control_dict[name] | |
reference_representation_dict = self.motion_representation_dict[name] | |
max_index = reference_representation_dict[1].to(torch.int64).to(current_temp_attn_prob.device) | |
current_motion_representation = torch.gather(current_temp_attn_prob, index = max_index, dim=-1) | |
reference_motion_representation = reference_representation_dict[0].to(dtype = current_motion_representation.dtype, device = current_motion_representation.device) | |
module_attn_loss = F.mse_loss(current_motion_representation, reference_motion_representation.detach()) | |
temp_attn_prob_loss.append(module_attn_loss) | |
loss_temp = torch.stack(temp_attn_prob_loss) | |
return loss_temp.sum() | |
def sample_video( | |
self, | |
eta: float = 0.0, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
noisy_latents: Optional[torch.FloatTensor] = None, | |
add_controlnet: bool = False, | |
): | |
# Determine if use controlnet, i.e., conditional image2video | |
self.add_controlnet = add_controlnet | |
if self.add_controlnet: | |
image_transforms = transforms.Compose([ | |
transforms.Resize((self.input_config.height, self.input_config.width)), | |
transforms.ToTensor(), | |
]) | |
controlnet_images = [image_transforms(Image.open(path).convert("RGB")) for path in self.input_config.condition_image_path_list] | |
controlnet_images = torch.stack(controlnet_images).unsqueeze(0).to(dtype=self.vae.dtype,device=self.vae.device) | |
controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w") | |
with torch.no_grad(): | |
if self.controlnet.use_simplified_condition_embedding: | |
num_controlnet_images = controlnet_images.shape[2] | |
controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w") | |
controlnet_images = self.vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * self.vae.config.scaling_factor | |
self.controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images) | |
else: | |
self.controlnet_images = controlnet_images | |
# Define call parameters | |
# perform classifier_free_guidance in default | |
batch_size = 1 | |
do_classifier_free_guidance = True | |
device = self._execution_device | |
# Encode input prompt | |
self.text_embeddings = self._encode_prompt(self.input_config.new_prompt, device, 1, do_classifier_free_guidance, self.input_config.negative_prompt) | |
# [uncond_embeddings, text_embeddings] [2, 77, 768] | |
# Prepare latent variables | |
noisy_latents = self.prepare_latents( | |
batch_size, | |
self.unet.config.in_channels, | |
self.input_config.video_length, | |
self.input_config.height, | |
self.input_config.width, | |
self.text_embeddings.dtype, | |
device, | |
generator, | |
noisy_latents, | |
) | |
self.motion_representation_dict = torch.load(self.motion_representation_path) | |
self.motion_scale = self.input_config.motion_guidance_weight | |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
# save GPU memory | |
# self.vae.to(device = "cpu") | |
# self.text_encoder.to(device = "cpu") | |
# torch.cuda.empty_cache() | |
with self.progress_bar(total=self.input_config.inference_steps) as progress_bar: | |
for step_index, step_t in enumerate(self.scheduler.timesteps): | |
noisy_latents = self.single_step_video(noisy_latents, step_index, step_t, extra_step_kwargs) | |
progress_bar.update() | |
# decode latents for videos | |
video = self.decode_latents(noisy_latents) | |
return video | |
def single_step_video(self, noisy_latents, step_index, step_t, extra_step_kwargs): | |
down_block_additional_residuals = mid_block_additional_residual = None | |
if self.add_controlnet: | |
with torch.no_grad(): | |
controlnet_cond_shape = list(self.controlnet_images.shape) | |
controlnet_cond_shape[2] = noisy_latents.shape[2] | |
controlnet_cond = torch.zeros(controlnet_cond_shape).to(noisy_latents.device).to(noisy_latents.dtype) | |
controlnet_conditioning_mask_shape = list(controlnet_cond.shape) | |
controlnet_conditioning_mask_shape[1] = 1 | |
controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(noisy_latents.device).to(noisy_latents.dtype) | |
controlnet_image_index = self.input_config.image_index | |
controlnet_cond[:,:,controlnet_image_index] = self.controlnet_images | |
controlnet_conditioning_mask[:,:,controlnet_image_index] = 1 | |
down_block_additional_residuals, mid_block_additional_residual = self.controlnet( | |
noisy_latents.expand(2,-1,-1,-1,-1), step_t, | |
encoder_hidden_states=self.text_embeddings, | |
controlnet_cond=controlnet_cond, | |
conditioning_mask=controlnet_conditioning_mask, | |
conditioning_scale=self.input_config.controlnet_scale, | |
guess_mode=False, return_dict=False, | |
) | |
# Only require grad when need to compute the gradient for guidance | |
if step_index < self.input_config.guidance_steps: | |
down_block_additional_residuals_uncond = down_block_additional_residuals_cond = None | |
mid_block_additional_residual_uncond = mid_block_additional_residual_cond = None | |
if self.add_controlnet: | |
down_block_additional_residuals_uncond = [tensor[[0],...].detach() for tensor in down_block_additional_residuals] | |
down_block_additional_residuals_cond = [tensor[[1],...].detach() for tensor in down_block_additional_residuals] | |
mid_block_additional_residual_uncond = mid_block_additional_residual[[0],...].detach() | |
mid_block_additional_residual_cond = mid_block_additional_residual[[1],...].detach() | |
control_latents = noisy_latents.clone().detach() | |
control_latents.requires_grad = True | |
control_latents = self.scheduler.scale_model_input(control_latents, step_t) | |
noisy_latents = self.scheduler.scale_model_input(noisy_latents, step_t) | |
with torch.no_grad(): | |
noise_pred_uncondition = self.unet(noisy_latents, step_t, encoder_hidden_states=self.text_embeddings[[0]], | |
down_block_additional_residuals = down_block_additional_residuals_uncond, | |
mid_block_additional_residual = mid_block_additional_residual_uncond,).sample.to(dtype=noisy_latents.dtype) | |
noise_pred_condition = self.unet(control_latents, step_t, encoder_hidden_states=self.text_embeddings[[1]], | |
down_block_additional_residuals = down_block_additional_residuals_cond, | |
mid_block_additional_residual = mid_block_additional_residual_cond,).sample.to(dtype=noisy_latents.dtype) | |
temp_attn_prob_control = self.get_temp_attn_prob() | |
loss_motion = self.motion_scale * self.compute_temp_loss(temp_attn_prob_control,) | |
if step_index < self.input_config.warm_up_steps: | |
scale = (step_index+1)/self.input_config.warm_up_steps | |
loss_motion = scale*loss_motion | |
if step_index > self.input_config.guidance_steps-self.input_config.cool_up_steps: | |
scale = (self.input_config.guidance_steps-step_index)/self.input_config.cool_up_steps | |
loss_motion = scale*loss_motion | |
gradient = torch.autograd.grad(loss_motion, control_latents, allow_unused=True)[0] # [1, 4, 16, 64, 64], | |
assert gradient is not None, f"Step {step_index}: grad is None" | |
noise_pred = noise_pred_condition + self.input_config.cfg_scale * (noise_pred_condition - noise_pred_uncondition) | |
control_latents = self.scheduler.customized_step(noise_pred, step_index, control_latents, score=gradient.detach(), | |
**extra_step_kwargs, return_dict=False)[0] # [1, 4, 16, 64, 64] | |
return control_latents.detach() | |
else: | |
with torch.no_grad(): | |
noisy_latents = self.scheduler.scale_model_input(noisy_latents, step_t) | |
noise_pred_group = self.unet( | |
noisy_latents.expand(2,-1,-1,-1,-1), step_t, | |
encoder_hidden_states=self.text_embeddings, | |
down_block_additional_residuals = down_block_additional_residuals, | |
mid_block_additional_residual = mid_block_additional_residual, | |
).sample.to(dtype=noisy_latents.dtype) | |
noise_pred = noise_pred_group[[1]] + self.input_config.cfg_scale * (noise_pred_group[[1]] - noise_pred_group[[0]]) | |
noisy_latents = self.scheduler.customized_step(noise_pred, step_index, noisy_latents, score=None, **extra_step_kwargs, return_dict=False)[0] # [1, 4, 16, 64, 64] | |
return noisy_latents.detach() | |
def get_temp_attn_prob(self,index_select=None): | |
attn_prob_dic = {} | |
for name, module in self.unet.named_modules(): | |
module_name = type(module).__name__ | |
if "VersatileAttention" in module_name and classify_blocks(self.input_config.motion_guidance_blocks, name): | |
key = module.processor.key | |
if index_select is not None: | |
get_index = torch.repeat_interleave(torch.tensor(index_select), repeats=key.shape[0]//len(index_select)) | |
index_all = torch.arange(key.shape[0]) | |
index_picked = index_all[get_index.bool()] | |
key = key[index_picked] | |
key = module.reshape_heads_to_batch_dim(key).contiguous() | |
query = module.processor.query | |
if index_select is not None: | |
query = query[index_picked] | |
query = module.reshape_heads_to_batch_dim(query).contiguous() | |
attention_probs = module.get_attention_scores(query, key, None) | |
attention_probs = attention_probs.reshape(-1, module.heads,attention_probs.shape[1], attention_probs.shape[2]) | |
attn_prob_dic[name] = attention_probs | |
return attn_prob_dic | |
def schedule_customized_step( | |
self, | |
model_output: torch.FloatTensor, | |
step_index: int, | |
sample: torch.FloatTensor, | |
eta: float = 0.0, | |
use_clipped_model_output: bool = False, | |
generator=None, | |
variance_noise: Optional[torch.FloatTensor] = None, | |
return_dict: bool = True, | |
# Guidance parameters | |
score=None, | |
guidance_scale=1.0, | |
indices=None, # [0] | |
return_middle = False, | |
): | |
if self.num_inference_steps is None: | |
raise ValueError( | |
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" | |
) | |
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf | |
# Ideally, read DDIM paper in-detail understanding | |
# Notation (<variable name> -> <name in paper> | |
# - pred_noise_t -> e_theta(x_t, t) | |
# - pred_original_sample -> f_theta(x_t, t) or x_0 | |
# - std_dev_t -> sigma_t | |
# - eta -> η | |
# - pred_sample_direction -> "direction pointing to x_t" | |
# - pred_prev_sample -> "x_t-1" | |
# Support IF models | |
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: | |
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) | |
else: | |
predicted_variance = None | |
timestep = self.timesteps[step_index] | |
# 1. get previous step value (=t-1) | |
# prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps | |
prev_timestep = self.timesteps[step_index+1] if step_index +1 <len(self.timesteps) else -1 | |
# 2. compute alphas, betas | |
alpha_prod_t = self.alphas_cumprod[timestep] | |
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod | |
beta_prod_t = 1 - alpha_prod_t | |
# 3. compute predicted original sample from predicted noise also called | |
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
if self.config.prediction_type == "epsilon": | |
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | |
pred_epsilon = model_output | |
elif self.config.prediction_type == "sample": | |
pred_original_sample = model_output | |
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) | |
elif self.config.prediction_type == "v_prediction": | |
pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output | |
pred_epsilon = (alpha_prod_t ** 0.5) * model_output + (beta_prod_t ** 0.5) * sample | |
else: | |
raise ValueError( | |
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" | |
" `v_prediction`" | |
) | |
# 4. Clip or threshold "predicted x_0" | |
if self.config.thresholding: | |
pred_original_sample = self._threshold_sample(pred_original_sample) | |
elif self.config.clip_sample: | |
pred_original_sample = pred_original_sample.clamp( | |
-self.config.clip_sample_range, self.config.clip_sample_range | |
) | |
# 5. compute variance: "sigma_t(η)" -> see formula (16) | |
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) | |
variance = self._get_variance(timestep, prev_timestep) | |
std_dev_t = eta * variance ** (0.5) | |
if use_clipped_model_output: | |
# the pred_epsilon is always re-derived from the clipped x_0 in Glide | |
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # [2, 4, 64, 64] | |
if score is not None and return_middle: | |
return pred_epsilon, alpha_prod_t, alpha_prod_t_prev, pred_original_sample | |
# 6. apply guidance following the formula (14) from https://arxiv.org/pdf/2105.05233.pdf | |
if score is not None and guidance_scale > 0.0: | |
if indices is not None: | |
# import pdb; pdb.set_trace() | |
assert pred_epsilon[indices].shape == score.shape, "pred_epsilon[indices].shape != score.shape" | |
pred_epsilon[indices] = pred_epsilon[indices] - guidance_scale * (1 - alpha_prod_t) ** (0.5) * score | |
else: | |
assert pred_epsilon.shape == score.shape | |
pred_epsilon = pred_epsilon - guidance_scale * (1 - alpha_prod_t) ** (0.5) * score | |
# | |
# 7. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2) ** (0.5) * pred_epsilon | |
# 8. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction | |
if eta > 0: | |
if variance_noise is not None and generator is not None: | |
raise ValueError( | |
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or" | |
" `variance_noise` stays `None`." | |
) | |
if variance_noise is None: | |
variance_noise = randn_tensor( | |
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype | |
) | |
variance = std_dev_t * variance_noise | |
prev_sample = prev_sample + variance | |
if not return_dict: | |
return (prev_sample,) | |
return prev_sample, pred_original_sample, alpha_prod_t_prev | |
def schedule_set_timesteps(self, num_inference_steps: int, guidance_steps: int = 0, guiduance_scale: float = 0.0, device: Union[str, torch.device] = None,timestep_spacing_type= "uneven"): | |
""" | |
Sets the discrete timesteps used for the diffusion chain (to be run before inference). | |
Args: | |
num_inference_steps (`int`): | |
The number of diffusion steps used when generating samples with a pre-trained model. | |
""" | |
if num_inference_steps > self.config.num_train_timesteps: | |
raise ValueError( | |
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" | |
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" | |
f" maximal {self.config.num_train_timesteps} timesteps." | |
) | |
self.num_inference_steps = num_inference_steps | |
# assign more steps in early denoising stage for motion guidance | |
if timestep_spacing_type == "uneven": | |
timesteps_guidance = ( | |
np.linspace(int((1-guiduance_scale)*self.config.num_train_timesteps), self.config.num_train_timesteps - 1, guidance_steps) | |
.round()[::-1] | |
.copy() | |
.astype(np.int64) | |
) | |
timesteps_vanilla = ( | |
np.linspace(0, int((1-guiduance_scale)*self.config.num_train_timesteps) - 1, num_inference_steps-guidance_steps) | |
.round()[::-1] | |
.copy() | |
.astype(np.int64) | |
) | |
timesteps = np.concatenate((timesteps_guidance, timesteps_vanilla)) | |
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 | |
elif timestep_spacing_type == "linspace": | |
timesteps = ( | |
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) | |
.round()[::-1] | |
.copy() | |
.astype(np.int64) | |
) | |
elif timestep_spacing_type == "leading": | |
step_ratio = self.config.num_train_timesteps // self.num_inference_steps | |
# creates integer timesteps by multiplying by ratio | |
# casting to int to avoid issues when num_inference_step is power of 3 | |
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) | |
timesteps += self.config.steps_offset | |
elif timestep_spacing_type == "trailing": | |
step_ratio = self.config.num_train_timesteps / self.num_inference_steps | |
# creates integer timesteps by multiplying by ratio | |
# casting to int to avoid issues when num_inference_step is power of 3 | |
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) | |
timesteps -= 1 | |
else: | |
raise ValueError( | |
f"{timestep_spacing_type} is not supported. Please make sure to choose one of 'leading' or 'trailing'." | |
) | |
self.timesteps = torch.from_numpy(timesteps).to(device) | |
class UNet3DConditionOutput(BaseOutput): | |
sample: torch.FloatTensor | |
def unet_customized_forward( | |
self, | |
sample: torch.FloatTensor, | |
timestep: Union[torch.Tensor, float, int], | |
encoder_hidden_states: torch.Tensor, | |
class_labels: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
# support controlnet | |
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, | |
mid_block_additional_residual: Optional[torch.Tensor] = None, | |
return_dict: bool = True, | |
only_motion_feature: bool = False, | |
) -> Union[UNet3DConditionOutput, Tuple]: | |
r""" | |
Args: | |
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor | |
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps | |
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. | |
Returns: | |
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: | |
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When | |
returning a tuple, the first element is the sample tensor. | |
""" | |
# By default samples have to be AT least a multiple of the overall upsampling factor. | |
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears). | |
# However, the upsampling interpolation output size can be forced to fit any upsampling size | |
# on the fly if necessary. | |
default_overall_up_factor = 2**self.num_upsamplers | |
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` | |
forward_upsample_size = False | |
upsample_size = None | |
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): | |
logger.info("Forward upsample size to force interpolation output size.") | |
forward_upsample_size = True | |
# prepare attention_mask | |
if attention_mask is not None: | |
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 | |
attention_mask = attention_mask.unsqueeze(1) | |
# center input if necessary | |
if self.config.center_input_sample: | |
sample = 2 * sample - 1.0 | |
# time | |
timesteps = timestep | |
if not torch.is_tensor(timesteps): | |
# This would be a good case for the `match` statement (Python 3.10+) | |
is_mps = sample.device.type == "mps" | |
if isinstance(timestep, float): | |
dtype = torch.float32 if is_mps else torch.float64 | |
else: | |
dtype = torch.int32 if is_mps else torch.int64 | |
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) | |
elif len(timesteps.shape) == 0: | |
timesteps = timesteps[None].to(sample.device) | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
timesteps = timesteps.expand(sample.shape[0]) | |
t_emb = self.time_proj(timesteps) | |
# timesteps does not contain any weights and will always return f32 tensors | |
# but time_embedding might actually be running in fp16. so we need to cast here. | |
# there might be better ways to encapsulate this. | |
t_emb = t_emb.to(dtype=self.dtype) | |
emb = self.time_embedding(t_emb) | |
if self.class_embedding is not None: | |
if class_labels is None: | |
raise ValueError("class_labels should be provided when num_class_embeds > 0") | |
if self.config.class_embed_type == "timestep": | |
class_labels = self.time_proj(class_labels) | |
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) | |
emb = emb + class_emb | |
# pre-process | |
sample = self.conv_in(sample) | |
# down | |
down_block_res_samples = (sample,) | |
for downsample_block in self.down_blocks: | |
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: | |
sample, res_samples = downsample_block( | |
hidden_states=sample, | |
temb=emb, | |
encoder_hidden_states=encoder_hidden_states, | |
attention_mask=attention_mask, | |
) | |
else: | |
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) | |
down_block_res_samples += res_samples | |
# support controlnet | |
down_block_res_samples = list(down_block_res_samples) | |
if down_block_additional_residuals is not None: | |
for i, down_block_additional_residual in enumerate(down_block_additional_residuals): | |
if down_block_additional_residual.dim() == 4: # boardcast | |
down_block_additional_residual = down_block_additional_residual.unsqueeze(2) | |
down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual | |
# mid | |
sample = self.mid_block( | |
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask | |
) | |
# support controlnet | |
if mid_block_additional_residual is not None: | |
if mid_block_additional_residual.dim() == 4: # boardcast | |
mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2) | |
sample = sample + mid_block_additional_residual | |
# up | |
for i, upsample_block in enumerate(self.up_blocks): | |
if i<= int(self.input_config.motion_guidance_blocks[-1].split(".")[-1]): | |
is_final_block = i == len(self.up_blocks) - 1 | |
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] | |
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] | |
# if we have not reached the final block and need to forward the | |
# upsample size, we do it here | |
if not is_final_block and forward_upsample_size: | |
upsample_size = down_block_res_samples[-1].shape[2:] | |
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: | |
sample = upsample_block( | |
hidden_states=sample, | |
temb=emb, | |
res_hidden_states_tuple=res_samples, | |
encoder_hidden_states=encoder_hidden_states, | |
upsample_size=upsample_size, | |
attention_mask=attention_mask, | |
) | |
else: | |
sample = upsample_block( | |
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, | |
) | |
else: | |
if only_motion_feature: | |
return 0 | |
with torch.no_grad(): | |
is_final_block = i == len(self.up_blocks) - 1 | |
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] | |
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] | |
# if we have not reached the final block and need to forward the | |
# upsample size, we do it here | |
if not is_final_block and forward_upsample_size: | |
upsample_size = down_block_res_samples[-1].shape[2:] | |
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: | |
sample = upsample_block( | |
hidden_states=sample, | |
temb=emb, | |
res_hidden_states_tuple=res_samples, | |
encoder_hidden_states=encoder_hidden_states, | |
upsample_size=upsample_size, | |
attention_mask=attention_mask, | |
) | |
else: | |
sample = upsample_block( | |
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, | |
) | |
# post-process | |
sample = self.conv_norm_out(sample) | |
sample = self.conv_act(sample) | |
sample = self.conv_out(sample) | |
if not return_dict: | |
return (sample,) | |
return UNet3DConditionOutput(sample=sample) | |