|
|
|
from typing import Optional |
|
from typing import Tuple |
|
import torch |
|
from torch import Tensor |
|
from torch.nn import Linear |
|
from torch.nn import Module |
|
from torch.nn.init import constant_ |
|
from torch.nn.init import xavier_normal_ |
|
from torch.nn.init import xavier_uniform_ |
|
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear |
|
from torch.nn.parameter import Parameter |
|
|
|
from torch.nn import functional as F |
|
from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched |
|
|
|
|
|
class MultiheadAttention(Module): |
|
__constants__ = ["batch_first"] |
|
bias_k: Optional[torch.Tensor] |
|
bias_v: Optional[torch.Tensor] |
|
|
|
def __init__( |
|
self, |
|
embed_dim, |
|
num_heads, |
|
dropout=0.0, |
|
bias=True, |
|
add_bias_kv=False, |
|
add_zero_attn=False, |
|
kdim=None, |
|
vdim=None, |
|
batch_first=False, |
|
linear1_cls=Linear, |
|
linear2_cls=Linear, |
|
device=None, |
|
dtype=None, |
|
) -> None: |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super(MultiheadAttention, self).__init__() |
|
self.embed_dim = embed_dim |
|
self.kdim = kdim if kdim is not None else embed_dim |
|
self.vdim = vdim if vdim is not None else embed_dim |
|
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim |
|
|
|
self.num_heads = num_heads |
|
self.dropout = dropout |
|
self.batch_first = batch_first |
|
self.head_dim = embed_dim // num_heads |
|
assert ( |
|
self.head_dim * num_heads == self.embed_dim |
|
), "embed_dim must be divisible by num_heads" |
|
|
|
if add_bias_kv: |
|
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) |
|
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) |
|
else: |
|
self.bias_k = self.bias_v = None |
|
|
|
if linear1_cls == Linear: |
|
if not self._qkv_same_embed_dim: |
|
self.q_proj_weight = Parameter( |
|
torch.empty((embed_dim, embed_dim), **factory_kwargs) |
|
) |
|
self.k_proj_weight = Parameter( |
|
torch.empty((embed_dim, self.kdim), **factory_kwargs) |
|
) |
|
self.v_proj_weight = Parameter( |
|
torch.empty((embed_dim, self.vdim), **factory_kwargs) |
|
) |
|
self.register_parameter("in_proj_weight", None) |
|
else: |
|
self.in_proj_weight = Parameter( |
|
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) |
|
) |
|
self.register_parameter("q_proj_weight", None) |
|
self.register_parameter("k_proj_weight", None) |
|
self.register_parameter("v_proj_weight", None) |
|
|
|
if bias: |
|
self.in_proj_bias = Parameter( |
|
torch.empty(3 * embed_dim, **factory_kwargs) |
|
) |
|
else: |
|
self.register_parameter("in_proj_bias", None) |
|
self.out_proj = NonDynamicallyQuantizableLinear( |
|
embed_dim, embed_dim, bias=bias, **factory_kwargs |
|
) |
|
|
|
self._reset_parameters() |
|
else: |
|
if not self._qkv_same_embed_dim: |
|
raise NotImplementedError |
|
else: |
|
self.in_proj_linear = linear1_cls( |
|
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs |
|
) |
|
self.in_proj_weight = self.in_proj_linear.weight |
|
|
|
self.register_parameter("q_proj_weight", None) |
|
self.register_parameter("k_proj_weight", None) |
|
self.register_parameter("v_proj_weight", None) |
|
|
|
if bias: |
|
self.in_proj_bias = self.in_proj_linear.bias |
|
else: |
|
self.register_parameter("in_proj_bias", None) |
|
|
|
self.out_proj = linear2_cls( |
|
embed_dim, embed_dim, bias=bias, **factory_kwargs |
|
) |
|
|
|
if self.bias_k is not None: |
|
xavier_normal_(self.bias_k) |
|
if self.bias_v is not None: |
|
xavier_normal_(self.bias_v) |
|
|
|
self.add_zero_attn = add_zero_attn |
|
|
|
def _reset_parameters(self): |
|
if self._qkv_same_embed_dim: |
|
xavier_uniform_(self.in_proj_weight) |
|
else: |
|
xavier_uniform_(self.q_proj_weight) |
|
xavier_uniform_(self.k_proj_weight) |
|
xavier_uniform_(self.v_proj_weight) |
|
|
|
if self.in_proj_bias is not None: |
|
constant_(self.in_proj_bias, 0.0) |
|
constant_(self.out_proj.bias, 0.0) |
|
|
|
if self.bias_k is not None: |
|
xavier_normal_(self.bias_k) |
|
if self.bias_v is not None: |
|
xavier_normal_(self.bias_v) |
|
|
|
def __setstate__(self, state): |
|
|
|
if "_qkv_same_embed_dim" not in state: |
|
state["_qkv_same_embed_dim"] = True |
|
|
|
super(MultiheadAttention, self).__setstate__(state) |
|
|
|
def forward( |
|
self, |
|
query: Tensor, |
|
key: Tensor, |
|
value: Tensor, |
|
key_padding_mask: Optional[Tensor] = None, |
|
need_weights: bool = True, |
|
attn_mask: Optional[Tensor] = None, |
|
average_attn_weights: bool = True, |
|
cache=None, |
|
) -> Tuple[Tensor, Optional[Tensor]]: |
|
any_nested = query.is_nested or key.is_nested or value.is_nested |
|
query = key = value = query.transpose(1, 0) |
|
attn_output = multi_head_attention_forward_patched( |
|
query, |
|
key, |
|
value, |
|
self.embed_dim, |
|
self.num_heads, |
|
self.in_proj_weight, |
|
self.in_proj_bias, |
|
self.bias_k, |
|
self.bias_v, |
|
self.add_zero_attn, |
|
self.dropout, |
|
self.out_proj.weight, |
|
self.out_proj.bias, |
|
training=self.training, |
|
key_padding_mask=key_padding_mask, |
|
need_weights=need_weights, |
|
attn_mask=attn_mask, |
|
average_attn_weights=average_attn_weights, |
|
cache=cache, |
|
) |
|
return attn_output.transpose(1, 0) |
|
|