from celebbot import CelebBot
import streamlit as st
import re
import spacy
import json
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
from utils import *


@st.cache_resource
def get_seq2seq_model(model_id):
    return AutoModelForSeq2SeqLM.from_pretrained(model_id)

@st.cache_resource
def get_auto_model(model_id):
    return AutoModel.from_pretrained(model_id)

@st.cache_resource
def get_tokenizer(model_id):
    return AutoTokenizer.from_pretrained(model_id)

@st.cache_data
def get_celeb_data(fpath):
    with open(fpath) as json_file:
        return json.load(json_file)

@st.cache_resource
def preprocess_text(name, gender, text, model_id):
    lname = name.split(" ")[-1]
    lname_regex = re.compile(rf'\b({lname})\b')
    name_regex = re.compile(rf'\b({name})\b')
    lnames = lname+"’s" if not lname.endswith("s") else lname+"’"
    lnames_regex = re.compile(rf'\b({lnames})\b')
    names = name+"’s" if not name.endswith("s") else name+"’"
    names_regex = re.compile(rf'\b({names})\b')
    if gender == "M":
        text = re.sub(he_regex, "I", text)
        text = re.sub(his_regex, "my", text)
    elif gender == "F":
        text = re.sub(she_regex, "I", text)
        text = re.sub(her_regex, "my", text)
    text = re.sub(names_regex, "my", text)
    text = re.sub(lnames_regex, "my", text)
    text = re.sub(name_regex, "I", text)
    text = re.sub(lname_regex, "I", text)
    spacy_model = spacy.load(model_id)
    texts = [i.text.strip() for i in spacy_model(text).sents]
    return spacy_model, texts

def main():
    hide_footer()
    if "messages" not in st.session_state:
        st.session_state["messages"] = []
    if "QA_model_path" not in st.session_state:          
        st.session_state["QA_model_path"] = "google/flan-t5-base"
    if "sentTr_model_path" not in st.session_state:          
        st.session_state["sentTr_model_path"] = "sentence-transformers/all-mpnet-base-v2"
    if "start_chat" not in st.session_state:          
        st.session_state["start_chat"] = False


    model_list = ["base", "large", "xl", "xxl"]

    for message in st.session_state["messages"]:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

    celeb_data = get_celeb_data(f'data.json')

    # Create a Form Component on the Sidebar for accepting input data and parameters
    celeb_name = st.sidebar.selectbox('Choose a celebrity', options=list(celeb_data.keys()))
    celeb_gender = celeb_data[celeb_name]["gender"]
    knowledge = celeb_data[celeb_name]["knowledge"]
    model_choice = st.sidebar.selectbox("Choose Your Flan-T5 model",options=model_list)
    st.session_state["QA_model_path"] = f"google/flan-t5-{model_choice}"

    #     submitted = st.form_submit_button(label="Start Chatting")
    # if submitted:
    #     st.session_state["start_chat"] = True

        
    # if st.session_state["start_chat"]:

    celeb_bot = CelebBot(celeb_name, 
                         get_tokenizer(st.session_state["QA_model_path"]), 
                         get_seq2seq_model(st.session_state["QA_model_path"]), 
                         get_tokenizer(st.session_state["sentTr_model_path"]), 
                         get_auto_model(st.session_state["sentTr_model_path"]), 
                         *preprocess_text(celeb_name, celeb_gender, knowledge, "en_core_web_sm")
                         )

    prompt = st.chat_input("Say something")
    print(prompt)
    if prompt:
        celeb_bot.text = prompt
        # Display user message in chat message container
        st.chat_message("user").markdown(prompt)
        # Add user message to chat history
        st.session_state["messages"].append({"role": "user", "content": prompt})

        # Add assistant response to chat history
        response = celeb_bot.question_answer()
        
        # disable autoplay to play in HTML
        b64 = celeb_bot.text_to_speech(autoplay=False)
        md = f"""
        <p>{response}</p>
        <audio controls autoplay style="display:none;">
        <source src="data:audio/wav;base64,{b64}" type="audio/wav">
        Your browser does not support the audio element.
        </audio>
        """
        st.chat_message("assistant").markdown(
            md,
            unsafe_allow_html=True,
        )
        # Display assistant response in chat message container
        st.session_state["messages"].append({"role": "assistant", "content": response})


if __name__ == "__main__":
    main()