Video-CCAM-14B-v1.1 / projector /modeling_ccam.py
jaronfei
first commit
c61bdc1
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
================================================
@author: Jaron
@time: 2024/07/10 19:47:01
@email: [email protected]
@description: Causal Cross-Attention Mask (CCAM)
================================================
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from .configuration_ccam import CCAMConfig
class CCAMMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_act = config.hidden_act
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.output_size = config.output_size
if self.hidden_act == 'swiglu':
self.fc1 = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.mlp_bias)
self.act_fn = ACT2FN['silu']
else:
self.fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[self.hidden_act]
self.fc2 = nn.Linear(self.intermediate_size, self.output_size, bias=config.mlp_bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
if self.hidden_act == 'swiglu':
gate, up = hidden_states.chunk(2, dim=-1)
hidden_states = self.act_fn(gate) * up
else:
hidden_states = self.act_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class CCAMCrossAttention(nn.Module):
"""Cross-attention layer of the CCAM projector.
Flash Attention 2 is not supported since the mask may be neither full nor causal. Only support `attn_implementation` as `eager` and `sdpa`.
"""
def __init__(self, config):
super().__init__()
self.num_heads = config.num_heads
self.hidden_size = config.hidden_size
self.attention_bias = config.attention_bias
self.attention_dropout = config.attention_dropout
self.cross_hidden_size = config.cross_hidden_size
self.num_key_value_heads = config.num_key_value_heads
self.attn_implementation = config._attn_implementation
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
assert self.head_dim * self.num_heads == self.hidden_size, f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads}).'
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.attention_bias)
self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=self.attention_bias)
self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=self.attention_bias)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.attention_bias)
def forward(
self,
hidden_states: torch.Tensor, # (B, Q, C)
cross_hidden_states: torch.Tensor, # (B, L, C')
attention_mask: torch.Tensor = None # (Q, L), '-inf' means masked, 0 means not masked
) -> torch.Tensor: # (B, Q, C)
B, Q, C = hidden_states.size()
query_states = self.q_proj(hidden_states) # (B, Q, C)
key_states = self.k_proj(cross_hidden_states)
value_states = self.v_proj(cross_hidden_states)
L = key_states.size(1)
query_states = query_states.view(B, Q, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(B, L, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(B, L, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if self.num_key_value_groups > 1:
key_states = key_states.repeat_interleave(repeats=self.num_key_value_groups, dim=1)
value_states = value_states.repeat_interleave(repeats=self.num_key_value_groups, dim=1)
if self.attn_implementation == 'eager':
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / self.head_dim ** 0.5 # (B, num_heads, Q, L)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask.view(1, 1, Q, L)
# upcast attention to fp32
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states) # (B, num_heads, Q, head_dim)
else: # 'sdpa'
# there are bugs in torch <=2.1.0, requiring qkv as contiguous(), be careful
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0
)
attn_output = attn_output.transpose(1, 2).reshape(B, Q, C) # (B, Q, C)
attn_output = self.o_proj(attn_output)
return attn_output
class CCAMModel(PreTrainedModel):
"""Causal Cross-Attention Mask Projector"""
config_class = CCAMConfig
_auto_class = 'AutoModel'
_supports_sdpa = True
_no_split_modules = ['CCAMCrossAttention', 'CCAMMLP']
def __init__(self, config):
super().__init__(config)
self.num_query = config.num_query
self.hidden_size = config.hidden_size
self.output_size = config.output_size
self.cross_hidden_size = config.cross_hidden_size
self.query = nn.Parameter(torch.empty(1, self.num_query, self.hidden_size).normal_(mean=.0, std=.02))
self.pre_ccam = nn.Sequential(
nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps),
nn.Dropout(config.dropout)
)
self.ccam = CCAMCrossAttention(config)
self.post_ccam = nn.Sequential(
nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps),
nn.Dropout(config.dropout),
CCAMMLP(config)
)
self.post_init()
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=.0, std=.02)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
def _get_mask(self, vision_hidden_state: torch.Tensor) -> torch.Tensor: # (Q, T*L)
"""Compute CCAM Mask for vision hidden state
Args:
vision_hidden_state (torch.Tensor): (T, L, C)
Returns:
torch.Tensor: (Q, T*L) -inf means masked
"""
T, L, _ = vision_hidden_state.size()
dtype, device = vision_hidden_state.dtype, vision_hidden_state.device
base_mask = torch.zeros(T, T, dtype=dtype, device=device)
t = torch.arange(T, device=device)
base_mask.masked_fill_(t > t[:, None], float('-inf'))
attention_mask = torch.zeros(self.num_query, T * L, dtype=dtype, device=device)
attention_mask[:self.num_query // T * T] = torch.kron(base_mask, torch.ones(self.num_query // T, L, dtype=dtype, device=device))
return attention_mask
def forward(self, vision_hidden_states: list[torch.Tensor]) -> torch.Tensor: # (B, Q, C)
"""Forward function, do not collect batch due to the support of zero3
Args:
vision_hidden_states (list[torch.Tensor]): [(t0, L, C), (t1, L, C), ...]
Returns:
torch.Tensor: (B, Q, C)
"""
output = []
for hidden_states in vision_hidden_states:
# reshape inputs and construct ccam masks
attention_mask = self._get_mask(hidden_states) # (Q, ti * L)
# forward
x = self.pre_ccam(self.query) # (1, Q, C)
x = self.ccam(
hidden_states=x, # (1, Q, C)
cross_hidden_states=hidden_states.flatten(0, 1)[None], # (1, ti * L, C')
attention_mask=attention_mask[None] # (1, Q, ti * L)
) + x
x = self.post_ccam(x)
output.append(x)
output = torch.cat(output, dim=0)
return output