import time
import os
import torch
from typing import Callable

from dartrs.v2 import (
    V2Model,
    MixtralModel,
    MistralModel,
    compose_prompt,
    LengthTag,
    AspectRatioTag,
    RatingTag,
    IdentityTag,
)
from dartrs.dartrs import DartTokenizer
from dartrs.utils import get_generation_config


import gradio as gr
from gradio.components import Component

try:
    import spaces
except ImportError:

    class spaces:
        def GPU(*args, **kwargs):
            return lambda x: x


from output import UpsamplingOutput

HF_TOKEN = os.getenv("HF_TOKEN", None)

V2_ALL_MODELS = {
    "dart-v2-moe-sft": {
        "repo": "p1atdev/dart-v2-moe-sft",
        "type": "sft",
        "class": MixtralModel,
    },
    "dart-v2-sft": {
        "repo": "p1atdev/dart-v2-sft",
        "type": "sft",
        "class": MistralModel,
    },
}


def prepare_models(model_config: dict):
    model_name = model_config["repo"]
    tokenizer = DartTokenizer.from_pretrained(model_name, auth_token=HF_TOKEN)
    model = model_config["class"].from_pretrained(model_name, auth_token=HF_TOKEN)

    return {
        "tokenizer": tokenizer,
        "model": model,
    }


def normalize_tags(tokenizer: DartTokenizer, tags: str):
    """Just remove unk tokens."""
    return ", ".join([tag for tag in tokenizer.tokenize(tags) if tag != "<|unk|>"])


@torch.no_grad()
def generate_tags(
    model: V2Model,
    tokenizer: DartTokenizer,
    prompt: str,
    ban_token_ids: list[int],
):
    output = model.generate(
        get_generation_config(
            prompt,
            tokenizer=tokenizer,
            temperature=1,
            top_p=0.9,
            top_k=100,
            max_new_tokens=256,
            ban_token_ids=ban_token_ids,
        ),
    )

    return output


def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
    return (
        [f"1{noun}"]
        + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
        + [f"{maximum+1}+{noun}s"]
    )


PEOPLE_TAGS = (
    _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
)


def gen_prompt_text(output: UpsamplingOutput):
    # separate people tags (e.g. 1girl)
    people_tags = []
    other_general_tags = []
    
    for tag in output.general_tags.split(","):
        tag = tag.strip()
        if tag in PEOPLE_TAGS:
            people_tags.append(tag)
        else:
            other_general_tags.append(tag)

    return ", ".join(
        [
            part.strip()
            for part in [
                *people_tags,
                output.character_tags,
                output.copyright_tags,
                *other_general_tags,
                output.upsampled_tags,
                output.rating_tag,
            ]
            if part.strip() != ""
        ]
    )


def elapsed_time_format(elapsed_time: float) -> str:
    return f"Elapsed: {elapsed_time:.2f} seconds"


def parse_upsampling_output(
    upsampler: Callable[..., UpsamplingOutput],
):
    def _parse_upsampling_output(*args) -> tuple[str, str, dict]:
        output = upsampler(*args)

        return (
            gen_prompt_text(output),
            elapsed_time_format(output.elapsed_time),
            gr.update(interactive=True),
            gr.update(interactive=True),
        )

    return _parse_upsampling_output


class V2UI:
    model_name: str | None = None
    model: V2Model
    tokenizer: DartTokenizer

    input_components: list[Component] = []
    generate_btn: gr.Button

    def on_generate(
        self,
        model_name: str,
        copyright_tags: str,
        character_tags: str,
        general_tags: str,
        rating_tag: RatingTag,
        aspect_ratio_tag: AspectRatioTag,
        length_tag: LengthTag,
        identity_tag: IdentityTag,
        ban_tags: str,
        *args,
    ) -> UpsamplingOutput:
        if self.model_name is None or self.model_name != model_name:
            models = prepare_models(V2_ALL_MODELS[model_name])
            self.model = models["model"]
            self.tokenizer = models["tokenizer"]
            self.model_name = model_name

        # normalize tags
        # copyright_tags = normalize_tags(self.tokenizer, copyright_tags)
        # character_tags = normalize_tags(self.tokenizer, character_tags)
        # general_tags = normalize_tags(self.tokenizer, general_tags)

        ban_token_ids = self.tokenizer.encode(ban_tags.strip())

        prompt = compose_prompt(
            prompt=general_tags,
            copyright=copyright_tags,
            character=character_tags,
            rating=rating_tag,
            aspect_ratio=aspect_ratio_tag,
            length=length_tag,
            identity=identity_tag,
        )

        start = time.time()
        upsampled_tags = generate_tags(
            self.model,
            self.tokenizer,
            prompt,
            ban_token_ids,
        )
        elapsed_time = time.time() - start

        return UpsamplingOutput(
            upsampled_tags=upsampled_tags,
            copyright_tags=copyright_tags,
            character_tags=character_tags,
            general_tags=general_tags,
            rating_tag=rating_tag,
            aspect_ratio_tag=aspect_ratio_tag,
            length_tag=length_tag,
            identity_tag=identity_tag,
            elapsed_time=elapsed_time,
        )