File size: 8,636 Bytes
c61bdc1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
================================================
@author: Jaron
@time: 2024/07/10 19:47:01
@email: [email protected]
@description: Causal Cross-Attention Mask (CCAM)
================================================
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from .configuration_ccam import CCAMConfig
class CCAMMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_act = config.hidden_act
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.output_size = config.output_size
if self.hidden_act == 'swiglu':
self.fc1 = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.mlp_bias)
self.act_fn = ACT2FN['silu']
else:
self.fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[self.hidden_act]
self.fc2 = nn.Linear(self.intermediate_size, self.output_size, bias=config.mlp_bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
if self.hidden_act == 'swiglu':
gate, up = hidden_states.chunk(2, dim=-1)
hidden_states = self.act_fn(gate) * up
else:
hidden_states = self.act_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class CCAMCrossAttention(nn.Module):
"""Cross-attention layer of the CCAM projector.
Flash Attention 2 is not supported since the mask may be neither full nor causal. Only support `attn_implementation` as `eager` and `sdpa`.
"""
def __init__(self, config):
super().__init__()
self.num_heads = config.num_heads
self.hidden_size = config.hidden_size
self.attention_bias = config.attention_bias
self.attention_dropout = config.attention_dropout
self.cross_hidden_size = config.cross_hidden_size
self.num_key_value_heads = config.num_key_value_heads
self.attn_implementation = config._attn_implementation
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
assert self.head_dim * self.num_heads == self.hidden_size, f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads}).'
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.attention_bias)
self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=self.attention_bias)
self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=self.attention_bias)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.attention_bias)
def forward(
self,
hidden_states: torch.Tensor, # (B, Q, C)
cross_hidden_states: torch.Tensor, # (B, L, C')
attention_mask: torch.Tensor = None # (Q, L), '-inf' means masked, 0 means not masked
) -> torch.Tensor: # (B, Q, C)
B, Q, C = hidden_states.size()
query_states = self.q_proj(hidden_states) # (B, Q, C)
key_states = self.k_proj(cross_hidden_states)
value_states = self.v_proj(cross_hidden_states)
L = key_states.size(1)
query_states = query_states.view(B, Q, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(B, L, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(B, L, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if self.num_key_value_groups > 1:
key_states = key_states.repeat_interleave(repeats=self.num_key_value_groups, dim=1)
value_states = value_states.repeat_interleave(repeats=self.num_key_value_groups, dim=1)
if self.attn_implementation == 'eager':
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / self.head_dim ** 0.5 # (B, num_heads, Q, L)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask.view(1, 1, Q, L)
# upcast attention to fp32
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states) # (B, num_heads, Q, head_dim)
else: # 'sdpa'
# there are bugs in torch <=2.1.0, requiring qkv as contiguous(), be careful
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0
)
attn_output = attn_output.transpose(1, 2).reshape(B, Q, C) # (B, Q, C)
attn_output = self.o_proj(attn_output)
return attn_output
class CCAMModel(PreTrainedModel):
"""Causal Cross-Attention Mask Projector"""
config_class = CCAMConfig
_auto_class = 'AutoModel'
_supports_sdpa = True
_no_split_modules = ['CCAMCrossAttention', 'CCAMMLP']
def __init__(self, config):
super().__init__(config)
self.num_query = config.num_query
self.hidden_size = config.hidden_size
self.output_size = config.output_size
self.cross_hidden_size = config.cross_hidden_size
self.query = nn.Parameter(torch.empty(1, self.num_query, self.hidden_size).normal_(mean=.0, std=.02))
self.pre_ccam = nn.Sequential(
nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps),
nn.Dropout(config.dropout)
)
self.ccam = CCAMCrossAttention(config)
self.post_ccam = nn.Sequential(
nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps),
nn.Dropout(config.dropout),
CCAMMLP(config)
)
self.post_init()
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=.0, std=.02)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
def _get_mask(self, vision_hidden_state: torch.Tensor) -> torch.Tensor: # (Q, T*L)
"""Compute CCAM Mask for vision hidden state
Args:
vision_hidden_state (torch.Tensor): (T, L, C)
Returns:
torch.Tensor: (Q, T*L) -inf means masked
"""
T, L, _ = vision_hidden_state.size()
dtype, device = vision_hidden_state.dtype, vision_hidden_state.device
base_mask = torch.zeros(T, T, dtype=dtype, device=device)
t = torch.arange(T, device=device)
base_mask.masked_fill_(t > t[:, None], float('-inf'))
attention_mask = torch.zeros(self.num_query, T * L, dtype=dtype, device=device)
attention_mask[:self.num_query // T * T] = torch.kron(base_mask, torch.ones(self.num_query // T, L, dtype=dtype, device=device))
return attention_mask
def forward(self, vision_hidden_states: list[torch.Tensor]) -> torch.Tensor: # (B, Q, C)
"""Forward function, do not collect batch due to the support of zero3
Args:
vision_hidden_states (list[torch.Tensor]): [(t0, L, C), (t1, L, C), ...]
Returns:
torch.Tensor: (B, Q, C)
"""
output = []
for hidden_states in vision_hidden_states:
# reshape inputs and construct ccam masks
attention_mask = self._get_mask(hidden_states) # (Q, ti * L)
# forward
x = self.pre_ccam(self.query) # (1, Q, C)
x = self.ccam(
hidden_states=x, # (1, Q, C)
cross_hidden_states=hidden_states.flatten(0, 1)[None], # (1, ti * L, C')
attention_mask=attention_mask[None] # (1, Q, ti * L)
) + x
x = self.post_ccam(x)
output.append(x)
output = torch.cat(output, dim=0)
return output
|