Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -191,11 +191,13 @@ async def generate_stream(model, tokenizer, input_text,
|
|
191 |
async def stream():
|
192 |
past_key_values = None
|
193 |
input_ids = None
|
194 |
-
async for token,past_key_values_response,input_ids_response in stream_text(model, tokenizer, input_text,
|
195 |
generation_config, stop_sequences,
|
196 |
device,pad_token_id, max_model_length, max_new_tokens, past_key_values, input_ids):
|
197 |
past_key_values = past_key_values_response
|
198 |
input_ids = input_ids_response
|
|
|
|
|
199 |
yield token
|
200 |
return stream()
|
201 |
|
@@ -283,25 +285,24 @@ async def stream_text(model, tokenizer, input_text,
|
|
283 |
|
284 |
output_text = ""
|
285 |
stop_criteria = StoppingCriteriaList([StopOnSequencesCriteria(stop_sequences, tokenizer)])
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
if input_ids is None:
|
290 |
-
encoded_input = tokenizer(
|
291 |
input_text, return_tensors="pt",
|
292 |
truncation=True,
|
293 |
padding = "max_length",
|
294 |
max_length=max_model_length
|
295 |
).to(device)
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
|
|
303 |
outputs = model.generate(
|
304 |
-
|
305 |
do_sample=generation_config.do_sample,
|
306 |
max_new_tokens=generation_config.max_new_tokens,
|
307 |
temperature=generation_config.temperature,
|
@@ -313,7 +314,7 @@ async def stream_text(model, tokenizer, input_text,
|
|
313 |
return_dict_in_generate=True,
|
314 |
pad_token_id=pad_token_id if pad_token_id is not None else None,
|
315 |
stopping_criteria = stop_criteria,
|
316 |
-
|
317 |
)
|
318 |
|
319 |
new_text = tokenizer.decode(
|
@@ -324,17 +325,19 @@ async def stream_text(model, tokenizer, input_text,
|
|
324 |
output_text += new_text
|
325 |
|
326 |
stop_index = find_stop(output_text, stop_sequences)
|
327 |
-
|
328 |
-
|
329 |
-
|
|
|
330 |
|
331 |
for text in final_output.split():
|
332 |
yield json.dumps({"text": text, "is_end": False}) + "\n", \
|
333 |
outputs.past_key_values if hasattr(outputs, "past_key_values") else None , \
|
334 |
-
outputs.sequences if hasattr(outputs, "sequences") else None
|
335 |
-
|
|
|
336 |
outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
|
337 |
-
outputs.sequences if hasattr(outputs, "sequences") else None
|
338 |
break
|
339 |
else:
|
340 |
|
@@ -345,19 +348,19 @@ async def stream_text(model, tokenizer, input_text,
|
|
345 |
chunk = tokens[i:i + max_new_tokens]
|
346 |
chunk_text = " ".join(chunk)
|
347 |
for text in chunk_text.split():
|
348 |
-
|
349 |
outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
|
350 |
-
outputs.sequences if hasattr(outputs, "sequences") else None
|
351 |
|
352 |
if len(new_text) == 0:
|
353 |
|
354 |
for text in output_text.split():
|
355 |
yield json.dumps({"text": text, "is_end": False}) + "\n", \
|
356 |
outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
|
357 |
-
outputs.sequences if hasattr(outputs, "sequences") else None
|
358 |
yield json.dumps({"text": "", "is_end": True}) + "\n",\
|
359 |
outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
|
360 |
-
outputs.sequences if hasattr(outputs, "sequences") else None
|
361 |
break
|
362 |
|
363 |
past_key_values = outputs.past_key_values if hasattr(outputs, "past_key_values") else None
|
@@ -366,7 +369,6 @@ async def stream_text(model, tokenizer, input_text,
|
|
366 |
output_text = ""
|
367 |
|
368 |
|
369 |
-
|
370 |
@app.post("/generate-image")
|
371 |
async def generate_image(request: GenerateRequest):
|
372 |
try:
|
|
|
191 |
async def stream():
|
192 |
past_key_values = None
|
193 |
input_ids = None
|
194 |
+
async for token,past_key_values_response,input_ids_response, is_end in stream_text(model, tokenizer, input_text,
|
195 |
generation_config, stop_sequences,
|
196 |
device,pad_token_id, max_model_length, max_new_tokens, past_key_values, input_ids):
|
197 |
past_key_values = past_key_values_response
|
198 |
input_ids = input_ids_response
|
199 |
+
if is_end:
|
200 |
+
break
|
201 |
yield token
|
202 |
return stream()
|
203 |
|
|
|
285 |
|
286 |
output_text = ""
|
287 |
stop_criteria = StoppingCriteriaList([StopOnSequencesCriteria(stop_sequences, tokenizer)])
|
288 |
+
|
289 |
+
if input_ids is None:
|
290 |
+
encoded_input = tokenizer(
|
|
|
|
|
291 |
input_text, return_tensors="pt",
|
292 |
truncation=True,
|
293 |
padding = "max_length",
|
294 |
max_length=max_model_length
|
295 |
).to(device)
|
296 |
+
input_ids = encoded_input.input_ids
|
297 |
+
else:
|
298 |
+
encoded_input = {
|
299 |
+
"input_ids":input_ids,
|
300 |
+
"past_key_values": past_key_values
|
301 |
+
}
|
302 |
+
|
303 |
+
while True:
|
304 |
outputs = model.generate(
|
305 |
+
**encoded_input,
|
306 |
do_sample=generation_config.do_sample,
|
307 |
max_new_tokens=generation_config.max_new_tokens,
|
308 |
temperature=generation_config.temperature,
|
|
|
314 |
return_dict_in_generate=True,
|
315 |
pad_token_id=pad_token_id if pad_token_id is not None else None,
|
316 |
stopping_criteria = stop_criteria,
|
317 |
+
|
318 |
)
|
319 |
|
320 |
new_text = tokenizer.decode(
|
|
|
325 |
output_text += new_text
|
326 |
|
327 |
stop_index = find_stop(output_text, stop_sequences)
|
328 |
+
|
329 |
+
is_end = False
|
330 |
+
if stop_index != -1 or (hasattr(outputs, "sequences") and outputs.sequences[0][-1] == tokenizer.eos_token_id):
|
331 |
+
final_output = output_text[:stop_index] if stop_index != -1 else output_text
|
332 |
|
333 |
for text in final_output.split():
|
334 |
yield json.dumps({"text": text, "is_end": False}) + "\n", \
|
335 |
outputs.past_key_values if hasattr(outputs, "past_key_values") else None , \
|
336 |
+
outputs.sequences if hasattr(outputs, "sequences") else None, True
|
337 |
+
|
338 |
+
yield json.dumps({"text": "", "is_end": True}) + "\n",\
|
339 |
outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
|
340 |
+
outputs.sequences if hasattr(outputs, "sequences") else None, True
|
341 |
break
|
342 |
else:
|
343 |
|
|
|
348 |
chunk = tokens[i:i + max_new_tokens]
|
349 |
chunk_text = " ".join(chunk)
|
350 |
for text in chunk_text.split():
|
351 |
+
yield json.dumps({"text": text, "is_end": False}) + "\n", \
|
352 |
outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
|
353 |
+
outputs.sequences if hasattr(outputs, "sequences") else None, False
|
354 |
|
355 |
if len(new_text) == 0:
|
356 |
|
357 |
for text in output_text.split():
|
358 |
yield json.dumps({"text": text, "is_end": False}) + "\n", \
|
359 |
outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
|
360 |
+
outputs.sequences if hasattr(outputs, "sequences") else None, True
|
361 |
yield json.dumps({"text": "", "is_end": True}) + "\n",\
|
362 |
outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
|
363 |
+
outputs.sequences if hasattr(outputs, "sequences") else None, True
|
364 |
break
|
365 |
|
366 |
past_key_values = outputs.past_key_values if hasattr(outputs, "past_key_values") else None
|
|
|
369 |
output_text = ""
|
370 |
|
371 |
|
|
|
372 |
@app.post("/generate-image")
|
373 |
async def generate_image(request: GenerateRequest):
|
374 |
try:
|