# -*- coding: utf-8 -*- from typing import Optional, Tuple import torch import triton import triton.language as tl from einops import rearrange from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous # Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009 @triton.autotune( configs=[ triton.Config({}, num_warps=1), triton.Config({}, num_warps=2), triton.Config({}, num_warps=4), triton.Config({}, num_warps=8), triton.Config({}, num_warps=16) ], key=["BK"] ) @triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def fwd_prepare_wy_repr_kernel_chunk32( k, beta, A, offsets, indices, T: tl.constexpr, H: tl.constexpr, K: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BC: tl.constexpr, USE_OFFSETS: tl.constexpr, HEAD_FIRST: tl.constexpr, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T if HEAD_FIRST: p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) else: p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) b_beta = tl.load(p_beta, boundary_check=(0,)) b_A = tl.zeros([BT, BT], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): if HEAD_FIRST: p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) else: p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) b_A += tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0) for i in range(1, BT): mask = tl.arange(0, BT) == i b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i) b_A = tl.where(mask[:, None], b_a, b_A) b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :] if HEAD_FIRST: p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) else: p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) tl.store(p_A, (b_A).to(p_A.dtype.element_ty), boundary_check=(0, 1)) b_A = b_A.to(k.dtype.element_ty) @triton.autotune( configs=[ triton.Config({}, num_warps=1), triton.Config({}, num_warps=2), triton.Config({}, num_warps=4), triton.Config({}, num_warps=8), triton.Config({}, num_warps=16) ], key=["BK"], ) @triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def fwd_prepare_wy_repr_kernel_chunk64( k, beta, A, offsets, indices, T: tl.constexpr, H: tl.constexpr, K: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BC: tl.constexpr, USE_OFFSETS: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T b_A = tl.zeros([BC, BC], dtype=tl.float32) b_A2 = tl.zeros([BC, BC], dtype=tl.float32) b_A3 = tl.zeros([BC, BC], dtype=tl.float32) if HEAD_FIRST: p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BC,), (0,)) else: p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BC,), (0,)) b_beta = tl.load(p_beta, boundary_check=(0,)) if HEAD_FIRST: p_beta2 = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT + BC,), (BC,), (0,)) else: p_beta2 = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT + BC,), (BC,), (0,)) b_beta2 = tl.load(p_beta2, boundary_check=(0,)) for i_k in range(tl.cdiv(K, BK)): if HEAD_FIRST: p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) p_k2 = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0)) else: p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) p_k2 = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) b_k2 = tl.load(p_k2, boundary_check=(0, 1)) b_kb2 = (b_k2 * b_beta2[:, None]).to(b_k2.dtype) b_A += tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) b_A2 += tl.dot(b_kb2, tl.trans(b_k2), allow_tf32=False) b_A3 += tl.dot(b_kb2, tl.trans(b_k), allow_tf32=False) b_A = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0) b_A2 = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0) for i in range(1, BC): mask = tl.arange(0, BC) == i b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0) b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i) b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i) b_A = tl.where(mask[:, None], b_a, b_A) b_A2 = tl.where(mask[:, None], b_a2, b_A2) # blockwise computation of lower triangular matrix's inverse # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1] b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] b_A3 = -tl.dot(tl.dot(b_A2, b_A3, allow_tf32=False), b_A, allow_tf32=False) if HEAD_FIRST: p_A1 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) p_A2 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) p_A3 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) p_A4 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) else: p_A1 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) p_A2 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) p_A3 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) p_A4 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) tl.store(p_A1, b_A.to(p_A1.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_A2, b_A2.to(p_A2.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_A3, b_A3.to(p_A3.dtype.element_ty), boundary_check=(0, 1)) # causal mask tl.store(p_A4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A4.dtype.element_ty), boundary_check=(0, 1)) @triton.autotune( configs=[ triton.Config({}, num_warps=1), triton.Config({}, num_warps=2), triton.Config({}, num_warps=4), triton.Config({}, num_warps=8) ], key=["BT", "BK", "BV"], ) @triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def fwd_recompute_w_u_kernel( k, v, beta, w, u, A, offsets, indices, T: tl.constexpr, H: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, USE_OFFSETS: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T if HEAD_FIRST: p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) else: p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) b_beta = tl.load(p_beta, boundary_check=(0,)) b_A = tl.load(p_A, boundary_check=(0, 1)) for i_v in range(tl.cdiv(V, BV)): if HEAD_FIRST: p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) else: p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) b_v = tl.load(p_v, boundary_check=(0, 1)) b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) b_u = tl.dot(b_A.to(b_vb.dtype), b_vb, allow_tf32=False) tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) for i_k in range(tl.cdiv(K, BK)): if HEAD_FIRST: p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) else: p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) b_w = tl.dot(b_A.to(b_kb.dtype), b_kb, allow_tf32=False) tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) @triton.autotune( configs=[ triton.Config({}, num_warps=1), triton.Config({}, num_warps=2), triton.Config({}, num_warps=4), triton.Config({}, num_warps=8) ], key=["BT", "BK"], ) @triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def fwd_recompute_w_kernel( k, beta, w, A, offsets, indices, T: tl.constexpr, H: tl.constexpr, K: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, USE_OFFSETS: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T if HEAD_FIRST: p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) else: p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) b_beta = tl.load(p_beta, boundary_check=(0,)) b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) for i_k in range(tl.cdiv(K, BK)): if HEAD_FIRST: p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) else: p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) b_w = tl.dot(b_A, b_kb, allow_tf32=False) tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) @triton.autotune( configs=[ triton.Config({}, num_warps=1), triton.Config({}, num_warps=2), triton.Config({}, num_warps=4), triton.Config({}, num_warps=8), triton.Config({}, num_warps=16) ], key=["BT", "BK", "BV"], ) @triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None}) @triton.jit def bwd_prepare_wy_repr_kernel( k, v, beta, A, dw, du, dk, dv, dbeta, offsets, indices, T: tl.constexpr, H: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, USE_OFFSETS: tl.constexpr, HEAD_FIRST: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T if HEAD_FIRST: p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) else: p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) b_beta = tl.load(p_beta, boundary_check=(0,)) b_A = tl.load(p_A, boundary_check=(0, 1)) b_dbeta = tl.zeros([BT], dtype=tl.float32) b_dA = tl.zeros([BT, BT], dtype=tl.float32) for i_v in range(tl.cdiv(V, BV)): if HEAD_FIRST: p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) else: p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) b_v = tl.load(p_v, boundary_check=(0, 1)) b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype) b_du = tl.load(p_du, boundary_check=(0, 1)) b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False) b_dv_beta = tl.dot(b_A, b_du, allow_tf32=False) b_dv = b_dv_beta * b_beta[:, None] b_dbeta += tl.sum(b_dv_beta * b_v, 1) tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) for i_k in range(tl.cdiv(K, BK)): if HEAD_FIRST: p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) else: p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dk = tl.make_block_ptr(dk + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dw = tl.make_block_ptr(dw + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) b_dw = tl.load(p_dw, boundary_check=(0, 1)) b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) b_dk_beta = tl.dot(b_A, b_dw, allow_tf32=False) b_dk = b_dk_beta * b_beta[:, None] b_dbeta += tl.sum(b_dk_beta * b_k, 1) tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0) b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty) for i_k in range(tl.cdiv(K, BK)): if HEAD_FIRST: p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) else: p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dk = tl.make_block_ptr(dk + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_dk = tl.load(p_dk, boundary_check=(0, 1)) b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False) b_dbeta += tl.sum(b_dk_beta * b_k, 1) b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False) b_dk += b_dk_beta * b_beta[:, None] tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) if HEAD_FIRST: p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) else: p_dbeta = tl.make_block_ptr(dbeta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) def fwd_prepare_wy_repr( k: torch.Tensor, v: torch.Tensor, beta: torch.Tensor, offsets: Optional[torch.LongTensor], indices: Optional[torch.LongTensor], head_first: bool = True, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if head_first: B, H, T, K = k.shape else: B, T, H, K = k.shape BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) if offsets is None: NT = triton.cdiv(T, BT) else: if indices is None: indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()]) indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) NT = len(indices) BC = min(BT, 32) BK = min(triton.next_power_of_2(K), 64) u = torch.empty_like(v) w = torch.empty_like(k) A = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=k.dtype) fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32 fwd_fn[(NT, B * H)]( k=k, beta=beta, A=A, offsets=offsets, indices=indices, T=T, H=H, K=K, BT=BT, BK=BK, BC=BC, HEAD_FIRST=head_first ) w, u = fwd_recompute_w_u( k=k, v=v, beta=beta, A=A, offsets=offsets, indices=indices, head_first=head_first, chunk_size=chunk_size ) return w, u, A def fwd_prepare_T( k: torch.Tensor, beta: torch.Tensor, offsets: Optional[torch.LongTensor], indices: Optional[torch.LongTensor], head_first: bool, chunk_size: int ) -> Tuple[torch.Tensor, torch.Tensor]: if head_first: B, H, T, K = k.shape else: B, T, H, K = k.shape BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) assert BT in [16, 32, 64] if offsets is None: NT = triton.cdiv(T, BT) else: if indices is None: indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()]) indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) NT = len(indices) BC = min(BT, 32) BK = min(triton.next_power_of_2(K), 64) A = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=k.dtype) fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32 fwd_fn[(NT, B * H)]( k=k, beta=beta, A=A, offsets=offsets, indices=indices, T=T, H=H, K=K, BT=BT, BK=BK, BC=BC, HEAD_FIRST=head_first ) return A def fwd_recompute_w_u( k: torch.Tensor, v: torch.Tensor, beta: torch.Tensor, A: torch.Tensor, offsets: Optional[torch.LongTensor], indices: Optional[torch.LongTensor], head_first: bool, chunk_size: int ) -> Tuple[torch.Tensor, torch.Tensor]: if head_first: B, H, T, K, V = *k.shape, v.shape[-1] else: B, T, H, K, V = *k.shape, v.shape[-1] BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) if offsets is None: NT = triton.cdiv(T, BT) else: if indices is None: indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()]) indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) NT = len(indices) BK = min(triton.next_power_of_2(K), 64) BV = min(triton.next_power_of_2(V), 64) u = torch.empty_like(v) w = torch.empty_like(k) fwd_recompute_w_u_kernel[(NT, B*H)]( k, v, beta, w, u, A, offsets=offsets, indices=indices, T=T, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV, HEAD_FIRST=head_first ) return w, u def fwd_recompute_w( k: torch.Tensor, beta: torch.Tensor, A: torch.Tensor, offsets: Optional[torch.LongTensor], indices: Optional[torch.LongTensor], head_first: bool, chunk_size: int ) -> torch.Tensor: if head_first: B, H, T, K = k.shape else: B, T, H, K = k.shape BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) if offsets is None: NT = triton.cdiv(T, BT) else: if indices is None: indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()]) indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) NT = len(indices) BK = min(triton.next_power_of_2(K), 64) w = torch.empty_like(k) fwd_recompute_w_kernel[(NT, B*H)]( k, beta, w, A, T=T, H=H, K=K, BT=BT, BK=BK, HEAD_FIRST=head_first ) return w def bwd_prepare_wy_repr( k: torch.Tensor, v: torch.Tensor, beta: torch.Tensor, A: torch.Tensor, dw: torch.Tensor, du: torch.Tensor, offsets: Optional[torch.LongTensor], indices: Optional[torch.LongTensor], head_first: bool, chunk_size: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if head_first: B, H, T, K, V = *k.shape, v.shape[-1] else: B, T, H, K, V = *k.shape, v.shape[-1] BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) if offsets is None: NT = triton.cdiv(T, BT) else: if indices is None: indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()]) indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) NT = len(indices) BK = min(triton.next_power_of_2(K), 64) BV = min(triton.next_power_of_2(V), 64) dk = torch.empty_like(k) dv = torch.empty_like(v) dbeta = torch.empty_like(beta) bwd_prepare_wy_repr_kernel[(NT, B * H)]( k, v, beta, A, dw, du, dk, dv, dbeta, offsets=offsets, indices=indices, T=T, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV, HEAD_FIRST=head_first ) return dk, dv, dbeta class WYRepresentationPrepration(torch.autograd.Function): @staticmethod @contiguous @autocast_custom_fwd def forward( ctx, k: torch.Tensor, v: torch.Tensor, beta: torch.Tensor, offsets: Optional[torch.LongTensor] = None, head_first: bool = True, chunk_size: int = 64 ): assert chunk_size in [16, 32, 64] # 2-d indices denoting the offsets of chunks in each sequence # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] indices = None if offsets is not None: indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()]) indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) w, u, A = fwd_prepare_wy_repr( k=k, v=v, beta=beta, offsets=offsets, indices=indices, head_first=head_first, chunk_size=chunk_size ) ctx.offsets = offsets ctx.indices = indices ctx.head_first = head_first ctx.chunk_size = chunk_size ctx.save_for_backward(k, v, beta, A) return w, u @staticmethod @contiguous @autocast_custom_bwd def backward( ctx, dw: torch.Tensor, du: torch.Tensor ): k, v, beta, A = ctx.saved_tensors dk, dv, dbeta = bwd_prepare_wy_repr( k=k, v=v, beta=beta, A=A, dw=dw, du=du, offsets=ctx.offsets, indices=ctx.indices, head_first=ctx.head_first, chunk_size=ctx.chunk_size ) return dk, dv, dbeta, None, None, None prepare_wy_repr = WYRepresentationPrepration.apply def naive(k, v, beta, chunk_size): l_org = k.shape[2] l_new = triton.next_power_of_2(l_org) # pad k, v, beta k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2) v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2) beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2) k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v)) # k = torch.nn.functional.normalize(k, dim=-1, p=2) beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size) mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0) k_beta = k * beta[..., None] v = v * beta[..., None] attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0) attn = attn * beta[..., None] x = attn @ v o = torch.zeros_like(k) o2 = torch.zeros_like(v) o[..., 0, :] = k_beta[..., 0, :].clone() o2[..., 0, :] = x[..., 0, :].clone() for i in range(1, chunk_size): o_i = (o[..., :i, :]).clone() o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :] o2_i = (o2[..., :i, :]).clone() o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :] return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2)) if __name__ == "__main__": torch.set_default_dtype(torch.bfloat16) seq_len = 1024 b = 4 h = 4 k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2) v = torch.randn(b, h, seq_len, 128) beta = torch.rand(b, h, seq_len).sigmoid() # beta = torch.ones(b, h, seq_len) require_grad = True k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad), (k, v, beta)) do = torch.rand_like(k) do2 = torch.rand_like(v) o1, o2 = naive(k.clone(), v.clone(), beta.clone(), 64) if require_grad: o1.backward(do, retain_graph=True) o2.backward(do2, retain_graph=True) k_grad2, v_grad2, beta_grad2 = k.grad, v.grad, beta.grad k.grad = v.grad = beta.grad = None o3, o4 = prepare_wy_repr(k.clone(), v.clone(), beta.clone(), 64) print((o1-o3).abs().max()) print((o2-o4).abs().max()) if require_grad: o3.backward(do, retain_graph=True) o4.backward(do2, retain_graph=True) k_grad, v_grad, beta_grad = k.grad, v.grad, beta.grad print((k_grad2-k_grad).abs().max()) print((v_grad2-v_grad).abs().max()) print((beta_grad2-beta_grad).abs().max())