from transformers import AutoModel, AutoTokenizer
import faiss
import numpy as np
import pandas as pd
import streamlit as st
import torch

import math
import os
import re

os.environ['KMP_DUPLICATE_LIB_OK']='True'


@st.cache(allow_output_mutation=True)
def load_model_and_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained("kaisugi/scitoricsbert")
    model = AutoModel.from_pretrained("kaisugi/scitoricsbert", output_attentions=True)
    model.eval()
    
    return model, tokenizer


@st.cache(allow_output_mutation=True)
def load_sentence_data():
    sentence_df = pd.read_csv("sentence_data_858k.csv.gz")
    
    return sentence_df


@st.cache(allow_output_mutation=True)
def load_sentence_embeddings_and_index():
    npz_comp = np.load("sentence_embeddings_858k.npz")
    sentence_embeddings = npz_comp["arr_0"]

    faiss.normalize_L2(sentence_embeddings)
    D = 768
    N = 857610
    Xt = sentence_embeddings[:100000]
    X = sentence_embeddings

    # Param of PQ
    M = 16  # The number of sub-vector. Typically this is 8, 16, 32, etc.
    nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte
    # Param of IVF
    nlist = int(math.sqrt(N))  # The number of cells (space partition). Typical value is sqrt(N)
    # Param of HNSW
    hnsw_m = 32  # The number of neighbors for HNSW. This is typically 32

    # Setup
    quantizer = faiss.IndexHNSWFlat(D, hnsw_m)
    index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits)

    # Train
    index.train(Xt)

    # Add
    index.add(X)

    # Search
    index.nprobe = 8  # Runtime param. The number of cells that are visited for search.

    return sentence_embeddings, index


@st.cache(allow_output_mutation=True)
def formulaic_phrase_extraction(sentences, model, tokenizer):
    THRESHOLD = 0.01
    LAYER = 10

    output_sentences = []

    with torch.no_grad():
        inputs = tokenizer.batch_encode_plus(
            sentences, 
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        )
        outputs = model(**inputs)
        attention = outputs[-1]

        cls_attentions = torch.mean(attention[LAYER][0], dim=0)

        for sentence, cls_attention in zip(sentences, cls_attentions):
            check_bool_arr = list((cls_attention > THRESHOLD).numpy())[1:-1]
            tokens = tokenizer.tokenize(sentence)

            cur_tokens = tokens.copy()

            while True:
                flg = False

                for idx, token in enumerate(cur_tokens):
                    if token.startswith("##"):
                        flg = True
                        back_token = token.replace("##", "")
                        front_token = cur_tokens.pop(idx-1)
                        cur_tokens[idx-1] = front_token + back_token

                        back_bool_val = check_bool_arr[idx]
                        front_bool_val = check_bool_arr.pop(idx-1)
                        check_bool_arr[idx-1] = front_bool_val and back_bool_val

                if not flg:
                    break

            result = " ".join([f'<font color="coral">{original_word}</font>' if b else original_word for (b, original_word) in zip(check_bool_arr, sentence.split())])
            output_sentences.append(result)

    return output_sentences


@st.cache(allow_output_mutation=True)
def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list, phrase_annotated=True):
    with torch.no_grad():
        inputs = tokenizer.encode_plus(
            input_text,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        )
        outputs = model(**inputs)
        query_embeddings = outputs.last_hidden_state[:, 0, :][0]
        query_embeddings = query_embeddings.detach().cpu().numpy()
        query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, ord=2)

    _, ids = index.search(x=np.array([query_embeddings]), k=top_k)
    retrieved_sentences = []
    retrieved_paper_ids = []

    for id in ids[0]:
        cur_sentence = sentence_df.loc[id, "sentence"]
        cur_link = f"https://aclanthology.org/{sentence_df.loc[id, 'file_id']}"

        if len(exclude_word_list) == 0:
            retrieved_sentences.append(cur_sentence)
            retrieved_paper_ids.append(cur_link)

        else:
            exclude_word_list_regex = '|'.join(exclude_word_list)
            pat = re.compile(f'{exclude_word_list_regex}')
            
            if not bool(pat.search(cur_sentence)):
                retrieved_sentences.append(cur_sentence)
                retrieved_paper_ids.append(cur_link)

    if phrase_annotated:
        retrieved_sentences = formulaic_phrase_extraction(retrieved_sentences, model, tokenizer)

    return retrieved_sentences, retrieved_paper_ids


if __name__ == "__main__":
    model, tokenizer = load_model_and_tokenizer()
    sentence_df = load_sentence_data()
    sentence_embeddings, index = load_sentence_embeddings_and_index()


    st.markdown("## AI-based Paraphrasing for Academic Writing")

    input_text = st.text_area("text input", "Our model shows good results.", placeholder="Write something here...")
    top_k = st.number_input('top_k (upperbound)', min_value=1, value=30, step=1)
    input_words = st.text_input("exclude words (comma separated)", "good, result")

    agree = st.checkbox('Include phrase annotation')

    if st.button('search'):
        exclude_word_list = [s.strip() for s in input_words.split(",") if s.strip() != ""]
        retrieved_sentences, retrieved_paper_ids = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list, phrase_annotated=agree)

        result_table_markdown = "|  sentence  |  source link  |\n|:---|:---|\n"

        for (retrieved_sentence, retrieved_paper_id) in zip(retrieved_sentences, retrieved_paper_ids):
            result_table_markdown += f"| {retrieved_sentence} | {retrieved_paper_id} |\n"
        
        st.markdown(result_table_markdown, unsafe_allow_html=True)

    st.markdown("---\n#### How this works")

    st.markdown("This app uses ScitoricsBERT [(Sugimoto and Aizawa, 2022)](https://aclanthology.org/2022.sdp-1.7/), a functional sentence representation model, to retrieve sentences that are functionally similar to the input. It also extracts phrasal patterns that accord to the function, by leveraging the attention patterns within ScitoricsBERT.")