import torch
from torch import nn
from dataclasses import dataclass
from enum import Enum
from typing import *
from math import ceil

class AttentionBackend(Enum):
    Naive = 0
    FlashAttentionCuda = 1
    FlashAttentionTriton = 2


global_config = {
    'attn_backend': AttentionBackend.Naive
}

@dataclass
class TransformerConfig:
    vocab_size: int = -1,
    num_layers: int = -1,
    num_heads: int = -1,
    hidden_size: int = -1,
    max_seq_len: int = -1,
    root_model: 'ToyTransformer' = None
    device: torch.device = torch.device('cpu')
    dtype: torch.dtype = torch.float32


def expand_attn_mask(custom_attn_mask: torch.Tensor):
    B, T = custom_attn_mask.shape
    mask = custom_attn_mask.unsqueeze(1).repeat((1, T, 1))
    seq_index_mask = (mask == custom_attn_mask[:, torch.arange(T)].view(B, T, 1))
    return seq_index_mask & (torch.tril(mask) > 0)


# expand attn mask to cu_seqlens for flash attn
def expand_attn_mask_to_seq_lengths(attn_mask: torch.Tensor):
    attn_mask = attn_mask.to('cpu')
    seq_len = attn_mask.shape[0] * attn_mask.shape[1]
    disjoint_point = torch.cat([torch.tensor([[True]] * attn_mask.shape[0]), attn_mask[:, 1:] != attn_mask[:, :-1]], dim=1)
    return torch.cat([torch.nonzero(disjoint_point.view((-1,))), torch.tensor([[seq_len]])]).to(dtype=torch.int32)


# naive RoPE implementation following https://arxiv.org/pdf/2104.09864.pdf
def get_rope_cache_slow(seq_len: int, dim: int, theta: int, device: torch.device, dtype: torch.dtype):
    assert dim % 2 == 0
    freqs = theta ** (-2 * torch.arange(0, dim // 2, 1.) / dim)
    freqs = torch.repeat_interleave(freqs, 2)
    v1 = torch.cos(torch.arange(seq_len, dtype=torch.float).view((seq_len, 1)) * freqs)
    v2 = torch.sin(torch.arange(seq_len, dtype=torch.float).view((seq_len, 1)) * freqs)
    v2 = v2 * torch.tensor([1, -1] * (dim // 2))
    indices = torch.tensor([j for i in range(0, dim, 2) for j in (i + 1, i)])
    return v1.to(device, dtype=dtype), v2.to(device, dtype=dtype), indices.to(device)


def apply_rope_slow(x, rope_cache, positions: Optional[torch.Tensor] = None):
    v1, v2, indices = rope_cache
    seq_len, dim = x.shape[1:]
    if positions is None:
        v1 = v1[:seq_len, :]
        v2 = v2[:seq_len, :]
    else:
        v1 = v1[positions, torch.arange(dim)].view((-1, dim))
        v2 = v2[positions, torch.arange(dim)].view((-1, dim))
    applied_x = x * v1 + (x * v2)[:, :, indices]
    return applied_x


# Optimized RoPE implementation adapted from https://github.com/facebookresearch/llama/blob/main/llama/model.py
def get_rope_cache_fast(seq_len: int, dim: int, theta: int, device: torch.device, dtype: torch.dtype):
    freqs = (1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)))
    t = torch.arange(seq_len, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis.to(device)


def apply_rope_fast(x, rope_cache, positions: Optional[torch.Tensor] = None) -> torch.Tensor:
    x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    if positions is None and x.shape[1] < rope_cache.shape[0]:
        freqs_cis = rope_cache[:x.shape[1], :]
    elif positions is not None:
        freqs_cis = rope_cache[positions, :]
    else:
        freqs_cis = rope_cache
    freqs_cis = freqs_cis.view([d if i == 1 or i == x_.ndim - 1 else 1 for i, d in enumerate(x_.shape)])

    applied_x = torch.view_as_real(x_ * freqs_cis).flatten(2)
    return applied_x.type_as(x)


# RMSNorm implementation following https://arxiv.org/pdf/1910.07467.pdf
class RMSNorm(nn.Module):
    def __init__(self, hidden_size, dtype, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size, dtype=dtype))
        self.eps = eps

    def forward(self, x: torch.Tensor):
        x_ = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return self.weight * x_


class AttentionHead(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.head_size = config.hidden_size // config.num_heads
        self.dtype = config.dtype
        self.q_proj = nn.Linear(config.hidden_size, self.head_size, dtype=config.dtype)
        self.k_proj = nn.Linear(config.hidden_size, self.head_size, dtype=config.dtype)
        self.v_proj = nn.Linear(config.hidden_size, self.head_size, dtype=config.dtype)

    def forward(self, x: torch.Tensor, attn_masked_bias: Optional[torch.Tensor],
                kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # if global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton:
        # padding the position indices for alignment
        # positions = torch.tensor([kv_cache[0].shape[1]] * q.shape[1]).to(q.device) if kv_cache is not None else torch.arange(0, x.shape[1], 1).to(q.device)

        positions = torch.tensor([kv_cache[0].shape[1]]).to(q.device) if kv_cache is not None else None
        q = apply_rope_fast(q, self.config.root_model.rope_cache, positions)
        k = apply_rope_fast(k, self.config.root_model.rope_cache, positions)

        if kv_cache is not None:
            k = torch.concat([kv_cache[0], k], dim=1)
            v = torch.concat([kv_cache[1], v], dim=1)

        if global_config['attn_backend'] == AttentionBackend.FlashAttentionCuda:
            q, k, v, = q.unsqueeze(2), k.unsqueeze(2), v.unsqueeze(2)
            attn_result = flash_attn_func(q, k, v, causal=True)
            q, k, v, attn_result = q.squeeze(2), k.squeeze(2), v.squeeze(2), attn_result.squeeze(2)
        elif global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton:
            q, k, v, = q.unsqueeze(2), k.unsqueeze(2), v.unsqueeze(2)
            attn_result = flash_attn_func_triton(q, k, v, attn_masked_bias.unsqueeze(1) if attn_masked_bias is not None else None,
                                                 True if kv_cache is None else False)
            q, k, v, attn_result = q.squeeze(2), k.squeeze(2), v.squeeze(2), attn_result.squeeze(2)
        else:
            attn_score = (q @ k.permute(0, 2, 1) / (self.head_size ** 0.5)) + attn_masked_bias
            attn_result = torch.softmax(attn_score, dim=2) @ v

        return attn_result, [k, v]


class MultiHeadAttention(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.attn_heads = nn.ModuleList([AttentionHead(config) for _ in range(config.num_heads)])
        self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, dtype=config.dtype)

    def forward(self, x: torch.Tensor, attn_masked_bias: Optional[torch.Tensor],
                kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[List[torch.Tensor]]]:
        head_outputs = [head(x, attn_masked_bias, kv_cache[idx] if kv_cache is not None else None) for idx, head in
                        enumerate(self.attn_heads)]
        return self.o_proj(torch.concat([o[0] for o in head_outputs], dim=2)), [o[1] for o in head_outputs]


class DecoderLayer(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.mha = MultiHeadAttention(config)
        self.up_proj = nn.Linear(config.hidden_size, config.hidden_size * 4, dtype=config.dtype)
        self.down_proj = nn.Linear(config.hidden_size * 4, config.hidden_size, dtype=config.dtype)
        self.ln_mha = nn.LayerNorm(config.hidden_size, dtype=config.dtype)
        self.ln_ffn = nn.LayerNorm(config.hidden_size, dtype=config.dtype)
        self.act = nn.GELU()

    def forward(self, x: torch.Tensor, attn_masked_bias: Optional[torch.Tensor],
                kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[List[torch.Tensor]]]:
        mha_output, new_kv_cache = self.mha(self.ln_mha(x), attn_masked_bias, kv_cache)
        mha_output = x + mha_output
        ffn_output = self.down_proj(self.act(self.up_proj(self.ln_ffn(mha_output))))
        return mha_output + ffn_output, new_kv_cache


class ToyTransformer(nn.Module):
    def __init__(self, vocab_size: int, num_layers: int, num_heads: int, hidden_size: int, max_seq_len: int,
                 device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float32):
        super().__init__()
        self.config = TransformerConfig(vocab_size, num_layers, num_heads, hidden_size, max_seq_len, self, device,
                                        dtype)

        self.sem_embed = nn.Embedding(vocab_size, hidden_size, dtype=dtype)

        self.rope_cache = get_rope_cache_fast(max_seq_len, hidden_size // num_heads, 10000, device, dtype)

        self.decoder_layers = nn.ModuleList([DecoderLayer(self.config) for _ in range(num_layers)])
        self.lm_head = nn.Linear(hidden_size, vocab_size, dtype=dtype)
        self.to(device)

    def forward(self, seq: torch.Tensor,
                attn_mask: Optional[torch.Tensor] = None,
                kv_cache: Optional[List[torch.Tensor]] = None) -> Tuple[torch.Tensor, List[List[List[torch.Tensor]]]]:
        # sanity checks
        assert attn_mask is None or kv_cache is None  # No support for attn_mask and kv_cache both enabled
        if kv_cache is not None:
            assert seq.shape[0] == 1, 'kv_cache is not supported for batch inference'
        # handle flash-attn triton alignment requirement (actually only needed for backward)
        seq_length = seq.shape[1]
        if kv_cache is None and global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton and seq_length % 128 != 0:
            if attn_mask is None:  # forcibly enable attn_mask due to padding
                attn_mask = torch.ones(seq.shape, device=self.device)
            pad_length = (ceil(seq_length / 128) * 128) - seq_length
            seq = nn.functional.pad(seq, (0, pad_length))
            attn_mask = nn.functional.pad(attn_mask, (0, pad_length))

        # handle attn_bias
        if global_config['attn_backend'] == AttentionBackend.FlashAttentionCuda:
            assert attn_mask is None, 'FlashAttn-Cuda does not support custom attn_mask'
            attn_masked_bias = None
        elif global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton and attn_mask is None:
            attn_masked_bias = None
        elif attn_mask is not None:
            attn_masked_bias = expand_attn_mask(attn_mask)
        elif attn_mask is None and kv_cache is None:
            attn_masked_bias = expand_attn_mask(torch.ones(seq.shape, device=self.device))
        elif kv_cache is not None:
            attn_masked_bias = torch.ones((1, seq.shape[1], seq.shape[1]), dtype=torch.bool, device=self.device)
        else:
            attn_masked_bias = None

        if attn_masked_bias is not None:
            mask_zero = torch.tensor(0, dtype=self.config.dtype)
            mask_val = torch.tensor(torch.finfo(self.config.dtype).min / 2, dtype=self.config.dtype)
            attn_masked_bias = torch.where(attn_masked_bias, mask_zero, mask_val).to(self.device)

        hidden = self.sem_embed(seq)

        new_kv_cache = []
        for idx, decoder in enumerate(self.decoder_layers):
            hidden, layer_kv_cache = decoder(hidden, attn_masked_bias, kv_cache[idx] if kv_cache is not None else None)
            new_kv_cache.append(layer_kv_cache)

        logits = self.lm_head(hidden)

        # remove padding for flash-attn triton
        if kv_cache is None and global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton and seq_length % 128 != 0:
            logits = logits[:, :seq_length, :]
            new_kv_cache = [[[cache[:, :seq_length, :] for cache in head] for head in layer] for layer in new_kv_cache]

        return logits, new_kv_cache

    @property
    def device(self):
        return next(self.parameters()).device