from typing import Tuple, Union, Dict, List

from multi_amr.data.postprocessing_graph import ParsedStatus
from multi_amr.data.tokenization import AMRTokenizerWrapper
from optimum.bettertransformer import BetterTransformer
import penman
import streamlit as st
import torch
from torch.quantization import quantize_dynamic
from torch import nn, qint8
from transformers import MBartForConditionalGeneration, AutoConfig


@st.cache_resource(show_spinner=False)
def get_resources(multilingual: bool, src_lang: str, quantize: bool = True, no_cuda: bool = False) -> Tuple[MBartForConditionalGeneration, AMRTokenizerWrapper]:
    """Get the relevant model, tokenizer and logits_processor. The loaded model depends on whether the multilingual
    model is requested, or not. If not, an English-only model is loaded. The model can be optionally quantized
    for better performance.

    :param multilingual: whether to load the multilingual model or not
    :param src_lang: source language
    :param quantize: whether to quantize the model with PyTorch's 'quantize_dynamic'
    :param no_cuda: whether to disable CUDA, even if it is available
    :return: the loaded model, and tokenizer wrapper
    """
    model_name = "BramVanroy/mbart-large-cc25-ft-amr30-en_es_nl"
    if not multilingual:
        if src_lang == "English":
            model_name = "BramVanroy/mbart-large-cc25-ft-amr30-en"
        elif src_lang == "Spanish":
            model_name = "BramVanroy/mbart-large-cc25-ft-amr30-es"
        elif src_lang == "Dutch":
            model_name = "BramVanroy/mbart-large-cc25-ft-amr30-nl"
        else:
            raise ValueError(f"Language {src_lang} not supported")

    # Tokenizer src_lang is reset during translation to the right language
    tok_wrapper = AMRTokenizerWrapper.from_pretrained(model_name, src_lang="en_XX")

    config = AutoConfig.from_pretrained(model_name)
    config.decoder_start_token_id = tok_wrapper.amr_token_id

    model = MBartForConditionalGeneration.from_pretrained(model_name, config=config)
    model.eval()

    embedding_size = model.get_input_embeddings().weight.shape[0]
    if len(tok_wrapper.tokenizer) > embedding_size:
        model.resize_token_embeddings(len(tok_wrapper.tokenizer))

    model = BetterTransformer.transform(model, keep_original_model=False)

    if torch.cuda.is_available() and not no_cuda:
        model = model.to("cuda")
    elif quantize:  # Quantization not supported on CUDA
        model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8)

    return model, tok_wrapper


def translate(texts: List[str], src_lang: str, model: MBartForConditionalGeneration, tok_wrapper: AMRTokenizerWrapper, **gen_kwargs) -> Dict[str, List[Union[penman.Graph, ParsedStatus]]]:
    """Translates a given text of a given source language with a given model and tokenizer. The generation is guided by
    potential keyword-arguments, which can include arguments such as max length, logits processors, etc.

    :param texts: source text to translate (potentially a batch)
    :param src_lang: source language
    :param model: MBART model
    :param tok_wrapper: MBART tokenizer wrapper
    :param gen_kwargs: potential keyword arguments for the generation process
    :return: the translation (linearized AMR graph)
    """
    if isinstance(texts, str):
        texts = [texts]

    tok_wrapper.src_lang = LANGUAGES[src_lang]
    encoded = tok_wrapper(texts, return_tensors="pt").to(model.device)
    with torch.no_grad():
        generated = model.generate(**encoded, output_scores=True, return_dict_in_generate=True, **gen_kwargs)

    generated["sequences"] = generated["sequences"].cpu()
    generated["sequences_scores"] = generated["sequences_scores"].cpu()
    best_scoring_results = {"graph": [], "status": []}
    beam_size = gen_kwargs["num_beams"]

    # Select the best item from the beam: the sequence with best status and highest score
    for sample_idx in range(0, len(generated["sequences_scores"]), beam_size):
        sequences = generated["sequences"][sample_idx: sample_idx + beam_size]
        scores = generated["sequences_scores"][sample_idx: sample_idx + beam_size].tolist()
        outputs = tok_wrapper.batch_decode_amr_ids(sequences)
        statuses = outputs["status"]
        graphs = outputs["graph"]
        zipped = zip(statuses, scores, graphs)
        # Lowest status first (OK=0, FIXED=1, BACKOFF=2), highest score second
        best = sorted(zipped, key=lambda item: (item[0].value, -item[1]))[0]
        best_scoring_results["graph"].append(best[2])
        best_scoring_results["status"].append(best[0])

    # Returns dictionary with "graph" and "status" keys
    return best_scoring_results


LANGUAGES = {
    "English": "en_XX",
    "Dutch": "nl_XX",
    "Spanish": "es_XX",
}