Spaces:
Sleeping
Sleeping
import requests | |
from typing import Union | |
import re | |
from collections import namedtuple | |
from sentence_transformers import SentenceTransformer, util | |
import streamlit as st | |
import os | |
Level = namedtuple("Level", ["name", "description", "hint", "phrase"]) | |
class EnigmaEscape: | |
API_ENDPOINT = "https://api.endpoints.anyscale.com/v1/chat/completions" | |
API_KEY = os.getenv("anyscale_api") | |
forbids = [ | |
"replace", | |
"swap", | |
"change", | |
"modify", | |
"alter", | |
"substitute", | |
] | |
def body(self, que: str): | |
self.messages[-1] = { | |
"role": "user", | |
"content": que, | |
} | |
return { | |
"model": self.level.model, | |
"messages": self.messages, | |
"temperature": 0.5, | |
"max_tokens": self.level.max_token, | |
} | |
def __enter__(self): | |
self.session = requests.Session() | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
self.session.close() | |
def __init__(self, levels: list[Level]): | |
self.valid_model = SentenceTransformer('paraphrase-distilroberta-base-v1') | |
self.levels = levels | |
self.level = None | |
self.messages = [] | |
self.session: Union["requests.Session", None] = None | |
def set_level(self, lev: int): | |
self.level = self.levels[lev] | |
self.messages = [ | |
{ | |
"role": "system", | |
"content": self.level.description, | |
}, | |
{ | |
"role": "user", | |
"content": "", | |
}, | |
] | |
def chat(self, que: str): | |
if any(w in que.lower() for w in self.forbids): | |
return { | |
"content": "I wont fall for that", | |
"type": "warning", | |
} | |
if self.regx_validate(que): | |
return { | |
"content": "You cannot include the phrase in your question", | |
"type": "error", | |
} | |
body = self.body(que) | |
resp = self.session.post( | |
self.API_ENDPOINT, | |
headers={"Authorization": f"Bearer {self.API_KEY}"}, | |
json=body | |
).json() | |
content = resp["choices"][0]["message"]["content"] | |
if self.resp_validate(content): | |
return { | |
"content": content, | |
"type": "success", | |
"tokens": resp["usage"]["prompt_tokens"], | |
} | |
return { | |
"content": content, | |
"type": "info", | |
} | |
def resp_validate(self, que: str): | |
tokens_que = [w for w in re.split(r"[^a-zA-Z0-9]", que) if w] | |
token_len = len([w for w in re.split(r'[^a-zA-Z0-9]', self.level.phrase) if w]) | |
embed_len = len(tokens_que) - token_len + 1 | |
embeds = self.valid_model.encode( | |
[self.level.phrase] + [" ".join(tokens_que[i: i + token_len]) for i in range(embed_len)], | |
convert_to_tensor=True | |
) | |
score = 0 | |
for i in range(embed_len): | |
score = max(util.pytorch_cos_sim(embeds[0], embeds[i + 1]).item(), score) | |
# print(score, " ".join(tokens_que[i: i + token_len]), self.level.phrase, sep=" | ") | |
return score > 0.82 | |
def regx_validate(self, que: str): | |
que_flat = re.sub(r'[^a-zA-Z0-9]', ' ', que).lower() | |
for word in self.level.phrase.split(): | |
if bool(re.search(f"(^| ){word}( |$)", que_flat)): | |
return True | |
return False | |
if __name__ == '__main__': | |
pass | |