|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple |
|
|
|
import torch |
|
import triton |
|
import triton.language as tl |
|
from einops import rearrange |
|
|
|
from fla.ops.delta_rule.wy_fast import fwd_prepare_T |
|
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({}, num_warps=1), |
|
triton.Config({}, num_warps=2), |
|
triton.Config({}, num_warps=4), |
|
], |
|
key=["BT", "K", "V"], |
|
) |
|
@triton.jit |
|
def chunk_transform_qk_fwd_kernel( |
|
q, |
|
k, |
|
v, |
|
beta, |
|
o, |
|
A, |
|
q_new, |
|
k_new, |
|
A_local, |
|
s_k_h, |
|
s_k_t, |
|
s_k_d, |
|
s_v_h, |
|
s_v_t, |
|
s_v_d, |
|
scale, |
|
T: tl.constexpr, |
|
K: tl.constexpr, |
|
V: tl.constexpr, |
|
BK: tl.constexpr, |
|
BV: tl.constexpr, |
|
BT: tl.constexpr, |
|
OUTPUT_ATTENTIONS: tl.constexpr, |
|
|
|
): |
|
i_t, i_bh = tl.program_id(0), tl.program_id(1) |
|
|
|
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0)) |
|
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0)) |
|
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0)) |
|
b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(p_q.dtype.element_ty) |
|
b_k = tl.load(p_k, boundary_check=(0, 1)) |
|
b_v = tl.load(p_v, boundary_check=(0, 1)) |
|
|
|
p_T = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) |
|
b_T = tl.load(p_T, boundary_check=(0, 1)) |
|
|
|
o_i = tl.arange(0, BT) |
|
m_t = o_i[:, None] >= o_i[None, :] |
|
b_qk = tl.where(m_t, tl.dot(b_q, tl.trans(b_k), allow_tf32=False), 0).to(b_q.dtype) |
|
m_t = o_i[:, None] > o_i[None, :] |
|
b_kk = tl.where(m_t, tl.dot(b_k, tl.trans(b_k), allow_tf32=False), 0).to(b_k.dtype) |
|
|
|
p_beta = tl.make_block_ptr(beta + i_bh * T, (T, ), (1, ), (i_t * BT, ), (BT, ), (0, )) |
|
b_beta = tl.load(p_beta, boundary_check=(0, )) |
|
b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) |
|
|
|
b_qkT = tl.dot(b_qk, b_T, allow_tf32=False).to(b_k.dtype) |
|
|
|
if OUTPUT_ATTENTIONS: |
|
p_a = tl.make_block_ptr(A_local + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) |
|
tl.store(p_a, b_qkT.to(p_a.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
b_kkT = tl.dot(b_kk, b_T, allow_tf32=False).to(b_k.dtype) |
|
p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0)) |
|
tl.store(p_o, tl.dot(b_qkT, b_v).to(p_o.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
p_q_new = tl.make_block_ptr(q_new + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0)) |
|
tl.store(p_q_new, (b_q - tl.dot(b_qkT, b_k_beta, allow_tf32=False)).to(p_q_new.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
p_k_new = tl.make_block_ptr(k_new + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0)) |
|
tl.store(p_k_new, (b_k - tl.dot(tl.trans(b_kkT), b_k_beta, allow_tf32=False) |
|
).to(p_k_new.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
|
def chunk_transform_qk_fwd_fn(q, k, v, beta, A, scale, BT, output_attentions): |
|
B, H, T, K = k.shape |
|
q_new = torch.empty_like(q) |
|
k_new = torch.empty_like(k) |
|
o = torch.empty_like(v) |
|
grid = (triton.cdiv(T, BT), B*H) |
|
V = v.shape[-1] |
|
A_local = torch.empty_like(A) if output_attentions else None |
|
chunk_transform_qk_fwd_kernel[grid]( |
|
q, k, v, beta, o, A, q_new, k_new, A_local, |
|
q.stride(1), q.stride(2), q.stride(3), |
|
v.stride(1), v.stride(2), v.stride(3), |
|
scale=scale, |
|
T=T, |
|
K=K, |
|
V=V, |
|
BT=BT, |
|
BK=triton.next_power_of_2(K), |
|
BV=triton.next_power_of_2(V), |
|
OUTPUT_ATTENTIONS=output_attentions |
|
) |
|
return q_new, k_new, o, A_local |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({}, num_warps=1), |
|
triton.Config({}, num_warps=2), |
|
], |
|
key=["BT"], |
|
) |
|
@triton.jit |
|
def save_intra_chunk_attn( |
|
A, |
|
A_local, |
|
T: tl.constexpr, |
|
BT: tl.constexpr, |
|
): |
|
i_t, i_bh = tl.program_id(0), tl.program_id(1) |
|
p_A = tl.make_block_ptr(A + i_bh * T * T, (T, T), (T, 1), (i_t * BT, i_t * BT), (BT, BT), (1, 0)) |
|
p_A_local = tl.make_block_ptr(A_local + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) |
|
b_A_local = tl.load(p_A_local, boundary_check=(0, 1)) |
|
tl.store(p_A, b_A_local.to(p_A.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
|
@triton.heuristics({ |
|
'OUTPUT_ATTENTIONS': lambda args: args['attn'] is not None |
|
}) |
|
@triton.jit |
|
def parallel_delta_rule_fwd_kernel( |
|
q, |
|
k, |
|
k2, |
|
v, |
|
beta, |
|
o, |
|
o_new, |
|
attn, |
|
s_k_h, |
|
s_k_t, |
|
s_v_h, |
|
s_v_t, |
|
T: tl.constexpr, |
|
K: tl.constexpr, |
|
V: tl.constexpr, |
|
BT: tl.constexpr, |
|
BS: tl.constexpr, |
|
BK: tl.constexpr, |
|
BV: tl.constexpr, |
|
OUTPUT_ATTENTIONS: tl.constexpr |
|
): |
|
i_t, i_bh = tl.program_id(0), tl.program_id(1) |
|
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, 1), (i_t * BT, 0), (BT, BK), (1, 0)) |
|
|
|
|
|
|
|
b_q = tl.zeros([BT, BK], dtype=tl.float32) |
|
b_q += tl.load(p_q, boundary_check=(0, 1)) |
|
|
|
b_o = tl.zeros([BT, BV], dtype=tl.float32) |
|
p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, 1), (i_t * BT, 0), (BT, BV), (1, 0)) |
|
b_o += tl.load(p_o, boundary_check=(0, 1)) |
|
|
|
|
|
|
|
|
|
for offset in range((i_t + 1) * BT - 2 * BS, i_t * BT - BS, -BS): |
|
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (1, s_k_t), (0, offset), (BK, BS), (0, 1)) |
|
p_k2 = tl.make_block_ptr(k2 + i_bh * s_k_h, (T, K), (s_k_t, 1), (offset, 0), (BS, BK), (1, 0)) |
|
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, 1), (offset, 0), (BS, BV), (1, 0)) |
|
p_beta = tl.make_block_ptr(beta + i_bh * T, (T, ), (1, ), (offset, ), (BS, ), (0,)) |
|
|
|
b_k = tl.load(p_k, boundary_check=(0, 1)) |
|
|
|
b_v = tl.load(p_v, boundary_check=(0, 1)) |
|
|
|
b_beta = tl.load(p_beta, boundary_check=(0,)) |
|
|
|
m_s = tl.arange(0, BT) >= (offset - i_t*BT + BS) |
|
b_s = tl.dot(b_q.to(b_k.dtype), b_k, allow_tf32=False) |
|
b_s = tl.where(m_s[:, None], b_s, 0) |
|
|
|
b_o += tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) |
|
b_k2 = (tl.load(p_k2, boundary_check=(0, 1)) * b_beta[:, None]).to(b_v.dtype) |
|
b_q -= tl.dot(b_s.to(b_v.dtype), b_k2, allow_tf32=False) |
|
|
|
if OUTPUT_ATTENTIONS: |
|
p_a = tl.make_block_ptr(attn + i_bh * T * T, (T, T), (T, 1), (i_t * BT, offset), (BT, BS), (1, 0)) |
|
tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
|
|
|
for offset in range(i_t * BT - BS, -BS, -BS): |
|
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (1, s_k_t), (0, offset), (BK, BS), (0, 1)) |
|
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, 1), (offset, 0), (BS, BV), (1, 0)) |
|
p_beta = tl.make_block_ptr(beta + i_bh * T, (T, ), (1, ), (offset, ), (BS, ), (0,)) |
|
p_k2 = tl.make_block_ptr(k2 + i_bh * s_k_h, (T, K), (s_k_t, 1), (offset, 0), (BS, BK), (1, 0)) |
|
|
|
|
|
b_k = tl.load(p_k, boundary_check=(0, 1)) |
|
|
|
b_v = tl.load(p_v, boundary_check=(0, 1)) |
|
|
|
b_beta = tl.load(p_beta, boundary_check=(0,)) |
|
|
|
b_s = (tl.dot(b_q.to(b_k.dtype), b_k, allow_tf32=False)) |
|
|
|
b_o += tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) |
|
b_k2 = (tl.load(p_k2, boundary_check=(0, 1)) * b_beta[:, None]).to(b_v.dtype) |
|
b_q -= tl.dot(b_s.to(b_v.dtype), b_k2, allow_tf32=False).to(b_q.dtype) |
|
|
|
if OUTPUT_ATTENTIONS: |
|
p_a = tl.make_block_ptr(attn + i_bh * T * T, (T, T), (T, 1), (i_t * BT, offset), (BT, BS), (1, 0)) |
|
tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
p_o_new = tl.make_block_ptr(o_new + i_bh * s_v_h, (T, V), (s_v_t, 1), (i_t*BT, 0), (BT, BV), (1, 0)) |
|
tl.store(p_o_new, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
|
class ParallelDeltaRuleFunction(torch.autograd.Function): |
|
|
|
@staticmethod |
|
@contiguous |
|
@autocast_custom_fwd |
|
def forward(ctx, q, k, v, beta, scale, output_attentions): |
|
B, H, T, K, V = *k.shape, v.shape[-1] |
|
assert q.shape[-1] <= 128, 'The maximum supported sequence length is 128.' |
|
BT, BS = 128, 32 |
|
BK = triton.next_power_of_2(k.shape[-1]) |
|
BV = triton.next_power_of_2(v.shape[-1]) |
|
assert BT % BS == 0 |
|
|
|
A = fwd_prepare_T(k, beta, BS) |
|
attn = q.new_zeros(B, H, T, T) if output_attentions else None |
|
q_new, k_new, o, A_local = chunk_transform_qk_fwd_fn(q, k, v, beta, A, scale, BS, output_attentions) |
|
|
|
num_stages = 3 if K <= 64 else 2 |
|
num_warps = 4 |
|
grid = (triton.cdiv(T, BT), B * H) |
|
o_new = torch.empty_like(o) |
|
|
|
parallel_delta_rule_fwd_kernel[grid]( |
|
q=q_new, |
|
k=k_new, |
|
k2=k, |
|
v=v, |
|
beta=beta, |
|
o=o, |
|
o_new=o_new, |
|
attn=attn, |
|
s_k_h=k.stride(1), |
|
s_k_t=k.stride(2), |
|
s_v_h=v.stride(1), |
|
s_v_t=v.stride(2), |
|
T=T, |
|
K=K, |
|
V=V, |
|
BT=BT, |
|
BS=BS, |
|
BK=BK, |
|
BV=BV, |
|
num_stages=num_stages, |
|
num_warps=num_warps |
|
) |
|
|
|
if output_attentions: |
|
grid = (triton.cdiv(T, BS), B * H) |
|
save_intra_chunk_attn[grid]( |
|
A=attn, A_local=A_local, T=T, BT=BS |
|
) |
|
return o_new.to(q.dtype), attn |
|
|
|
@staticmethod |
|
@contiguous |
|
@autocast_custom_bwd |
|
def backward(ctx, do, d_attn=None): |
|
raise NotImplementedError('Backward pass is not implemented. Stay tuned!') |
|
|
|
|
|
def parallel_delta_rule( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
beta: torch.Tensor, |
|
scale: float = None, |
|
output_attentions: bool = False, |
|
head_first: bool = True |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
r""" |
|
Args: |
|
q (torch.Tensor): |
|
queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. |
|
k (torch.Tensor): |
|
keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. |
|
v (torch.Tensor): |
|
values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. |
|
beta (torch.Tensor): |
|
betas of shape `[B, H, T]` if `head_first=True` else `[B, T, H]`. |
|
scale (Optional[int]): |
|
Scale factor for attention scores. |
|
If not provided, it will default to `1 / sqrt(K)`. Default: `None`. |
|
output_attentions (bool): |
|
Whether to output the materialized attention scores of shape [B, H, T, T]. Default: `False`. |
|
head_first (Optional[bool]): |
|
Whether the inputs are in the head-first format. |
|
Default: `True`. |
|
|
|
Returns: |
|
o (torch.Tensor): |
|
Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. |
|
attn (torch.Tensor): |
|
Attention scores of shape `[B, H, T, T]` if `output_attentions=True` else `None`. |
|
""" |
|
if not head_first: |
|
q, k, v, beta = map(lambda x: x.transpose(1, 2), (q, k, v, beta)) |
|
o, attn = ParallelDeltaRuleFunction.apply(q, k, v, beta, scale, output_attentions) |
|
if not head_first: |
|
o = o.transpose(1, 2) |
|
return o, attn |
|
|
|
|
|
def naive_delta_rule_parallel(q, k, v, beta, BM=128, BN=32): |
|
b, h, l, d_k = q.shape |
|
q = q * (d_k ** -0.5) |
|
v = v * beta[..., None] |
|
k_beta = k * beta[..., None] |
|
|
|
q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=BN), [q, k, v, k_beta]) |
|
mask = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=0) |
|
T = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0) |
|
for i in range(1, BN): |
|
T[..., i, :i] = T[..., i, :i].clone() + (T[..., i, :, None].clone() * T[..., :, :i].clone()).sum(-2) |
|
T = T + torch.eye(BN, dtype=q.dtype, device=q.device) |
|
|
|
mask2 = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=1) |
|
A_local = (q @ k.transpose(-1, -2)).masked_fill(mask2, 0) @ T |
|
o_intra = A_local @ v |
|
|
|
|
|
k = k - ((k @ k.transpose(-1, -2)).masked_fill(mask, 0) @ T).transpose(-1, -2) @ k_beta |
|
|
|
q = q - A_local @ k_beta |
|
o_intra = A_local @ v |
|
|
|
A = torch.zeros(b, h, l, l, device=q.device) |
|
|
|
q, k, v, k_beta, o_intra = map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d'), [q, k, v, k_beta, o_intra]) |
|
o = torch.empty_like(v) |
|
for i in range(0, l, BM): |
|
q_i = q[:, :, i:i+BM] |
|
o_i = o_intra[:, :, i:i+BM] |
|
|
|
for j in range(i + BM - 2 * BN, i-BN, -BN): |
|
k_j = k[:, :, j:j+BN] |
|
A_ij = q_i @ k_j.transpose(-1, -2) |
|
mask = torch.arange(i, i+BM) >= (j + BN) |
|
A_ij = A_ij.masked_fill_(~mask[:, None].to(A_ij.device), 0) |
|
A[:, :, i:i+BM, j:j+BN] = A_ij |
|
q_i = q_i - A_ij @ k_beta[:, :, j:j+BN] |
|
o_i += A_ij @ v[:, :, j:j+BN] |
|
|
|
for j in range(i - BN, -BN, -BN): |
|
k_j = k[:, :, j:j+BN] |
|
A_ij = q_i @ k_j.transpose(-1, -2) |
|
A[:, :, i:i+BM, j:j+BN] = A_ij |
|
q_i = q_i - A_ij @ k_beta[:, :, j:j+BN] |
|
o_i += A_ij @ v[:, :, j:j+BN] |
|
o[:, :, i:i+BM] = o_i |
|
|
|
for i in range(0, l//BN): |
|
A[:, :, i*BN:i*BN+BN, i*BN:i*BN+BN] = A_local[:, :, i] |
|
|
|
return o, A |
|
|
|
|
|
if __name__ == "__main__": |
|
B, H, T, K, V = 2, 4, 512, 64, 64 |
|
torch.set_default_dtype(torch.bfloat16) |
|
|
|
q = torch.randn[B, H, T, K].cuda() |
|
k = torch.nn.functional.normalize(torch.randn[B, H, T, K].cuda(), p=2, dim=-1) |
|
v = torch.randn[B, H, T, V].cuda() |
|
beta = torch.ones(B, H, T).cuda() |
|
|
|
output_attentions = True |
|
ref_o, ref_attn = naive_delta_rule_parallel(q.clone(), k.clone(), v.clone(), beta.clone()) |
|
o, attn = parallel_delta_rule(q.clone(), k.clone(), v.clone(), beta.clone(), K**-0.5, output_attentions) |
|
print((ref_o-o).abs().max()) |
|
print((ref_attn-attn).abs().max()) |
|
|