import math from typing import Optional, Tuple import torch import torch.utils.checkpoint from torch import nn from transformers.activations import ACT2FN from transformers.utils import logging from transformers import LlamaForCausalLM from .config_llama import SVD_LlamaConfig logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "SVD_LlamaConfig" class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states class LlamaRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) self.register_buffer("inv_freq", inv_freq) # Build here to make `torch.jit.trace` work. self.max_seq_len_cached = max_position_embeddings t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. if seq_len > self.max_seq_len_cached: self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1).to(x.device) self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) return ( self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ) def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids): gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class SVD_LlamaMLP(nn.Module): def __init__( self, config: SVD_LlamaConfig ): super().__init__() self.ratio = config.ratio low_rank = int(config.intermediate_size * config.hidden_size * self.ratio / (config.intermediate_size + config.hidden_size)) self.gate_u_proj = nn.Linear(low_rank, config.intermediate_size, bias=False) self.gate_v_proj = nn.Linear(config.hidden_size, low_rank, bias=False) self.down_u_proj = nn.Linear(low_rank, config.hidden_size, bias=False) self.down_v_proj = nn.Linear(config.intermediate_size, low_rank, bias=False) self.up_u_proj = nn.Linear(low_rank, config.intermediate_size, bias=False) self.up_v_proj = nn.Linear(config.hidden_size, low_rank, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): up = self.up_u_proj(self.up_v_proj(x)) gate = self.gate_u_proj(self.gate_v_proj(x)) return self.down_u_proj(self.down_v_proj(self.act_fn(gate) * up)) class SVD_LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: SVD_LlamaConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.max_position_embeddings = config.max_position_embeddings self.ratio = config.ratio # 1 means no truncate, just keep normal attn if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) low_rank = int(self.hidden_size * self.ratio/2) self.q_u_proj = nn.Linear(low_rank, self.num_heads * self.head_dim, bias=False) self.q_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False) self.k_u_proj = nn.Linear(low_rank, self.num_heads * self.head_dim, bias=False) self.k_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False) self.v_u_proj = nn.Linear(low_rank, self.num_heads * self.head_dim, bias=False) self.v_v_proj = nn.Linear(self.hidden_size, low_rank, bias=False) self.o_u_proj = nn.Linear(low_rank, self.hidden_size, bias=False) self.o_v_proj = nn.Linear(self.num_heads * self.head_dim, low_rank, bias=False) self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() query_states = self.q_u_proj(self.q_v_proj(hidden_states)).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = self.k_u_proj(self.k_v_proj(hidden_states)).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = self.v_u_proj(self.v_v_proj(hidden_states)).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) # [bsz, nh, t, hd] if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)) # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.o_u_proj(self.o_v_proj(attn_output)) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class SVD_LlamaForCausalLM(LlamaForCausalLM): config_class = SVD_LlamaConfig def __init__(self, config: SVD_LlamaConfig): super().__init__(config) for i in range(len(self.model.layers)): self.model.layers[i].mlp = SVD_LlamaMLP(config=config) self.model.layers[i].self_attn = SVD_LlamaAttention(config)