|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention |
|
|
|
|
|
class GPT2KNNAttention(GPT2Attention): |
|
def __init__(self, config, knn_memory, device, is_cross_attention=False, layer_idx=None, num_retrieve_memories=32): |
|
super().__init__(config, is_cross_attention, layer_idx) |
|
|
|
self.knn_memory = knn_memory |
|
self.device = device |
|
self.num_retrieve_memories = num_retrieve_memories |
|
self.knn_attn_dropout = nn.Dropout(config.attn_pdrop) |
|
self.attn_comb_bias = nn.Parameter(torch.empty(self.num_heads,)) |
|
nn.init.normal_(self.attn_comb_bias, mean=0.0, std=1.0) |
|
|
|
|
|
def _knn_attn(self, query, key, value, mask, head_mask=None): |
|
query = query.unsqueeze(-2) |
|
attn_weights = torch.matmul(query, key.transpose(-1, -2)) |
|
|
|
if self.scale_attn_weights: |
|
attn_weights = attn_weights / torch.full( |
|
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device |
|
) |
|
|
|
|
|
if self.scale_attn_by_inverse_layer_idx: |
|
attn_weights = attn_weights / float(self.layer_idx + 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1) |
|
|
|
|
|
attn_weights = attn_weights.type(value.dtype) |
|
attn_weights = self.knn_attn_dropout(attn_weights) |
|
|
|
|
|
sh = mask.size() |
|
attn_weights = attn_weights * mask.view((sh[0], 1, 1, 1, sh[1])) |
|
|
|
|
|
if head_mask is not None: |
|
attn_weights = attn_weights * head_mask |
|
|
|
attn_output = torch.matmul(attn_weights, value) |
|
attn_output.squeeze_(dim=-2) |
|
|
|
return attn_output |
|
|
|
def _attn(self, query, key, value, attention_mask=None, head_mask=None): |
|
attn_output, attn_weights = super()._attn( |
|
query, key, value, attention_mask, head_mask) |
|
knn_key, knn_value, knn_mask = self.knn_memory.search( |
|
query, self.num_retrieve_memories) |
|
g = torch.sigmoid(self.attn_comb_bias)[:, None, None] |
|
|
|
if knn_key.numel() == 0: |
|
return attn_output * (1 - g), attn_weights |
|
|
|
knn_key, knn_value, knn_mask = knn_key.to( |
|
self.device), knn_value.to(self.device), knn_mask.to(self.device) |
|
knn_attn_output = self._knn_attn( |
|
query, knn_key, knn_value, knn_mask, head_mask) |
|
|
|
|
|
attn = knn_attn_output * g + attn_output * (1 - g) |
|
|
|
return attn, attn_weights |
|
|
|
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): |
|
raise RuntimeError( |
|
"KNN attention is not yet implemented for _upcast_and_reordered_attn") |
|
|
|
def forward( |
|
self, |
|
hidden_states: Optional[Tuple[torch.FloatTensor]], |
|
layer_past: Optional[Tuple[torch.Tensor]] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = False, |
|
output_attentions: Optional[bool] = False, |
|
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: |
|
if encoder_hidden_states is not None: |
|
if not hasattr(self, "q_attn"): |
|
raise ValueError( |
|
"If class is used as cross attention, the weights `q_attn` have to be defined. " |
|
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." |
|
) |
|
|
|
query = self.q_attn(hidden_states) |
|
key, value = self.c_attn(encoder_hidden_states).split( |
|
self.split_size, dim=2) |
|
attention_mask = encoder_attention_mask |
|
else: |
|
query, key, value = self.c_attn( |
|
hidden_states).split(self.split_size, dim=2) |
|
|
|
query = self._split_heads(query, self.num_heads, self.head_dim) |
|
key = self._split_heads(key, self.num_heads, self.head_dim) |
|
value = self._split_heads(value, self.num_heads, self.head_dim) |
|
|
|
|
|
query, key = F.normalize(query, dim=-1), F.normalize(key, dim=-1) |
|
new_memories = (key, value) |
|
|
|
if layer_past is not None: |
|
past_key, past_value = layer_past |
|
key = torch.cat((past_key, key), dim=-2) |
|
value = torch.cat((past_value, value), dim=-2) |
|
|
|
if use_cache is True: |
|
present = (key, value) |
|
else: |
|
present = None |
|
|
|
if self.reorder_and_upcast_attn: |
|
raise RuntimeError("Not implemented") |
|
attn_output, attn_weights = self._upcast_and_reordered_attn( |
|
query, key, value, attention_mask, head_mask) |
|
else: |
|
attn_output, attn_weights = self._attn( |
|
query, key, value, attention_mask, head_mask) |
|
|
|
attn_output = self._merge_heads( |
|
attn_output, self.num_heads, self.head_dim) |
|
attn_output = self.c_proj(attn_output) |
|
attn_output = self.resid_dropout(attn_output) |
|
|
|
outputs = (attn_output, present) |
|
if output_attentions: |
|
outputs += (attn_weights,) |
|
|
|
self.knn_memory.add(*new_memories) |
|
|
|
return outputs |
|
|