import chromadb
from llama_index.core.base.embeddings.base import similarity
#from llama_index.llms.ollama import Ollama
from llama_index.llms.groq import Groq
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings, DocumentSummaryIndex
from llama_index.core import StorageContext, get_response_synthesizer
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import load_index_from_storage
import os
from dotenv import load_dotenv
from llama_index.core.callbacks import CallbackManager, LlamaDebugHandler, CBEventType
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.postprocessor import SimilarityPostprocessor
import time
import gradio as gr
from llama_index.core.memory import ChatMemoryBuffer
from llama_parse import LlamaParse
from llama_index.core import PromptTemplate
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core.chat_engine import CondenseQuestionChatEngine


# load env file
load_dotenv()
GROQ_API_KEY = os.getenv('GROQ_API_KEY')
LLAMAINDEX_API_KEY = os.getenv('LLAMAINDEX_API_KEY')

# set up callback manager
llama_debug = LlamaDebugHandler(print_trace_on_end=True)
callback_manager = CallbackManager([llama_debug])
Settings.callback_manager = callback_manager

# set up LLM
llm = Groq(model="llama3-8b-8192")#"llama3-8b-8192")
Settings.llm = llm

# set up embedding model
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
Settings.embed_model = embed_model

# create splitter
splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=50)
Settings.transformations = [splitter]

# create parser
parser = LlamaParse(
    api_key=LLAMAINDEX_API_KEY, 
    result_type="markdown",  # "markdown" and "text" are available
    verbose=True,
)

#create index
if os.path.exists("./vectordb"):
    print("Index Exists!")
    storage_context = StorageContext.from_defaults(persist_dir="./vectordb")
    index = load_index_from_storage(storage_context)
else:
    filename_fn = lambda filename: {"file_name": filename}
    required_exts = [".pdf",".docx"]
    file_extractor = {".pdf": parser}
    reader = SimpleDirectoryReader(
        input_dir="./data",
        file_extractor=file_extractor,
        required_exts=required_exts,
        recursive=True,
        file_metadata=filename_fn
    )
    documents = reader.load_data()
    print("index creating with `%d` documents", len(documents))
    index = VectorStoreIndex.from_documents(documents, embed_model=embed_model, transformations=[splitter])
    index.storage_context.persist(persist_dir="./vectordb")

"""
#create document summary index
if os.path.exists("./docsummarydb"):
    print("Index Exists!")
    storage_context = StorageContext.from_defaults(persist_dir="./docsummarydb")
    doc_index = load_index_from_storage(storage_context)
else:
    filename_fn = lambda filename: {"file_name": filename}
    required_exts = [".pdf",".docx"]
    reader = SimpleDirectoryReader(
        input_dir="./data",
        required_exts=required_exts,
        recursive=True,
        file_metadata=filename_fn
    )
    documents = reader.load_data()
    print("index creating with `%d` documents", len(documents))
    
    response_synthesizer = get_response_synthesizer(
        response_mode="tree_summarize", use_async=True
    )
    doc_index = DocumentSummaryIndex.from_documents(
        documents,
        llm = llm,
        transformations = [splitter],
        response_synthesizer = response_synthesizer,
        show_progress = True
    )
    doc_index.storage_context.persist(persist_dir="./docsummarydb")
"""
"""
retriever = DocumentSummaryIndexEmbeddingRetriever(
    doc_index,
    similarity_top_k=5,
)
"""

# set up retriever
retriever = VectorIndexRetriever(
    index = index,
    similarity_top_k = 10,
    #vector_store_query_mode="mmr",
    #vector_store_kwargs={"mmr_threshold": 0.4}
)

# set up response synthesizer
response_synthesizer = get_response_synthesizer()

### customising prompts worsened the result###
"""
# set up prompt template
qa_prompt_tmpl = (
    "Context information from multiple sources is below.\n"
    "---------------------\n"
    "{context_str}\n"
    "---------------------\n"
    "Given the information from multiple sources and not prior knowledge, "
    "answer the query.\n"
    "Query: {query_str}\n"
    "Answer: "
)
qa_prompt = PromptTemplate(qa_prompt_tmpl)
"""
# setting up query engine
query_engine = RetrieverQueryEngine(
    retriever = retriever,
    node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.53)],
    response_synthesizer=get_response_synthesizer(response_mode="tree_summarize",verbose=True) 
)
print(query_engine.get_prompts())

#response = query_engine.query("What happens if the distributor wants its own warehouse for pizzahood?")
#print(response)


memory = ChatMemoryBuffer.from_defaults(token_limit=10000)

custom_prompt = PromptTemplate(
    """\
    Given a conversation (between Human and Assistant) and a follow up message from Human, \
    rewrite the message to be a standalone question that captures all relevant context \
    from the conversation. If you are unsure, ask for more information.
    <Chat History>
    {chat_history}
    <Follow Up Message>
    {question}
    <Standalone question>
    """
)

# list of `ChatMessage` objects
custom_chat_history = [
    ChatMessage(
        role=MessageRole.USER,
        content="Hello assistant.",
    ),
    ChatMessage(role=MessageRole.ASSISTANT, content="Hello user."),
]

chat_engine = CondenseQuestionChatEngine.from_defaults(
    query_engine=query_engine,
    condense_question_prompt=custom_prompt,
    chat_history=custom_chat_history,
    verbose=True,
    memory=memory
)


# gradio with streaming support
with gr.Blocks() as demo:
    chat_engine = chat_engine
    chatbot = gr.Chatbot()
    msg = gr.Textbox(label="⏎ for sending",
            placeholder="Ask me something",)
    clear = gr.Button("Delete")

    def user(user_message, history):
        return "", history + [[user_message, None]]

    def bot(history):
        user_message = history[-1][0]
        #bot_message = chat_engine.chat(user_message)
        bot_message = query_engine.query(user_message + "Let's think step by step to get the correct answer. If you cannot provide an answer, say you don't know.")
        history[-1][1] = ""
        for character in bot_message.response:
            history[-1][1] += character
            time.sleep(0.01)
            yield history

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, chatbot, chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)
# demo.queue()
demo.launch(share=False)