File size: 6,577 Bytes
6bc49a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
class GPT2KNNAttention(GPT2Attention):
def __init__(self, config, knn_memory, device, is_cross_attention=False, layer_idx=None, num_retrieve_memories=32):
super().__init__(config, is_cross_attention, layer_idx)
self.knn_memory = knn_memory
self.device = device
self.num_retrieve_memories = num_retrieve_memories
self.knn_attn_dropout = nn.Dropout(config.attn_pdrop)
self.attn_comb_bias = nn.Parameter(torch.empty(self.num_heads,))
nn.init.normal_(self.attn_comb_bias, mean=0.0, std=1.0)
# self.attn_comb_bias = nn.Parameter(torch.full((self.num_heads,), 1.0))
def _knn_attn(self, query, key, value, mask, head_mask=None):
query = query.unsqueeze(-2)
attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights:
attn_weights = attn_weights / torch.full(
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
)
# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
attn_weights = attn_weights / float(self.layer_idx + 1)
# if not self.is_cross_attention:
# raise RuntimeError("KNN attention is not yet implemented for !cross_attention")
# # if only "normal" attention layer implements causal mask
# query_length, key_length = query.size(-3), key.size(-3)
# causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
# mask_value = torch.finfo(attn_weights.dtype).min
# # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
# mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
# attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.knn_attn_dropout(attn_weights)
# masking missing keys
sh = mask.size()
attn_weights = attn_weights * mask.view((sh[0], 1, 1, 1, sh[1]))
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
attn_output.squeeze_(dim=-2)
return attn_output
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_output, attn_weights = super()._attn(
query, key, value, attention_mask, head_mask)
knn_key, knn_value, knn_mask = self.knn_memory.search(
query, self.num_retrieve_memories)
g = torch.sigmoid(self.attn_comb_bias)[:, None, None]
if knn_key.numel() == 0:
return attn_output * (1 - g), attn_weights
knn_key, knn_value, knn_mask = knn_key.to(
self.device), knn_value.to(self.device), knn_mask.to(self.device)
knn_attn_output = self._knn_attn(
query, knn_key, knn_value, knn_mask, head_mask)
# combining two attentions
attn = knn_attn_output * g + attn_output * (1 - g)
return attn, attn_weights
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
raise RuntimeError(
"KNN attention is not yet implemented for _upcast_and_reordered_attn")
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
)
query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(
self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(
hidden_states).split(self.split_size, dim=2)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
# normalization of queries and keys reduces the effect of staleness
query, key = F.normalize(query, dim=-1), F.normalize(key, dim=-1)
new_memories = (key, value)
if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
present = (key, value)
else:
present = None
if self.reorder_and_upcast_attn:
raise RuntimeError("Not implemented")
attn_output, attn_weights = self._upcast_and_reordered_attn(
query, key, value, attention_mask, head_mask)
else:
attn_output, attn_weights = self._attn(
query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(
attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
self.knn_memory.add(*new_memories)
return outputs # a, present, (attentions)
|