File size: 5,703 Bytes
6bc94ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436ce71
15303cb
6bc94ac
436ce71
 
15303cb
 
6bc94ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import datetime
import numpy as np
import torch
import torch.nn.functional as F
import os 
import json
import speech_recognition as sr
import re
import time
import spacy
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
import pickle
import streamlit as st
from sklearn.metrics.pairwise import cosine_similarity
import run_tts

# Build the AI
class CelebBot():
    def __init__(self, name, QA_tokenizer, QA_model, sentTr_tokenizer, sentTr_model, spacy_model, knowledge_sents):
        self.name = name
        print("--- starting up", self.name, "---")
        self.text = ""
        self.QA_tokenizer = QA_tokenizer
        self.QA_model = QA_model

        self.sentTr_tokenizer = sentTr_tokenizer
        self.sentTr_model = sentTr_model
        self.spacy_model = spacy_model

        self.all_knowledge = knowledge_sents

    def speech_to_text(self):
        recognizer = sr.Recognizer()
        with sr.Microphone() as mic:
            recognizer.adjust_for_ambient_noise(mic, duration=1)
            # flag = input("Are you ready to record?\nProceed (Y/n)")
    
            # try:
            #     assert flag=='Y'
            # except:
            #     self.text = ""
            #     print(f"me -->  Permission denied")
            time.sleep(1)

            print("listening")
            audio = recognizer.listen(mic)
            try:
                self.text = recognizer.recognize_google(audio)
            except:
                self.text = ""
                print(f"me -->  No audio recognized")


    def wake_up(self, text):
        return True if "hey " + self.name in text.lower() else False

    def text_to_speech(self, autoplay=True):
        return run_tts.tts(self.text, "_".join(self.name.split(" ")), self.spacy_model, autoplay)

    def sentence_embeds_inference(self, texts: list):
        def _mean_pooling(model_output, attention_mask):
            token_embeddings = model_output[0] #First element of model_output contains all token embeddings
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        # Tokenize sentences
        encoded_input = self.sentTr_tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
        encoded_input["input_ids"] = encoded_input["input_ids"]
        encoded_input["attention_mask"] = encoded_input["attention_mask"]

        # Compute token embeddings
        with torch.no_grad():
            model_output = self.sentTr_model(**encoded_input)

        # Perform pooling
        sentence_embeddings = _mean_pooling(model_output, encoded_input['attention_mask'])

        # Normalize embeddings
        sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

        return sentence_embeddings
    
    def retrieve_knowledge_assertions(self):
        question_embeddings = self.sentence_embeds_inference([self.name + ', ' + self.text])

        all_knowledge_embeddings = self.sentence_embeds_inference(self.all_knowledge)
        similarity = cosine_similarity(all_knowledge_embeddings.cpu(), question_embeddings.cpu())
        similarity = np.reshape(similarity, (1, -1))[0]
        K = min(8, len(self.all_knowledge))
        top_K = np.sort(np.argpartition(similarity, -K)[-K: ])
        all_knowledge_assertions = np.array(self.all_knowledge)[top_K]

        # similarities = np.array(similarity)[top_K]

        # print(*list(zip(all_knowledge_assertions, similarities)), sep='\n')

        return ' '.join(all_knowledge_assertions)

    def question_answer(self, instruction1='', knowledge=''):
        if self.text != "":
            ## wake up
            if self.wake_up(self.text) is True:
                self.text = f"Hello I am {self.name} the AI, what can I do for you?"
            ## have a conversation
            else:
                if re.search(re.compile(rf'\b(you|your|{self.name})\b', flags=re.IGNORECASE), self.text) != None:              
                    instruction1 = f'You are a celebrity named {self.name}. You need to answer the question based on knowledge and commonsense.'

                    knowledge = self.retrieve_knowledge_assertions()
                else:
                    instruction1 = f'You need to answer the question based on commonsense.'
                query = f"Context: {instruction1} {knowledge}\n\nQuestion: {self.text}\n\nAnswer:"
                input_ids = self.QA_tokenizer(f"{query}", return_tensors="pt").input_ids
                outputs = self.QA_model.generate(input_ids, max_length=1024)
                self.text = self.QA_tokenizer.decode(outputs[0], skip_special_tokens=True)

            #     instruction2 = f'[Instruction] You are a celebrity named {self.name}. You need to answer the question based on knowledge'
            #     query = f"{instruction2} [knowledge] {self.text} {answer} [question] {self.name}, {self.text}"
            #     input_ids = self.QA_tokenizer(f"{query}", return_tensors="pt").input_ids
            #     outputs = self.QA_model.generate(input_ids, max_length=1024)
            #     self.text = self.QA_tokenizer.decode(outputs[0], skip_special_tokens=True)
            
        return self.text

    @staticmethod
    def action_time():
        return f"it's {datetime.datetime.now().time().strftime('%H:%M')}"

    @staticmethod 
    def save_kb(kb, filename):
        with open(filename, "wb") as f:
            pickle.dump(kb, f)

    @staticmethod
    def load_kb(filename):
        res = None
        with open(filename, "rb") as f:
            res = pickle.load(f)
        return res