# -*- coding: utf-8 -*- # Copyright (c) 2024, Songlin Yang, Yu Zhang from typing import Optional, Tuple import torch import triton import triton.language as tl from fla.ops.common.fused_recurrent import (fused_recurrent_bwd_kernel, fused_recurrent_fwd_kernel) from fla.ops.utils import chunk_global_cumsum from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous @triton.jit def fused_recurrent_gsa_inference_kernel( q, k, v, s, g, o, hk0, hv0, hkt, hvt, scale, K: tl.constexpr, V: tl.constexpr, M: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NG: tl.constexpr ): i_bh = tl.program_id(0) i_bg = i_bh // NG b_s = tl.load(s + i_bg * M + tl.arange(0, M)).to(tl.float32) b_g = tl.load(g + i_bg * M + tl.arange(0, M)).to(tl.float32) b_g = tl.exp(b_g) b_ok = tl.zeros([M], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): o_k = i_k * BK + tl.arange(0, BK) p_hk0 = hk0 + i_bg * K * M + (o_k[None, :]) * M + tl.arange(0, M)[:, None] # [BK,] mask_k = o_k < K # [M, BK] mask_hk = (tl.arange(0, M) < M)[:, None] & mask_k[None, :] # [M, BK] b_hk = tl.load(p_hk0, mask=mask_hk, other=0.).to(tl.float32) # [BK,] b_q = tl.load(q + i_bh * K + o_k, mask=mask_k, other=0.).to(tl.float32) * scale b_k = tl.load(k + i_bg * K + o_k, mask=mask_k, other=0.).to(tl.float32) b_hk = b_hk * b_g[:, None] + b_k[None, :] * b_s[:, None] b_ok += tl.sum(b_hk * b_q[None, :], axis=1) if i_bh % NG == 0: p_hkt = hkt + i_bg * K * M + o_k[None, :] * M + tl.arange(0, M)[:, None] tl.store(p_hkt, b_hk.to(p_hkt.dtype.element_ty), mask=mask_hk) b_qv = tl.softmax(b_ok) for i_v in range(tl.cdiv(V, BV)): o_v = i_v * BV + tl.arange(0, BV) p_hv0 = hv0 + i_bg * M * V + tl.arange(0, M)[None, :] * V + o_v[:, None] # [BV,] mask_v = o_v < V # [BV, M] mask_hv = mask_v[:, None] & (tl.arange(0, M) < M)[None, :] # [BV, M] b_hv = tl.load(p_hv0, mask=mask_hv, other=0).to(tl.float32) # [BV,] b_v = tl.load(v + i_bg * V + o_v, mask=mask_v, other=0).to(tl.float32) b_hv = b_hv * b_g[None, :] + b_s[None, :] * b_v[:, None] b_ov = tl.sum(b_hv * b_qv[None, :], axis=1) tl.store(o + i_bh * V + o_v, b_ov.to(o.dtype.element_ty), mask=mask_v) if i_bh % NG == 0: p_hvt = hvt + i_bg * M * V + tl.arange(0, M)[None, :] * V + o_v[:, None] tl.store(p_hvt, b_hv.to(p_hvt.dtype.element_ty), mask=mask_hv) def fused_recurrent_gsa_inference( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, s: torch.Tensor, g: torch.Tensor, initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, output_final_state: bool = False, scale: float = 1., head_first: bool = True ) -> torch.Tensor: if head_first: B, H, T, K, V, M = *k.shape, v.shape[-1], s.shape[-1] else: B, T, H, K, V, M = *k.shape, v.shape[-1], s.shape[-1] HQ = q.shape[1] if head_first else q.shape[2] BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64) NG = HQ // H hk0, hv0 = None, None if initial_state is not None: hk0, hv0 = initial_state hkt, hvt = None, None if output_final_state: if NG == 1: hkt, hvt = hk0, hv0 else: hkt, hvt = q.new_empty(B, H, K, M, dtype=torch.float), q.new_empty(B, H, M, V, dtype=torch.float) o = v.new_empty(B, HQ, T, V) if head_first else v.new_empty(B, T, HQ, V) grid = (B * HQ,) fused_recurrent_gsa_inference_kernel[grid]( q, k, v, s, g, o, hk0, hv0, hkt, hvt, scale=scale, K=K, V=V, M=M, BK=BK, BV=BV, NG=NG ) return o, (hkt, hvt) def fused_recurrent_gsa_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, s: torch.Tensor, g: torch.Tensor, initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, output_final_state: bool = False, scale: float = 1., reverse: bool = False, offsets: Optional[torch.LongTensor] = None, head_first: bool = True ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]: if head_first: B, H, T, K, V, M = *k.shape, v.shape[-1], s.shape[-1] else: B, T, H, K, V, M = *k.shape, v.shape[-1], s.shape[-1] N = B if offsets is None else len(offsets) - 1 HQ = q.shape[1] if head_first else q.shape[2] if HQ != H: raise ValueError("GQA not supported yet.") BK, BV, BM = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64), min(M, 64) NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM) hk0, hv0 = None, None if initial_state is not None: hk0, hv0 = initial_state hkt, hvt = None, None if output_final_state: hkt, hvt = q.new_empty(N, H, K, M, dtype=torch.float), q.new_empty(N, H, M, V, dtype=torch.float) ok = q.new_empty(NK, *s.shape, dtype=torch.float) gk, gv = None, g grid = (NM, NK, N * H) fused_recurrent_fwd_kernel[grid]( q=q, k=k, v=s, g=None, gk=gk, gv=gv, o=ok, h0=hk0, ht=hkt, offsets=offsets, scale=scale, B=B, T=T, H=H, K=K, V=M, BK=BK, BV=BM, USE_G=False, USE_GK=False, USE_GV=True, REVERSE=reverse, HEAD_FIRST=head_first ) ok = ok.sum(0) qv = ok.softmax(-1, dtype=torch.float) ov = q.new_empty(NM, *v.shape, dtype=torch.float) gk, gv = g, None grid = (NV, NM, N * H) fused_recurrent_fwd_kernel[grid]( q=qv, k=s, v=v, g=None, gk=gk, gv=gv, o=ov, h0=hv0, ht=hvt, offsets=offsets, scale=1., B=B, T=T, H=H, K=M, V=V, BK=BM, BV=BV, USE_G=False, USE_GK=True, USE_GV=False, REVERSE=reverse, HEAD_FIRST=head_first ) ov = ov.sum(0) return ok, hkt, qv, ov, hvt def fused_recurrent_gsa_bwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, s: torch.Tensor, g: torch.Tensor, qv: torch.Tensor, hk0: Optional[torch.Tensor] = None, hv0: Optional[torch.Tensor] = None, ok: Optional[torch.Tensor] = None, do: Optional[torch.Tensor] = None, dhkt: Optional[torch.Tensor] = None, dhvt: Optional[torch.Tensor] = None, scale: float = 1., reverse: bool = False, offsets: Optional[torch.LongTensor] = None, head_first: bool = True ) -> Tuple[torch.Tensor]: if head_first: B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] else: B, T, H, K, V, M = *q.shape, v.shape[-1], s.shape[-1] N = B if offsets is None else len(offsets) - 1 BK, BV, BM = min(K, 64), min(V, 64), min(M, 64) NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM) if head_first: dqv = q.new_empty(NV, B, H, T, M, dtype=torch.float) dsv = q.new_empty(NV, B, H, T, M, dtype=torch.float) dv = q.new_empty(NM, B, H, T, V, dtype=torch.float) else: dqv = q.new_empty(NV, B, T, H, M, dtype=torch.float) dsv = q.new_empty(NV, B, T, H, M, dtype=torch.float) dv = q.new_empty(NM, B, T, H, V, dtype=torch.float) dhk0 = torch.empty_like(hk0)if hk0 is not None else None dhv0 = torch.empty_like(hv0)if hv0 is not None else None gk, gv = g, None grid = (NV, NM, N * H) fused_recurrent_bwd_kernel[grid]( q=qv, k=s, v=v, g=None, gk=gk, gv=gv, h0=hv0, do=do, dq=dqv, dk=dsv, dv=dv, dht=dhvt, dh0=dhv0, offsets=offsets, scale=1., B=B, T=T, H=H, K=M, V=V, BK=BM, BV=BV, USE_G=False, USE_GK=True, USE_GV=False, REVERSE=reverse, HEAD_FIRST=head_first ) dqv = dqv.sum(0) dsv = dsv.sum(0) dv = dv.sum(0) dgk = chunk_global_cumsum(dqv * qv.float() - dsv * s.float(), reverse=not reverse, offsets=offsets, head_first=head_first) dok = qv * (dqv - (qv * dqv).sum(-1, True)) if head_first: dq = q.new_empty(NM, B, H, T, K, dtype=torch.float) dk = q.new_empty(NM, B, H, T, K, dtype=torch.float) dsk = q.new_empty(NK, B, H, T, M, dtype=torch.float) else: dq = q.new_empty(NM, B, T, H, K, dtype=torch.float) dk = q.new_empty(NM, B, T, H, K, dtype=torch.float) dsk = q.new_empty(NK, B, T, H, M, dtype=torch.float) gk, gv = None, g grid = (NM, NK, N * H) fused_recurrent_bwd_kernel[grid]( q=q, k=k, v=s, g=None, gk=gk, gv=gv, h0=hk0, do=dok, dq=dq, dk=dk, dv=dsk, dht=dhkt, dh0=dhk0, offsets=offsets, scale=scale, B=B, T=T, H=H, K=K, V=M, BK=BK, BV=BM, USE_G=False, USE_GK=False, USE_GV=True, REVERSE=reverse, HEAD_FIRST=head_first ) dq = dq.sum(0) dk = dk.sum(0) dsk = dsk.sum(0) dgv = chunk_global_cumsum(dok.float() * ok.float() - dsk * s.float(), reverse=not reverse, offsets=offsets, head_first=head_first) ds = dsk.add_(dsv) dg = dgk.add_(dgv) return dq, dk, dv, ds, dg, dhk0, dhv0 class FusedRecurrentGSAFunction(torch.autograd.Function): @staticmethod @contiguous @autocast_custom_fwd def forward( ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, s: torch.Tensor, g: torch.Tensor, scale: Optional[float] = None, hk0: Optional[torch.Tensor] = None, hv0: Optional[torch.Tensor] = None, output_final_state: bool = False, reverse: bool = False, offsets: Optional[torch.LongTensor] = None, head_first: bool = True ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]: T = q.shape[2] if head_first else q.shape[1] if T == 1 and not q.requires_grad: o, (hkt, hvt) = fused_recurrent_gsa_inference( q=q, k=k, v=v, s=s, g=g, initial_state=(hk0, hv0), output_final_state=output_final_state, scale=scale, head_first=head_first ) return o, (hkt, hvt) ok, hkt, qv, ov, hvt = fused_recurrent_gsa_fwd( q=q, k=k, v=v, s=s, g=g, initial_state=(hk0, hv0), output_final_state=output_final_state, scale=scale, reverse=reverse, offsets=offsets, head_first=head_first ) ctx.save_for_backward(q, k, v, s, g, qv, hk0, hv0, ok) ctx.scale = scale ctx.reverse = reverse ctx.offsets = offsets ctx.head_first = head_first return ov.to(q.dtype), hkt, hvt @staticmethod @contiguous @autocast_custom_bwd def backward(ctx, do, dhkt=None, dhvt=None): q, k, v, s, g, qv, hk0, hv0, ok = ctx.saved_tensors scale = ctx.scale reverse = ctx.reverse offsets = ctx.offsets head_first = ctx.head_first # not supported yet. if dhkt is not None or dhvt is not None: if g is not None: assert g.requires_grad is False, "Cannot load final state gradient and use gates at the same time" dq, dk, dv, ds, dg, dhk0, dhv0 = fused_recurrent_gsa_bwd( q=q, k=k, v=v, s=s, g=g, qv=qv, hk0=hk0, hv0=hv0, ok=ok, do=do, dhkt=dhkt, dhvt=dhvt, scale=scale, reverse=reverse, offsets=offsets, head_first=head_first ) return dq.to(q), dk.to(k), dv.to(v), ds.to(s), dg.to(g), None, dhk0, dhv0, None, None, None, None def fused_recurrent_gsa( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, s: torch.Tensor, g: Optional[torch.Tensor] = None, scale: Optional[int] = None, initial_state: Optional[Tuple[torch.Tensor]] = None, output_final_state: Optional[bool] = False, reverse: Optional[bool] = False, offsets: Optional[torch.LongTensor] = None, head_first: bool = True ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: q (torch.Tensor): queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. k (torch.Tensor): keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. v (torch.Tensor): values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. s (torch.Tensor): slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]`. g (torch.Tensor): Forget gates of shape `[B, H, T, M]` applied to keys. scale (Optional[int]): Scale factor for the attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. initial_state (Optional[Tuple[torch.Tensor]]): Initial state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` for `N` input sequences. For equal-length input sequences, `N` equals the batch size `B`. Default: `None`. output_final_state (Optional[bool]): Whether to output the final state of shape `[N, H, K, V]` and `[N, H, M, V]`. Default: `False`. reverse (Optional[bool]): If `True`, process the state passing in reverse order. Default: `False`. offsets (Optional[torch.LongTensor]): Offsets of shape `[N+1]` defining the bos/eos positions of `N` variable-length sequences in the batch. For example, if `offsets` is `[0, 1, 3, 6, 10, 15]`, there are `N=5` sequences with lengths 1, 2, 3, 4 and 5 respectively. If provided, the inputs are concatenated and the batch size `B` is expected to be 1. Default: `None`. head_first (Optional[bool]): Whether the inputs are in the head-first format, which is not supported for variable-length inputs. Default: `True`. Returns: o (torch.Tensor): Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. final_state (Tuple[torch.Tensor]): Final state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]`. Examples:: >>> import torch >>> import torch.nn.functional as F >>> from einops import rearrange >>> from fla.ops.gsa import fused_recurrent_gsa # inputs with equal lengths >>> B, T, H, K, V, M = 4, 2048, 4, 512, 512, 64 >>> q = torch.randn(B, T, H, K, device='cuda') >>> k = torch.randn(B, T, H, K, device='cuda') >>> v = torch.randn(B, T, H, V, device='cuda') >>> s = torch.randn(B, T, H, M, device='cuda') >>> g = F.logsigmoid(torch.randn(B, T, H, M, device='cuda')) >>> h0 = (torch.randn(B, H, K, M, device='cuda'), torch.randn(B, H, M, V, device='cuda')) >>> o, (hk, hv) = fused_recurrent_gsa(q, k, v, s, g, initial_state=h0, output_final_state=True, head_first=False) # for variable-length inputs, the batch size `B` is expected to be 1 and `offsets` is required >>> q, k, v, s, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, s, g)) # for a batch with 4 sequences, offsets with 5 start/end positions are expected >>> offsets = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) >>> o_var, (hk_var, hv_var) = fused_recurrent_gsa(q, k, v, s, g, initial_state=h0, output_final_state=True, offsets=offsets, head_first=False) >>> assert o.allclose(o_var.view(o.shape)) >>> assert hk.allclose(hk_var) >>> assert hv.allclose(hv_var) """ if offsets is not None: if q.shape[0] != 1: raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `offsets`." f"Please flatten variable-length inputs before processing.") if head_first: raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") if initial_state is not None and initial_state[0].shape[0] != len(offsets) - 1: raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " f"i.e., {len(offsets) - 1} rather than {initial_state[0].shape[0]}.") if scale is None: scale = k.shape[-1] ** -0.5 if initial_state is None: initial_state = (None, None) o, *final_state = FusedRecurrentGSAFunction.apply( q, k, v, s, g, scale, *initial_state, output_final_state, reverse, offsets, head_first ) return o, final_state