import torch
import sys
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
sys.path.append(project_root)

from hyvideo.modules.attenion import attention
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from xfuser.core.distributed import (
    init_distributed_environment,
    initialize_model_parallel,
    # initialize_runtime_state,
)

def init_dist(backend="nccl"):
    local_rank = int(os.environ["LOCAL_RANK"])
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    print(
        f"Initializing distributed environment with rank {rank}, world size {world_size}, local rank {local_rank}"
    )

    torch.cuda.set_device(local_rank)
    init_distributed_environment(rank=rank, world_size=world_size)
    # dist.init_process_group(backend=backend)
       # construct a hybrid sequence parallel config (ulysses=2, ring = world_size // 2)

    if world_size > 1:
        ring_degree = world_size // 2
        ulysses_degree = 2
    else:
        ring_degree = 1
        ulysses_degree = 1
    initialize_model_parallel(
        sequence_parallel_degree=world_size,
        ring_degree=ring_degree,
        ulysses_degree=ulysses_degree,
    )

    return rank, world_size

def test_mm_double_stream_block_attention(rank, world_size):
    device = torch.device(f"cuda:{rank}")
    dtype = torch.bfloat16
    batch_size = 1
    seq_len_img = 118800
    seq_len_txt = 256
    heads_num = 24
    head_dim = 128

    img_q = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
    img_k = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
    img_v = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
    txt_q = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
    txt_k = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
    txt_v = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)

    with torch.no_grad():
        torch.distributed.broadcast(img_q, src=0)
        torch.distributed.broadcast(img_k, src=0)
        torch.distributed.broadcast(img_v, src=0)
        torch.distributed.broadcast(txt_q, src=0)
        torch.distributed.broadcast(txt_k, src=0)
        torch.distributed.broadcast(txt_v, src=0)
        q = torch.cat((img_q, txt_q), dim=1)
        k = torch.cat((img_k, txt_k), dim=1)
        v = torch.cat((img_v, txt_v), dim=1)
        

        cu_seqlens_q = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
        cu_seqlens_kv = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
        max_seqlen_q = 119056
        max_seqlen_kv = 119056
        mode = "torch" # "torch", "vanilla", "flash"

        original_output = attention(
            q,
            k,
            v,
            mode=mode,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_kv=max_seqlen_kv,
            batch_size=batch_size
        )

        hybrid_seq_parallel_attn = xFuserLongContextAttention()
        hybrid_seq_parallel_output = hybrid_seq_parallel_attn(
            None,
            img_q,
            img_k,
            img_v,
            dropout_p=0.0,
            causal=False,
            joint_tensor_query=txt_q,
            joint_tensor_key=txt_k,
            joint_tensor_value=txt_v,
            joint_strategy="rear",
        )

        b, s, a, d = hybrid_seq_parallel_output.shape
        hybrid_seq_parallel_output = hybrid_seq_parallel_output.reshape(b, s, -1)

        assert original_output.shape == hybrid_seq_parallel_output.shape, f"Shape mismatch: {original_output.shape} vs {hybrid_seq_parallel_output.shape}"

        torch.testing.assert_close(original_output, hybrid_seq_parallel_output, rtol=1e-3, atol=1e-3)
        print("test_mm_double_stream_block_attention Passed")

def test_mm_single_stream_block_attention(rank, world_size):
    device = torch.device(f"cuda:{rank}")
    dtype = torch.bfloat16
    txt_len = 256
    batch_size = 1
    seq_len_img = 118800
    seq_len_txt = 256
    heads_num = 24
    head_dim = 128

    with torch.no_grad():   
        img_q = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
        img_k = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
        txt_q = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
        txt_k = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
        v = torch.randn(batch_size, seq_len_img + seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)

        torch.distributed.broadcast(img_q, src=0)
        torch.distributed.broadcast(img_k, src=0)
        torch.distributed.broadcast(txt_q, src=0)
        torch.distributed.broadcast(txt_k, src=0)
        torch.distributed.broadcast(v, src=0)

        q = torch.cat((img_q, txt_q), dim=1)
        k = torch.cat((img_k, txt_k), dim=1)

        cu_seqlens_q = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
        cu_seqlens_kv = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
        max_seqlen_q = 119056
        max_seqlen_kv = 119056
        mode = "torch" # "torch", "vanilla", "flash"

        original_output = attention(
            q,
            k,
            v,
            mode=mode,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_kv=max_seqlen_kv,
            batch_size=batch_size
        )

        hybrid_seq_parallel_attn = xFuserLongContextAttention()
        hybrid_seq_parallel_output = hybrid_seq_parallel_attn(
            None,
            q[:, :-txt_len, :, :],
            k[:, :-txt_len, :, :],
            v[:, :-txt_len, :, :],
            dropout_p=0.0,
            causal=False,
            joint_tensor_query=q[:, -txt_len:, :, :],
            joint_tensor_key=k[:, -txt_len:, :, :],
            joint_tensor_value=v[:, -txt_len:, :, :],
            joint_strategy="rear",
        )
        b, s, a, d = hybrid_seq_parallel_output.shape
        hybrid_seq_parallel_output = hybrid_seq_parallel_output.reshape(b, s, -1)

        assert original_output.shape == hybrid_seq_parallel_output.shape, f"Shape mismatch: {original_output.shape} vs {hybrid_seq_parallel_output.shape}"

        torch.testing.assert_close(original_output, hybrid_seq_parallel_output, rtol=1e-3, atol=1e-3)
        print("test_mm_single_stream_block_attention Passed")

if __name__ == "__main__":
    rank, world_size = init_dist()
    test_mm_double_stream_block_attention(rank, world_size)
    test_mm_single_stream_block_attention(rank, world_size)