Etash Guha commited on
Commit
32bc229
·
1 Parent(s): 8d320a4
Files changed (1) hide show
  1. generators/model.py +2 -3
generators/model.py CHANGED
@@ -123,6 +123,7 @@ class Samba():
123
 
124
  def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]:
125
  resps = []
 
126
  for i in range(num_comps):
127
  payload = {
128
  "inputs": [dataclasses.asdict(message) for message in messages],
@@ -145,19 +146,17 @@ class Samba():
145
  "Content-Type": "application/json"
146
  }
147
  post_response = requests.post(url, json=payload, headers=headers, stream=True)
148
-
149
  response_text = ""
150
  for line in post_response.iter_lines():
151
  if line.startswith(b"data: "):
152
  data_str = line.decode('utf-8')[6:]
153
  try:
154
  line_json = json.loads(data_str)
155
- content = line_json.get("stream_token", "")
156
  if content:
157
  response_text += content
158
  except json.JSONDecodeError as e:
159
  pass
160
- resps.append(response_text)
161
 
162
  if num_comps == 1:
163
  return resps[0]
 
123
 
124
  def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]:
125
  resps = []
126
+
127
  for i in range(num_comps):
128
  payload = {
129
  "inputs": [dataclasses.asdict(message) for message in messages],
 
146
  "Content-Type": "application/json"
147
  }
148
  post_response = requests.post(url, json=payload, headers=headers, stream=True)
 
149
  response_text = ""
150
  for line in post_response.iter_lines():
151
  if line.startswith(b"data: "):
152
  data_str = line.decode('utf-8')[6:]
153
  try:
154
  line_json = json.loads(data_str)
155
+ content = line_json['0'].get("stream_token", "")
156
  if content:
157
  response_text += content
158
  except json.JSONDecodeError as e:
159
  pass
 
160
 
161
  if num_comps == 1:
162
  return resps[0]