from dataclasses import dataclass import os import numpy as np import torch import matplotlib.pyplot as plt import matplotlib.colors as mcolors from typing import Callable, List, Optional, Union from diffusers.utils import deprecate, logging, BaseOutput from .xformer_attention import * from .conv_layer import * from .util import * from diffusers.utils.torch_utils import randn_tensor from typing import List, Optional, Tuple, Union logger = logging.get_logger(__name__) # pylint: disable=invalid-name from motionclone.utils.util import video_preprocess import einops import torchvision.transforms as transforms def add_noise(self, timestep, x_0, noise_pred): alpha_prod_t = self.scheduler.alphas_cumprod[timestep] beta_prod_t = 1 - alpha_prod_t latents_noise = alpha_prod_t ** 0.5 * x_0 + beta_prod_t ** 0.5 * noise_pred return latents_noise @torch.no_grad() def obtain_motion_representation(self, generator=None, motion_representation_path: str = None, duration=None,use_controlnet=False,): video_data = video_preprocess(self.input_config.video_path, self.input_config.height, self.input_config.width, self.input_config.video_length,duration=duration) video_latents = self.vae.encode(video_data.to(self.vae.dtype).to(self.vae.device)).latent_dist.mode() video_latents = self.vae.config.scaling_factor * video_latents video_latents = video_latents.unsqueeze(0) video_latents = einops.rearrange(video_latents, "b f c h w -> b c f h w") uncond_input = self.tokenizer( [""], padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt" ) step_t = int(self.input_config.add_noise_step) uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] noise_sampled = randn_tensor(video_latents.shape, generator=generator, device=video_latents.device, dtype=video_latents.dtype) noisy_latents = self.add_noise(step_t, video_latents, noise_sampled) down_block_additional_residuals = mid_block_additional_residual = None if use_controlnet: controlnet_image_index = self.input_config.image_index if self.controlnet.use_simplified_condition_embedding: controlnet_images = video_latents[:,:,controlnet_image_index,:,:] else: controlnet_images = (einops.rearrange(video_data.unsqueeze(0).to(self.vae.dtype).to(self.vae.device), "b f c h w -> b c f h w")+1)/2 controlnet_images = controlnet_images[:,:,controlnet_image_index,:,:] controlnet_cond_shape = list(controlnet_images.shape) controlnet_cond_shape[2] = noisy_latents.shape[2] controlnet_cond = torch.zeros(controlnet_cond_shape).to(noisy_latents.device).to(noisy_latents.dtype) controlnet_conditioning_mask_shape = list(controlnet_cond.shape) controlnet_conditioning_mask_shape[1] = 1 controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(noisy_latents.device).to(noisy_latents.dtype) controlnet_cond[:,:,controlnet_image_index] = controlnet_images controlnet_conditioning_mask[:,:,controlnet_image_index] = 1 down_block_additional_residuals, mid_block_additional_residual = self.controlnet( noisy_latents, step_t, encoder_hidden_states=uncond_embeddings, controlnet_cond=controlnet_cond, conditioning_mask=controlnet_conditioning_mask, conditioning_scale=self.input_config.controlnet_scale, guess_mode=False, return_dict=False, ) _ = self.unet(noisy_latents, step_t, encoder_hidden_states=uncond_embeddings, return_dict=False, only_motion_feature=True, down_block_additional_residuals = down_block_additional_residuals, mid_block_additional_residual = mid_block_additional_residual,) temp_attn_prob_control = self.get_temp_attn_prob() motion_representation = { key: [max_value, max_index.to(torch.uint8)] for key, tensor in temp_attn_prob_control.items() for max_value, max_index in [torch.topk(tensor, k=1, dim=-1)]} torch.save(motion_representation, motion_representation_path) self.motion_representation_path = motion_representation_path def compute_temp_loss(self, temp_attn_prob_control_dict): temp_attn_prob_loss = [] for name in temp_attn_prob_control_dict.keys(): current_temp_attn_prob = temp_attn_prob_control_dict[name] reference_representation_dict = self.motion_representation_dict[name] max_index = reference_representation_dict[1].to(torch.int64).to(current_temp_attn_prob.device) current_motion_representation = torch.gather(current_temp_attn_prob, index = max_index, dim=-1) reference_motion_representation = reference_representation_dict[0].to(dtype = current_motion_representation.dtype, device = current_motion_representation.device) module_attn_loss = F.mse_loss(current_motion_representation, reference_motion_representation.detach()) temp_attn_prob_loss.append(module_attn_loss) loss_temp = torch.stack(temp_attn_prob_loss) return loss_temp.sum() def sample_video( self, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, noisy_latents: Optional[torch.FloatTensor] = None, add_controlnet: bool = False, ): # Determine if use controlnet, i.e., conditional image2video self.add_controlnet = add_controlnet if self.add_controlnet: image_transforms = transforms.Compose([ transforms.Resize((self.input_config.height, self.input_config.width)), transforms.ToTensor(), ]) controlnet_images = [image_transforms(Image.open(path).convert("RGB")) for path in self.input_config.condition_image_path_list] controlnet_images = torch.stack(controlnet_images).unsqueeze(0).to(dtype=self.vae.dtype,device=self.vae.device) controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w") with torch.no_grad(): if self.controlnet.use_simplified_condition_embedding: num_controlnet_images = controlnet_images.shape[2] controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w") controlnet_images = self.vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * self.vae.config.scaling_factor self.controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images) else: self.controlnet_images = controlnet_images # Define call parameters # perform classifier_free_guidance in default batch_size = 1 do_classifier_free_guidance = True device = self._execution_device # Encode input prompt self.text_embeddings = self._encode_prompt(self.input_config.new_prompt, device, 1, do_classifier_free_guidance, self.input_config.negative_prompt) # [uncond_embeddings, text_embeddings] [2, 77, 768] # Prepare latent variables noisy_latents = self.prepare_latents( batch_size, self.unet.config.in_channels, self.input_config.video_length, self.input_config.height, self.input_config.width, self.text_embeddings.dtype, device, generator, noisy_latents, ) self.motion_representation_dict = torch.load(self.motion_representation_path) self.motion_scale = self.input_config.motion_guidance_weight extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # save GPU memory # self.vae.to(device = "cpu") # self.text_encoder.to(device = "cpu") # torch.cuda.empty_cache() with self.progress_bar(total=self.input_config.inference_steps) as progress_bar: for step_index, step_t in enumerate(self.scheduler.timesteps): noisy_latents = self.single_step_video(noisy_latents, step_index, step_t, extra_step_kwargs) progress_bar.update() # decode latents for videos video = self.decode_latents(noisy_latents) return video def single_step_video(self, noisy_latents, step_index, step_t, extra_step_kwargs): down_block_additional_residuals = mid_block_additional_residual = None if self.add_controlnet: with torch.no_grad(): controlnet_cond_shape = list(self.controlnet_images.shape) controlnet_cond_shape[2] = noisy_latents.shape[2] controlnet_cond = torch.zeros(controlnet_cond_shape).to(noisy_latents.device).to(noisy_latents.dtype) controlnet_conditioning_mask_shape = list(controlnet_cond.shape) controlnet_conditioning_mask_shape[1] = 1 controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(noisy_latents.device).to(noisy_latents.dtype) controlnet_image_index = self.input_config.image_index controlnet_cond[:,:,controlnet_image_index] = self.controlnet_images controlnet_conditioning_mask[:,:,controlnet_image_index] = 1 down_block_additional_residuals, mid_block_additional_residual = self.controlnet( noisy_latents.expand(2,-1,-1,-1,-1), step_t, encoder_hidden_states=self.text_embeddings, controlnet_cond=controlnet_cond, conditioning_mask=controlnet_conditioning_mask, conditioning_scale=self.input_config.controlnet_scale, guess_mode=False, return_dict=False, ) # Only require grad when need to compute the gradient for guidance if step_index < self.input_config.guidance_steps: down_block_additional_residuals_uncond = down_block_additional_residuals_cond = None mid_block_additional_residual_uncond = mid_block_additional_residual_cond = None if self.add_controlnet: down_block_additional_residuals_uncond = [tensor[[0],...].detach() for tensor in down_block_additional_residuals] down_block_additional_residuals_cond = [tensor[[1],...].detach() for tensor in down_block_additional_residuals] mid_block_additional_residual_uncond = mid_block_additional_residual[[0],...].detach() mid_block_additional_residual_cond = mid_block_additional_residual[[1],...].detach() control_latents = noisy_latents.clone().detach() control_latents.requires_grad = True control_latents = self.scheduler.scale_model_input(control_latents, step_t) noisy_latents = self.scheduler.scale_model_input(noisy_latents, step_t) with torch.no_grad(): noise_pred_uncondition = self.unet(noisy_latents, step_t, encoder_hidden_states=self.text_embeddings[[0]], down_block_additional_residuals = down_block_additional_residuals_uncond, mid_block_additional_residual = mid_block_additional_residual_uncond,).sample.to(dtype=noisy_latents.dtype) noise_pred_condition = self.unet(control_latents, step_t, encoder_hidden_states=self.text_embeddings[[1]], down_block_additional_residuals = down_block_additional_residuals_cond, mid_block_additional_residual = mid_block_additional_residual_cond,).sample.to(dtype=noisy_latents.dtype) temp_attn_prob_control = self.get_temp_attn_prob() loss_motion = self.motion_scale * self.compute_temp_loss(temp_attn_prob_control,) if step_index < self.input_config.warm_up_steps: scale = (step_index+1)/self.input_config.warm_up_steps loss_motion = scale*loss_motion if step_index > self.input_config.guidance_steps-self.input_config.cool_up_steps: scale = (self.input_config.guidance_steps-step_index)/self.input_config.cool_up_steps loss_motion = scale*loss_motion gradient = torch.autograd.grad(loss_motion, control_latents, allow_unused=True)[0] # [1, 4, 16, 64, 64], assert gradient is not None, f"Step {step_index}: grad is None" noise_pred = noise_pred_condition + self.input_config.cfg_scale * (noise_pred_condition - noise_pred_uncondition) control_latents = self.scheduler.customized_step(noise_pred, step_index, control_latents, score=gradient.detach(), **extra_step_kwargs, return_dict=False)[0] # [1, 4, 16, 64, 64] return control_latents.detach() else: with torch.no_grad(): noisy_latents = self.scheduler.scale_model_input(noisy_latents, step_t) noise_pred_group = self.unet( noisy_latents.expand(2,-1,-1,-1,-1), step_t, encoder_hidden_states=self.text_embeddings, down_block_additional_residuals = down_block_additional_residuals, mid_block_additional_residual = mid_block_additional_residual, ).sample.to(dtype=noisy_latents.dtype) noise_pred = noise_pred_group[[1]] + self.input_config.cfg_scale * (noise_pred_group[[1]] - noise_pred_group[[0]]) noisy_latents = self.scheduler.customized_step(noise_pred, step_index, noisy_latents, score=None, **extra_step_kwargs, return_dict=False)[0] # [1, 4, 16, 64, 64] return noisy_latents.detach() def get_temp_attn_prob(self,index_select=None): attn_prob_dic = {} for name, module in self.unet.named_modules(): module_name = type(module).__name__ if "VersatileAttention" in module_name and classify_blocks(self.input_config.motion_guidance_blocks, name): key = module.processor.key if index_select is not None: get_index = torch.repeat_interleave(torch.tensor(index_select), repeats=key.shape[0]//len(index_select)) index_all = torch.arange(key.shape[0]) index_picked = index_all[get_index.bool()] key = key[index_picked] key = module.reshape_heads_to_batch_dim(key).contiguous() query = module.processor.query if index_select is not None: query = query[index_picked] query = module.reshape_heads_to_batch_dim(query).contiguous() attention_probs = module.get_attention_scores(query, key, None) attention_probs = attention_probs.reshape(-1, module.heads,attention_probs.shape[1], attention_probs.shape[2]) attn_prob_dic[name] = attention_probs return attn_prob_dic @torch.no_grad() def schedule_customized_step( self, model_output: torch.FloatTensor, step_index: int, sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, variance_noise: Optional[torch.FloatTensor] = None, return_dict: bool = True, # Guidance parameters score=None, guidance_scale=1.0, indices=None, # [0] return_middle = False, ): if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding # Notation ( -> # - pred_noise_t -> e_theta(x_t, t) # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # Support IF models if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None timestep = self.timesteps[step_index] # 1. get previous step value (=t-1) # prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps prev_timestep = self.timesteps[step_index+1] if step_index +1 = 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output pred_epsilon = (alpha_prod_t ** 0.5) * model_output + (beta_prod_t ** 0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) if use_clipped_model_output: # the pred_epsilon is always re-derived from the clipped x_0 in Glide pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # [2, 4, 64, 64] if score is not None and return_middle: return pred_epsilon, alpha_prod_t, alpha_prod_t_prev, pred_original_sample # 6. apply guidance following the formula (14) from https://arxiv.org/pdf/2105.05233.pdf if score is not None and guidance_scale > 0.0: if indices is not None: # import pdb; pdb.set_trace() assert pred_epsilon[indices].shape == score.shape, "pred_epsilon[indices].shape != score.shape" pred_epsilon[indices] = pred_epsilon[indices] - guidance_scale * (1 - alpha_prod_t) ** (0.5) * score else: assert pred_epsilon.shape == score.shape pred_epsilon = pred_epsilon - guidance_scale * (1 - alpha_prod_t) ** (0.5) * score # # 7. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2) ** (0.5) * pred_epsilon # 8. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if eta > 0: if variance_noise is not None and generator is not None: raise ValueError( "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" " `variance_noise` stays `None`." ) if variance_noise is None: variance_noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype ) variance = std_dev_t * variance_noise prev_sample = prev_sample + variance if not return_dict: return (prev_sample,) return prev_sample, pred_original_sample, alpha_prod_t_prev def schedule_set_timesteps(self, num_inference_steps: int, guidance_steps: int = 0, guiduance_scale: float = 0.0, device: Union[str, torch.device] = None,timestep_spacing_type= "uneven"): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. """ if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps # assign more steps in early denoising stage for motion guidance if timestep_spacing_type == "uneven": timesteps_guidance = ( np.linspace(int((1-guiduance_scale)*self.config.num_train_timesteps), self.config.num_train_timesteps - 1, guidance_steps) .round()[::-1] .copy() .astype(np.int64) ) timesteps_vanilla = ( np.linspace(0, int((1-guiduance_scale)*self.config.num_train_timesteps) - 1, num_inference_steps-guidance_steps) .round()[::-1] .copy() .astype(np.int64) ) timesteps = np.concatenate((timesteps_guidance, timesteps_vanilla)) # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 elif timestep_spacing_type == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) .round()[::-1] .copy() .astype(np.int64) ) elif timestep_spacing_type == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif timestep_spacing_type == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) timesteps -= 1 else: raise ValueError( f"{timestep_spacing_type} is not supported. Please make sure to choose one of 'leading' or 'trailing'." ) self.timesteps = torch.from_numpy(timesteps).to(device) @dataclass class UNet3DConditionOutput(BaseOutput): sample: torch.FloatTensor def unet_customized_forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, # support controlnet down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, return_dict: bool = True, only_motion_feature: bool = False, ) -> Union[UNet3DConditionOutput, Tuple]: r""" Args: sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # time timesteps = timestep if not torch.is_tensor(timesteps): # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb # pre-process sample = self.conv_in(sample) # down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) down_block_res_samples += res_samples # support controlnet down_block_res_samples = list(down_block_res_samples) if down_block_additional_residuals is not None: for i, down_block_additional_residual in enumerate(down_block_additional_residuals): if down_block_additional_residual.dim() == 4: # boardcast down_block_additional_residual = down_block_additional_residual.unsqueeze(2) down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual # mid sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask ) # support controlnet if mid_block_additional_residual is not None: if mid_block_additional_residual.dim() == 4: # boardcast mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2) sample = sample + mid_block_additional_residual # up for i, upsample_block in enumerate(self.up_blocks): if i<= int(self.input_config.motion_guidance_blocks[-1].split(".")[-1]): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, upsample_size=upsample_size, attention_mask=attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, ) else: if only_motion_feature: return 0 with torch.no_grad(): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, upsample_size=upsample_size, attention_mask=attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, ) # post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNet3DConditionOutput(sample=sample)