|
|
|
|
|
|
|
from typing import Optional, Tuple |
|
|
|
import math |
|
import torch |
|
import triton |
|
import triton.language as tl |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({'BLOCK_SIZE_S': 16, 'BLOCK_SIZE_W': 16}, num_warps=2), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
], |
|
key=[] |
|
) |
|
@triton.jit |
|
def afak_fwd_kernel( |
|
q_ptr, k_ptr, states_ptr, y_ptr, |
|
B: tl.constexpr, T: tl.constexpr, S:tl.constexpr, C: tl.constexpr, W: tl.constexpr, |
|
BLOCK_SIZE_S: tl.constexpr, |
|
BLOCK_SIZE_W: tl.constexpr, |
|
): |
|
|
|
b_id = tl.program_id(axis=0) |
|
t_id = tl.program_id(axis=1) |
|
sw_block_id = tl.program_id(axis=2) |
|
num_s_blocks = triton.cdiv(S, BLOCK_SIZE_S) |
|
num_w_blocks = triton.cdiv(W, BLOCK_SIZE_W) |
|
SW = S + W |
|
|
|
|
|
q_base = q_ptr + b_id * T * C |
|
k_base = k_ptr + b_id * T * C |
|
states_base = states_ptr + b_id * T * S * C |
|
y_base = y_ptr + b_id * T * W |
|
|
|
|
|
q_block_ptr = tl.make_block_ptr( |
|
base=q_ptr, |
|
shape=(B, T, C), |
|
strides=(T * C, C, 1), |
|
offsets=(b_id, t_id, 0), |
|
block_shape=(1, 1, C), |
|
order=(0, 1, 2), |
|
) |
|
q = tl.load(q_block_ptr) |
|
|
|
if sw_block_id < num_s_blocks: |
|
s_first_id = sw_block_id * BLOCK_SIZE_S |
|
|
|
s_block_ptr = tl.make_block_ptr( |
|
base=states_ptr, |
|
shape=(B, T, S, C), |
|
strides=(T * S * C, S * C, C, 1), |
|
offsets=(b_id, t_id, s_first_id, 0), |
|
block_shape=(1, 1, BLOCK_SIZE_S, C), |
|
order=(0, 1, 2, 3), |
|
) |
|
s = tl.load(s_block_ptr) |
|
o = q[:, :, None, :] * s |
|
o = tl.sum(o, axis=-1) |
|
|
|
y_block_ptr = tl.make_block_ptr( |
|
base=y_ptr, |
|
shape=(B, T, SW), |
|
strides=(T * SW, SW, 1), |
|
offsets=(b_id, t_id, s_first_id), |
|
block_shape=(1, 1, BLOCK_SIZE_S), |
|
order=(0, 1, 2), |
|
) |
|
tl.store(y_block_ptr, o.to(y_block_ptr.dtype.element_ty)) |
|
else: |
|
w_first_id = (sw_block_id - num_s_blocks) * BLOCK_SIZE_W |
|
|
|
|
|
tw_offs = tl.arange(0, BLOCK_SIZE_W) |
|
c_offs = tl.arange(0, C) |
|
k_block_ptr = k_base + (t_id - W + 1 + (w_first_id + tw_offs[:, None])) * C + c_offs[None, :] |
|
mask = w_first_id + tl.arange(0, BLOCK_SIZE_W)[:, None] > (W - t_id - 2) |
|
k = tl.load(k_block_ptr, mask=mask) |
|
|
|
y = q * k[None, :] |
|
y = tl.sum(y, axis=-1) |
|
|
|
y_block_ptr = tl.make_block_ptr( |
|
base=y_ptr, |
|
shape=(B, T, SW), |
|
strides=(T * SW, SW, 1), |
|
offsets=(b_id, t_id, S + w_first_id), |
|
block_shape=(1, 1, BLOCK_SIZE_W), |
|
order=(0, 1, 2), |
|
) |
|
tl.store(y_block_ptr, y[None, :].to(y_block_ptr.dtype.element_ty)) |
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({ |
|
'BLOCK_SIZE_C': bs_c, |
|
}, num_warps=warps) |
|
for bs_c in [16] |
|
for warps in [2] |
|
], |
|
key=[] |
|
) |
|
@triton.jit |
|
def afak_bwd_kernel( |
|
q_ptr, k_ptr, states_ptr, dy_ptr, dq_ptr, dk_ptr, ds_ptr, |
|
B: tl.constexpr, T: tl.constexpr, S: tl.constexpr, C: tl.constexpr, W: tl.constexpr, |
|
BLOCK_SIZE_C: tl.constexpr, |
|
): |
|
|
|
b_id = tl.program_id(axis=0) |
|
t_id = tl.program_id(axis=1) |
|
c_block_id = tl.program_id(axis=2) |
|
c_first_id = c_block_id * BLOCK_SIZE_C |
|
SW = S + W |
|
|
|
|
|
q_base = q_ptr + b_id * T * C |
|
k_base = k_ptr + b_id * T * C |
|
dy_base = dy_ptr + b_id * T * SW |
|
dq_base = dq_ptr + b_id * T * C |
|
dk_base = dk_ptr + b_id * T * C |
|
|
|
|
|
|
|
|
|
tw_offs = tl.arange(0, W) |
|
c_offs = tl.arange(0, BLOCK_SIZE_C) |
|
k_block_ptr = k_base + (t_id - W + 1 + tw_offs[:, None]) * C + c_first_id + c_offs[None, :] |
|
mask = tl.arange(0, W)[:, None] > (W - t_id - 2) |
|
k = tl.load(k_block_ptr, mask=mask) |
|
|
|
dy_block_ptr = tl.make_block_ptr( |
|
base=dy_ptr, |
|
shape=(B, T, SW), |
|
strides=(T * SW, SW, 1), |
|
offsets=(b_id, t_id, S), |
|
block_shape=(1, 1, W), |
|
order=(0, 1, 2), |
|
) |
|
dy = tl.load(dy_block_ptr) |
|
|
|
dqk = dy.permute(0, 2, 1) * k[None, :] |
|
dqk = tl.sum(dqk, axis=1) |
|
|
|
|
|
s_block_ptr = tl.make_block_ptr( |
|
base=states_ptr, |
|
shape=(B, T, S, C), |
|
strides=(T * S * C, S * C, C, 1), |
|
offsets=(b_id, t_id, 0, c_first_id), |
|
block_shape=(1, 1, S, BLOCK_SIZE_C), |
|
order=(0, 1, 2, 3), |
|
) |
|
s = tl.load(s_block_ptr) |
|
|
|
dy_block_ptr = tl.make_block_ptr( |
|
base=dy_ptr, |
|
shape=(B, T, SW), |
|
strides=(T * SW, SW, 1), |
|
offsets=(b_id, t_id, 0), |
|
block_shape=(1, 1, S), |
|
order=(0, 1, 2), |
|
) |
|
dy = tl.load(dy_block_ptr) |
|
|
|
dqs = dy[:, :, :, None] * s |
|
dqs = tl.sum(dqs, axis=2) |
|
dq = dqk[None, :] + dqs |
|
|
|
dq_block_ptr = tl.make_block_ptr( |
|
base=dq_ptr, |
|
shape=(B, T, C), |
|
strides=(T * C, C, 1), |
|
offsets=(b_id, t_id, c_first_id), |
|
block_shape=(1, 1, BLOCK_SIZE_C), |
|
order=(0, 1, 2), |
|
) |
|
tl.store(dq_block_ptr, dq.to(dq_block_ptr.dtype.element_ty)) |
|
|
|
|
|
|
|
q_block_ptr = tl.make_block_ptr( |
|
base=q_ptr, |
|
shape=(B, T, C), |
|
strides=(T * C, C, 1), |
|
offsets=(b_id, t_id, c_first_id), |
|
block_shape=(1, 1, BLOCK_SIZE_C), |
|
order=(0, 1, 2), |
|
) |
|
q = tl.load(q_block_ptr) |
|
|
|
ds = dy[:, :, :, None] * q[:, :, None, :] |
|
|
|
ds_block_ptr = tl.make_block_ptr( |
|
base=ds_ptr, |
|
shape=(B, T, S, C), |
|
strides=(T * S * C, S * C, C, 1), |
|
offsets=(b_id, t_id, 0, c_first_id), |
|
block_shape=(1, 1, S, BLOCK_SIZE_C), |
|
order=(0, 1, 2, 3), |
|
) |
|
tl.store(ds_block_ptr, ds.to(ds_block_ptr.dtype.element_ty)) |
|
|
|
|
|
|
|
tw_offs = tl.arange(0, W) |
|
c_offs = tl.arange(0, BLOCK_SIZE_C) |
|
q_block_ptr = q_base + (t_id + tw_offs[:, None]) * C + c_first_id + c_offs[None, :] |
|
mask = tl.arange(0, W)[:, None] < T - t_id |
|
q = tl.load(q_block_ptr, mask=mask) |
|
|
|
|
|
|
|
w_offs = tl.arange(0, W) |
|
diag_dy_base = dy_base + t_id * SW + S + tl.flip(w_offs, 0) |
|
dy_block_ptr = diag_dy_base + w_offs * SW |
|
mask = tl.arange(0, W) < T - t_id |
|
dy = tl.load(dy_block_ptr, mask=mask) |
|
|
|
dk = dy.reshape(W, 1) * q |
|
dk = tl.sum(dk, axis=0) |
|
|
|
dk_block_ptr = tl.make_block_ptr( |
|
base=dk_ptr, |
|
shape=(B, T, C), |
|
strides=(T * C, C, 1), |
|
offsets=(b_id, t_id, c_first_id), |
|
block_shape=(1, 1, BLOCK_SIZE_C), |
|
order=(0, 1, 2), |
|
) |
|
tl.store(dk_block_ptr, dk.reshape(1, 1, BLOCK_SIZE_C).to(dk_block_ptr.dtype.element_ty)) |
|
|
|
class AttendFoldedAllKeysTriton(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx, q, k, states, W): |
|
B, T, C = q.shape |
|
B, T, S, C = states.shape |
|
q = q.contiguous() |
|
k = k.contiguous() |
|
states = states.contiguous() |
|
ctx.save_for_backward(q, k, states) |
|
ctx.W = W |
|
|
|
|
|
grid = lambda meta: (B, T, triton.cdiv(S, meta['BLOCK_SIZE_S']) + triton.cdiv(W, meta['BLOCK_SIZE_W'])) |
|
|
|
|
|
y = torch.zeros((B, T, S+W), dtype=q.dtype, device=q.device).contiguous() |
|
|
|
|
|
afak_fwd_kernel[grid]( |
|
q, k, states, y, |
|
B, T, S, C, W, |
|
) |
|
|
|
return y |
|
|
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
grad_output = grad_output.contiguous() |
|
q, k, states = ctx.saved_tensors |
|
B, T, S, C = states.shape |
|
W = ctx.W |
|
|
|
|
|
grid = lambda meta: (B, T, triton.cdiv(C, meta['BLOCK_SIZE_C'])) |
|
|
|
gq = torch.zeros_like(q).contiguous() |
|
gk = torch.zeros_like(k).contiguous() |
|
gs = torch.zeros_like(states).contiguous() |
|
|
|
|
|
afak_bwd_kernel[grid]( |
|
q, k, states, grad_output, gq, gk, gs, |
|
B, T, S, C, W |
|
) |
|
|
|
return gq, gk, gs, None |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({ |
|
'BLOCK_SIZE_C': bs_c, |
|
}, num_warps=warps) |
|
for bs_c in [16] |
|
for warps in [2] |
|
], |
|
key=[] |
|
) |
|
@triton.jit |
|
def afav_fwd_kernel( |
|
s_ptr, v_ptr, states_ptr, y_ptr, |
|
B: tl.constexpr, T: tl.constexpr, S: tl.constexpr, C: tl.constexpr, W: tl.constexpr, |
|
BLOCK_SIZE_C: tl.constexpr, |
|
): |
|
|
|
b_id = tl.program_id(axis=0) |
|
t_id = tl.program_id(axis=1) |
|
c_block_id = tl.program_id(axis=2) |
|
c_first_id = c_block_id * BLOCK_SIZE_C |
|
SW = S + W |
|
|
|
|
|
s_base = s_ptr + b_id * T * W |
|
v_base = v_ptr + b_id * T * C |
|
y_base = y_ptr + b_id * T * C |
|
|
|
|
|
|
|
sv_block_ptr = tl.make_block_ptr( |
|
base=s_ptr, |
|
shape=(B, T, SW), |
|
strides=(T * SW, SW, 1), |
|
offsets=(b_id, t_id, S), |
|
block_shape=(1, 1, W), |
|
order=(0, 1, 2), |
|
) |
|
sv = tl.load(sv_block_ptr) |
|
|
|
|
|
tw_offs = tl.arange(0, W) |
|
c_offs = tl.arange(0, BLOCK_SIZE_C) |
|
v_block_ptr = v_base + (t_id - W + 1 + tw_offs[:, None]) * C + c_first_id + c_offs[None, :] |
|
mask = tl.arange(0, W)[:, None] > (W - t_id - 2) |
|
v = tl.load(v_block_ptr, mask=mask) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ss_block_ptr = tl.make_block_ptr( |
|
base=s_ptr, |
|
shape=(B, T, SW), |
|
strides=(T * SW, SW, 1), |
|
offsets=(b_id, t_id, 0), |
|
block_shape=(1, 1, S), |
|
order=(0, 1, 2), |
|
) |
|
ss = tl.load(ss_block_ptr) |
|
|
|
states_block_ptr = tl.make_block_ptr( |
|
base=states_ptr, |
|
shape=(B, T, S, C), |
|
strides=(T * S * C, S * C, C, 1), |
|
offsets=(b_id, t_id, 0, c_first_id), |
|
block_shape=(1, 1, S, BLOCK_SIZE_C), |
|
order=(0, 1, 2, 3), |
|
) |
|
states = tl.load(states_block_ptr) |
|
|
|
y = tl.sum(sv.permute(0, 2, 1) * v[None, :], axis=1) + tl.sum(ss[:, :, :, None] * states, axis=2).reshape(1, BLOCK_SIZE_C) |
|
|
|
|
|
y_block_ptr = tl.make_block_ptr( |
|
base=y_ptr, |
|
shape=(B, T, C), |
|
strides=(T * C, C, 1), |
|
offsets=(b_id, t_id, c_first_id), |
|
block_shape=(1, 1, BLOCK_SIZE_C), |
|
order=(0, 1, 2), |
|
) |
|
tl.store(y_block_ptr, y[None, :].to(y_block_ptr.dtype.element_ty)) |
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({ |
|
'BLOCK_SIZE_S': bs_s, |
|
'BLOCK_SIZE_W': bs_w |
|
}, num_warps=warps) |
|
for bs_s in [16] |
|
for bs_w in [16] |
|
for warps in [2] |
|
], |
|
key=[] |
|
) |
|
@triton.jit |
|
def afav_bwd_kernel( |
|
s_ptr, v_ptr, states_ptr, dy_ptr, ds_ptr, dv_ptr, dstates_ptr, |
|
B: tl.constexpr, T: tl.constexpr, S:tl.constexpr, C: tl.constexpr, W: tl.constexpr, |
|
BLOCK_SIZE_S: tl.constexpr, |
|
BLOCK_SIZE_W: tl.constexpr, |
|
): |
|
|
|
b_id = tl.program_id(axis=0) |
|
t_id = tl.program_id(axis=1) |
|
sw_block_id = tl.program_id(axis=2) |
|
num_s_blocks = triton.cdiv(S, BLOCK_SIZE_S) |
|
num_w_blocks = triton.cdiv(W, BLOCK_SIZE_W) |
|
is_state = sw_block_id < num_s_blocks |
|
SW = S + W |
|
|
|
|
|
s_base = s_ptr + b_id * T * C |
|
v_base = v_ptr + b_id * T * C |
|
dy_base = dy_ptr + b_id * T * W |
|
ds_base = ds_ptr + b_id * T * C |
|
dv_base = dv_ptr + b_id * (T+W-1) * C + (W-1) * C |
|
|
|
if not is_state: |
|
|
|
w_first_id = (sw_block_id - num_s_blocks) * BLOCK_SIZE_W |
|
|
|
|
|
dy_block_ptr = tl.make_block_ptr( |
|
base=dy_ptr, |
|
shape=(B, T, C), |
|
strides=(T * C, C, 1), |
|
offsets=(b_id, t_id, 0), |
|
block_shape=(1, 1, C), |
|
order=(0, 1, 2), |
|
) |
|
dy = tl.load(dy_block_ptr) |
|
|
|
s_block_ptr = tl.make_block_ptr( |
|
base=s_ptr, |
|
shape=(B, T, SW), |
|
strides=(T * SW, SW, 1), |
|
offsets=(b_id, t_id, S+w_first_id), |
|
block_shape=(1, 1, BLOCK_SIZE_W), |
|
order=(0, 1, 2), |
|
) |
|
s = tl.load(s_block_ptr) |
|
|
|
|
|
tw_offs = tl.arange(0, BLOCK_SIZE_W) |
|
c_offs = tl.arange(0, C) |
|
v_block_ptr = v_base + (t_id - W + 1 + (w_first_id + tw_offs[:, None])) * C + c_offs[None, :] |
|
mask = w_first_id + tl.arange(0, BLOCK_SIZE_W)[:, None] > (W - t_id - 2) |
|
v = tl.load(v_block_ptr, mask=mask) |
|
|
|
|
|
|
|
dv = dy * s.reshape(1, BLOCK_SIZE_W, 1) |
|
|
|
|
|
dsv = dy * v[None, :] |
|
dsv = tl.sum(dsv, axis=-1) |
|
|
|
|
|
dsv_block_ptr = tl.make_block_ptr( |
|
base=ds_ptr, |
|
shape=(B, T, SW), |
|
strides=(T * SW, SW, 1), |
|
offsets=(b_id, t_id, S+w_first_id), |
|
block_shape=(1, 1, BLOCK_SIZE_W), |
|
order=(0, 1, 2), |
|
) |
|
tl.store(dsv_block_ptr, dsv[None, :].to(dsv_block_ptr.dtype.element_ty)) |
|
|
|
|
|
|
|
tw_offs = tl.arange(0, BLOCK_SIZE_W) |
|
c_offs = tl.arange(0, C) |
|
dv_block_ptr = dv_base + (t_id - W + 1 + (w_first_id + tw_offs[:, None])) * C + c_offs[None, :] |
|
mask = w_first_id + tl.arange(0, BLOCK_SIZE_W)[:, None] > (W - t_id - 2) |
|
|
|
tl.atomic_add(dv_block_ptr[None, :], dv) |
|
else: |
|
s_first_id = sw_block_id * BLOCK_SIZE_S |
|
|
|
|
|
|
|
states_block_ptr = tl.make_block_ptr( |
|
base=states_ptr, |
|
shape=(B, T, S, C), |
|
strides=(T * S * C, S * C, C, 1), |
|
offsets=(b_id, t_id, s_first_id, 0), |
|
block_shape=(1, 1, BLOCK_SIZE_S, C), |
|
order=(0, 1, 2, 3), |
|
) |
|
states = tl.load(states_block_ptr) |
|
|
|
dy_block_ptr = tl.make_block_ptr( |
|
base=dy_ptr, |
|
shape=(B, T, C), |
|
strides=(T * C, C, 1), |
|
offsets=(b_id, t_id, 0), |
|
block_shape=(1, 1, C), |
|
order=(0, 1, 2), |
|
) |
|
dy = tl.load(dy_block_ptr) |
|
|
|
ss_block_ptr = tl.make_block_ptr( |
|
base=s_ptr, |
|
shape=(B, T, SW), |
|
strides=(T * SW, SW, 1), |
|
offsets=(b_id, t_id, s_first_id), |
|
block_shape=(1, 1, BLOCK_SIZE_S), |
|
order=(0, 1, 2), |
|
) |
|
ss = tl.load(ss_block_ptr) |
|
|
|
|
|
dss = dy[:, :, None, :] * states |
|
dss = tl.sum(dss, axis=-1) |
|
|
|
|
|
dstates = dy[:, :, None, :] * ss[:, :, :, None] |
|
|
|
|
|
dss_block_ptr = tl.make_block_ptr( |
|
base=ds_ptr, |
|
shape=(B, T, SW), |
|
strides=(T * SW, SW, 1), |
|
offsets=(b_id, t_id, s_first_id), |
|
block_shape=(1, 1, BLOCK_SIZE_S), |
|
order=(0, 1, 2), |
|
) |
|
tl.store(dss_block_ptr, dss.to(dss_block_ptr.dtype.element_ty)) |
|
|
|
|
|
dstates_block_ptr = tl.make_block_ptr( |
|
base=dstates_ptr, |
|
shape=(B, T, S, C), |
|
strides=(T * S * C, S * C, C, 1), |
|
offsets=(b_id, t_id, s_first_id, 0), |
|
block_shape=(1, 1, BLOCK_SIZE_S, C), |
|
order=(0, 1, 2, 3), |
|
) |
|
tl.store(dstates_block_ptr, dstates.to(dstates_block_ptr.dtype.element_ty)) |
|
|
|
class AccumulateFoldedAllValuesTriton(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx, s, v, states, W): |
|
B, T, S, C = states.shape |
|
s = s.contiguous() |
|
v = v.contiguous() |
|
states = states.contiguous() |
|
ctx.save_for_backward(s, v, states) |
|
ctx.W = W |
|
|
|
|
|
grid = lambda meta: (B, T, triton.cdiv(C, meta['BLOCK_SIZE_C'])) |
|
|
|
|
|
y = torch.zeros((B, T, C), dtype=v.dtype, device=v.device).contiguous() |
|
|
|
|
|
afav_fwd_kernel[grid]( |
|
s, v, states, y, |
|
B, T, S, C, W, |
|
) |
|
|
|
return y |
|
|
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
grad_output = grad_output.contiguous() |
|
s, v, states = ctx.saved_tensors |
|
B, T, S, C = states.shape |
|
W = ctx.W |
|
|
|
|
|
grid = lambda meta: (B, T, triton.cdiv(S, meta['BLOCK_SIZE_S']) + triton.cdiv(W, meta['BLOCK_SIZE_W'])) |
|
|
|
gs = torch.zeros_like(s).contiguous() |
|
|
|
gv = torch.zeros((B, T+W-1, C), device=v.device).contiguous() |
|
gst = torch.zeros_like(states).contiguous() |
|
|
|
|
|
afav_bwd_kernel[grid]( |
|
s, v, states, grad_output, gs, gv, gst, |
|
B, T, S, C, W, |
|
) |
|
|
|
|
|
return gs, gv[:, W-1:].to(s.dtype), gst, None |
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({ |
|
'BLOCK_SIZE': bs, |
|
'BLOCK_SIZE_S': bs_s, |
|
'BLOCK_SIZE_C': bs_c |
|
}, num_warps=warps) |
|
for bs in [16] |
|
for bs_s in [16] |
|
for bs_c in [16] |
|
for warps in [2] |
|
], |
|
key=[] |
|
) |
|
@triton.jit |
|
def cg2d_fwd_kernel( |
|
xg_ptr, gi_ptr, |
|
B: tl.constexpr, S: tl.constexpr, C: tl.constexpr, T: tl.constexpr, nstages: tl.constexpr, |
|
BLOCK_SIZE: tl.constexpr, |
|
|
|
BLOCK_SIZE_S: tl.constexpr, |
|
BLOCK_SIZE_C: tl.constexpr, |
|
): |
|
|
|
pid = tl.program_id(axis=0) |
|
|
|
num_s_blocks = tl.cdiv(S, BLOCK_SIZE_S) |
|
num_c_blocks = tl.cdiv(C, BLOCK_SIZE_C) |
|
b = pid // (num_s_blocks * num_c_blocks) |
|
rem = pid % (num_s_blocks * num_c_blocks) |
|
s_block = rem // num_c_blocks |
|
c_block = rem % num_c_blocks |
|
|
|
|
|
s_offs = tl.arange(0, BLOCK_SIZE_S) |
|
c_offs = tl.arange(0, BLOCK_SIZE_C) |
|
s_mask = s_offs < (S - s_block * BLOCK_SIZE_S) |
|
c_mask = c_offs < (C - c_block * BLOCK_SIZE_C) |
|
s_offs = s_block * BLOCK_SIZE_S + s_offs |
|
c_offs = c_block * BLOCK_SIZE_C + c_offs |
|
|
|
|
|
xg_base = xg_ptr + b * T * S * C |
|
gi_base = gi_ptr + b * T * S |
|
|
|
|
|
|
|
offs = tl.arange(0, BLOCK_SIZE) |
|
|
|
for stage in tl.range(nstages): |
|
group_stride = 1 << stage |
|
|
|
for block_start in tl.range(0, T//2, BLOCK_SIZE): |
|
block_mask = offs < (T//2 - block_start) |
|
block_s_mask = block_mask[:, None] & s_mask[None, :] |
|
block_s_c_mask = block_mask[:, None, None] & s_mask[None, :, None] & c_mask[None, None, :] |
|
|
|
|
|
initial_indices = group_stride + ((offs + block_start) // group_stride) * group_stride * 2 |
|
t_targets = initial_indices + ((offs + block_start) % group_stride) |
|
t_adders = initial_indices - 1 |
|
|
|
xg_targets_ptr = xg_base + t_targets[:, None, None] * S * C + s_offs[None, :, None] * C + c_offs[None, None, :] |
|
xg_adders_ptr = xg_base + t_adders[:, None, None] * S * C + s_offs[None, :, None] * C + c_offs[None, None, :] |
|
gi_targets_ptr = gi_base + t_targets[:, None] * S + s_offs[None, :] |
|
gi_adders_ptr = gi_base + t_adders[:, None] * S + s_offs[None, :] |
|
|
|
xg_targets = tl.load(xg_targets_ptr, mask=block_s_c_mask) |
|
xg_adders = tl.load(xg_adders_ptr, mask=block_s_c_mask) |
|
gi_targets = tl.load(gi_targets_ptr, mask=block_s_mask) |
|
gi_adders = tl.load(gi_adders_ptr, mask=block_s_mask) |
|
|
|
|
|
xg_targets += xg_adders * gi_targets[:, :, None] |
|
|
|
gi_targets *= gi_adders |
|
|
|
tl.store(xg_targets_ptr, xg_targets.to(xg_targets_ptr.dtype.element_ty), mask=block_s_c_mask) |
|
tl.store(gi_targets_ptr, gi_targets.to(gi_targets_ptr.dtype.element_ty), mask=block_s_mask) |
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({ |
|
'BLOCK_SIZE': bs, |
|
'BLOCK_SIZE_S': bs_s, |
|
'BLOCK_SIZE_C': bs_c |
|
}, num_warps=warps) |
|
for bs in [16] |
|
for bs_s in [16] |
|
for bs_c in [16] |
|
for warps in [2] |
|
], |
|
key=[] |
|
) |
|
@triton.jit |
|
def cg2d_gxg_bwd_kernel( |
|
gi_ptr, go_ptr, |
|
B: tl.constexpr, S: tl.constexpr, C: tl.constexpr, T: tl.constexpr, nstages: tl.constexpr, |
|
BLOCK_SIZE: tl.constexpr, |
|
BLOCK_SIZE_S: tl.constexpr, |
|
BLOCK_SIZE_C: tl.constexpr, |
|
): |
|
|
|
pid = tl.program_id(axis=0) |
|
num_s_blocks = tl.cdiv(S, BLOCK_SIZE_S) |
|
num_c_blocks = tl.cdiv(C, BLOCK_SIZE_C) |
|
b = pid // (num_s_blocks * num_c_blocks) |
|
rem = pid % (num_s_blocks * num_c_blocks) |
|
s_block = rem // num_c_blocks |
|
c_block = rem % num_c_blocks |
|
|
|
s_offs = tl.arange(0, BLOCK_SIZE_S) |
|
c_offs = tl.arange(0, BLOCK_SIZE_C) |
|
s_mask = s_offs < (S - s_block * BLOCK_SIZE_S) |
|
c_mask = c_offs < (C - c_block * BLOCK_SIZE_C) |
|
s_offs = s_block * BLOCK_SIZE_S + s_offs |
|
c_offs = c_block * BLOCK_SIZE_C + c_offs |
|
|
|
gi_base = gi_ptr + b * T * S |
|
go_base = go_ptr + b * T * S * C |
|
|
|
|
|
offs = tl.arange(0, BLOCK_SIZE) |
|
|
|
for stage in tl.range(nstages): |
|
group_stride = 1 << stage |
|
for block_start in tl.range(0, T//2, BLOCK_SIZE): |
|
block_mask = offs < (T//2 - block_start) |
|
block_s_mask = block_mask[:, None] & s_mask[None, :] |
|
block_s_c_mask = block_mask[:, None, None] & s_mask[None, :, None] & c_mask[None, None, :] |
|
|
|
initial_indices = T - 1 - group_stride - ((offs + block_start) // group_stride) * group_stride * 2 |
|
t_targets = initial_indices - ((offs + block_start) % group_stride) |
|
t_adders = initial_indices + 1 |
|
|
|
go_targets_ptr = go_base + t_targets[:, None, None] * S * C + s_offs[None, :, None] * C + c_offs[None, None, :] |
|
go_adders_ptr = go_base + t_adders[:, None, None] * S * C + s_offs[None, :, None] * C + c_offs[None, None, :] |
|
gi_targets_ptr = gi_base + t_targets[:, None] * S + s_offs[None, :] |
|
gi_adders_ptr = gi_base + t_adders[:, None] * S + s_offs[None, :] |
|
|
|
|
|
go_targets = tl.load(go_targets_ptr, mask=block_s_c_mask) |
|
go_adders = tl.load(go_adders_ptr, mask=block_s_c_mask) |
|
gi_targets = tl.load(gi_targets_ptr, mask=block_s_mask) |
|
gi_adders = tl.load(gi_adders_ptr, mask=block_s_mask) |
|
|
|
|
|
go_targets += go_adders * gi_targets[:, :, None] |
|
gi_targets *= gi_adders |
|
|
|
tl.store(go_targets_ptr, go_targets.to(go_targets_ptr.dtype.element_ty), mask=block_s_c_mask) |
|
tl.store(gi_targets_ptr, gi_targets.to(gi_targets_ptr.dtype.element_ty), mask=block_s_mask) |
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({ |
|
'BLOCK_SIZE': bs, |
|
'BLOCK_SIZE_S': bs_s, |
|
'BLOCK_SIZE_C': bs_c |
|
}, num_warps=warps) |
|
for bs in [16] |
|
for bs_s in [16] |
|
for bs_c in [16] |
|
for warps in [2] |
|
], |
|
key=[] |
|
) |
|
@triton.jit |
|
def cg2d_ggi_bwd_kernel( |
|
go_ptr, y_ptr, grad_gi_ptr, |
|
B: tl.constexpr, S: tl.constexpr, C: tl.constexpr, T: tl.constexpr, |
|
BLOCK_SIZE: tl.constexpr, |
|
BLOCK_SIZE_S: tl.constexpr, |
|
BLOCK_SIZE_C: tl.constexpr |
|
): |
|
b = tl.program_id(axis=0) |
|
pid = tl.program_id(axis=1) |
|
num_t_blocks = tl.cdiv(T, BLOCK_SIZE) |
|
num_s_blocks = tl.cdiv(S, BLOCK_SIZE_S) |
|
num_c_blocks = tl.cdiv(C, BLOCK_SIZE_C) |
|
t_block = pid // (num_s_blocks * num_c_blocks) |
|
rem = pid % (num_s_blocks * num_c_blocks) |
|
s_block = rem // num_c_blocks |
|
c_block = rem % num_c_blocks |
|
|
|
t_offs = tl.arange(0, BLOCK_SIZE) |
|
s_offs = tl.arange(0, BLOCK_SIZE_S) |
|
c_offs = tl.arange(0, BLOCK_SIZE_C) |
|
t_mask = t_offs < (T - t_block * BLOCK_SIZE) |
|
s_mask = s_offs < (S - s_block * BLOCK_SIZE_S) |
|
c_mask = c_offs < (C - c_block * BLOCK_SIZE_C) |
|
t_offs = t_block * BLOCK_SIZE + t_offs |
|
s_offs = s_block * BLOCK_SIZE_S + s_offs |
|
c_offs = c_block * BLOCK_SIZE_C + c_offs |
|
|
|
|
|
|
|
|
|
|
|
grad_gi_base = grad_gi_ptr + b * T * S |
|
t_first_id = t_block * BLOCK_SIZE |
|
s_first_id = s_block * BLOCK_SIZE_S |
|
c_first_id = c_block * BLOCK_SIZE_C |
|
|
|
go_block_ptr = tl.make_block_ptr( |
|
base=go_ptr, |
|
shape=(B, T, S, C), |
|
strides=(T * S * C, S * C, C, 1), |
|
offsets=(b, t_first_id, s_first_id, c_first_id), |
|
block_shape=(1, BLOCK_SIZE, BLOCK_SIZE_S, BLOCK_SIZE_C), |
|
order=(0, 1, 2, 3) |
|
) |
|
go_block = tl.load(go_block_ptr) |
|
y_block_ptr = tl.make_block_ptr( |
|
base=y_ptr, |
|
shape=(B, T, S, C), |
|
strides=(T * S * C, S * C, C, 1), |
|
offsets=(b, t_first_id, s_first_id, c_first_id), |
|
block_shape=(1, BLOCK_SIZE, BLOCK_SIZE_S, BLOCK_SIZE_C), |
|
order=(0, 1, 2, 3) |
|
) |
|
y_block = tl.load(y_block_ptr) |
|
|
|
grad_gi = go_block * y_block |
|
grad_gi = tl.sum(grad_gi, axis=-1) |
|
|
|
|
|
grad_gi_block_ptr = grad_gi_base + t_offs[:, None] * S + s_offs[None, :] |
|
grad_gi_mask = t_mask[:, None] & s_mask[None, :] |
|
tl.atomic_add(grad_gi_block_ptr[None, :], grad_gi, mask=grad_gi_mask[None, :]) |
|
|
|
class CumulativeGating2DTriton(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx, xg, gi): |
|
xg = xg.contiguous() |
|
gi = gi.contiguous() |
|
orig_gi = gi.clone() |
|
B, T, S, C = xg.shape |
|
|
|
|
|
grid = lambda meta: (B * triton.cdiv(S, meta['BLOCK_SIZE_S']) * triton.cdiv(C, meta['BLOCK_SIZE_C']),) |
|
|
|
|
|
nstages = math.ceil(math.log2(T)) |
|
cg2d_fwd_kernel[grid]( |
|
xg, gi, |
|
B, S, C, T, nstages, |
|
) |
|
|
|
ctx.save_for_backward(xg, orig_gi) |
|
return xg |
|
|
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
grad_output = grad_output.contiguous() |
|
y, gi = ctx.saved_tensors |
|
B, T, S, C = y.shape |
|
|
|
|
|
grid = lambda meta: (B * triton.cdiv(S, meta['BLOCK_SIZE_S']) * triton.cdiv(C, meta['BLOCK_SIZE_C']),) |
|
|
|
gi = torch.cat((gi[:, 1:], torch.ones_like(gi[:, -1:])), dim=1).contiguous() |
|
grad_xg = grad_output.clone() |
|
y = torch.cat((torch.zeros_like(y[:, :1]), y[:, :-1]), dim=1).contiguous() |
|
grad_gi = torch.zeros((B, T, S), device=gi.device).contiguous() |
|
|
|
|
|
nstages = math.ceil(math.log2(T)) |
|
cg2d_gxg_bwd_kernel[grid]( |
|
gi, grad_xg, |
|
B, S, C, T, nstages, |
|
) |
|
|
|
|
|
grid = lambda meta: (B, triton.cdiv(T, meta['BLOCK_SIZE']) * triton.cdiv(S, meta['BLOCK_SIZE_S']) * triton.cdiv(C, meta['BLOCK_SIZE_C'])) |
|
cg2d_ggi_bwd_kernel[grid]( |
|
grad_xg, y, grad_gi, |
|
B, S, C, T, |
|
) |
|
|
|
return grad_xg, grad_gi.to(gi.dtype) |
|
|
|
|
|
def parallel_scan( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
s: torch.Tensor, |
|
g: torch.Tensor, |
|
window_size: int, |
|
num_heads: int, |
|
alibi: torch.Tensor, |
|
mask: torch.Tensor, |
|
scale: Optional[int] = None, |
|
initial_state: Optional[torch.Tensor] = None, |
|
output_final_state: Optional[bool] = False, |
|
checkpoint_level: Optional[int] = 2, |
|
offsets: Optional[torch.LongTensor] = None, |
|
head_first: Optional[bool] = True |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
r""" |
|
Args: |
|
q (torch.Tensor): |
|
queries of shape `[B, HQ, T, K]` if `head_first=True` else `[B, T, HQ, K]`. |
|
k (torch.Tensor): |
|
keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. |
|
GQA is performed if `H` is not equal to `HQ`. |
|
v (torch.Tensor): |
|
values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. |
|
s (torch.Tensor): |
|
slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]`. |
|
g (torch.Tensor): |
|
Forget gates of shape `[B, H, T, M]` applied to keys. |
|
If not provided, this function is equivalent to vanilla ABC. |
|
scale (Optional[int]): |
|
Scale factor for attention scores. |
|
If not provided, it will default to `1 / sqrt(K)`. Default: `None`. |
|
initial_state (Optional[Tuple[torch.Tensor]]): |
|
Initial state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` for `N` input sequences. |
|
For equal-length input sequences, `N` equals the batch size `B`. |
|
Default: `None`. |
|
output_final_state (Optional[bool]): |
|
Whether to output the final state tuple, having tensors of shape `[N, H, K, M]` and `[N, H, M, V]`. |
|
Default: `False`. |
|
checkpoint_level (Optional[int]): |
|
Checkpointing level; higher values will save more memories and do more recomputations during backward. |
|
Default: `2`: |
|
- Level `0`: no memory saved, no recomputation. |
|
- Level `1`: recompute the fp32 cumulative values during backward. |
|
- Level `2`: recompute the fp32 cumulative values and forward hidden states during backward. |
|
offsets (Optional[torch.LongTensor]): |
|
Offsets of shape `[N+1]` defining the bos/eos positions of `N` variable-length sequences in the batch. |
|
For example, |
|
if `offsets` is `[0, 1, 3, 6, 10, 15]`, there are `N=5` sequences with lengths 1, 2, 3, 4 and 5 respectively. |
|
If provided, the inputs are concatenated and the batch size `B` is expected to be 1. |
|
Default: `None`. |
|
head_first (Optional[bool]): |
|
Whether the inputs are in the head-first format, which is not supported for variable-length inputs. |
|
Default: `True`. |
|
|
|
Returns: |
|
o (torch.Tensor): |
|
Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. |
|
final_state (Tuple[torch.Tensor]): |
|
Final state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` if `output_final_state=True`. |
|
`None` otherwise. |
|
|
|
Examples:: |
|
>>> import torch |
|
>>> import torch.nn.functional as F |
|
>>> from einops import rearrange |
|
>>> from fla.ops.gsa import fused_recurrent_gsa |
|
# inputs with equal lengths |
|
>>> B, T, H, K, V, M = 4, 2048, 4, 512, 512, 64 |
|
>>> q = torch.randn(B, T, H, K, device='cuda') |
|
>>> k = torch.randn(B, T, H, K, device='cuda') |
|
>>> v = torch.randn(B, T, H, V, device='cuda') |
|
>>> s = torch.randn(B, T, H, M, device='cuda') |
|
>>> g = F.logsigmoid(torch.randn(B, T, H, M, device='cuda')) |
|
>>> h0 = (torch.randn(B, H, K, M, device='cuda'), torch.randn(B, H, M, V, device='cuda')) |
|
>>> o, (hk, hv) = chunk_gsa(q, k, v, s, g, |
|
initial_state=h0, |
|
output_final_state=True, |
|
head_first=False) |
|
# for variable-length inputs, the batch size `B` is expected to be 1 and `offsets` is required |
|
>>> q, k, v, s, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, s, g)) |
|
# for a batch with 4 sequences, offsets with 5 start/end positions are expected |
|
>>> offsets = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) |
|
>>> o_var, (hk_var, hv_var) = chunk_gsa(q, k, v, s, g, |
|
initial_state=h0, |
|
output_final_state=True, |
|
offsets=offsets, |
|
head_first=False) |
|
>>> assert o.allclose(o_var.view(o.shape)) |
|
>>> assert hk.allclose(hk_var) |
|
>>> assert hv.allclose(hv_var) |
|
""" |
|
if offsets is not None: |
|
if q.shape[0] != 1: |
|
raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `offsets`." |
|
f"Please flatten variable-length inputs before processing.") |
|
if head_first: |
|
raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") |
|
if initial_state is not None and initial_state[0].shape[0] != len(offsets) - 1: |
|
raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " |
|
f"i.e., {len(offsets) - 1} rather than {initial_state[0].shape[0]}.") |
|
assert checkpoint_level in [0, 1, 2] |
|
|
|
if scale is None: |
|
scale = q.shape[-1] ** -0.5 |
|
|
|
BH, T, S = g.shape |
|
|
|
sg = torch.einsum('bts,btc->btsc', g, s) |
|
gi = 1 - g |
|
states = CumulativeGating2DTriton.apply(sg, gi) |
|
scores = AttendFoldedAllKeysTriton.apply(q, k, states, window_size) * scale |
|
|
|
scores = scores.view(-1, num_heads, T, S + window_size) + alibi[:, :T] |
|
scores = scores.masked_fill(mask[:T] == 0, float('-inf')) |
|
scores = torch.softmax(scores, dim=-1).view(BH, T, S + window_size) |
|
o = AccumulateFoldedAllValuesTriton.apply(scores, v, states, window_size) |
|
|
|
final_state = None |
|
return o, final_state |
|
|