# -*- coding: utf-8 -*- # Copyright (c) 2023-2024, Songlin Yang, Yu Zhang from typing import Optional import torch import triton import triton.language as tl from fla.utils import contiguous @triton.autotune( configs=[ triton.Config({}, num_warps=1), triton.Config({}, num_warps=2), triton.Config({}, num_warps=4), triton.Config({}, num_warps=8) ], key=['BT'] ) @triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def chunk_local_cumsum_scalar_kernel( s, o, offsets, indices, T: tl.constexpr, H: tl.constexpr, BT: tl.constexpr, HEAD_FIRST: tl.constexpr, USE_OFFSETS: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T if HEAD_FIRST: p_s = tl.make_block_ptr(s + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) p_o = tl.make_block_ptr(o + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) else: p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) # [BT] b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) b_o = tl.cumsum(b_s, axis=0) tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) @triton.autotune( configs=[ triton.Config({}, num_warps=1), triton.Config({}, num_warps=2), triton.Config({}, num_warps=4), triton.Config({}, num_warps=8) ], key=['BT'] ) @triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def chunk_local_reversed_cumsum_scalar_kernel( s, o, offsets, indices, T: tl.constexpr, H: tl.constexpr, BT: tl.constexpr, HEAD_FIRST: tl.constexpr, USE_OFFSETS: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T if HEAD_FIRST: p_s = tl.make_block_ptr(s + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) p_o = tl.make_block_ptr(o + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) else: p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) # [BT] b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) b_z = tl.sum(b_s, axis=0) b_o = b_z[None] - tl.cumsum(b_s, axis=0) + b_s tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) @triton.autotune( configs=[ triton.Config({'BS': 16}, num_warps=2), triton.Config({'BS': 16}, num_warps=4), triton.Config({'BS': 16}, num_warps=8), triton.Config({'BS': 32}, num_warps=2), triton.Config({'BS': 32}, num_warps=4), triton.Config({'BS': 32}, num_warps=8), triton.Config({'BS': 64}, num_warps=2), triton.Config({'BS': 64}, num_warps=4), triton.Config({'BS': 64}, num_warps=8), ], key=['S', 'BT'] ) @triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def chunk_local_cumsum_vector_kernel( s, o, offsets, indices, T: tl.constexpr, H: tl.constexpr, S: tl.constexpr, BT: tl.constexpr, BS: tl.constexpr, HEAD_FIRST: tl.constexpr, USE_OFFSETS: tl.constexpr ): i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T o_i = tl.arange(0, BT) m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) if HEAD_FIRST: p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) p_o = tl.make_block_ptr(o + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) else: p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) # [BT, BS] b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) b_o = tl.dot(m_s, b_s, allow_tf32=False) tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) @triton.autotune( configs=[ triton.Config({'BS': 16}, num_warps=2), triton.Config({'BS': 16}, num_warps=4), triton.Config({'BS': 16}, num_warps=8), triton.Config({'BS': 32}, num_warps=2), triton.Config({'BS': 32}, num_warps=4), triton.Config({'BS': 32}, num_warps=8), triton.Config({'BS': 64}, num_warps=2), triton.Config({'BS': 64}, num_warps=4), triton.Config({'BS': 64}, num_warps=8), ], key=['S', 'BT'] ) @triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def chunk_local_reversed_cumsum_vector_kernel( s, o, offsets, indices, T: tl.constexpr, H: tl.constexpr, S: tl.constexpr, BT: tl.constexpr, BS: tl.constexpr, HEAD_FIRST: tl.constexpr, USE_OFFSETS: tl.constexpr ): i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T o_i = tl.arange(0, BT) m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.) if HEAD_FIRST: p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) p_o = tl.make_block_ptr(o + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) else: p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) # [BT, BS] b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) b_o = tl.dot(m_s, b_s, allow_tf32=False) tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) @triton.autotune( configs=[ triton.Config({'BT': 16}, num_warps=2), triton.Config({'BT': 32}, num_warps=4), triton.Config({'BT': 32}, num_warps=2), triton.Config({'BT': 64}, num_warps=8), triton.Config({'BT': 64}, num_warps=4), ], key=[] ) @triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def chunk_global_cumsum_scalar_kernel( s, o, offsets, T: tl.constexpr, H: tl.constexpr, BT: tl.constexpr, HEAD_FIRST: tl.constexpr, USE_OFFSETS: tl.constexpr ): i_bh = tl.program_id(0) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: bos, eos = tl.load(offsets + i_b).to(tl.int32), tl.load(offsets + i_b + 1).to(tl.int32) else: bos, eos = i_b * T, i_b * T + T T = eos - bos b_z = tl.zeros([], dtype=tl.float32) for i_t in range(tl.cdiv(T, BT)): if HEAD_FIRST: p_s = tl.make_block_ptr(s + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) p_o = tl.make_block_ptr(o + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) else: p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) b_o = tl.cumsum(b_s, axis=0) + b_z[None] b_z += tl.sum(b_s, axis=0) tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) @triton.autotune( configs=[ triton.Config({'BT': 16}, num_warps=2), triton.Config({'BT': 32}, num_warps=4), triton.Config({'BT': 32}, num_warps=2), triton.Config({'BT': 64}, num_warps=8), triton.Config({'BT': 64}, num_warps=4), ], key=[] ) @triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def chunk_global_reversed_cumsum_scalar_kernel( s, o, offsets, T: tl.constexpr, H: tl.constexpr, BT: tl.constexpr, HEAD_FIRST: tl.constexpr, USE_OFFSETS: tl.constexpr ): i_bh = tl.program_id(0) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: bos, eos = tl.load(offsets + i_b).to(tl.int32), tl.load(offsets + i_b + 1).to(tl.int32) else: bos, eos = i_b * T, i_b * T + T T = eos - bos b_z = tl.zeros([], dtype=tl.float32) for i_t in range(tl.cdiv(T, BT) - 1, -1, -1): if HEAD_FIRST: p_s = tl.make_block_ptr(s + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) p_o = tl.make_block_ptr(o + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) else: p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) b_zz = tl.sum(b_s, axis=0) b_z += b_zz b_o = b_s - tl.cumsum(b_s, axis=0) + b_z[None] tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) @triton.autotune( configs=[ triton.Config({'BT': 16}, num_warps=2), triton.Config({'BT': 16}, num_warps=4), triton.Config({'BT': 16}, num_warps=8), triton.Config({'BT': 32}, num_warps=2), triton.Config({'BT': 32}, num_warps=4), triton.Config({'BT': 32}, num_warps=8), triton.Config({'BT': 64}, num_warps=2), triton.Config({'BT': 64}, num_warps=4), triton.Config({'BT': 64}, num_warps=8), ], key=['S'] ) @triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def chunk_global_cumsum_vector_kernel( s, z, offsets, T: tl.constexpr, H: tl.constexpr, S: tl.constexpr, BT: tl.constexpr, BS: tl.constexpr, HEAD_FIRST: tl.constexpr, USE_OFFSETS: tl.constexpr ): i_s, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: bos, eos = tl.load(offsets + i_b).to(tl.int32), tl.load(offsets + i_b + 1).to(tl.int32) else: bos, eos = i_b * T, i_b * T + T T = eos - bos o_i = tl.arange(0, BT) m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) b_z = tl.zeros([BS], dtype=tl.float32) for i_t in range(tl.cdiv(T, BT)): if HEAD_FIRST: p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) else: p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) p_z = tl.make_block_ptr(z + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) # [BT, BS] b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False) tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1)) if i_t >= 0: b_z += tl.sum(b_s, 0) @triton.autotune( configs=[ triton.Config({'BT': 16}, num_warps=2), triton.Config({'BT': 16}, num_warps=4), triton.Config({'BT': 16}, num_warps=8), triton.Config({'BT': 32}, num_warps=2), triton.Config({'BT': 32}, num_warps=4), triton.Config({'BT': 32}, num_warps=8), triton.Config({'BT': 64}, num_warps=2), triton.Config({'BT': 64}, num_warps=4), triton.Config({'BT': 64}, num_warps=8), ], key=['S'] ) @triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def chunk_global_reversed_cumsum_vector_kernel( s, z, offsets, T: tl.constexpr, H: tl.constexpr, S: tl.constexpr, BT: tl.constexpr, BS: tl.constexpr, HEAD_FIRST: tl.constexpr, USE_OFFSETS: tl.constexpr ): i_s, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: bos, eos = tl.load(offsets + i_b).to(tl.int32), tl.load(offsets + i_b + 1).to(tl.int32) else: bos, eos = i_b * T, i_b * T + T T = eos - bos o_i = tl.arange(0, BT) m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.) b_z = tl.zeros([BS], dtype=tl.float32) for i_t in range(tl.cdiv(T, BT) - 1, -1, -1): if HEAD_FIRST: p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) else: p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) p_z = tl.make_block_ptr(z + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) # [BT, BS] b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False) tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1)) if i_t >= 0: b_z += tl.sum(b_s, 0) def chunk_local_cumsum_scalar( g: torch.Tensor, chunk_size: int, reverse: bool = False, offsets: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, head_first: bool = True ) -> torch.Tensor: if head_first: B, H, T = g.shape else: B, T, H = g.shape if offsets is not None: B = len(offsets) - 1 BT = chunk_size if offsets is None: NT = triton.cdiv(T, BT) else: if indices is None: indices = torch.cat([ torch.stack([offsets.new_full((n,), i), offsets.new_tensor(range(n))], 1) for i, n in enumerate(triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()) ]) NT = len(indices) g_org, g = g, torch.empty_like(g, dtype=torch.float) grid = (NT, B * H) if reverse: chunk_local_reversed_cumsum_scalar_kernel[grid]( g_org, g, offsets, indices, T=T, H=H, BT=BT, HEAD_FIRST=head_first ) else: chunk_local_cumsum_scalar_kernel[grid]( g_org, g, offsets, indices, T=T, H=H, BT=BT, HEAD_FIRST=head_first ) return g def chunk_local_cumsum_vector( g: torch.Tensor, chunk_size: int, reverse: bool = False, offsets: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, head_first: bool = True ) -> torch.Tensor: if head_first: B, H, T, S = g.shape else: B, T, H, S = g.shape BT = chunk_size if offsets is None: NT = triton.cdiv(T, BT) else: if indices is None: indices = torch.cat([ torch.stack([offsets.new_full((n,), i), offsets.new_tensor(range(n))], 1) for i, n in enumerate(triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()) ]) NT = len(indices) g_org, g = g, torch.empty_like(g, dtype=torch.float) def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H) # keep cummulative normalizer in fp32 # this kernel is equivalent to # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) if reverse: chunk_local_reversed_cumsum_vector_kernel[grid]( g_org, g, offsets, indices, T=T, H=H, S=S, BT=BT, HEAD_FIRST=head_first ) else: chunk_local_cumsum_vector_kernel[grid]( g_org, g, offsets, indices, T=T, H=H, S=S, BT=BT, HEAD_FIRST=head_first ) return g @contiguous def chunk_local_cumsum( g: torch.Tensor, chunk_size: int, reverse: bool = False, offsets: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, head_first: bool = True ) -> torch.Tensor: if offsets is not None: assert not head_first, "Sequences with variable lengths are not supported for head-first mode" assert g.shape[0] == 1, "Only batch size 1 is supported when offsets are provided" if len(g.shape) == 3: return chunk_local_cumsum_scalar(g, chunk_size, reverse, offsets, indices, head_first) elif len(g.shape) == 4: return chunk_local_cumsum_vector(g, chunk_size, reverse, offsets, indices, head_first) else: raise ValueError(f"Unsupported input shape {g.shape}. " f"which should be (B, H, T, dim) if `head_first=True` " f"or (batch_size, num_heads, seq_len) otherwise") @contiguous def chunk_global_cumsum_scalar( s: torch.Tensor, dtype: Optional[torch.dtype] = None, reverse: bool = False, offsets: Optional[torch.Tensor] = None, head_first: bool = True ) -> torch.Tensor: dtype = dtype or s.dtype if head_first: B, H, T = s.shape else: B, T, H = s.shape if offsets is not None: B = len(offsets) - 1 grid = (B * H,) z = torch.empty_like(s, dtype=dtype) if reverse: chunk_global_reversed_cumsum_scalar_kernel[grid]( s, z, offsets, T=T, H=H, HEAD_FIRST=head_first ) else: chunk_global_cumsum_scalar_kernel[grid]( s, z, offsets, T=T, H=H, HEAD_FIRST=head_first ) return z @contiguous def chunk_global_cumsum_vector( s: torch.Tensor, dtype: Optional[torch.dtype] = None, reverse: bool = False, offsets: Optional[torch.Tensor] = None, head_first: bool = True ) -> torch.Tensor: dtype = dtype or s.dtype if head_first: B, H, T, S = s.shape else: B, T, H, S = s.shape BS = min(32, S) if offsets is not None: B = len(offsets) - 1 grid = (triton.cdiv(S, BS), B * H) z = torch.empty_like(s, dtype=dtype) if reverse: chunk_global_reversed_cumsum_vector_kernel[grid]( s, z, offsets, T=T, H=H, S=S, BS=BS, HEAD_FIRST=head_first ) else: chunk_global_cumsum_vector_kernel[grid]( s, z, offsets, T=T, H=H, S=S, BS=BS, HEAD_FIRST=head_first ) return z @contiguous def chunk_global_cumsum( s: torch.Tensor, dtype: Optional[torch.dtype] = None, reverse: bool = False, offsets: Optional[torch.Tensor] = None, head_first: bool = True ) -> torch.Tensor: if offsets is not None: assert not head_first, "Sequences with variable lengths are not supported for head-first mode" assert s.shape[0] == 1, "Only batch size 1 is supported when offsets are provided" if len(s.shape) == 3: return chunk_global_cumsum_scalar(s, dtype, reverse, offsets, head_first) elif len(s.shape) == 4: return chunk_global_cumsum_vector(s, dtype, reverse, offsets, head_first) else: raise ValueError(f"Unsupported input shape {s.shape}. " f"which should be [B, H, T]/[B, H, T, D] if `head_first=True` " f"or [B, T, H]/[B, T, H, D] otherwise")