import argparse
import sys
import torch
import torch.nn as nn
from PIL import Image
from transformers import (
    AutoTokenizer,
    BitsAndBytesConfig,
    LlamaForCausalLM,
    SiglipImageProcessor,
    SiglipVisionModel,
)
from transformers import TextStreamer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def tokenizer_image_token(prompt, tokenizer, image_token_index=-200):
    prompt_chunks = prompt.split("<image>")
    tokenized_chunks = [tokenizer(chunk).input_ids for chunk in prompt_chunks]
    input_ids = tokenized_chunks[0]

    for chunk in tokenized_chunks[1:]:
        input_ids.append(image_token_index)
        input_ids.extend(chunk[1:])  # Exclude BOS token on nonzero index

    return torch.tensor(input_ids, dtype=torch.long)


def process_tensors(input_ids, image_features, embedding_layer):
    # Find the index of -200 in input_ids
    split_index = (input_ids == -200).nonzero(as_tuple=True)[1][0]

    # Split the input_ids at the index found, excluding -200
    input_ids_1 = input_ids[:, :split_index]
    input_ids_2 = input_ids[:, split_index + 1 :]

    # Convert input_ids to embeddings
    embeddings_1 = embedding_layer(input_ids_1)
    embeddings_2 = embedding_layer(input_ids_2)

    device = image_features.device
    token_embeddings_part1 = embeddings_1.to(device)
    token_embeddings_part2 = embeddings_2.to(device)

    # Concatenate the token embeddings and image features
    concatenated_embeddings = torch.cat(
        [token_embeddings_part1, image_features, token_embeddings_part2], dim=1
    )

    # Create the corrected attention mask
    attention_mask = torch.ones(
        concatenated_embeddings.shape[:2], dtype=torch.long, device=device
    )
    return concatenated_embeddings, attention_mask


def initialize_models():
    # bnb_config = BitsAndBytesConfig(
    #     load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
    # )

    tokenizer = AutoTokenizer.from_pretrained(
        #"unsloth/llama-3-8b-Instruct", 
        "E:\Workspace\BAP\LlamaVision\llama-3-8b-Instruct",
        use_fast=True
    )
    model = LlamaForCausalLM.from_pretrained(
        #"unsloth/llama-3-8b-Instruct",
        "E:\Workspace\BAP\LlamaVision\llama-3-8b-Instruct",
        torch_dtype=torch.float16,
        device_map="auto"
        #quantization_config=bnb_config,
    )

    for param in model.base_model.parameters():
        param.requires_grad = False

    #model_name = "google/siglip-so400m-patch14-384"
    model_name = "E:\Workspace\BAP\LlamaVision\siglip-so400m-patch14-384"
    vision_model = SiglipVisionModel.from_pretrained(
        model_name, torch_dtype=torch.float16
    )
    processor = SiglipImageProcessor.from_pretrained(model_name)

    vision_model = vision_model.to("cpu")

    return tokenizer, model, vision_model, processor


class ProjectionModule(nn.Module):
    def __init__(self, mm_hidden_size, hidden_size):
        super(ProjectionModule, self).__init__()

        # Directly set up the sequential model
        self.model = nn.Sequential(
            nn.Linear(mm_hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
        )

    def forward(self, x):
        return self.model(x)


def load_projection_module(mm_hidden_size=1152, hidden_size=4096, device="cpu"):
    projection_module = ProjectionModule(mm_hidden_size, hidden_size)
    checkpoint = torch.load("./mm_projector.bin")
    checkpoint = {k.replace("mm_projector.", ""): v for k, v in checkpoint.items()}
    projection_module.load_state_dict(checkpoint)
    projection_module = projection_module.to(device).half()
    return projection_module


def answer_question(
    image_path, tokenizer, model, vision_model, processor, projection_module
):
    image = Image.open(image_path).convert("RGB")
    
    tokenizer.eos_token = "<|eot_id|>"

    try:
        q = input("\nuser: ")
    except EOFError:
        q = ""
    if not q:
        print("no input detected. exiting.")
        sys.exit()


    question = "<image>" + q

    prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"

    input_ids = (
        tokenizer_image_token(prompt, tokenizer)
        .unsqueeze(0)
        .to(model.device)
    )

    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    with torch.inference_mode():
        image_inputs = processor(
            images=[image],
            return_tensors="pt",
            do_resize=True,
            size={"height": 384, "width": 384},
        ).to("cpu")

        image_inputs = image_inputs["pixel_values"].squeeze(0)

        image_forward_outs = vision_model(
            image_inputs.to(device="cpu", dtype=torch.float16).unsqueeze(0),
            output_hidden_states=True,
        )

        image_features = image_forward_outs.hidden_states[-2]

        projected_embeddings = projection_module(image_features).to("cpu")

        embedding_layer = model.get_input_embeddings()
        # text_embeddings = embedding_layer(input_ids)

        new_embeds, attn_mask = process_tensors(
            input_ids, projected_embeddings, embedding_layer
        )
        device = model.device
        attn_mask = attn_mask.to(device)
        new_embeds = new_embeds.to(device)

        model_kwargs = {
            "do_sample": True,
            "temperature": 0.2,
            "max_new_tokens": 2000,
            "use_cache": True,
            "streamer": streamer,
            "pad_token_id": tokenizer.eos_token_id
        }

        while True:
            print('assistant: ')
            generated_ids = model.generate(
                inputs_embeds=new_embeds, attention_mask=attn_mask, **model_kwargs
            )[0]

            generated_text = tokenizer.decode(generated_ids, skip_special_tokens=False)
            try:
                q = input("\nuser: ")
            except EOFError:
                q = ""
            if not q:
                print("no input detected. exiting.")

            new_text = (
                generated_text
                + "<|start_header_id|>user<|end_header_id|>\n\n"
                + q
                + "<|start_header_id|>assistant<|end_header_id|>\n\n"
            )
            new_input_ids = tokenizer(new_text, return_tensors="pt").input_ids.to(
                device
            )
            new_embeddings = embedding_layer(new_input_ids)

            new_embeds = torch.cat([new_embeds, new_embeddings], dim=1)
            attn_mask = torch.ones(new_embeds.shape[:2], device=device)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Answer questions based on an image")
    parser.add_argument("-i", "--image", required=True, help="Path to the image file")
    args = parser.parse_args()

    tokenizer, model, vision_model, processor = initialize_models()
    projection_module = load_projection_module()

    answer_question(
        args.image,
        tokenizer,
        model,
        vision_model,
        processor,
        projection_module,
    )