Tuchuanhuhuhu commited on
Commit
67d913f
·
1 Parent(s): 9aeba67

llama支持流式传输

Browse files
Files changed (1) hide show
  1. modules/models.py +17 -30
modules/models.py CHANGED
@@ -342,6 +342,7 @@ class LLaMA_Client(BaseLLMModel):
342
  def _get_llama_style_input(self):
343
  history = [x["content"] for x in self.history]
344
  context = "\n".join(history)
 
345
  return context
346
 
347
  def get_answer_at_once(self):
@@ -359,40 +360,26 @@ class LLaMA_Client(BaseLLMModel):
359
  )
360
 
361
  response = output_dataset.to_dict()["instances"][0]["text"]
362
-
363
- try:
364
- index = response.index(self.end_string)
365
- except ValueError:
366
- response += self.end_string
367
- index = response.index(self.end_string)
368
-
369
- response = response[: index + 1]
370
  return response, len(response)
371
 
372
  def get_answer_stream_iter(self):
373
  context = self._get_llama_style_input()
374
-
375
- input_dataset = self.dataset.from_dict(
376
- {"type": "text_only", "instances": [{"text": context}]}
377
- )
378
-
379
- output_dataset = self.inferencer.inference(
380
- model=self.model,
381
- dataset=input_dataset,
382
- max_new_tokens=self.max_generation_token,
383
- temperature=self.temperature,
384
- )
385
-
386
- response = output_dataset.to_dict()["instances"][0]["text"]
387
-
388
- try:
389
- index = response.index(self.end_string)
390
- except ValueError:
391
- response += self.end_string
392
- index = response.index(self.end_string)
393
-
394
- response = response[: index + 1]
395
- yield response
396
 
397
 
398
  class ModelManager:
 
342
  def _get_llama_style_input(self):
343
  history = [x["content"] for x in self.history]
344
  context = "\n".join(history)
345
+ context += "\nOutput:"
346
  return context
347
 
348
  def get_answer_at_once(self):
 
360
  )
361
 
362
  response = output_dataset.to_dict()["instances"][0]["text"]
 
 
 
 
 
 
 
 
363
  return response, len(response)
364
 
365
  def get_answer_stream_iter(self):
366
  context = self._get_llama_style_input()
367
+ partial_text = ""
368
+ for i in range(self.max_generation_token):
369
+ input_dataset = self.dataset.from_dict(
370
+ {"type": "text_only", "instances": [{"text": context+partial_text}]}
371
+ )
372
+ output_dataset = self.inferencer.inference(
373
+ model=self.model,
374
+ dataset=input_dataset,
375
+ max_new_tokens=1,
376
+ temperature=self.temperature,
377
+ )
378
+ response = output_dataset.to_dict()["instances"][0]["text"]
379
+ if response == "":
380
+ break
381
+ partial_text += response
382
+ yield partial_text
 
 
 
 
 
 
383
 
384
 
385
  class ModelManager: