from ragatouille import RAGPretrainedModel
import subprocess
import json
import spaces
import firebase_admin
from firebase_admin import credentials, firestore
import logging
from pathlib import Path
from time import perf_counter
from datetime import datetime
import gradio as gr
from jinja2 import Environment, FileSystemLoader
import numpy as np
from sentence_transformers import CrossEncoder
from huggingface_hub import InferenceClient
from os import getenv

from backend.query_llm import generate_hf, generate_openai
from backend.semantic_search import table, retriever
from huggingface_hub import InferenceClient


VECTOR_COLUMN_NAME = "vector"
TEXT_COLUMN_NAME = "text"
HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
proj_dir = Path(__file__).parent
# Setting up the logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1",token=HF_TOKEN)
# Set up the template environment with the templates directory
env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))

# Load the templates directly from the environment
template = env.get_template('template.j2')
template_html = env.get_template('template_html.j2')
#___________________
# service_account_key='firebase.json'
# # Create a Certificate object from the service account info
# cred = credentials.Certificate(service_account_key)
# # Initialize the Firebase Admin 
# firebase_admin.initialize_app(cred)

# # # Create a reference to the Firestore database
# db = firestore.client()
# #db usage
# collection_name = 'Nirvachana'  # Replace with your collection name
# field_name = 'message_count'  # Replace with your field name for count
# Examples
examples = ['Tabulate the difference between cellas and Tissues','What are cell organelles?',
            'Frame 5 short questions and 5 MCQ from tissues ','Suggest creative and engaging ideas to teach students on Chapter on Metals and Non Metals '
            ]



# def get_and_increment_value_count(db , collection_name, field_name):
#     """
#     Retrieves a value count from the specified Firestore collection and field,
#     increments it by 1, and updates the field with the new value."""
#     collection_ref = db.collection(collection_name)
#     doc_ref = collection_ref.document('count_doc')  # Assuming a dedicated document for count

#     # Use a transaction to ensure consistency across reads and writes
#     try:
#         with db.transaction() as transaction:
#             # Get the current value count (or initialize to 0 if it doesn't exist)
#             current_count_doc = doc_ref.get()
#             current_count_data = current_count_doc.to_dict()
#             if current_count_data:
#                 current_count = current_count_data.get(field_name, 0)
#             else:
#                 current_count = 0
#             # Increment the count
#             new_count = current_count + 1
#             # Update the document with the new count
#             transaction.set(doc_ref, {field_name: new_count})
#             return new_count
#     except Exception as e:
#         print(f"Error retrieving and updating value count: {e}")
#         return None  # Indicate error
        
# def update_count_html():
#     usage_count = get_and_increment_value_count(db ,collection_name, field_name)
#     ccount_html = gr.HTML(value=f"""
#     <div style="display: flex; justify-content: flex-end;">
#         <span style="font-weight: bold; color: maroon; font-size: 18px;">No of Usages:</span>
#         <span style="font-weight: bold; color: maroon; font-size: 18px;">{usage_count}</span>
#     </div>
# """)
#     return count_html
    
# def store_message(db,query,answer,cross_encoder):
#     timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
#     # Create a new document reference with a dynamic document name based on timestamp
#     new_completion= db.collection('Nirvachana').document(f"chatlogs_{timestamp}")
#     new_completion.set({
#         'query': query,
#         'answer':answer,
#         'created_time': firestore.SERVER_TIMESTAMP,
#         'embedding': cross_encoder,
#         'title': 'Expenditure observer bot'
#     })


def add_text(history, text):
    history = [] if history is None else history
    history = history + [(text, None)]
    return history, gr.Textbox(value="", interactive=False)


def bot(history, cross_encoder):
    top_rerank = 25
    top_k_rank = 10
    query = history[-1][0]

    if not query:
         gr.Warning("Please submit a non-empty string as a prompt")
         raise ValueError("Empty string was submitted")

    logger.warning('Retrieving documents...')
    
    # if COLBERT RAGATATOUILLE PROCEDURE  : 
    if cross_encoder=='(HIGH ACCURATE) ColBERT':
        gr.Warning('Retrieving using ColBERT.. First time query will take a minute for model to load..pls wait')
        RAG= RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
        RAG_db=RAG.from_index('.ragatouille/colbert/indexes/cbseclass10index')
        documents_full=RAG_db.search(query,k=top_k_rank)
        
        documents=[item['content'] for item in documents_full]
        # Create Prompt
        prompt = template.render(documents=documents, query=query)
        prompt_html = template_html.render(documents=documents, query=query)
    
        generate_fn = generate_hf
    
        history[-1][1] = ""
        for character in generate_fn(prompt, history[:-1]):
            history[-1][1] = character
            yield history, prompt_html
        print('Final history is ',history)
        #store_message(db,history[-1][0],history[-1][1],cross_encoder)
    else:
        # Retrieve documents relevant to query
        document_start = perf_counter()
    
        query_vec = retriever.encode(query)
        logger.warning(f'Finished query vec')
        doc1 = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank)
    
        
    
        logger.warning(f'Finished search')
        documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_rerank).to_list()
        documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
        logger.warning(f'start cross encoder {len(documents)}')
        # Retrieve documents relevant to query
        query_doc_pair = [[query, doc] for doc in documents]
        if cross_encoder=='(FAST) MiniLM-L6v2' :
               cross_encoder1 = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') 
        elif cross_encoder=='(ACCURATE) BGE reranker':
               cross_encoder1 = CrossEncoder('BAAI/bge-reranker-base')
        
        cross_scores = cross_encoder1.predict(query_doc_pair)
        sim_scores_argsort = list(reversed(np.argsort(cross_scores)))
        logger.warning(f'Finished cross encoder {len(documents)}')
        
        documents = [documents[idx] for idx in sim_scores_argsort[:top_k_rank]]
        logger.warning(f'num documents {len(documents)}')
    
        document_time = perf_counter() - document_start
        logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
    
        # Create Prompt
        prompt = template.render(documents=documents, query=query)
        prompt_html = template_html.render(documents=documents, query=query)
    
        generate_fn = generate_hf
    
        history[-1][1] = ""
        for character in generate_fn(prompt, history[:-1]):
            history[-1][1] = character            
            yield history, prompt_html
        print('Final history is ',history)
        #store_message(db,history[-1][0],history[-1][1],cross_encoder)

def system_instructions(question_difficulty, topic,documents_str):
    return f"""<s> [INST] Your are a great teacher and your task is to create 10 questions with 4 choices with a {question_difficulty} difficulty  about topic request " {topic} " only from the below given documents, {documents_str} then create an answers. Index in JSON format, the questions as "Q#":"" to "Q#":"", the four choices as "Q#:C1":"" to "Q#:C4":"", and the answers as "A#":"Q#:C#" to "A#":"Q#:C#". [/INST]"""


#with gr.Blocks(theme='Insuz/SimpleIndigo') as demo:
with gr.Blocks(theme='NoCrypt/miku') as CHATBOT:
    with gr.Row():
        with gr.Column(scale=10):
            # gr.Markdown(
            #     """
            #     # Theme preview: `paris`
            #     To use this theme, set `theme='earneleh/paris'` in `gr.Blocks()` or `gr.Interface()`.
            #     You can append an `@` and a semantic version expression, e.g. @>=1.0.0,<2.0.0 to pin to a given version
            #     of this theme.
            #     """
            # )
            gr.HTML(value="""<div style="color: #FF4500;"><h1>CHEERFULL CBSE-</h1> <h1><span style="color: #008000">AI Assisted Fun Learning</span></h1>
            </div>""", elem_id='heading')
        
            gr.HTML(value=f"""
            <p style="font-family: sans-serif; font-size: 16px;">
              A free Artificial Intelligence  Chatbot assistant trained on CBSE Class 9 Science Notes to engage and help students and teachers of Puducherry.
            </p>
            """, elem_id='Sub-heading')
            #usage_count = get_and_increment_value_count(db,collection_name, field_name)
            gr.HTML(value=f"""<p style="font-family: Arial, sans-serif; font-size: 14px;">Developed by K M Ramyasri , TGT,GHS.SUTHUKENY . Suggestions may be sent to <a href="mailto:ramyadevi1607@yahoo.com" style="color: #00008B; font-style: italic;">ramyadevi1607@yahoo.com</a>.</p>""", elem_id='Sub-heading1 ')

        with gr.Column(scale=3):
            gr.Image(value='logo.png',height=200,width=200)

    
#     gr.HTML(value="""<div style="color: #FF4500;"><h1>CHEERFUL CBSE-</h1> <h1><span style="color: #008000">AI Assisted Fun Learning</span></h1>
#     <img src='logo.png' alt="Chatbot" width="50" height="50" />
#     </div>""", elem_id='heading')

#     gr.HTML(value=f"""
#     <p style="font-family: sans-serif; font-size: 16px;">
#       A free Artificial Intelligence  Chatbot assistant trained on CBSE Class 9 Science Notes to engage and help students and teachers of Puducherry.
#     </p>
#     """, elem_id='Sub-heading')
#     #usage_count = get_and_increment_value_count(db,collection_name, field_name)
#     gr.HTML(value=f"""<p style="font-family: Arial, sans-serif; font-size: 16px;">Developed by K M Ramyasri , PGT . Suggestions may be sent to <a href="mailto:ramyadevi1607@yahoo.com" style="color: #00008B; font-style: italic;">ramyadevi1607@yahoo.com</a>.</p>""", elem_id='Sub-heading1 ')
# #     count_html = gr.HTML(value=f"""
# #     <div style="display: flex; justify-content: flex-end;">
# #         <span style="font-weight: bold; color: maroon; font-size: 18px;">No of Usages:</span>
# #         <span style="font-weight: bold; color: maroon; font-size: 18px;">{usage_count}</span>
# #     </div>
# # """)
   
    chatbot = gr.Chatbot(
            [],
            elem_id="chatbot",
            avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
                           'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
            bubble_full_width=False,
            show_copy_button=True,
            show_share_button=True,
            )

    with gr.Row():
        txt = gr.Textbox(
                scale=3,
                show_label=False,
                placeholder="Enter text and press enter",
                container=False,
                )
        txt_btn = gr.Button(value="Submit text", scale=1)

    cross_encoder = gr.Radio(choices=['(FAST) MiniLM-L6v2','(ACCURATE) BGE reranker','(HIGH ACCURATE) ColBERT'], value='(ACCURATE) BGE reranker',label="Embeddings", info="Only First query to Colbert may take litte time)")

    prompt_html = gr.HTML()
    # Turn off interactivity while generating if you click
    txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
            bot, [chatbot, cross_encoder], [chatbot, prompt_html])#.then(update_count_html,[],[count_html])

    # Turn it back on
    txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)

    # Turn off interactivity while generating if you hit enter
    txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
            bot, [chatbot, cross_encoder], [chatbot, prompt_html])#.then(update_count_html,[],[count_html])

    # Turn it back on
    txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)

    # Examples
    gr.Examples(examples, txt)


RAG_db=gr.State()

with gr.Blocks(title="Quiz Maker", theme=gr.themes.Default(primary_hue="green", secondary_hue="green"), css="style.css") as QUIZBOT:
    def load_model():
        RAG= RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
        RAG_db.value=RAG.from_index('.ragatouille/colbert/indexes/cbseclass10index')
        return 'Ready to Go!!'
    with gr.Column(scale=4):
        gr.HTML("""
    <center>
      <h1><span style="color: purple;">AI NANBAN</span> - CBSE Class Quiz Maker</h1>
      <h2>AI-powered Learning Game</h2>
      <i>⚠️ Students create quiz from any topic /CBSE Chapter ! ⚠️</i>
    </center>
    """)
        #gr.Warning('Retrieving using ColBERT.. First time query will take a minute for model to load..pls wait')
    with gr.Column(scale=2):
        load_btn = gr.Button("Click to Load!🚀")
        load_text=gr.Textbox()
        load_btn.click(load_model,[],load_text)
        
   
    topic = gr.Textbox(label="Enter the Topic for Quiz", placeholder="Write any topic from CBSE notes")

    with gr.Row():
        radio = gr.Radio(
            ["easy", "average", "hard"], label="How difficult should the quiz be?"
        )


    generate_quiz_btn = gr.Button("Generate Quiz!🚀")
    quiz_msg=gr.Textbox()

    question_radios = [gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(
        visible=False), gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(
        visible=False), gr.Radio(visible=False), gr.Radio(visible=False)]

    print(question_radios)

    @spaces.GPU
    @generate_quiz_btn.click(inputs=[radio, topic], outputs=[quiz_msg]+question_radios, api_name="generate_quiz")
    def generate_quiz(question_difficulty, topic):
        top_k_rank=10
        RAG_db_=RAG_db.value
        documents_full=RAG_db_.search(topic,k=top_k_rank)
    
        

        generate_kwargs = dict(
            temperature=0.2,
            max_new_tokens=4000,
            top_p=0.95,
            repetition_penalty=1.0,
            do_sample=True,
            seed=42,
        )
        question_radio_list = []
        count=0
        while count<=3:
            try:
                documents=[item['content'] for item in documents_full]
                document_summaries = [f"[DOCUMENT {i+1}]: {summary}{count}" for i, summary in enumerate(documents)]
                documents_str='\n'.join(document_summaries)
                formatted_prompt = system_instructions(
                    question_difficulty, topic,documents_str)
                print(formatted_prompt)
                pre_prompt = [
                    {"role": "system", "content": formatted_prompt}
                ]
                response = client.text_generation(
                    formatted_prompt, **generate_kwargs, stream=False, details=False, return_full_text=False,
                )
                output_json = json.loads(f"{response}")
                
        
                print(response)
                print('output json', output_json)
        
                global quiz_data
        
                quiz_data = output_json
        
                
        
                for question_num in range(1, 11):
                    question_key = f"Q{question_num}"
                    answer_key = f"A{question_num}"
        
                    question = quiz_data.get(question_key)
                    answer = quiz_data.get(quiz_data.get(answer_key))
        
                    if not question or not answer:
                        continue
        
                    choice_keys = [f"{question_key}:C{i}" for i in range(1, 5)]
                    choice_list = []
                    for choice_key in choice_keys:
                        choice = quiz_data.get(choice_key, "Choice not found")
                        choice_list.append(f"{choice}")
        
                    radio = gr.Radio(choices=choice_list, label=question,
                                     visible=True, interactive=True)
        
                    question_radio_list.append(radio)
                if len(question_radio_list)==10:
                    break
                else:
                    print('10 questions not generated . So trying again!')
                    count+=1
                    continue
            except Exception as e:
                count+=1
                print(f"Exception occurred: {e}")
                if count==3:
                    print('Retry exhausted')
                    gr.Warning('Sorry. Pls try with another topic !')
                else:
                    print(f"Trying again..{count} time...please wait")
                    continue

        print('Question radio list ' , question_radio_list)

        return ['Quiz Generated!']+ question_radio_list

    check_button = gr.Button("Check Score")

    score_textbox = gr.Markdown()

    @check_button.click(inputs=question_radios, outputs=score_textbox)
    def compare_answers(*user_answers):
        user_anwser_list = []
        user_anwser_list = user_answers

        answers_list = []

        for question_num in range(1, 20):
            answer_key = f"A{question_num}"
            answer = quiz_data.get(quiz_data.get(answer_key))
            if not answer:
                break
            answers_list.append(answer)

        score = 0

        for item in user_anwser_list:
            if item in answers_list:
                score += 1
        if score>5:
             message = f"### Good ! You got {score} over 10!"
        elif score>7:
             message = f"### Excellent ! You got {score} over 10!"
        else:
             message = f"### You got {score} over 10! Dont worry . You can prepare well and try better next time !"

        return message



demo = gr.TabbedInterface([CHATBOT,QUIZBOT], ["AI ChatBot", "AI Nanban-Quizbot"])

demo.queue()
demo.launch(debug=True)