Husnain commited on
Commit
e9f6b1e
·
unverified ·
1 Parent(s): b5f45b3

Delete networks/message_streamer.py

Browse files
Files changed (1) hide show
  1. networks/message_streamer.py +0 -201
networks/message_streamer.py DELETED
@@ -1,201 +0,0 @@
1
- import json
2
- import re
3
- import requests
4
-
5
- from tiktoken import get_encoding as tiktoken_get_encoding
6
- from transformers import AutoTokenizer
7
-
8
- from constants.models import (
9
- MODEL_MAP,
10
- STOP_SEQUENCES_MAP,
11
- TOKEN_LIMIT_MAP,
12
- TOKEN_RESERVED,
13
- )
14
- from messagers.message_outputer import OpenaiStreamOutputer
15
- from utils.logger import logger
16
- from utils.enver import enver
17
-
18
-
19
- class MessageStreamer:
20
-
21
- def __init__(self, model: str):
22
- if model in MODEL_MAP.keys():
23
- self.model = model
24
- else:
25
- self.model = "default"
26
- self.model_fullname = MODEL_MAP[self.model]
27
- self.message_outputer = OpenaiStreamOutputer()
28
-
29
- if self.model == "gemma-7b":
30
- # this is not wrong, as repo `google/gemma-7b-it` is gated and must authenticate to access it
31
- # so I use mistral-7b as a fallback
32
- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_MAP["mistral-7b"])
33
- else:
34
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_fullname)
35
-
36
- def parse_line(self, line):
37
- line = line.decode("utf-8")
38
- line = re.sub(r"data:\s*", "", line)
39
- data = json.loads(line)
40
- try:
41
- content = data["token"]["text"]
42
- except:
43
- logger.err(data)
44
- return content
45
-
46
- def count_tokens(self, text):
47
- tokens = self.tokenizer.encode(text)
48
- token_count = len(tokens)
49
- logger.note(f"Prompt Token Count: {token_count}")
50
- return token_count
51
-
52
- def chat_response(
53
- self,
54
- prompt: str = None,
55
- temperature: float = 0.5,
56
- top_p: float = 0.95,
57
- max_new_tokens: int = None,
58
- api_key: str = None,
59
- use_cache: bool = False,
60
- ):
61
- # https://huggingface.co/docs/api-inference/detailed_parameters?code=curl
62
- # curl --proxy http://<server>:<port> https://api-inference.huggingface.co/models/<org>/<model_name> -X POST -d '{"inputs":"who are you?","parameters":{"max_new_token":64}}' -H 'Content-Type: application/json' -H 'Authorization: Bearer <HF_TOKEN>'
63
- self.request_url = (
64
- f"https://api-inference.huggingface.co/models/{self.model_fullname}"
65
- )
66
- self.request_headers = {
67
- "Content-Type": "application/json",
68
- }
69
-
70
- if api_key:
71
- logger.note(
72
- f"Using API Key: {api_key[:3]}{(len(api_key)-7)*'*'}{api_key[-4:]}"
73
- )
74
- self.request_headers["Authorization"] = f"Bearer {api_key}"
75
-
76
- if temperature is None or temperature < 0:
77
- temperature = 0.0
78
- # temperature must 0 < and < 1 for HF LLM models
79
- temperature = max(temperature, 0.01)
80
- temperature = min(temperature, 0.99)
81
- top_p = max(top_p, 0.01)
82
- top_p = min(top_p, 0.99)
83
-
84
- token_limit = int(
85
- TOKEN_LIMIT_MAP[self.model] - TOKEN_RESERVED - self.count_tokens(prompt)
86
- )
87
- if token_limit <= 0:
88
- raise ValueError("Prompt exceeded token limit!")
89
-
90
- if max_new_tokens is None or max_new_tokens <= 0:
91
- max_new_tokens = token_limit
92
- else:
93
- max_new_tokens = min(max_new_tokens, token_limit)
94
-
95
- # References:
96
- # huggingface_hub/inference/_client.py:
97
- # class InferenceClient > def text_generation()
98
- # huggingface_hub/inference/_text_generation.py:
99
- # class TextGenerationRequest > param `stream`
100
- # https://huggingface.co/docs/text-generation-inference/conceptual/streaming#streaming-with-curl
101
- # https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
102
- self.request_body = {
103
- "inputs": prompt,
104
- "parameters": {
105
- "temperature": temperature,
106
- "top_p": top_p,
107
- "max_new_tokens": max_new_tokens,
108
- "return_full_text": False,
109
- },
110
- "options": {
111
- "use_cache": use_cache,
112
- },
113
- "stream": True,
114
- }
115
-
116
- if self.model in STOP_SEQUENCES_MAP.keys():
117
- self.stop_sequences = STOP_SEQUENCES_MAP[self.model]
118
- # self.request_body["parameters"]["stop_sequences"] = [
119
- # self.STOP_SEQUENCES[self.model]
120
- # ]
121
-
122
- logger.back(self.request_url)
123
- enver.set_envs(proxies=True)
124
- stream_response = requests.post(
125
- self.request_url,
126
- headers=self.request_headers,
127
- json=self.request_body,
128
- proxies=enver.requests_proxies,
129
- stream=True,
130
- )
131
- status_code = stream_response.status_code
132
- if status_code == 200:
133
- logger.success(status_code)
134
- else:
135
- logger.err(status_code)
136
-
137
- return stream_response
138
-
139
- def chat_return_dict(self, stream_response):
140
- # https://platform.openai.com/docs/guides/text-generation/chat-completions-response-format
141
- final_output = self.message_outputer.default_data.copy()
142
- final_output["choices"] = [
143
- {
144
- "index": 0,
145
- "finish_reason": "stop",
146
- "message": {
147
- "role": "assistant",
148
- "content": "",
149
- },
150
- }
151
- ]
152
- logger.back(final_output)
153
-
154
- final_content = ""
155
- for line in stream_response.iter_lines():
156
- if not line:
157
- continue
158
- content = self.parse_line(line)
159
-
160
- if content.strip() == self.stop_sequences:
161
- logger.success("\n[Finished]")
162
- break
163
- else:
164
- logger.back(content, end="")
165
- final_content += content
166
-
167
- if self.model in STOP_SEQUENCES_MAP.keys():
168
- final_content = final_content.replace(self.stop_sequences, "")
169
-
170
- final_content = final_content.strip()
171
- final_output["choices"][0]["message"]["content"] = final_content
172
- return final_output
173
-
174
- def chat_return_generator(self, stream_response):
175
- is_finished = False
176
- line_count = 0
177
- for line in stream_response.iter_lines():
178
- if line:
179
- line_count += 1
180
- else:
181
- continue
182
-
183
- content = self.parse_line(line)
184
-
185
- if content.strip() == self.stop_sequences:
186
- content_type = "Finished"
187
- logger.success("\n[Finished]")
188
- is_finished = True
189
- else:
190
- content_type = "Completions"
191
- if line_count == 1:
192
- content = content.lstrip()
193
- logger.back(content, end="")
194
-
195
- output = self.message_outputer.output(
196
- content=content, content_type=content_type
197
- )
198
- yield output
199
-
200
- if not is_finished:
201
- yield self.message_outputer.output(content="", content_type="Finished")