import datetime
import numpy as np
import torch
import torch.nn.functional as F
import os 
import json
import speech_recognition as sr
import re
import time
import spacy
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
import pickle
import streamlit as st
from sklearn.metrics.pairwise import cosine_similarity
import run_tts

# Build the AI
class CelebBot():
    def __init__(self, name, QA_tokenizer, QA_model, sentTr_tokenizer, sentTr_model, spacy_model, knowledge_sents):
        self.name = name
        print("--- starting up", self.name, "---")
        self.text = ""
        self.QA_tokenizer = QA_tokenizer
        self.QA_model = QA_model

        self.sentTr_tokenizer = sentTr_tokenizer
        self.sentTr_model = sentTr_model
        self.spacy_model = spacy_model

        self.all_knowledge = knowledge_sents

    def speech_to_text(self):
        recognizer = sr.Recognizer()
        with sr.Microphone() as mic:
            recognizer.adjust_for_ambient_noise(mic, duration=1)
            # flag = input("Are you ready to record?\nProceed (Y/n)")
    
            # try:
            #     assert flag=='Y'
            # except:
            #     self.text = ""
            #     print(f"me -->  Permission denied")
            time.sleep(1)

            print("listening")
            audio = recognizer.listen(mic)
            try:
                self.text = recognizer.recognize_google(audio)
            except:
                self.text = ""
                print(f"me -->  No audio recognized")


    def wake_up(self, text):
        return True if "hey " + self.name in text.lower() else False

    def text_to_speech(self, autoplay=True):
        return run_tts.tts(self.text, "_".join(self.name.split(" ")), self.spacy_model, autoplay)

    def sentence_embeds_inference(self, texts: list):
        def _mean_pooling(model_output, attention_mask):
            token_embeddings = model_output[0] #First element of model_output contains all token embeddings
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        # Tokenize sentences
        encoded_input = self.sentTr_tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
        encoded_input["input_ids"] = encoded_input["input_ids"]
        encoded_input["attention_mask"] = encoded_input["attention_mask"]

        # Compute token embeddings
        with torch.no_grad():
            model_output = self.sentTr_model(**encoded_input)

        # Perform pooling
        sentence_embeddings = _mean_pooling(model_output, encoded_input['attention_mask'])

        # Normalize embeddings
        sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

        return sentence_embeddings
    
    def retrieve_knowledge_assertions(self):
        question_embeddings = self.sentence_embeds_inference([self.name + ', ' + self.text])

        all_knowledge_embeddings = self.sentence_embeds_inference(self.all_knowledge)
        similarity = cosine_similarity(all_knowledge_embeddings.cpu(), question_embeddings.cpu())
        similarity = np.reshape(similarity, (1, -1))[0]
        K = min(8, len(self.all_knowledge))
        top_K = np.sort(np.argpartition(similarity, -K)[-K: ])
        all_knowledge_assertions = np.array(self.all_knowledge)[top_K]

        # similarities = np.array(similarity)[top_K]

        # print(*list(zip(all_knowledge_assertions, similarities)), sep='\n')

        return ' '.join(all_knowledge_assertions)

    def question_answer(self, instruction1='', knowledge=''):
        if self.text != "":
            ## wake up
            if self.wake_up(self.text) is True:
                self.text = f"Hello I am {self.name} the AI, what can I do for you?"
            ## have a conversation
            else:
                if re.search(re.compile(rf'\b(you|your|{self.name})\b', flags=re.IGNORECASE), self.text) != None:              
                    instruction1 = f'You are a celebrity named {self.name}. You need to answer the question based on knowledge and commonsense.'

                    knowledge = self.retrieve_knowledge_assertions()
                else:
                    instruction1 = f'You need to answer the question based on commonsense.'
                query = f"Context: {instruction1} {knowledge}\n\nQuestion: {self.text}\n\nAnswer:"
                input_ids = self.QA_tokenizer(f"{query}", return_tensors="pt").input_ids
                outputs = self.QA_model.generate(input_ids, max_length=1024)
                self.text = self.QA_tokenizer.decode(outputs[0], skip_special_tokens=True)

            #     instruction2 = f'[Instruction] You are a celebrity named {self.name}. You need to answer the question based on knowledge'
            #     query = f"{instruction2} [knowledge] {self.text} {answer} [question] {self.name}, {self.text}"
            #     input_ids = self.QA_tokenizer(f"{query}", return_tensors="pt").input_ids
            #     outputs = self.QA_model.generate(input_ids, max_length=1024)
            #     self.text = self.QA_tokenizer.decode(outputs[0], skip_special_tokens=True)
            
        return self.text

    @staticmethod
    def action_time():
        return f"it's {datetime.datetime.now().time().strftime('%H:%M')}"

    @staticmethod 
    def save_kb(kb, filename):
        with open(filename, "wb") as f:
            pickle.dump(kb, f)

    @staticmethod
    def load_kb(filename):
        res = None
        with open(filename, "rb") as f:
            res = pickle.load(f)
        return res