Etash Guha
commited on
Commit
·
32bc229
1
Parent(s):
8d320a4
pease
Browse files- generators/model.py +2 -3
generators/model.py
CHANGED
@@ -123,6 +123,7 @@ class Samba():
|
|
123 |
|
124 |
def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]:
|
125 |
resps = []
|
|
|
126 |
for i in range(num_comps):
|
127 |
payload = {
|
128 |
"inputs": [dataclasses.asdict(message) for message in messages],
|
@@ -145,19 +146,17 @@ class Samba():
|
|
145 |
"Content-Type": "application/json"
|
146 |
}
|
147 |
post_response = requests.post(url, json=payload, headers=headers, stream=True)
|
148 |
-
|
149 |
response_text = ""
|
150 |
for line in post_response.iter_lines():
|
151 |
if line.startswith(b"data: "):
|
152 |
data_str = line.decode('utf-8')[6:]
|
153 |
try:
|
154 |
line_json = json.loads(data_str)
|
155 |
-
content = line_json.get("stream_token", "")
|
156 |
if content:
|
157 |
response_text += content
|
158 |
except json.JSONDecodeError as e:
|
159 |
pass
|
160 |
-
resps.append(response_text)
|
161 |
|
162 |
if num_comps == 1:
|
163 |
return resps[0]
|
|
|
123 |
|
124 |
def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]:
|
125 |
resps = []
|
126 |
+
|
127 |
for i in range(num_comps):
|
128 |
payload = {
|
129 |
"inputs": [dataclasses.asdict(message) for message in messages],
|
|
|
146 |
"Content-Type": "application/json"
|
147 |
}
|
148 |
post_response = requests.post(url, json=payload, headers=headers, stream=True)
|
|
|
149 |
response_text = ""
|
150 |
for line in post_response.iter_lines():
|
151 |
if line.startswith(b"data: "):
|
152 |
data_str = line.decode('utf-8')[6:]
|
153 |
try:
|
154 |
line_json = json.loads(data_str)
|
155 |
+
content = line_json['0'].get("stream_token", "")
|
156 |
if content:
|
157 |
response_text += content
|
158 |
except json.JSONDecodeError as e:
|
159 |
pass
|
|
|
160 |
|
161 |
if num_comps == 1:
|
162 |
return resps[0]
|