Hjgugugjhuhjggg commited on
Commit
b7effa9
·
verified ·
1 Parent(s): 66c68f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -19,6 +19,7 @@ import base64
19
  from huggingface_hub import login
20
  from botocore.exceptions import NoCredentialsError
21
  from functools import lru_cache
 
22
 
23
 
24
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
@@ -130,9 +131,9 @@ async def generate(request: GenerateRequest):
130
  )
131
  if stream:
132
  return StreamingResponse(
133
- stream_text(model, tokenizer, input_text,
134
  generation_config, stop_sequences,
135
- device),
136
  media_type="text/plain"
137
  )
138
  else:
@@ -163,9 +164,10 @@ class StopOnSequences(StoppingCriteria):
163
  return True
164
  return False
165
 
 
166
  async def stream_text(model, tokenizer, input_text,
167
  generation_config, stop_sequences,
168
- device):
169
 
170
  encoded_input = tokenizer(
171
  input_text, return_tensors="pt",
@@ -198,11 +200,10 @@ async def stream_text(model, tokenizer, input_text,
198
  skip_special_tokens=True
199
  )
200
 
201
- if len(new_text) == 0:
202
  if not stop_criteria(outputs.sequences, None):
203
- yield {"text": output_text, "is_end": False}
204
-
205
- yield {"text": "", "is_end": True}
206
  break
207
 
208
  output_text += new_text
@@ -220,10 +221,11 @@ async def stream_text(model, tokenizer, input_text,
220
  output_text = ""
221
 
222
 
223
- async def stream_json_responses(generator):
224
  async for data in generator:
225
  yield json.dumps(data) + "\n"
226
-
 
227
  async def generate_text(model, tokenizer, input_text,
228
  generation_config, stop_sequences,
229
  device):
 
19
  from huggingface_hub import login
20
  from botocore.exceptions import NoCredentialsError
21
  from functools import lru_cache
22
+ from typing import AsyncGenerator
23
 
24
 
25
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
 
131
  )
132
  if stream:
133
  return StreamingResponse(
134
+ stream_json_responses(stream_text(model, tokenizer, input_text,
135
  generation_config, stop_sequences,
136
+ device)),
137
  media_type="text/plain"
138
  )
139
  else:
 
164
  return True
165
  return False
166
 
167
+
168
  async def stream_text(model, tokenizer, input_text,
169
  generation_config, stop_sequences,
170
+ device) -> AsyncGenerator[dict, None]:
171
 
172
  encoded_input = tokenizer(
173
  input_text, return_tensors="pt",
 
200
  skip_special_tokens=True
201
  )
202
 
203
+ if not new_text:
204
  if not stop_criteria(outputs.sequences, None):
205
+ yield {"text": output_text, "is_end": False}
206
+ yield {"text": "", "is_end": True}
 
207
  break
208
 
209
  output_text += new_text
 
221
  output_text = ""
222
 
223
 
224
+ async def stream_json_responses(generator: AsyncGenerator[dict, None]) -> AsyncGenerator[str, None]:
225
  async for data in generator:
226
  yield json.dumps(data) + "\n"
227
+
228
+
229
  async def generate_text(model, tokenizer, input_text,
230
  generation_config, stop_sequences,
231
  device):