File size: 3,529 Bytes
48f6350
 
 
 
 
 
5cfcb46
48f6350
 
 
 
 
 
5cfcb46
48f6350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import requests
from typing import Union
import re
from collections import namedtuple
from sentence_transformers import SentenceTransformer, util
import streamlit as st
import os

Level = namedtuple("Level", ["name", "description", "hint", "phrase"])


class EnigmaEscape:
    API_ENDPOINT = "https://api.endpoints.anyscale.com/v1/chat/completions"
    API_KEY = os.getenv("anyscale_api")
    forbids = [
        "replace",
        "swap",
        "change",
        "modify",
        "alter",
        "substitute",
    ]

    def body(self, que: str):
        self.messages[-1] = {
            "role": "user",
            "content": que,
        }
        return {
            "model": self.level.model,
            "messages": self.messages,
            "temperature": 0.5,
            "max_tokens": self.level.max_token,
        }

    def __enter__(self):
        self.session = requests.Session()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.session.close()

    def __init__(self, levels: list[Level]):
        self.valid_model = SentenceTransformer('paraphrase-distilroberta-base-v1')
        self.levels = levels
        self.level = None
        self.messages = []
        self.session: Union["requests.Session", None] = None

    def set_level(self, lev: int):
        self.level = self.levels[lev]
        self.messages = [
            {
                "role": "system",
                "content": self.level.description,
            },
            {
                "role": "user",
                "content": "",
            },
        ]

    def chat(self, que: str):
        if any(w in que.lower() for w in self.forbids):
            return {
                "content": "I wont fall for that",
                "type": "warning",
            }
        if self.regx_validate(que):
            return {
                "content": "You cannot include the phrase in your question",
                "type": "error",
            }
        body = self.body(que)
        resp = self.session.post(
            self.API_ENDPOINT,
            headers={"Authorization": f"Bearer {self.API_KEY}"},
            json=body
        ).json()
        content = resp["choices"][0]["message"]["content"]
        if self.resp_validate(content):
            return {
                "content": content,
                "type": "success",
                "tokens": resp["usage"]["prompt_tokens"],
            }
        return {
            "content": content,
            "type": "info",
        }

    def resp_validate(self, que: str):
        tokens_que = [w for w in re.split(r"[^a-zA-Z0-9]", que) if w]
        token_len = len([w for w in re.split(r'[^a-zA-Z0-9]', self.level.phrase) if w])
        embed_len = len(tokens_que) - token_len + 1
        embeds = self.valid_model.encode(
            [self.level.phrase] + [" ".join(tokens_que[i: i + token_len]) for i in range(embed_len)],
            convert_to_tensor=True
        )
        score = 0
        for i in range(embed_len):
            score = max(util.pytorch_cos_sim(embeds[0], embeds[i + 1]).item(), score)
            # print(score, " ".join(tokens_que[i: i + token_len]), self.level.phrase, sep=" | ")
        return score > 0.82

    def regx_validate(self, que: str):
        que_flat = re.sub(r'[^a-zA-Z0-9]', ' ', que).lower()
        for word in self.level.phrase.split():
            if bool(re.search(f"(^| ){word}( |$)", que_flat)):
                return True
        return False


if __name__ == '__main__':
    pass