memorizing_transformer_gpt2 / gpt2_knn_attention.py
lavawolfiee's picture
Finally
6bc49a9
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
class GPT2KNNAttention(GPT2Attention):
def __init__(self, config, knn_memory, device, is_cross_attention=False, layer_idx=None, num_retrieve_memories=32):
super().__init__(config, is_cross_attention, layer_idx)
self.knn_memory = knn_memory
self.device = device
self.num_retrieve_memories = num_retrieve_memories
self.knn_attn_dropout = nn.Dropout(config.attn_pdrop)
self.attn_comb_bias = nn.Parameter(torch.empty(self.num_heads,))
nn.init.normal_(self.attn_comb_bias, mean=0.0, std=1.0)
# self.attn_comb_bias = nn.Parameter(torch.full((self.num_heads,), 1.0))
def _knn_attn(self, query, key, value, mask, head_mask=None):
query = query.unsqueeze(-2)
attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights:
attn_weights = attn_weights / torch.full(
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
)
# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
attn_weights = attn_weights / float(self.layer_idx + 1)
# if not self.is_cross_attention:
# raise RuntimeError("KNN attention is not yet implemented for !cross_attention")
# # if only "normal" attention layer implements causal mask
# query_length, key_length = query.size(-3), key.size(-3)
# causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
# mask_value = torch.finfo(attn_weights.dtype).min
# # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
# mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
# attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.knn_attn_dropout(attn_weights)
# masking missing keys
sh = mask.size()
attn_weights = attn_weights * mask.view((sh[0], 1, 1, 1, sh[1]))
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
attn_output.squeeze_(dim=-2)
return attn_output
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_output, attn_weights = super()._attn(
query, key, value, attention_mask, head_mask)
knn_key, knn_value, knn_mask = self.knn_memory.search(
query, self.num_retrieve_memories)
g = torch.sigmoid(self.attn_comb_bias)[:, None, None]
if knn_key.numel() == 0:
return attn_output * (1 - g), attn_weights
knn_key, knn_value, knn_mask = knn_key.to(
self.device), knn_value.to(self.device), knn_mask.to(self.device)
knn_attn_output = self._knn_attn(
query, knn_key, knn_value, knn_mask, head_mask)
# combining two attentions
attn = knn_attn_output * g + attn_output * (1 - g)
return attn, attn_weights
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
raise RuntimeError(
"KNN attention is not yet implemented for _upcast_and_reordered_attn")
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
)
query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(
self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(
hidden_states).split(self.split_size, dim=2)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
# normalization of queries and keys reduces the effect of staleness
query, key = F.normalize(query, dim=-1), F.normalize(key, dim=-1)
new_memories = (key, value)
if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
present = (key, value)
else:
present = None
if self.reorder_and_upcast_attn:
raise RuntimeError("Not implemented")
attn_output, attn_weights = self._upcast_and_reordered_attn(
query, key, value, attention_mask, head_mask)
else:
attn_output, attn_weights = self._attn(
query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(
attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
self.knn_memory.add(*new_memories)
return outputs # a, present, (attentions)