enigma-escape / nemo.py
myselfshravan
chore: Update API key retrieval method
5cfcb46
import requests
from typing import Union
import re
from collections import namedtuple
from sentence_transformers import SentenceTransformer, util
import streamlit as st
import os
Level = namedtuple("Level", ["name", "description", "hint", "phrase"])
class EnigmaEscape:
API_ENDPOINT = "https://api.endpoints.anyscale.com/v1/chat/completions"
API_KEY = os.getenv("anyscale_api")
forbids = [
"replace",
"swap",
"change",
"modify",
"alter",
"substitute",
]
def body(self, que: str):
self.messages[-1] = {
"role": "user",
"content": que,
}
return {
"model": self.level.model,
"messages": self.messages,
"temperature": 0.5,
"max_tokens": self.level.max_token,
}
def __enter__(self):
self.session = requests.Session()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.session.close()
def __init__(self, levels: list[Level]):
self.valid_model = SentenceTransformer('paraphrase-distilroberta-base-v1')
self.levels = levels
self.level = None
self.messages = []
self.session: Union["requests.Session", None] = None
def set_level(self, lev: int):
self.level = self.levels[lev]
self.messages = [
{
"role": "system",
"content": self.level.description,
},
{
"role": "user",
"content": "",
},
]
def chat(self, que: str):
if any(w in que.lower() for w in self.forbids):
return {
"content": "I wont fall for that",
"type": "warning",
}
if self.regx_validate(que):
return {
"content": "You cannot include the phrase in your question",
"type": "error",
}
body = self.body(que)
resp = self.session.post(
self.API_ENDPOINT,
headers={"Authorization": f"Bearer {self.API_KEY}"},
json=body
).json()
content = resp["choices"][0]["message"]["content"]
if self.resp_validate(content):
return {
"content": content,
"type": "success",
"tokens": resp["usage"]["prompt_tokens"],
}
return {
"content": content,
"type": "info",
}
def resp_validate(self, que: str):
tokens_que = [w for w in re.split(r"[^a-zA-Z0-9]", que) if w]
token_len = len([w for w in re.split(r'[^a-zA-Z0-9]', self.level.phrase) if w])
embed_len = len(tokens_que) - token_len + 1
embeds = self.valid_model.encode(
[self.level.phrase] + [" ".join(tokens_que[i: i + token_len]) for i in range(embed_len)],
convert_to_tensor=True
)
score = 0
for i in range(embed_len):
score = max(util.pytorch_cos_sim(embeds[0], embeds[i + 1]).item(), score)
# print(score, " ".join(tokens_que[i: i + token_len]), self.level.phrase, sep=" | ")
return score > 0.82
def regx_validate(self, que: str):
que_flat = re.sub(r'[^a-zA-Z0-9]', ' ', que).lower()
for word in self.level.phrase.split():
if bool(re.search(f"(^| ){word}( |$)", que_flat)):
return True
return False
if __name__ == '__main__':
pass