Hjgugugjhuhjggg commited on
Commit
3b07d00
·
verified ·
1 Parent(s): 24570ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -26
app.py CHANGED
@@ -191,11 +191,13 @@ async def generate_stream(model, tokenizer, input_text,
191
  async def stream():
192
  past_key_values = None
193
  input_ids = None
194
- async for token,past_key_values_response,input_ids_response in stream_text(model, tokenizer, input_text,
195
  generation_config, stop_sequences,
196
  device,pad_token_id, max_model_length, max_new_tokens, past_key_values, input_ids):
197
  past_key_values = past_key_values_response
198
  input_ids = input_ids_response
 
 
199
  yield token
200
  return stream()
201
 
@@ -283,25 +285,24 @@ async def stream_text(model, tokenizer, input_text,
283
 
284
  output_text = ""
285
  stop_criteria = StoppingCriteriaList([StopOnSequencesCriteria(stop_sequences, tokenizer)])
286
-
287
- while True:
288
-
289
- if input_ids is None:
290
- encoded_input = tokenizer(
291
  input_text, return_tensors="pt",
292
  truncation=True,
293
  padding = "max_length",
294
  max_length=max_model_length
295
  ).to(device)
296
- input_ids = encoded_input.input_ids
297
- else:
298
- encoded_input = {
299
- "input_ids":input_ids,
300
- "past_key_values": past_key_values
301
- }
302
-
 
303
  outputs = model.generate(
304
- **encoded_input,
305
  do_sample=generation_config.do_sample,
306
  max_new_tokens=generation_config.max_new_tokens,
307
  temperature=generation_config.temperature,
@@ -313,7 +314,7 @@ async def stream_text(model, tokenizer, input_text,
313
  return_dict_in_generate=True,
314
  pad_token_id=pad_token_id if pad_token_id is not None else None,
315
  stopping_criteria = stop_criteria,
316
-
317
  )
318
 
319
  new_text = tokenizer.decode(
@@ -324,17 +325,19 @@ async def stream_text(model, tokenizer, input_text,
324
  output_text += new_text
325
 
326
  stop_index = find_stop(output_text, stop_sequences)
327
-
328
- if stop_index != -1:
329
- final_output = output_text[:stop_index]
 
330
 
331
  for text in final_output.split():
332
  yield json.dumps({"text": text, "is_end": False}) + "\n", \
333
  outputs.past_key_values if hasattr(outputs, "past_key_values") else None , \
334
- outputs.sequences if hasattr(outputs, "sequences") else None
335
- yield json.dumps({"text": "", "is_end": True}) + "\n", \
 
336
  outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
337
- outputs.sequences if hasattr(outputs, "sequences") else None
338
  break
339
  else:
340
 
@@ -345,19 +348,19 @@ async def stream_text(model, tokenizer, input_text,
345
  chunk = tokens[i:i + max_new_tokens]
346
  chunk_text = " ".join(chunk)
347
  for text in chunk_text.split():
348
- yield json.dumps({"text": text, "is_end": False}) + "\n", \
349
  outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
350
- outputs.sequences if hasattr(outputs, "sequences") else None
351
 
352
  if len(new_text) == 0:
353
 
354
  for text in output_text.split():
355
  yield json.dumps({"text": text, "is_end": False}) + "\n", \
356
  outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
357
- outputs.sequences if hasattr(outputs, "sequences") else None
358
  yield json.dumps({"text": "", "is_end": True}) + "\n",\
359
  outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
360
- outputs.sequences if hasattr(outputs, "sequences") else None
361
  break
362
 
363
  past_key_values = outputs.past_key_values if hasattr(outputs, "past_key_values") else None
@@ -366,7 +369,6 @@ async def stream_text(model, tokenizer, input_text,
366
  output_text = ""
367
 
368
 
369
-
370
  @app.post("/generate-image")
371
  async def generate_image(request: GenerateRequest):
372
  try:
 
191
  async def stream():
192
  past_key_values = None
193
  input_ids = None
194
+ async for token,past_key_values_response,input_ids_response, is_end in stream_text(model, tokenizer, input_text,
195
  generation_config, stop_sequences,
196
  device,pad_token_id, max_model_length, max_new_tokens, past_key_values, input_ids):
197
  past_key_values = past_key_values_response
198
  input_ids = input_ids_response
199
+ if is_end:
200
+ break
201
  yield token
202
  return stream()
203
 
 
285
 
286
  output_text = ""
287
  stop_criteria = StoppingCriteriaList([StopOnSequencesCriteria(stop_sequences, tokenizer)])
288
+
289
+ if input_ids is None:
290
+ encoded_input = tokenizer(
 
 
291
  input_text, return_tensors="pt",
292
  truncation=True,
293
  padding = "max_length",
294
  max_length=max_model_length
295
  ).to(device)
296
+ input_ids = encoded_input.input_ids
297
+ else:
298
+ encoded_input = {
299
+ "input_ids":input_ids,
300
+ "past_key_values": past_key_values
301
+ }
302
+
303
+ while True:
304
  outputs = model.generate(
305
+ **encoded_input,
306
  do_sample=generation_config.do_sample,
307
  max_new_tokens=generation_config.max_new_tokens,
308
  temperature=generation_config.temperature,
 
314
  return_dict_in_generate=True,
315
  pad_token_id=pad_token_id if pad_token_id is not None else None,
316
  stopping_criteria = stop_criteria,
317
+
318
  )
319
 
320
  new_text = tokenizer.decode(
 
325
  output_text += new_text
326
 
327
  stop_index = find_stop(output_text, stop_sequences)
328
+
329
+ is_end = False
330
+ if stop_index != -1 or (hasattr(outputs, "sequences") and outputs.sequences[0][-1] == tokenizer.eos_token_id):
331
+ final_output = output_text[:stop_index] if stop_index != -1 else output_text
332
 
333
  for text in final_output.split():
334
  yield json.dumps({"text": text, "is_end": False}) + "\n", \
335
  outputs.past_key_values if hasattr(outputs, "past_key_values") else None , \
336
+ outputs.sequences if hasattr(outputs, "sequences") else None, True
337
+
338
+ yield json.dumps({"text": "", "is_end": True}) + "\n",\
339
  outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
340
+ outputs.sequences if hasattr(outputs, "sequences") else None, True
341
  break
342
  else:
343
 
 
348
  chunk = tokens[i:i + max_new_tokens]
349
  chunk_text = " ".join(chunk)
350
  for text in chunk_text.split():
351
+ yield json.dumps({"text": text, "is_end": False}) + "\n", \
352
  outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
353
+ outputs.sequences if hasattr(outputs, "sequences") else None, False
354
 
355
  if len(new_text) == 0:
356
 
357
  for text in output_text.split():
358
  yield json.dumps({"text": text, "is_end": False}) + "\n", \
359
  outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
360
+ outputs.sequences if hasattr(outputs, "sequences") else None, True
361
  yield json.dumps({"text": "", "is_end": True}) + "\n",\
362
  outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
363
+ outputs.sequences if hasattr(outputs, "sequences") else None, True
364
  break
365
 
366
  past_key_values = outputs.past_key_values if hasattr(outputs, "past_key_values") else None
 
369
  output_text = ""
370
 
371
 
 
372
  @app.post("/generate-image")
373
  async def generate_image(request: GenerateRequest):
374
  try: