Spaces:
Runtime error
Runtime error
import datetime | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import speech_recognition as sr | |
import re | |
import time | |
import pickle | |
from sklearn.metrics.pairwise import cosine_similarity | |
# Build the AI | |
class CelebBot(): | |
def __init__(self, name, gender, QA_tokenizer, QA_model, sentTr_tokenizer, sentTr_model, spacy_model, knowledge_sents, top_k = 8): | |
self.name = name | |
self.gender = gender | |
print("--- starting up", self.name, self.gender, "---") | |
self.text = "" | |
self.QA_tokenizer = QA_tokenizer | |
self.QA_model = QA_model | |
self.sentTr_tokenizer = sentTr_tokenizer | |
self.sentTr_model = sentTr_model | |
self.spacy_model = spacy_model | |
self.all_knowledge = knowledge_sents | |
self.top_k = top_k | |
def speech_to_text(self): | |
recognizer = sr.Recognizer() | |
with sr.Microphone() as mic: | |
recognizer.adjust_for_ambient_noise(mic, duration=1) | |
# flag = input("Are you ready to record?\nProceed (Y/n)") | |
# try: | |
# assert flag=='Y' | |
# except: | |
# self.text = "" | |
# print(f"me --> Permission denied") | |
time.sleep(1) | |
print("listening") | |
audio = recognizer.listen(mic) | |
try: | |
self.text = recognizer.recognize_google(audio) | |
except: | |
self.text = "" | |
print(f"me --> No audio recognized") | |
def text_to_speech(self, autoplay=True): | |
import run_tts | |
return run_tts.tts(self.text, "_".join(self.name.split(" ")), self.spacy_model, autoplay) | |
def sentence_embeds_inference(self, texts: list): | |
def _mean_pooling(model_output, attention_mask): | |
token_embeddings = model_output[0] #First element of model_output contains all token embeddings | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
# Tokenize sentences | |
encoded_input = self.sentTr_tokenizer(texts, padding=True, truncation=True, return_tensors='pt') | |
encoded_input["input_ids"] = encoded_input["input_ids"].to(self.sentTr_model.device) | |
encoded_input["attention_mask"] = encoded_input["attention_mask"].to(self.sentTr_model.device) | |
# Compute token embeddings | |
with torch.no_grad(): | |
model_output = self.sentTr_model(**encoded_input) | |
# Perform pooling | |
sentence_embeddings = _mean_pooling(model_output, encoded_input['attention_mask']) | |
# Normalize embeddings | |
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) | |
return sentence_embeddings | |
def retrieve_knowledge_assertions(self, change_person=True): | |
question_embeddings = self.sentence_embeds_inference([self.text]) | |
all_knowledge_embeddings = self.sentence_embeds_inference(self.all_knowledge) | |
similarity = cosine_similarity(all_knowledge_embeddings.cpu(), question_embeddings.cpu()) | |
similarity = np.reshape(similarity, (1, -1))[0] | |
K = min(self.top_k, len(self.all_knowledge)) | |
top_K = np.sort(np.argpartition(similarity, -K)[-K: ]) | |
all_knowledge_assertions = np.array(self.all_knowledge)[top_K] | |
# similarities = np.array(similarity)[top_K] | |
# print(*all_knowledge_assertions, sep='\n') | |
if change_person: | |
all_knowledge_assertions = [self.third_to_first_person(sent) for sent in all_knowledge_assertions] | |
return " ".join(all_knowledge_assertions) | |
def third_to_first_person(self, text): | |
text = text.replace(" ", " ") | |
possible_names = [name.lower() for name in self.name.split(" ")] | |
if "bundchen" in self.name.lower(): | |
possible_names.append("bündchen") | |
if "beyonce" in self.name.lower(): | |
possible_names.append("beyoncé") | |
if "adele" in self.name.lower(): | |
possible_names.append("adkins") | |
if "katy perry" in self.name.lower(): | |
possible_names.append("hudson") | |
if "lady gaga" in self.name.lower(): | |
possible_names.append("germanotta") | |
if "michelle obama" in self.name.lower(): | |
possible_names.append("robinson") | |
if "natalie portman" in self.name.lower(): | |
possible_names.append("hershlag") | |
if "rihanna" in self.name.lower(): | |
possible_names.append("fenty") | |
if "victoria beckham" in self.name.lower(): | |
possible_names.append("adams") | |
doc = self.spacy_model(text) | |
transformed_text = [] | |
for i, token in enumerate(doc): | |
if self.gender == "M": | |
if token.text.lower() == "he": | |
transformed_text.append("I") | |
elif token.text.lower() == "him": | |
transformed_text.append("me") | |
elif token.text.lower() == "his": | |
transformed_text.append("my") | |
elif token.text.lower() == "himself": | |
transformed_text.append("myself") | |
elif token.text.lower() in possible_names and token.dep_ in ["nsubj", "nsubjpass"]: | |
transformed_text.append("I") | |
elif token.text in ["'s", "’s"] and doc[i-1].text.lower() in possible_names: | |
transformed_text[-1] = "my" | |
elif token.text.lower() in possible_names and token.dep_ in ["dobj", "dative"]: | |
transformed_text.append("me") | |
elif token.text.lower() == "their": | |
transformed_text.append("our") | |
elif token.text.lower() == "they": | |
transformed_text.append("we") | |
else: | |
transformed_text.append(token.text) | |
elif self.gender == "F": | |
if token.text.lower() == "she": | |
transformed_text.append("I") | |
elif token.text.lower() == "her": | |
if i < len(doc)-2 and doc[i+2].dep_ in ["nsubj", "nsubjpass", "dobj", "appos", "dative", "attr", "amod", "nummod", "compound", "pobj", "pcomp"]: | |
transformed_text.append("my") | |
else: | |
transformed_text.append("me") | |
elif token.text.lower() == "herself": | |
transformed_text.append("myself") | |
elif token.text.lower() in possible_names and token.dep_ in ["nsubj", "nsubjpass"]: | |
transformed_text.append("I") | |
elif token.text in ["'s", "’s"] and doc[i-1].text.lower() in possible_names: | |
transformed_text[-1] = "my" | |
elif token.text.lower() in possible_names and token.dep_ in ["dobj", "dative"]: | |
transformed_text.append("me") | |
elif token.text.lower() == "their": | |
transformed_text.append("our") | |
elif token.text.lower() == "they": | |
transformed_text.append("we") | |
else: | |
transformed_text.append(token.text) | |
return "".join(transformed_text) | |
def question_answer(self, instruction='', knowledge='', chat_his=''): | |
instruction = f"Your name is {self.name}. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information." | |
if self.text != "": | |
if re.search(re.compile(rf'\b({self.name})\b', flags=re.IGNORECASE), self.text) != None: | |
knowledge = self.retrieve_knowledge_assertions(change_person=False) | |
else: | |
knowledge = self.retrieve_knowledge_assertions() | |
query = f"Context: {instruction} {knowledge}\n\nChat History: {chat_his}Question: {self.text}\n\nAnswer:" | |
input_ids = self.QA_tokenizer(f"{query}", truncation=False, return_tensors="pt").input_ids.to(self.QA_model.device) | |
outputs = self.QA_model.generate(input_ids, max_length=1024, min_length=8, repetition_penalty=2.5) | |
self.text = self.QA_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return self.text | |
def action_time(): | |
return f"it's {datetime.datetime.now().time().strftime('%H:%M')}" | |
def save_kb(kb, filename): | |
with open(filename, "wb") as f: | |
pickle.dump(kb, f) | |
def load_kb(filename): | |
res = None | |
with open(filename, "rb") as f: | |
res = pickle.load(f) | |
return res |