jsl_medm_ner_latest / triton_blocksparse_attention_layer.py
hashamulhaq's picture
Add files using upload-large-folder tool
f016446 verified
import math
from typing import Optional, Tuple, TypeVar
import torch.nn as nn
import torch
import triton
from functools import lru_cache
from .triton_flash_blocksparse_attn import get_local_strided_sparse_attention_op, _get_sparse_attn_mask, blocksparse_flash_attn_padded_fwd, blocksparse_flash_attn_varlen_fwd
Layout = Tuple[torch.LongTensor, torch.LongTensor]
def create_sparse_attn_mask(
n_heads: int,
max_seq_len: int,
max_seq_len_k: int,
dtype: torch.dtype,
device: torch.device,
BLOCK: int,
local_blocks: int,
vert_stride: int,
homo_head: bool,
return_dense: bool
) -> Tuple[Layout, torch.Tensor, Optional[torch.Tensor]]:
layout, block_sparse_pattern, _ = _get_sparse_attn_mask(
n_heads=n_heads,
q_len=max_seq_len,
N_CTX=max_seq_len_k,
dtype=dtype,
device=device,
BLOCK=BLOCK,
local_blocks=local_blocks,
vert_stride=vert_stride,
homo_head=homo_head,
return_dense=return_dense
)
return layout, block_sparse_pattern
class BlockSparseAttentionLayer(nn.Module):
def __init__(
self,
n_heads: int,
max_seq_len: int,
sparse_block_size: int,
local_blocks: int,
vert_stride: int,
kernel_block_size: Optional[int] = None,
homo_head: bool = False,
active_head_range: Optional[Tuple[int]] = None
) -> None:
super().__init__()
self.n_heads = n_heads
self.max_seq_len = max_seq_len
self.sparse_block_size = sparse_block_size
self.kernel_block_size = kernel_block_size or sparse_block_size
self.local_blocks = local_blocks
self.vert_stride = vert_stride
self.homo_head = homo_head
self.active_head_range = active_head_range
# Internal Parameters used by the layer
self._sparse_block_mask = None
self._sparse_layout = None
self._dtype = None
self._device = None
# TODO(bapatra): Ideally, I'd want to keep all the code for
# forward to be handled here, and not branch for training and inference.
# However, that refactor would need a lot of testing. For now, using the
# training op as is, and will refactor again later.
def prune_blocksparse_layout_to_heads(self, h_start: int, h_end: int) -> None:
self._sparse_block_mask = self._sparse_block_mask[h_start: h_end]
self._sparse_layout[0] = self._sparse_layout[0][h_start: h_end]
self._sparse_layout[1] = self._sparse_layout[1][h_start: h_end]
def _initialize_internals(
self,
dtype: torch.dtype,
device: torch.device
) -> None:
self._dtype, self._device = dtype, device
self._sparse_layout, self._sparse_block_mask = create_sparse_attn_mask(
n_heads=self.n_heads,
max_seq_len=self.max_seq_len,
max_seq_len_k=self.max_seq_len,
dtype=dtype,
device=device,
BLOCK=self.sparse_block_size,
local_blocks=self.local_blocks,
vert_stride=self.vert_stride,
homo_head=self.homo_head,
return_dense=False,
)
if (not self.homo_head) and (self.active_head_range is not None):
assert len(self.active_head_range) == 2, "\"active_head_range\" should be a tuple of start/end index of the heads."
h_start, h_end = self.active_head_range
self.prune_blocksparse_layout_to_heads(h_start=h_start, h_end=h_end)
assert self.sparse_block_size % self.kernel_block_size == 0, f"The sparse block size must be a multiple of {self.kernel_block_size}. Found {self.sparse_block_size}."
assert self.kernel_block_size >=16 and math.log2(self.kernel_block_size) % 1 == 0, f"block_size must be power of 2 and at least 16, but {self.kernel_block_size} is given"
if self.sparse_block_size // self.kernel_block_size > 1:
_mul = self.sparse_block_size // self.kernel_block_size
# need to consider if block_m and block_n are different
self._sparse_block_mask = torch.kron(self._sparse_block_mask, self._sparse_block_mask.new_ones(_mul, _mul))
num_sparse_blocks = self._sparse_block_mask.size(-1)
block_causal_mask = torch.arange(0, num_sparse_blocks)[:, None] >= torch.arange(0, num_sparse_blocks)[None]
self._sparse_block_mask *= block_causal_mask.type_as(self._sparse_block_mask)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
sm_scale: float,
*,
# Arguments Related to Block Attention Inference
left_paddings: Optional[torch.LongTensor] = None,
seqlens: Optional[torch.LongTensor] = None,
# Arguements Related to Variable Length Inference
cu_seqlens_k: Optional[torch.LongTensor] = None,
cu_seqlens_q: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
if left_paddings is None and seqlens is None and cu_seqlens_k is None and cu_seqlens_q is None:
blocksparse_op = get_local_strided_sparse_attention_op(
n_heads=self.n_heads,
max_seq_len=self.max_seq_len,
sparse_block_size=self.sparse_block_size,
kernel_block_size=self.kernel_block_size,
local_blocks=self.local_blocks,
vert_stride=self.vert_stride,
homo_head=self.homo_head,
device=q.device,
inference=not self.training
)
return blocksparse_op(q, k, v, sm_scale)
assert not torch.is_grad_enabled(), "Variable Length Inference / Batched inference is not supported during training. Please run it in a torch.no_grad() context"
# First set internals if they have not been set
if self._sparse_block_mask is None or (self._dtype != q.dtype) or (self._device != q.device):
self._initialize_internals(dtype=q.dtype, device=q.device)
if k.dim() == 3:
assert cu_seqlens_k is not None
return blocksparse_flash_attn_varlen_fwd(
q=q,
k=k,
v=v,
cu_seqlens_k=cu_seqlens_k,
cu_seqlens_q=cu_seqlens_q,
sm_scale=sm_scale,
sparse_layout=self._sparse_layout,
block_size=self.kernel_block_size,
max_seqlen=self.max_seq_len,
)
if k.dim() == 4:
assert not (left_paddings is None and seqlens is None), "Either left_paddings or seqlens must be provided for batched inference."
return blocksparse_flash_attn_padded_fwd(
q=q,
k=k,
v=v,
sm_scale=sm_scale,
sparse_layout=self._sparse_layout,
left_paddings=left_paddings,
seqlens=seqlens,
block_size=self.kernel_block_size,
max_seqlen=self.max_seq_len,
)
raise ValueError('q/k/v must be either 3 dim for variable-length input or 4 dim for fixed-length.')