|
|
|
|
|
|
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
import triton |
|
import triton.language as tl |
|
|
|
from fla.ops.common.fused_recurrent import (fused_recurrent_bwd_kernel, |
|
fused_recurrent_fwd_kernel) |
|
from fla.ops.utils import chunk_global_cumsum |
|
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous |
|
|
|
|
|
@triton.jit |
|
def fused_recurrent_gsa_inference_kernel( |
|
q, |
|
k, |
|
v, |
|
s, |
|
g, |
|
o, |
|
hk0, |
|
hv0, |
|
hkt, |
|
hvt, |
|
scale, |
|
K: tl.constexpr, |
|
V: tl.constexpr, |
|
M: tl.constexpr, |
|
BK: tl.constexpr, |
|
BV: tl.constexpr, |
|
NG: tl.constexpr |
|
): |
|
i_bh = tl.program_id(0) |
|
i_bg = i_bh // NG |
|
|
|
b_s = tl.load(s + i_bg * M + tl.arange(0, M)).to(tl.float32) |
|
b_g = tl.load(g + i_bg * M + tl.arange(0, M)).to(tl.float32) |
|
b_g = tl.exp(b_g) |
|
|
|
b_ok = tl.zeros([M], dtype=tl.float32) |
|
for i_k in range(tl.cdiv(K, BK)): |
|
o_k = i_k * BK + tl.arange(0, BK) |
|
|
|
p_hk0 = hk0 + i_bg * K * M + (o_k[None, :]) * M + tl.arange(0, M)[:, None] |
|
|
|
mask_k = o_k < K |
|
|
|
mask_hk = (tl.arange(0, M) < M)[:, None] & mask_k[None, :] |
|
|
|
b_hk = tl.load(p_hk0, mask=mask_hk, other=0.).to(tl.float32) |
|
|
|
b_q = tl.load(q + i_bh * K + o_k, mask=mask_k, other=0.).to(tl.float32) * scale |
|
b_k = tl.load(k + i_bg * K + o_k, mask=mask_k, other=0.).to(tl.float32) |
|
b_hk = b_hk * b_g[:, None] + b_k[None, :] * b_s[:, None] |
|
b_ok += tl.sum(b_hk * b_q[None, :], axis=1) |
|
|
|
if i_bh % NG == 0: |
|
p_hkt = hkt + i_bg * K * M + o_k[None, :] * M + tl.arange(0, M)[:, None] |
|
tl.store(p_hkt, b_hk.to(p_hkt.dtype.element_ty), mask=mask_hk) |
|
|
|
b_qv = tl.softmax(b_ok) |
|
for i_v in range(tl.cdiv(V, BV)): |
|
o_v = i_v * BV + tl.arange(0, BV) |
|
|
|
p_hv0 = hv0 + i_bg * M * V + tl.arange(0, M)[None, :] * V + o_v[:, None] |
|
|
|
mask_v = o_v < V |
|
|
|
mask_hv = mask_v[:, None] & (tl.arange(0, M) < M)[None, :] |
|
|
|
b_hv = tl.load(p_hv0, mask=mask_hv, other=0).to(tl.float32) |
|
|
|
b_v = tl.load(v + i_bg * V + o_v, mask=mask_v, other=0).to(tl.float32) |
|
b_hv = b_hv * b_g[None, :] + b_s[None, :] * b_v[:, None] |
|
b_ov = tl.sum(b_hv * b_qv[None, :], axis=1) |
|
|
|
tl.store(o + i_bh * V + o_v, b_ov.to(o.dtype.element_ty), mask=mask_v) |
|
|
|
if i_bh % NG == 0: |
|
p_hvt = hvt + i_bg * M * V + tl.arange(0, M)[None, :] * V + o_v[:, None] |
|
tl.store(p_hvt, b_hv.to(p_hvt.dtype.element_ty), mask=mask_hv) |
|
|
|
|
|
def fused_recurrent_gsa_inference( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
s: torch.Tensor, |
|
g: torch.Tensor, |
|
initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
output_final_state: bool = False, |
|
scale: float = 1., |
|
head_first: bool = True |
|
) -> torch.Tensor: |
|
if head_first: |
|
B, H, T, K, V, M = *k.shape, v.shape[-1], s.shape[-1] |
|
else: |
|
B, T, H, K, V, M = *k.shape, v.shape[-1], s.shape[-1] |
|
HQ = q.shape[1] if head_first else q.shape[2] |
|
BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64) |
|
NG = HQ // H |
|
|
|
hk0, hv0 = None, None |
|
if initial_state is not None: |
|
hk0, hv0 = initial_state |
|
hkt, hvt = None, None |
|
if output_final_state: |
|
if NG == 1: |
|
hkt, hvt = hk0, hv0 |
|
else: |
|
hkt, hvt = q.new_empty(B, H, K, M, dtype=torch.float), q.new_empty(B, H, M, V, dtype=torch.float) |
|
|
|
o = v.new_empty(B, HQ, T, V) if head_first else v.new_empty(B, T, HQ, V) |
|
grid = (B * HQ,) |
|
fused_recurrent_gsa_inference_kernel[grid]( |
|
q, |
|
k, |
|
v, |
|
s, |
|
g, |
|
o, |
|
hk0, |
|
hv0, |
|
hkt, |
|
hvt, |
|
scale=scale, |
|
K=K, |
|
V=V, |
|
M=M, |
|
BK=BK, |
|
BV=BV, |
|
NG=NG |
|
) |
|
return o, (hkt, hvt) |
|
|
|
|
|
def fused_recurrent_gsa_fwd( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
s: torch.Tensor, |
|
g: torch.Tensor, |
|
initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
output_final_state: bool = False, |
|
scale: float = 1., |
|
reverse: bool = False, |
|
offsets: Optional[torch.LongTensor] = None, |
|
head_first: bool = True |
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]: |
|
if head_first: |
|
B, H, T, K, V, M = *k.shape, v.shape[-1], s.shape[-1] |
|
else: |
|
B, T, H, K, V, M = *k.shape, v.shape[-1], s.shape[-1] |
|
N = B if offsets is None else len(offsets) - 1 |
|
HQ = q.shape[1] if head_first else q.shape[2] |
|
if HQ != H: |
|
raise ValueError("GQA not supported yet.") |
|
|
|
BK, BV, BM = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64), min(M, 64) |
|
NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM) |
|
|
|
hk0, hv0 = None, None |
|
if initial_state is not None: |
|
hk0, hv0 = initial_state |
|
hkt, hvt = None, None |
|
if output_final_state: |
|
hkt, hvt = q.new_empty(N, H, K, M, dtype=torch.float), q.new_empty(N, H, M, V, dtype=torch.float) |
|
|
|
ok = q.new_empty(NK, *s.shape, dtype=torch.float) |
|
gk, gv = None, g |
|
grid = (NM, NK, N * H) |
|
fused_recurrent_fwd_kernel[grid]( |
|
q=q, |
|
k=k, |
|
v=s, |
|
g=None, |
|
gk=gk, |
|
gv=gv, |
|
o=ok, |
|
h0=hk0, |
|
ht=hkt, |
|
offsets=offsets, |
|
scale=scale, |
|
B=B, |
|
T=T, |
|
H=H, |
|
K=K, |
|
V=M, |
|
BK=BK, |
|
BV=BM, |
|
USE_G=False, |
|
USE_GK=False, |
|
USE_GV=True, |
|
REVERSE=reverse, |
|
HEAD_FIRST=head_first |
|
) |
|
ok = ok.sum(0) |
|
|
|
qv = ok.softmax(-1, dtype=torch.float) |
|
ov = q.new_empty(NM, *v.shape, dtype=torch.float) |
|
gk, gv = g, None |
|
grid = (NV, NM, N * H) |
|
fused_recurrent_fwd_kernel[grid]( |
|
q=qv, |
|
k=s, |
|
v=v, |
|
g=None, |
|
gk=gk, |
|
gv=gv, |
|
o=ov, |
|
h0=hv0, |
|
ht=hvt, |
|
offsets=offsets, |
|
scale=1., |
|
B=B, |
|
T=T, |
|
H=H, |
|
K=M, |
|
V=V, |
|
BK=BM, |
|
BV=BV, |
|
USE_G=False, |
|
USE_GK=True, |
|
USE_GV=False, |
|
REVERSE=reverse, |
|
HEAD_FIRST=head_first |
|
) |
|
ov = ov.sum(0) |
|
return ok, hkt, qv, ov, hvt |
|
|
|
|
|
def fused_recurrent_gsa_bwd( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
s: torch.Tensor, |
|
g: torch.Tensor, |
|
qv: torch.Tensor, |
|
hk0: Optional[torch.Tensor] = None, |
|
hv0: Optional[torch.Tensor] = None, |
|
ok: Optional[torch.Tensor] = None, |
|
do: Optional[torch.Tensor] = None, |
|
dhkt: Optional[torch.Tensor] = None, |
|
dhvt: Optional[torch.Tensor] = None, |
|
scale: float = 1., |
|
reverse: bool = False, |
|
offsets: Optional[torch.LongTensor] = None, |
|
head_first: bool = True |
|
) -> Tuple[torch.Tensor]: |
|
if head_first: |
|
B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] |
|
else: |
|
B, T, H, K, V, M = *q.shape, v.shape[-1], s.shape[-1] |
|
N = B if offsets is None else len(offsets) - 1 |
|
|
|
BK, BV, BM = min(K, 64), min(V, 64), min(M, 64) |
|
NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM) |
|
|
|
if head_first: |
|
dqv = q.new_empty(NV, B, H, T, M, dtype=torch.float) |
|
dsv = q.new_empty(NV, B, H, T, M, dtype=torch.float) |
|
dv = q.new_empty(NM, B, H, T, V, dtype=torch.float) |
|
else: |
|
dqv = q.new_empty(NV, B, T, H, M, dtype=torch.float) |
|
dsv = q.new_empty(NV, B, T, H, M, dtype=torch.float) |
|
dv = q.new_empty(NM, B, T, H, V, dtype=torch.float) |
|
dhk0 = torch.empty_like(hk0)if hk0 is not None else None |
|
dhv0 = torch.empty_like(hv0)if hv0 is not None else None |
|
|
|
gk, gv = g, None |
|
grid = (NV, NM, N * H) |
|
fused_recurrent_bwd_kernel[grid]( |
|
q=qv, |
|
k=s, |
|
v=v, |
|
g=None, |
|
gk=gk, |
|
gv=gv, |
|
h0=hv0, |
|
do=do, |
|
dq=dqv, |
|
dk=dsv, |
|
dv=dv, |
|
dht=dhvt, |
|
dh0=dhv0, |
|
offsets=offsets, |
|
scale=1., |
|
B=B, |
|
T=T, |
|
H=H, |
|
K=M, |
|
V=V, |
|
BK=BM, |
|
BV=BV, |
|
USE_G=False, |
|
USE_GK=True, |
|
USE_GV=False, |
|
REVERSE=reverse, |
|
HEAD_FIRST=head_first |
|
) |
|
dqv = dqv.sum(0) |
|
dsv = dsv.sum(0) |
|
dv = dv.sum(0) |
|
dgk = chunk_global_cumsum(dqv * qv.float() - dsv * s.float(), |
|
reverse=not reverse, |
|
offsets=offsets, |
|
head_first=head_first) |
|
|
|
dok = qv * (dqv - (qv * dqv).sum(-1, True)) |
|
if head_first: |
|
dq = q.new_empty(NM, B, H, T, K, dtype=torch.float) |
|
dk = q.new_empty(NM, B, H, T, K, dtype=torch.float) |
|
dsk = q.new_empty(NK, B, H, T, M, dtype=torch.float) |
|
else: |
|
dq = q.new_empty(NM, B, T, H, K, dtype=torch.float) |
|
dk = q.new_empty(NM, B, T, H, K, dtype=torch.float) |
|
dsk = q.new_empty(NK, B, T, H, M, dtype=torch.float) |
|
gk, gv = None, g |
|
grid = (NM, NK, N * H) |
|
fused_recurrent_bwd_kernel[grid]( |
|
q=q, |
|
k=k, |
|
v=s, |
|
g=None, |
|
gk=gk, |
|
gv=gv, |
|
h0=hk0, |
|
do=dok, |
|
dq=dq, |
|
dk=dk, |
|
dv=dsk, |
|
dht=dhkt, |
|
dh0=dhk0, |
|
offsets=offsets, |
|
scale=scale, |
|
B=B, |
|
T=T, |
|
H=H, |
|
K=K, |
|
V=M, |
|
BK=BK, |
|
BV=BM, |
|
USE_G=False, |
|
USE_GK=False, |
|
USE_GV=True, |
|
REVERSE=reverse, |
|
HEAD_FIRST=head_first |
|
) |
|
dq = dq.sum(0) |
|
dk = dk.sum(0) |
|
dsk = dsk.sum(0) |
|
|
|
dgv = chunk_global_cumsum(dok.float() * ok.float() - dsk * s.float(), |
|
reverse=not reverse, |
|
offsets=offsets, |
|
head_first=head_first) |
|
|
|
ds = dsk.add_(dsv) |
|
dg = dgk.add_(dgv) |
|
|
|
return dq, dk, dv, ds, dg, dhk0, dhv0 |
|
|
|
|
|
class FusedRecurrentGSAFunction(torch.autograd.Function): |
|
|
|
@staticmethod |
|
@contiguous |
|
@autocast_custom_fwd |
|
def forward( |
|
ctx, |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
s: torch.Tensor, |
|
g: torch.Tensor, |
|
scale: Optional[float] = None, |
|
hk0: Optional[torch.Tensor] = None, |
|
hv0: Optional[torch.Tensor] = None, |
|
output_final_state: bool = False, |
|
reverse: bool = False, |
|
offsets: Optional[torch.LongTensor] = None, |
|
head_first: bool = True |
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]: |
|
T = q.shape[2] if head_first else q.shape[1] |
|
if T == 1 and not q.requires_grad: |
|
o, (hkt, hvt) = fused_recurrent_gsa_inference( |
|
q=q, |
|
k=k, |
|
v=v, |
|
s=s, |
|
g=g, |
|
initial_state=(hk0, hv0), |
|
output_final_state=output_final_state, |
|
scale=scale, |
|
head_first=head_first |
|
) |
|
return o, (hkt, hvt) |
|
ok, hkt, qv, ov, hvt = fused_recurrent_gsa_fwd( |
|
q=q, |
|
k=k, |
|
v=v, |
|
s=s, |
|
g=g, |
|
initial_state=(hk0, hv0), |
|
output_final_state=output_final_state, |
|
scale=scale, |
|
reverse=reverse, |
|
offsets=offsets, |
|
head_first=head_first |
|
) |
|
ctx.save_for_backward(q, k, v, s, g, qv, hk0, hv0, ok) |
|
ctx.scale = scale |
|
ctx.reverse = reverse |
|
ctx.offsets = offsets |
|
ctx.head_first = head_first |
|
return ov.to(q.dtype), hkt, hvt |
|
|
|
@staticmethod |
|
@contiguous |
|
@autocast_custom_bwd |
|
def backward(ctx, do, dhkt=None, dhvt=None): |
|
q, k, v, s, g, qv, hk0, hv0, ok = ctx.saved_tensors |
|
scale = ctx.scale |
|
reverse = ctx.reverse |
|
offsets = ctx.offsets |
|
head_first = ctx.head_first |
|
|
|
|
|
if dhkt is not None or dhvt is not None: |
|
if g is not None: |
|
assert g.requires_grad is False, "Cannot load final state gradient and use gates at the same time" |
|
dq, dk, dv, ds, dg, dhk0, dhv0 = fused_recurrent_gsa_bwd( |
|
q=q, |
|
k=k, |
|
v=v, |
|
s=s, |
|
g=g, |
|
qv=qv, |
|
hk0=hk0, |
|
hv0=hv0, |
|
ok=ok, |
|
do=do, |
|
dhkt=dhkt, |
|
dhvt=dhvt, |
|
scale=scale, |
|
reverse=reverse, |
|
offsets=offsets, |
|
head_first=head_first |
|
) |
|
return dq.to(q), dk.to(k), dv.to(v), ds.to(s), dg.to(g), None, dhk0, dhv0, None, None, None, None |
|
|
|
|
|
def fused_recurrent_gsa( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
s: torch.Tensor, |
|
g: Optional[torch.Tensor] = None, |
|
scale: Optional[int] = None, |
|
initial_state: Optional[Tuple[torch.Tensor]] = None, |
|
output_final_state: Optional[bool] = False, |
|
reverse: Optional[bool] = False, |
|
offsets: Optional[torch.LongTensor] = None, |
|
head_first: bool = True |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
r""" |
|
Args: |
|
q (torch.Tensor): |
|
queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. |
|
k (torch.Tensor): |
|
keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. |
|
v (torch.Tensor): |
|
values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. |
|
s (torch.Tensor): |
|
slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]`. |
|
g (torch.Tensor): |
|
Forget gates of shape `[B, H, T, M]` applied to keys. |
|
scale (Optional[int]): |
|
Scale factor for the attention scores. |
|
If not provided, it will default to `1 / sqrt(K)`. Default: `None`. |
|
initial_state (Optional[Tuple[torch.Tensor]]): |
|
Initial state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` for `N` input sequences. |
|
For equal-length input sequences, `N` equals the batch size `B`. |
|
Default: `None`. |
|
output_final_state (Optional[bool]): |
|
Whether to output the final state of shape `[N, H, K, V]` and `[N, H, M, V]`. |
|
Default: `False`. |
|
reverse (Optional[bool]): |
|
If `True`, process the state passing in reverse order. Default: `False`. |
|
offsets (Optional[torch.LongTensor]): |
|
Offsets of shape `[N+1]` defining the bos/eos positions of `N` variable-length sequences in the batch. |
|
For example, |
|
if `offsets` is `[0, 1, 3, 6, 10, 15]`, there are `N=5` sequences with lengths 1, 2, 3, 4 and 5 respectively. |
|
If provided, the inputs are concatenated and the batch size `B` is expected to be 1. |
|
Default: `None`. |
|
head_first (Optional[bool]): |
|
Whether the inputs are in the head-first format, which is not supported for variable-length inputs. |
|
Default: `True`. |
|
|
|
Returns: |
|
o (torch.Tensor): |
|
Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. |
|
final_state (Tuple[torch.Tensor]): |
|
Final state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]`. |
|
|
|
Examples:: |
|
>>> import torch |
|
>>> import torch.nn.functional as F |
|
>>> from einops import rearrange |
|
>>> from fla.ops.gsa import fused_recurrent_gsa |
|
# inputs with equal lengths |
|
>>> B, T, H, K, V, M = 4, 2048, 4, 512, 512, 64 |
|
>>> q = torch.randn(B, T, H, K, device='cuda') |
|
>>> k = torch.randn(B, T, H, K, device='cuda') |
|
>>> v = torch.randn(B, T, H, V, device='cuda') |
|
>>> s = torch.randn(B, T, H, M, device='cuda') |
|
>>> g = F.logsigmoid(torch.randn(B, T, H, M, device='cuda')) |
|
>>> h0 = (torch.randn(B, H, K, M, device='cuda'), torch.randn(B, H, M, V, device='cuda')) |
|
>>> o, (hk, hv) = fused_recurrent_gsa(q, k, v, s, g, |
|
initial_state=h0, |
|
output_final_state=True, |
|
head_first=False) |
|
# for variable-length inputs, the batch size `B` is expected to be 1 and `offsets` is required |
|
>>> q, k, v, s, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, s, g)) |
|
# for a batch with 4 sequences, offsets with 5 start/end positions are expected |
|
>>> offsets = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) |
|
>>> o_var, (hk_var, hv_var) = fused_recurrent_gsa(q, k, v, s, g, |
|
initial_state=h0, |
|
output_final_state=True, |
|
offsets=offsets, |
|
head_first=False) |
|
>>> assert o.allclose(o_var.view(o.shape)) |
|
>>> assert hk.allclose(hk_var) |
|
>>> assert hv.allclose(hv_var) |
|
""" |
|
if offsets is not None: |
|
if q.shape[0] != 1: |
|
raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `offsets`." |
|
f"Please flatten variable-length inputs before processing.") |
|
if head_first: |
|
raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") |
|
if initial_state is not None and initial_state[0].shape[0] != len(offsets) - 1: |
|
raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " |
|
f"i.e., {len(offsets) - 1} rather than {initial_state[0].shape[0]}.") |
|
if scale is None: |
|
scale = k.shape[-1] ** -0.5 |
|
if initial_state is None: |
|
initial_state = (None, None) |
|
o, *final_state = FusedRecurrentGSAFunction.apply( |
|
q, |
|
k, |
|
v, |
|
s, |
|
g, |
|
scale, |
|
*initial_state, |
|
output_final_state, |
|
reverse, |
|
offsets, |
|
head_first |
|
) |
|
return o, final_state |
|
|