import gradio as gr
import random
import time

from langchain import PromptTemplate
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Pinecone
from langchain.chains import LLMChain
from langchain.chains.question_answering import load_qa_chain
import pinecone

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

#OPENAI_API_KEY = ""
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
OPENAI_TEMP  = 1
OPENAI_API_LINK = "[OpenAI API Key](https://platform.openai.com/account/api-keys)"
OPENAI_LINK = "[OpenAI](https://openai.com)"

PINECONE_KEY = os.environ.get("PINECONE_KEY", "")
PINECONE_ENV = os.environ.get("PINECONE_ENV", "asia-northeast1-gcp")
PINECONE_INDEX = os.environ.get("PINECONE_INDEX", '3gpp-r16')

PINECONE_LINK  = "[Pinecone](https://www.pinecone.io)"
LANGCHAIN_LINK = "[LangChain](https://python.langchain.com/en/latest/index.html)"

EMBEDDING_MODEL = os.environ.get("PINECONE_INDEX", "sentence-transformers/all-mpnet-base-v2")

# return top-k text chunks from vector store
TOP_K_DEFAULT = 15
TOP_K_MAX = 30
SCORE_DEFAULT = 0.3


BUTTON_MIN_WIDTH = 215

LLM_NULL = "LLM-UNLOAD-critical"
LLM_DONE = "LLM-LOADED-9cf"

DB_NULL = "DB-UNLOAD-critical"
DB_DONE = "DB-LOADED-9cf"

FORK_BADGE = "Fork-HuggingFace Space-9cf"


def get_logo(inputs, logo) -> str:
    return f"""https://img.shields.io/badge/{inputs}?style=flat&logo={logo}&logoColor=white"""

def get_status(inputs, logo, pos) -> str:
    return f"""<img
    src   = "{get_logo(inputs, logo)}";
    style = "margin: 0 auto;float:{pos};border: 2px solid transparent;";
    >"""
    

KEY_INIT   = "Initialize Model"
KEY_SUBMIT = "Submit"
KEY_CLEAR  = "Clear"

MODEL_NULL = get_status(LLM_NULL, "openai", "right")
MODEL_DONE = get_status(LLM_DONE, "openai", "right")

DOCS_NULL = get_status(DB_NULL, "processingfoundation", "right")
DOCS_DONE = get_status(DB_DONE, "processingfoundation", "right")

TAB_1 = "Chatbot"
TAB_2 = "Details"
TAB_3 = "Database"



FAVICON = './icon.svg'

LLM_LIST = ["gpt-3.5-turbo", "text-davinci-003"]


DOC_1 = '3GPP'
DOC_2 = 'HTTP2'

DOC_SUPPORTED = [DOC_1]
DOC_DEFAULT   = [DOC_1]
DOC_LABEL = "Reference Docs"


MODEL_WARNING = f"Please paste your **{OPENAI_API_LINK}** and then **{KEY_INIT}**"

DOCS_WARNING = f"""Database Unloaded
Please check your **{TAB_3}** config and then **{KEY_INIT}**
Or you could uncheck **{DOC_LABEL}** to ask LLM directly"""


webui_title = """
# OpenAI Chatbot Based on Vector Database
"""

dup_link = f'''<a href="https://huggingface.co/spaces/ShawnAI/3GPP-ChatBot?duplicate=true"
style="display:grid; width: 200px;">
<img src="{get_logo(FORK_BADGE, "addthis")}"></a>'''

init_message = f"""This demonstration website is based on \
**{OPENAI_LINK}** with **{LANGCHAIN_LINK}** and **{PINECONE_LINK}**
    1. Insert your **{OPENAI_API_LINK}** and click  `{KEY_INIT}`
    2. Insert your **Question** and click  `{KEY_SUBMIT}`
"""

PROMPT_DOC = PromptTemplate(
    input_variables=["context", "chat_history", "question"],
    template="""Context:
##
{context}
##

Chat History:
##
{chat_history}
##

Question:
{question}

Optinal:
Don't use standalone clause/figure name in the answer, expand it with corresponding metadata TS name

Desired format:
Clause/figure name: <dot_separated_numbers>
TS name: [\w\.]

Answer:"""
)

PROMPT_BASE = PromptTemplate(
    input_variables=['question', "chat_history"],
    template="""Chat History:
##
{chat_history}
##

Question:
##
{question}
##

Answer:"""
)

#----------------------------------------------------------------------------------------------------------
#----------------------------------------------------------------------------------------------------------

def init_model(api_key, emb_name, db_api_key, db_env, db_index):
    try:
        if not (api_key and api_key.startswith("sk-") and len(api_key) > 50):
            return None,MODEL_NULL+DOCS_NULL,None,None,None,None
        


        llm_dict = {}
        for llm_name in LLM_LIST:
            if llm_name == "gpt-3.5-turbo":
                llm_dict[llm_name] = ChatOpenAI(model_name=llm_name,
                                                temperature = OPENAI_TEMP,
                                                openai_api_key = api_key
                                                )
            else:
                llm_dict[llm_name] = OpenAI(model_name=llm_name,
                                            temperature = OPENAI_TEMP,
                                            openai_api_key = api_key)
                    
        if not (emb_name and db_api_key and db_env and db_index):
            return api_key,MODEL_DONE+DOCS_NULL,llm_dict,None,None,None
            
        embeddings = HuggingFaceEmbeddings(model_name=emb_name)

        pinecone.init(api_key     = db_api_key,
                      environment = db_env)
        db = Pinecone.from_existing_index(index_name = db_index,
                                          embedding  = embeddings)

        return api_key, MODEL_DONE+DOCS_DONE, llm_dict, None, db, None
        
    except Exception as e:
        print(e)
        return None,MODEL_NULL+DOCS_NULL,None,None,None,None


def get_chat_history(inputs) -> str:
    res = []
    for human, ai in inputs:
        res.append(f"Q: {human}\nA: {ai}")
    return "\n".join(res)

def remove_duplicates(documents, score_min):
    seen_content = set()
    unique_documents = []
    for (doc, score) in documents:
        if (doc.page_content not in seen_content) and (score >= score_min):
            seen_content.add(doc.page_content)
            unique_documents.append(doc)
    return unique_documents

def doc_similarity(query, db, top_k, score):
    docs = db.similarity_search_with_score(query = query,
                                           k=top_k)
    #docsearch = db.as_retriever(search_kwargs={'k':top_k})
    #docs = docsearch.get_relevant_documents(query)
    # print(docs)
    udocs = remove_duplicates(docs, score)
    return udocs

def user(user_message, history):
    return "", history+[[user_message, None]]

def bot(box_message, ref_message,
        llm_dropdown, llm_dict, doc_list,
        db, top_k, score):

    # bot_message = random.choice(["Yes", "No"])
    # 0 is user question, 1 is bot response
    question = box_message[-1][0]
    history  = box_message[:-1]
    
    if (not llm_dict):
        box_message[-1][1] = MODEL_WARNING
        return box_message, "", ""

    if not ref_message:
        ref_message = question
        details = f"Q: {question}"
    else:
        details = f"Q: {question}\nR: {ref_message}"
        
        
    llm = llm_dict[llm_dropdown]
    
    if DOC_1 in doc_list:
        if (not db):
            box_message[-1][1] = DOCS_WARNING
            return box_message, "", ""
        
        docs = doc_similarity(ref_message, db, top_k, score)
        delta_top_k = top_k - len(docs)

        if delta_top_k > 0:
            docs = doc_similarity(ref_message, db, top_k+delta_top_k, score)

        prompt = PROMPT_DOC
        #chain = load_qa_chain(llm, chain_type="stuff")
   
    else:
        prompt = PROMPT_BASE
        docs = []
    
    chain = LLMChain(llm = llm,
                     prompt = prompt,
                     output_key = 'output_text')

    all_output = chain({"question": question,
                        "context": docs,
                        "chat_history": get_chat_history(history)
                       })
        
    
    bot_message = all_output['output_text']

    source = "".join([f"""<details> <summary>{doc.metadata["source"]}</summary>
{doc.page_content}

</details>""" for i, doc in enumerate(docs)])

    #print(source)

    box_message[-1][1] = bot_message
    return box_message, "", [[details, bot_message + '\n\nMetadata:\n' + source]]

#----------------------------------------------------------------------------------------------------------
#----------------------------------------------------------------------------------------------------------

with gr.Blocks(
    title = TAB_1,
    theme = "Base",
    css = """.bigbox {
    min-height:250px;
}
""") as demo:
    llm = gr.State()
    chain_2 = gr.State() # not inuse
    vector_db = gr.State()
    gr.Markdown(webui_title)
    gr.Markdown(dup_link)
    gr.Markdown(init_message)
    
    with gr.Row():
        with gr.Column(scale=10):
            llm_api_textbox = gr.Textbox(
                label = "OpenAI API Key",
                # show_label = False,
                value = OPENAI_API_KEY,
                placeholder = "Paste Your OpenAI API Key (sk-...) and Hit ENTER",
                lines=1,
                type='password')
            
        with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH):
            
            init = gr.Button(KEY_INIT) #.style(full_width=False)
            model_statusbox = gr.HTML(MODEL_NULL+DOCS_NULL)
    
    with gr.Tab(TAB_1):
        with gr.Row():
            with gr.Column(scale=10):
                chatbot = gr.Chatbot(elem_classes="bigbox")
            #with gr.Column(scale=1):
            with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH):
                doc_check = gr.CheckboxGroup(choices = DOC_SUPPORTED,
                                             value   = DOC_DEFAULT,
                                             label   = DOC_LABEL,
                                             interactive=True)
                llm_dropdown = gr.Dropdown(LLM_LIST,
                                           value=LLM_LIST[0],
                                           multiselect=False,
                                           interactive=True,
                                           label="LLM Selection",
                                           )
        with gr.Row():
            with gr.Column(scale=10):
                query = gr.Textbox(label="Question:",
                                   lines=2)
                ref = gr.Textbox(label="Reference(optional):")

            with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH):

                clear = gr.Button(KEY_CLEAR)
                submit = gr.Button(KEY_SUBMIT,variant="primary")
                

    with gr.Tab(TAB_2):
        with gr.Row():
            with gr.Column():
                top_k = gr.Slider(1,
                                  TOP_K_MAX,
                                  value=TOP_K_DEFAULT,
                                  step=1,
                                  label="Vector similarity top_k",
                                  interactive=True)
            with gr.Column():
                score = gr.Slider(0.01,
                                  0.99,
                                  value=SCORE_DEFAULT,
                                  step=0.01,
                                  label="Vector similarity score",
                                  interactive=True)
        detail_panel = gr.Chatbot(label="Related Docs")
    
    with gr.Tab(TAB_3):
        with gr.Row():
            emb_textbox = gr.Textbox(
                label = "Embedding Model",
                # show_label = False,
                value = EMBEDDING_MODEL,
                placeholder = "Paste Your Embedding Model Repo on HuggingFace",
                lines=1,
                interactive=True,
                type='email')
        with gr.Accordion("Pinecone Database for "+DOC_1):
            with gr.Row():
                db_api_textbox = gr.Textbox(
                    label = "Pinecone API Key",
                    # show_label = False,
                    value = PINECONE_KEY,
                    placeholder = "Paste Your Pinecone API Key (xx-xx-xx-xx-xx) and Hit ENTER",
                    lines=1,
                    interactive=True,
                    type='password')
            with gr.Row():
                db_env_textbox = gr.Textbox(
                    label = "Pinecone Environment",
                    # show_label = False,
                    value = PINECONE_ENV,
                    placeholder = "Paste Your Pinecone Environment (xx-xx-xx) and Hit ENTER",
                    lines=1,
                    interactive=True,
                    type='email')
                db_index_textbox = gr.Textbox(
                    label = "Pinecone Index",
                    # show_label = False,
                    value = PINECONE_INDEX,
                    placeholder = "Paste Your Pinecone Index (xxxx) and Hit ENTER",
                    lines=1,
                    interactive=True,
                    type='email')

    init_input  = [llm_api_textbox, emb_textbox, db_api_textbox, db_env_textbox, db_index_textbox]
    init_output = [llm_api_textbox, model_statusbox,
                   llm, chain_2,
                   vector_db, chatbot]
                
    llm_api_textbox.submit(init_model, init_input, init_output)
    init.click(init_model, init_input, init_output)
    
    submit.click(user,
                 [query, chatbot],
                 [query, chatbot],
                 queue=False).then(
        bot,
        [chatbot, ref,
         llm_dropdown, llm, doc_check,
         vector_db, top_k, score],
        [chatbot, ref, detail_panel]
    )
    
    clear.click(lambda: (None,None,None), None, [query, ref, chatbot], queue=False)

#----------------------------------------------------------------------------------------------------------
#----------------------------------------------------------------------------------------------------------
    
if __name__ == "__main__":
    demo.launch(share        = False,
                inbrowser    = True,
                favicon_path = FAVICON)