#    Copyright 2023 OpenNLPLab
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

import torch
import triton
import triton.language as tl


@triton.jit
def _fwd_kernel(
    Q,
    K,
    V,
    Out,
    S,
    stride_qz,
    stride_qh,
    stride_qm,
    stride_qk,
    stride_kz,
    stride_kh,
    stride_kn,
    stride_kk,
    stride_vz,
    stride_vh,
    stride_vn,
    stride_ve,
    stride_oz,
    stride_oh,
    stride_om,
    stride_oe,
    stride_sh,
    Z,
    H,
    N_CTX,
    BLOCK_M: tl.constexpr,
    BLOCK_DMODEL_QK: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_DMODEL_V: tl.constexpr,
    IS_CAUSAL: tl.constexpr,
    USE_DECAY: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    off_h = off_hz % H
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_DMODEL_QK)
    offs_e = tl.arange(0, BLOCK_DMODEL_V)
    # get current offset of q k v
    off_q = (off_hz * stride_qh + offs_m[:, None] * stride_qm +
             offs_k[None, :] * stride_qk)
    off_k = (off_hz * stride_kh + offs_n[:, None] * stride_kn +
             offs_k[None, :] * stride_kk)
    off_v = (off_hz * stride_vh + offs_n[:, None] * stride_vn +
             offs_e[None, :] * stride_ve)
    off_o = (off_hz * stride_oh + offs_m[:, None] * stride_om +
             offs_e[None, :] * stride_oe)

    # Initialize pointers to Q, K, V
    q_ptrs = Q + off_q
    k_ptrs = K + off_k
    v_ptrs = V + off_v

    # initialize pointer to m and l
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_V], dtype=tl.float32)
    # load q: it will stay in SRAM throughout
    q = tl.load(q_ptrs, mask=offs_m[:, None] < N_CTX, other=0.0)
    # loop over k, v and update accumulator
    lo = 0
    # print(start_m)
    hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
    for start_n in range(lo, hi, BLOCK_N):
        # -- load k, v --
        k = tl.load(
            k_ptrs + start_n * stride_kn,
            mask=(start_n + offs_n)[:, None] < N_CTX,
            other=0.0,
        )
        v = tl.load(
            v_ptrs + start_n * stride_vn,
            mask=(start_n + offs_n)[:, None] < N_CTX,
            other=0.0,
        )
        # -- compute qk ---
        # qk = tl.dot(q, k)
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        # qk += tl.dot(q, k, trans_b=True)
        qk += tl.dot(q, tl.trans(k))
        if IS_CAUSAL:
            index = offs_m[:, None] - (start_n + offs_n[None, :])
            if USE_DECAY:
                S_block_ptr = S + off_h * stride_sh
                s = tl.load(S_block_ptr)
                s_index = s * index
                s_index = tl.where(s_index >= 0, -s_index, float("-inf"))
                qk = tl.exp(s_index) * qk
            else:
                qk = tl.where(index >= 0, qk, 0)
        acc += tl.dot(qk, v.to(qk.dtype))

    out_ptrs = Out + off_o
    tl.store(out_ptrs, acc.to(q.dtype), mask=offs_m[:, None] < N_CTX)


@triton.jit
def _bwd_kernel_kv(
    Q,
    K,
    V,
    S,
    DO,
    DQ,
    DK,
    DV,
    stride_qz,
    stride_qh,
    stride_qm,
    stride_qk,
    stride_kz,
    stride_kh,
    stride_kn,
    stride_kk,
    stride_vz,
    stride_vh,
    stride_vn,
    stride_ve,
    stride_oz,
    stride_oh,
    stride_om,
    stride_oe,
    stride_sh,
    Z,
    H,
    N_CTX,
    num_block,
    BLOCK_M: tl.constexpr,
    BLOCK_DMODEL_QK: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_DMODEL_V: tl.constexpr,
    CAUSAL: tl.constexpr,
    USE_DECAY: tl.constexpr,
):
    start_n = tl.program_id(0)
    off_hz = tl.program_id(1)

    off_z = off_hz // H
    off_h = off_hz % H
    # offset pointers for batch/head
    Q += off_z * stride_qz + off_h * stride_qh
    K += off_z * stride_kz + off_h * stride_kh
    V += off_z * stride_vz + off_h * stride_vh
    DO += off_z * stride_oz + off_h * stride_oh
    DQ += off_z * stride_qz + off_h * stride_qh
    DK += off_z * stride_kz + off_h * stride_kh
    DV += off_z * stride_vz + off_h * stride_vh

    # start of q
    if CAUSAL:
        lo = start_n * BLOCK_M
    else:
        lo = 0
    # initialize row/col offsets
    # seqlence offset
    offs_qm = lo + tl.arange(0, BLOCK_M)
    offs_kvn = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
    # feature offset
    offs_qkk = tl.arange(0, BLOCK_DMODEL_QK)
    offs_ve = tl.arange(0, BLOCK_DMODEL_V)
    # row block index
    offs_m = tl.arange(0, BLOCK_M)
    # initialize pointers to value-like data
    q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_qkk[None, :] * stride_qk)
    k_ptrs = K + (offs_kvn[:, None] * stride_kn +
                  offs_qkk[None, :] * stride_kk)
    v_ptrs = V + (offs_kvn[:, None] * stride_vn + offs_ve[None, :] * stride_ve)
    do_ptrs = DO + (offs_qm[:, None] * stride_om +
                    offs_ve[None, :] * stride_oe)
    dq_ptrs = DQ + (offs_qm[:, None] * stride_qm +
                    offs_qkk[None, :] * stride_qk)
    # initialize dv amd dk
    dv = tl.zeros([BLOCK_N, BLOCK_DMODEL_V], dtype=tl.float32)
    dk = tl.zeros([BLOCK_N, BLOCK_DMODEL_QK], dtype=tl.float32)
    # k and v stay in SRAM throughout
    k = tl.load(k_ptrs, mask=offs_kvn[:, None] < N_CTX, other=0.0)
    v = tl.load(v_ptrs, mask=offs_kvn[:, None] < N_CTX, other=0.0)
    # loop over rows
    for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
        offs_m_curr = start_m + offs_m
        # load q, k, v, do on-chip
        q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < N_CTX, other=0.0)
        qk = tl.dot(q, tl.trans(k))
        # qk = tl.dot(q, k, trans_b=True)
        if CAUSAL:
            index = offs_m_curr[:, None] - offs_kvn[None, :]
            if USE_DECAY:
                S_block_ptr = S + off_h * stride_sh
                s = tl.load(S_block_ptr)
                s_index = s * index
                s_index = tl.where(s_index >= 0, -s_index, float("-inf"))
                s = tl.exp(s_index)
                qk = qk * s
            else:
                qk = tl.where(index >= 0, qk, 0)

        p = qk
        # compute dv
        do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < N_CTX, other=0.0)
        dv += tl.dot(tl.trans(p.to(do.dtype)), do)
        dp = tl.dot(do, tl.trans(v).to(do.dtype))
        if CAUSAL:
            if USE_DECAY:
                dp = dp * s
            else:
                dp = tl.where(index >= 0, dp, 0)

        dk += tl.dot(tl.trans(dp.to(q.dtype)), q).to(tl.float32)

        # increment pointers
        q_ptrs += BLOCK_M * stride_qm
        do_ptrs += BLOCK_M * stride_om
    # write-back
    dv_ptrs = DV + (offs_kvn[:, None] * stride_vn +
                    offs_ve[None, :] * stride_ve)
    dk_ptrs = DK + (offs_kvn[:, None] * stride_kn +
                    offs_qkk[None, :] * stride_kk)
    tl.store(dv_ptrs, dv, mask=offs_kvn[:, None] < N_CTX)
    tl.store(dk_ptrs, dk, mask=offs_kvn[:, None] < N_CTX)


@triton.jit
def _bwd_kernel_q(
    Q,
    K,
    V,
    S,
    DO,
    DQ,
    DK,
    DV,
    stride_qz,
    stride_qh,
    stride_qm,
    stride_qk,
    stride_kz,
    stride_kh,
    stride_kn,
    stride_kk,
    stride_vz,
    stride_vh,
    stride_vn,
    stride_ve,
    stride_oz,
    stride_oh,
    stride_om,
    stride_oe,
    stride_sh,
    Z,
    H,
    N_CTX,
    num_block,
    BLOCK_M: tl.constexpr,
    BLOCK_DMODEL_QK: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_DMODEL_V: tl.constexpr,
    CAUSAL: tl.constexpr,
    USE_DECAY: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    off_z = off_hz // H
    off_h = off_hz % H
    # offset pointers for batch/head
    K += off_z * stride_kz + off_h * stride_kh
    V += off_z * stride_vz + off_h * stride_vh
    DO += off_z * stride_oz + off_h * stride_oh
    DQ += off_z * stride_qz + off_h * stride_qh
    # feature offset
    offs_qkk = tl.arange(0, BLOCK_DMODEL_QK)
    offs_ve = tl.arange(0, BLOCK_DMODEL_V)
    # row block index
    offs_m = tl.arange(0, BLOCK_M)
    # row block index
    offs_qm = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    # do
    do_ptrs = DO + (offs_qm[:, None] * stride_om +
                    offs_ve[None, :] * stride_oe)
    dq_ptrs = DQ + (offs_qm[:, None] * stride_qm +
                    offs_qkk[None, :] * stride_qk)

    do = tl.load(do_ptrs, mask=offs_qm[:, None] < N_CTX, other=0.0)

    dq = tl.zeros([BLOCK_M, BLOCK_DMODEL_QK], dtype=tl.float32)
    lo = 0
    hi = (start_m + 1) * BLOCK_M if CAUSAL else N_CTX

    offs_m_curr = start_m * BLOCK_M + offs_m

    for start_n in range(0, num_block):
        offs_kvn = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
        k_ptrs = K + (offs_kvn[:, None] * stride_kn +
                      offs_qkk[None, :] * stride_kk)
        v_ptrs = V + (offs_kvn[:, None] * stride_vn +
                      offs_ve[None, :] * stride_ve)
        # k and v stay in SRAM throughout
        k = tl.load(k_ptrs, mask=offs_kvn[:, None] < N_CTX, other=0.0)
        v = tl.load(v_ptrs, mask=offs_kvn[:, None] < N_CTX, other=0.0)
        # dp = do vT
        dp = tl.dot(do, tl.trans(v).to(do.dtype))
        if CAUSAL:
            index = offs_m_curr[:, None] - offs_kvn[None, :]
            if USE_DECAY:
                S_block_ptr = S + off_h * stride_sh
                s = tl.load(S_block_ptr)
                s_index = s * index
                s_index = tl.where(s_index >= 0, -s_index, float("-inf"))
                s = tl.exp(s_index)
                dp = dp * s
            else:
                dp = tl.where(index >= 0, dp, 0)
        # dq = dq + dp k
        dq += tl.dot(dp.to(k.dtype), k)

    tl.store(dq_ptrs, dq, mask=offs_qm[:, None] < N_CTX)


class _attention(torch.autograd.Function):

    @staticmethod
    def forward(ctx, q, k, v, causal, s):
        q = q.contiguous()
        k = k.contiguous()
        v = v.contiguous()
        s = s.contiguous()
        # only support for Ampere now
        capability = torch.cuda.get_device_capability()
        if capability[0] < 8:
            raise RuntimeError(
                "Lightning attention currently only supported for compute capability >= 80"
            )
        # shape constraints
        Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
        # right
        o = torch.empty(
            (q.shape[0], q.shape[1], q.shape[2], v.shape[-1]),
            dtype=q.dtype,
            device=q.device,
        )

        BLOCK_M = 128
        BLOCK_N = 64
        num_warps = 4 if Lk <= 64 else 8
        num_stages = 1

        grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
        use_decay = s.shape[0] > 0
        _fwd_kernel[grid](
            q,
            k,
            v,
            o,
            s,
            q.stride(0),
            q.stride(1),
            q.stride(2),
            q.stride(3),
            k.stride(0),
            k.stride(1),
            k.stride(2),
            k.stride(3),
            v.stride(0),
            v.stride(1),
            v.stride(2),
            v.stride(3),
            o.stride(0),
            o.stride(1),
            o.stride(2),
            o.stride(3),
            s.stride(0),
            q.shape[0],
            q.shape[1],
            q.shape[2],
            BLOCK_M=BLOCK_M,
            BLOCK_DMODEL_QK=Lk,
            BLOCK_N=BLOCK_N,
            BLOCK_DMODEL_V=Lv,
            IS_CAUSAL=causal,
            USE_DECAY=use_decay,
            num_warps=num_warps,
            num_stages=num_stages,
        )

        ctx.save_for_backward(q, k, v, s)
        ctx.grid = grid
        ctx.BLOCK_M = BLOCK_M
        ctx.BLOCK_DMODEL_QK = Lk
        ctx.BLOCK_N = BLOCK_N
        ctx.BLOCK_DMODEL_V = Lv
        ctx.causal = causal
        ctx.use_decay = use_decay
        return o

    @staticmethod
    def backward(ctx, do):
        q, k, v, s = ctx.saved_tensors
        BLOCK_M = 32
        BLOCK_N = 32
        num_warps = 4
        num_stages = 1

        do = do.contiguous()
        dq = torch.zeros_like(q, dtype=torch.float32)
        dk = torch.empty_like(k)
        dv = torch.empty_like(v)

        grid_kv = (triton.cdiv(k.shape[2],
                               BLOCK_N), k.shape[0] * k.shape[1], 1)
        _bwd_kernel_kv[grid_kv](
            q,
            k,
            v,
            s,
            do,
            dq,
            dk,
            dv,
            q.stride(0),
            q.stride(1),
            q.stride(2),
            q.stride(3),
            k.stride(0),
            k.stride(1),
            k.stride(2),
            k.stride(3),
            v.stride(0),
            v.stride(1),
            v.stride(2),
            v.stride(3),
            do.stride(0),
            do.stride(1),
            do.stride(2),
            do.stride(3),
            s.stride(0),
            q.shape[0],
            q.shape[1],
            q.shape[2],
            grid_kv[0],
            BLOCK_M=BLOCK_M,
            BLOCK_DMODEL_QK=ctx.BLOCK_DMODEL_QK,
            BLOCK_N=BLOCK_N,
            BLOCK_DMODEL_V=ctx.BLOCK_DMODEL_V,
            CAUSAL=ctx.causal,
            USE_DECAY=ctx.use_decay,
            num_warps=num_warps,
            num_stages=num_stages,
        )

        grid_q = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)

        _bwd_kernel_q[grid_q](
            q,
            k,
            v,
            s,
            do,
            dq,
            dk,
            dv,
            q.stride(0),
            q.stride(1),
            q.stride(2),
            q.stride(3),
            k.stride(0),
            k.stride(1),
            k.stride(2),
            k.stride(3),
            v.stride(0),
            v.stride(1),
            v.stride(2),
            v.stride(3),
            do.stride(0),
            do.stride(1),
            do.stride(2),
            do.stride(3),
            s.stride(0),
            q.shape[0],
            q.shape[1],
            q.shape[2],
            grid_q[0],
            BLOCK_M=BLOCK_M,
            BLOCK_DMODEL_QK=ctx.BLOCK_DMODEL_QK,
            BLOCK_N=BLOCK_N,
            BLOCK_DMODEL_V=ctx.BLOCK_DMODEL_V,
            CAUSAL=ctx.causal,
            USE_DECAY=ctx.use_decay,
            num_warps=num_warps,
            num_stages=num_stages,
        )

        return dq.to(q.dtype), dk, dv, None, None


attention = _attention.apply


def lightning_attention(q, k, v, causal, ed):
    d = q.shape[-1]
    e = v.shape[-1]
    # arr = f(d)
    if d >= 128:
        m = 128
    else:
        m = 64
    arr = [m * i for i in range(d // m + 1)]
    if arr[-1] != d:
        arr.append(d)
    n = len(arr)
    output = 0
    for i in range(n - 1):
        s = arr[i]
        e = arr[i + 1]
        q1 = q[..., s:e]
        k1 = k[..., s:e]
        o = attention(q1, k1, v, causal, ed)
        output = output + o

    return output