|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from typing import Any, Dict, Optional, Tuple, Union |
|
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers |
|
from diffusers.models.modeling_outputs import Transformer2DModelOutput |
|
from safetensors.torch import load_file |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
@torch.no_grad() |
|
def decode_latents(pipe, latents): |
|
video = pipe.decode_latents(latents) |
|
video = pipe.video_processor.postprocess_video(video=video, output_type="np") |
|
return video |
|
|
|
def create_attention_mask(text_length: int, seq_length: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: |
|
""" |
|
Create an attention mask to block text from attending to alpha. |
|
|
|
Args: |
|
text_length: Length of the text sequence. |
|
seq_length: Length of the other sequence. |
|
device: The device where the mask will be stored. |
|
dtype: The data type of the mask tensor. |
|
|
|
Returns: |
|
An attention mask tensor. |
|
""" |
|
total_length = text_length + seq_length |
|
dense_mask = torch.ones((total_length, total_length), dtype=torch.bool) |
|
dense_mask[:text_length, text_length + seq_length // 2:] = False |
|
return dense_mask.to(device=device, dtype=dtype) |
|
|
|
class RGBALoRACogVideoXAttnProcessor: |
|
r""" |
|
Processor for implementing scaled dot-product attention for the CogVideoX model. |
|
It applies a rotary embedding on query and key vectors, but does not include spatial normalization. |
|
""" |
|
|
|
def __init__(self, device, dtype, attention_mask, lora_rank=128, lora_alpha=1.0, latent_dim=3072): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0 or later.") |
|
|
|
|
|
self.lora_alpha = lora_alpha |
|
self.lora_rank = lora_rank |
|
|
|
|
|
def create_lora_layer(in_dim, mid_dim, out_dim): |
|
return nn.Sequential( |
|
nn.Linear(in_dim, mid_dim, bias=False, device=device, dtype=dtype), |
|
nn.Linear(mid_dim, out_dim, bias=False, device=device, dtype=dtype) |
|
) |
|
|
|
self.to_q_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) |
|
self.to_k_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) |
|
self.to_v_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) |
|
self.to_out_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) |
|
|
|
|
|
self.attention_mask = attention_mask |
|
|
|
def _apply_lora(self, hidden_states, seq_len, query, key, value, scaling): |
|
"""Applies LoRA updates to query, key, and value tensors.""" |
|
query_delta = self.to_q_lora(hidden_states).to(query.device) |
|
query[:, -seq_len // 2:, :] += query_delta[:, -seq_len // 2:, :] * scaling |
|
|
|
key_delta = self.to_k_lora(hidden_states).to(key.device) |
|
key[:, -seq_len // 2:, :] += key_delta[:, -seq_len // 2:, :] * scaling |
|
|
|
value_delta = self.to_v_lora(hidden_states).to(value.device) |
|
value[:, -seq_len // 2:, :] += value_delta[:, -seq_len // 2:, :] * scaling |
|
|
|
return query, key, value |
|
|
|
def _apply_rotary_embedding(self, query, key, image_rotary_emb, seq_len, text_seq_length, attn): |
|
"""Applies rotary embeddings to query and key tensors.""" |
|
from diffusers.models.embeddings import apply_rotary_emb |
|
|
|
|
|
query[:, :, text_seq_length:text_seq_length + seq_len // 2] = apply_rotary_emb( |
|
query[:, :, text_seq_length:text_seq_length + seq_len // 2], image_rotary_emb) |
|
query[:, :, text_seq_length + seq_len // 2:] = apply_rotary_emb( |
|
query[:, :, text_seq_length + seq_len // 2:], image_rotary_emb) |
|
|
|
if not attn.is_cross_attention: |
|
key[:, :, text_seq_length:text_seq_length + seq_len // 2] = apply_rotary_emb( |
|
key[:, :, text_seq_length:text_seq_length + seq_len // 2], image_rotary_emb) |
|
key[:, :, text_seq_length + seq_len // 2:] = apply_rotary_emb( |
|
key[:, :, text_seq_length + seq_len // 2:], image_rotary_emb) |
|
|
|
return query, key |
|
|
|
def __call__( |
|
self, |
|
attn, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
image_rotary_emb: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
|
|
text_seq_length = encoder_hidden_states.size(1) |
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) |
|
|
|
batch_size, sequence_length, _ = hidden_states.shape |
|
seq_len = hidden_states.shape[1] - text_seq_length |
|
scaling = self.lora_alpha / self.lora_rank |
|
|
|
|
|
query = attn.to_q(hidden_states) |
|
key = attn.to_k(hidden_states) |
|
value = attn.to_v(hidden_states) |
|
|
|
query, key, value = self._apply_lora(hidden_states, seq_len, query, key, value, scaling) |
|
|
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
if attn.norm_q is not None: |
|
query = attn.norm_q(query) |
|
if attn.norm_k is not None: |
|
key = attn.norm_k(key) |
|
|
|
|
|
if image_rotary_emb is not None: |
|
query, key = self._apply_rotary_embedding(query, key, image_rotary_emb, seq_len, text_seq_length, attn) |
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=self.attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
|
|
|
|
original_hidden_states = attn.to_out[0](hidden_states) |
|
hidden_states_delta = self.to_out_lora(hidden_states).to(hidden_states.device) |
|
original_hidden_states[:, -seq_len // 2:, :] += hidden_states_delta[:, -seq_len // 2:, :] * scaling |
|
|
|
|
|
hidden_states = attn.to_out[1](original_hidden_states) |
|
|
|
|
|
encoder_hidden_states, hidden_states = hidden_states.split( |
|
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 |
|
) |
|
|
|
return hidden_states, encoder_hidden_states |
|
|
|
def prepare_for_rgba_inference( |
|
model, rgba_weights_path: str, device: torch.device, dtype: torch.dtype, |
|
lora_rank: int = 128, lora_alpha: float = 1.0, text_length: int = 226, seq_length: int = 35100 |
|
): |
|
def load_lora_sequential_weights(lora_layer, lora_layers, prefix): |
|
lora_layer[0].load_state_dict({'weight': lora_layers[f"{prefix}.lora_A.weight"]}) |
|
lora_layer[1].load_state_dict({'weight': lora_layers[f"{prefix}.lora_B.weight"]}) |
|
|
|
|
|
rgba_weights = load_file(rgba_weights_path) |
|
aux_emb = rgba_weights['domain_emb'] |
|
|
|
attention_mask = create_attention_mask(text_length, seq_length, device, dtype) |
|
attn_procs = {} |
|
|
|
for name in model.attn_processors.keys(): |
|
attn_processor = RGBALoRACogVideoXAttnProcessor( |
|
device=device, dtype=dtype, attention_mask=attention_mask, |
|
lora_rank=lora_rank, lora_alpha=lora_alpha |
|
) |
|
|
|
index = name.split('.')[1] |
|
base_prefix = f'transformer.transformer_blocks.{index}.attn1' |
|
|
|
for lora_layer, prefix in [ |
|
(attn_processor.to_q_lora, f'{base_prefix}.to_q'), |
|
(attn_processor.to_k_lora, f'{base_prefix}.to_k'), |
|
(attn_processor.to_v_lora, f'{base_prefix}.to_v'), |
|
(attn_processor.to_out_lora, f'{base_prefix}.to_out.0'), |
|
]: |
|
load_lora_sequential_weights(lora_layer, rgba_weights, prefix) |
|
|
|
attn_procs[name] = attn_processor |
|
|
|
model.set_attn_processor(attn_procs) |
|
|
|
def custom_forward(self): |
|
def forward( |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor, |
|
timestep: Union[int, float, torch.LongTensor], |
|
timestep_cond: Optional[torch.Tensor] = None, |
|
ofs: Optional[Union[int, float, torch.LongTensor]] = None, |
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
attention_kwargs: Optional[Dict[str, Any]] = None, |
|
return_dict: bool = True, |
|
): |
|
if attention_kwargs is not None: |
|
attention_kwargs = attention_kwargs.copy() |
|
lora_scale = attention_kwargs.pop("scale", 1.0) |
|
else: |
|
lora_scale = 1.0 |
|
|
|
if USE_PEFT_BACKEND: |
|
|
|
scale_lora_layers(self, lora_scale) |
|
else: |
|
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: |
|
logger.warning( |
|
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." |
|
) |
|
|
|
batch_size, num_frames, channels, height, width = hidden_states.shape |
|
|
|
|
|
timesteps = timestep |
|
t_emb = self.time_proj(timesteps) |
|
|
|
|
|
|
|
|
|
t_emb = t_emb.to(dtype=hidden_states.dtype) |
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
|
|
if self.ofs_embedding is not None: |
|
ofs_emb = self.ofs_proj(ofs) |
|
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype) |
|
ofs_emb = self.ofs_embedding(ofs_emb) |
|
emb = emb + ofs_emb |
|
|
|
|
|
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) |
|
hidden_states = self.embedding_dropout(hidden_states) |
|
|
|
text_seq_length = encoder_hidden_states.shape[1] |
|
encoder_hidden_states = hidden_states[:, :text_seq_length] |
|
hidden_states = hidden_states[:, text_seq_length:] |
|
|
|
hidden_states[:, hidden_states.size(1) // 2:, :] += aux_emb.expand(batch_size, -1, -1).to(hidden_states.device, dtype=hidden_states.dtype) |
|
|
|
|
|
for i, block in enumerate(self.transformer_blocks): |
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module): |
|
def custom_forward(*inputs): |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(block), |
|
hidden_states, |
|
encoder_hidden_states, |
|
emb, |
|
image_rotary_emb, |
|
**ckpt_kwargs, |
|
) |
|
else: |
|
hidden_states, encoder_hidden_states = block( |
|
hidden_states=hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
temb=emb, |
|
image_rotary_emb=image_rotary_emb, |
|
) |
|
|
|
if not self.config.use_rotary_positional_embeddings: |
|
|
|
hidden_states = self.norm_final(hidden_states) |
|
else: |
|
|
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) |
|
hidden_states = self.norm_final(hidden_states) |
|
hidden_states = hidden_states[:, text_seq_length:] |
|
|
|
|
|
hidden_states = self.norm_out(hidden_states, temb=emb) |
|
hidden_states = self.proj_out(hidden_states) |
|
|
|
|
|
p = self.config.patch_size |
|
p_t = self.config.patch_size_t |
|
|
|
if p_t is None: |
|
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) |
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) |
|
else: |
|
output = hidden_states.reshape( |
|
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p |
|
) |
|
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) |
|
|
|
if USE_PEFT_BACKEND: |
|
|
|
unscale_lora_layers(self, lora_scale) |
|
|
|
if not return_dict: |
|
return (output,) |
|
return Transformer2DModelOutput(sample=output) |
|
|
|
|
|
return forward |
|
|
|
model.forward = custom_forward(model) |
|
|
|
|