Spaces:
Sleeping
Sleeping
File size: 3,529 Bytes
48f6350 5cfcb46 48f6350 5cfcb46 48f6350 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import requests
from typing import Union
import re
from collections import namedtuple
from sentence_transformers import SentenceTransformer, util
import streamlit as st
import os
Level = namedtuple("Level", ["name", "description", "hint", "phrase"])
class EnigmaEscape:
API_ENDPOINT = "https://api.endpoints.anyscale.com/v1/chat/completions"
API_KEY = os.getenv("anyscale_api")
forbids = [
"replace",
"swap",
"change",
"modify",
"alter",
"substitute",
]
def body(self, que: str):
self.messages[-1] = {
"role": "user",
"content": que,
}
return {
"model": self.level.model,
"messages": self.messages,
"temperature": 0.5,
"max_tokens": self.level.max_token,
}
def __enter__(self):
self.session = requests.Session()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.session.close()
def __init__(self, levels: list[Level]):
self.valid_model = SentenceTransformer('paraphrase-distilroberta-base-v1')
self.levels = levels
self.level = None
self.messages = []
self.session: Union["requests.Session", None] = None
def set_level(self, lev: int):
self.level = self.levels[lev]
self.messages = [
{
"role": "system",
"content": self.level.description,
},
{
"role": "user",
"content": "",
},
]
def chat(self, que: str):
if any(w in que.lower() for w in self.forbids):
return {
"content": "I wont fall for that",
"type": "warning",
}
if self.regx_validate(que):
return {
"content": "You cannot include the phrase in your question",
"type": "error",
}
body = self.body(que)
resp = self.session.post(
self.API_ENDPOINT,
headers={"Authorization": f"Bearer {self.API_KEY}"},
json=body
).json()
content = resp["choices"][0]["message"]["content"]
if self.resp_validate(content):
return {
"content": content,
"type": "success",
"tokens": resp["usage"]["prompt_tokens"],
}
return {
"content": content,
"type": "info",
}
def resp_validate(self, que: str):
tokens_que = [w for w in re.split(r"[^a-zA-Z0-9]", que) if w]
token_len = len([w for w in re.split(r'[^a-zA-Z0-9]', self.level.phrase) if w])
embed_len = len(tokens_que) - token_len + 1
embeds = self.valid_model.encode(
[self.level.phrase] + [" ".join(tokens_que[i: i + token_len]) for i in range(embed_len)],
convert_to_tensor=True
)
score = 0
for i in range(embed_len):
score = max(util.pytorch_cos_sim(embeds[0], embeds[i + 1]).item(), score)
# print(score, " ".join(tokens_que[i: i + token_len]), self.level.phrase, sep=" | ")
return score > 0.82
def regx_validate(self, que: str):
que_flat = re.sub(r'[^a-zA-Z0-9]', ' ', que).lower()
for word in self.level.phrase.split():
if bool(re.search(f"(^| ){word}( |$)", que_flat)):
return True
return False
if __name__ == '__main__':
pass
|