Husnain commited on
Commit
f41554f
·
unverified ·
1 Parent(s): e9f6b1e

💎 [Feature] Enable gpt-3.5 in chat_api

Browse files
Files changed (1) hide show
  1. networks/openai_streamer.py +219 -0
networks/openai_streamer.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import re
4
+ import tiktoken
5
+ import uuid
6
+
7
+ from curl_cffi import requests
8
+ from tclogger import logger
9
+
10
+ from constants.envs import PROXIES
11
+ from constants.headers import OPENAI_GET_HEADERS, OPENAI_POST_DATA
12
+ from constants.models import TOKEN_LIMIT_MAP, TOKEN_RESERVED
13
+
14
+ from messagers.message_outputer import OpenaiStreamOutputer
15
+
16
+
17
+ class OpenaiRequester:
18
+ def __init__(self):
19
+ self.init_requests_params()
20
+
21
+ def init_requests_params(self):
22
+ self.api_base = "https://chat.openai.com/backend-anon"
23
+ self.api_me = f"{self.api_base}/me"
24
+ self.api_models = f"{self.api_base}/models"
25
+ self.api_chat_requirements = f"{self.api_base}/sentinel/chat-requirements"
26
+ self.api_conversation = f"{self.api_base}/conversation"
27
+ self.uuid = str(uuid.uuid4())
28
+ self.requests_headers = copy.deepcopy(OPENAI_GET_HEADERS)
29
+ extra_headers = {
30
+ "Oai-Device-Id": self.uuid,
31
+ }
32
+ self.requests_headers.update(extra_headers)
33
+
34
+ def log_request(self, url, method="GET"):
35
+ logger.note(f"> {method}:", end=" ")
36
+ logger.mesg(f"{url}", end=" ")
37
+
38
+ def log_response(self, res: requests.Response, stream=False, verbose=False):
39
+ status_code = res.status_code
40
+ status_code_str = f"[{status_code}]"
41
+
42
+ if status_code == 200:
43
+ logger_func = logger.success
44
+ else:
45
+ logger_func = logger.warn
46
+
47
+ logger_func(status_code_str)
48
+
49
+ if verbose:
50
+ if stream:
51
+ if not hasattr(self, "content_offset"):
52
+ self.content_offset = 0
53
+
54
+ for line in res.iter_lines():
55
+ line = line.decode("utf-8")
56
+ line = re.sub(r"^data:\s*", "", line)
57
+ if re.match(r"^\[DONE\]", line):
58
+ logger.success("\n[Finished]")
59
+ break
60
+ line = line.strip()
61
+ if line:
62
+ try:
63
+ data = json.loads(line, strict=False)
64
+ message_role = data["message"]["author"]["role"]
65
+ message_status = data["message"]["status"]
66
+ if (
67
+ message_role == "assistant"
68
+ and message_status == "in_progress"
69
+ ):
70
+ content = data["message"]["content"]["parts"][0]
71
+ delta_content = content[self.content_offset :]
72
+ self.content_offset = len(content)
73
+ logger_func(delta_content, end="")
74
+ except Exception as e:
75
+ logger.warn(e)
76
+ else:
77
+ logger_func(res.json())
78
+
79
+ def get_models(self):
80
+ self.log_request(self.api_models)
81
+ res = requests.get(
82
+ self.api_models,
83
+ headers=self.requests_headers,
84
+ proxies=PROXIES,
85
+ timeout=10,
86
+ impersonate="chrome120",
87
+ )
88
+ self.log_response(res)
89
+
90
+ def auth(self):
91
+ self.log_request(self.api_chat_requirements, method="POST")
92
+ res = requests.post(
93
+ self.api_chat_requirements,
94
+ headers=self.requests_headers,
95
+ proxies=PROXIES,
96
+ timeout=10,
97
+ impersonate="chrome120",
98
+ )
99
+ self.chat_requirements_token = res.json()["token"]
100
+ self.log_response(res)
101
+
102
+ def transform_messages(self, messages: list[dict]):
103
+ def get_role(role):
104
+ if role in ["system", "user", "assistant"]:
105
+ return role
106
+ else:
107
+ return "system"
108
+
109
+ new_messages = [
110
+ {
111
+ "author": {"role": get_role(message["role"])},
112
+ "content": {"content_type": "text", "parts": [message["content"]]},
113
+ "metadata": {},
114
+ }
115
+ for message in messages
116
+ ]
117
+ return new_messages
118
+
119
+ def chat_completions(self, messages: list[dict], verbose=False):
120
+ extra_headers = {
121
+ "Accept": "text/event-stream",
122
+ "Openai-Sentinel-Chat-Requirements-Token": self.chat_requirements_token,
123
+ }
124
+ requests_headers = copy.deepcopy(self.requests_headers)
125
+ requests_headers.update(extra_headers)
126
+
127
+ post_data = copy.deepcopy(OPENAI_POST_DATA)
128
+ extra_data = {
129
+ "messages": self.transform_messages(messages),
130
+ "websocket_request_id": str(uuid.uuid4()),
131
+ }
132
+ post_data.update(extra_data)
133
+
134
+ self.log_request(self.api_conversation, method="POST")
135
+ s = requests.Session()
136
+ res = s.post(
137
+ self.api_conversation,
138
+ headers=requests_headers,
139
+ json=post_data,
140
+ proxies=PROXIES,
141
+ timeout=10,
142
+ impersonate="chrome120",
143
+ stream=True,
144
+ )
145
+ if verbose:
146
+ self.log_response(res, stream=True, verbose=True)
147
+ return res
148
+
149
+
150
+ class OpenaiStreamer:
151
+ def __init__(self):
152
+ self.model = "gpt-3.5"
153
+ self.message_outputer = OpenaiStreamOutputer(owned_by="openai", model="gpt-3.5")
154
+ self.tokenizer = tiktoken.get_encoding("cl100k_base")
155
+
156
+ def count_tokens(self, messages: list[dict]):
157
+ token_count = sum(
158
+ len(self.tokenizer.encode(message["content"])) for message in messages
159
+ )
160
+ logger.note(f"Prompt Token Count: {token_count}")
161
+ return token_count
162
+
163
+ def check_token_limit(self, messages: list[dict]):
164
+ token_limit = TOKEN_LIMIT_MAP[self.model]
165
+ token_redundancy = int(
166
+ token_limit - TOKEN_RESERVED - self.count_tokens(messages)
167
+ )
168
+ if token_redundancy <= 0:
169
+ raise ValueError(f"Prompt exceeded token limit: {token_limit}")
170
+ return True
171
+
172
+ def chat_response(self, messages: list[dict]):
173
+ self.check_token_limit(messages)
174
+ requester = OpenaiRequester()
175
+ requester.auth()
176
+ return requester.chat_completions(messages, verbose=False)
177
+
178
+ def chat_return_generator(self, stream_response: requests.Response):
179
+ content_offset = 0
180
+ is_finished = False
181
+
182
+ for line in stream_response.iter_lines():
183
+ line = line.decode("utf-8")
184
+ line = re.sub(r"^data:\s*", "", line)
185
+ line = line.strip()
186
+
187
+ if not line:
188
+ continue
189
+
190
+ if re.match(r"^\[DONE\]", line):
191
+ content_type = "Finished"
192
+ delta_content = ""
193
+ logger.success("\n[Finished]")
194
+ is_finished = True
195
+ else:
196
+ content_type = "Completions"
197
+ try:
198
+ data = json.loads(line, strict=False)
199
+ message_role = data["message"]["author"]["role"]
200
+ message_status = data["message"]["status"]
201
+ if message_role == "assistant" and message_status == "in_progress":
202
+ content = data["message"]["content"]["parts"][0]
203
+ if not len(content):
204
+ continue
205
+ delta_content = content[content_offset:]
206
+ content_offset = len(content)
207
+ logger.success(delta_content, end="")
208
+ else:
209
+ continue
210
+ except Exception as e:
211
+ logger.warn(e)
212
+
213
+ output = self.message_outputer.output(
214
+ content=delta_content, content_type=content_type
215
+ )
216
+ yield output
217
+
218
+ if not is_finished:
219
+ yield self.message_outputer.output(content="", content_type="Finished")