Hjgugugjhuhjggg commited on
Commit
2e0bd60
·
verified ·
1 Parent(s): fbbc32a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -12
app.py CHANGED
@@ -41,7 +41,7 @@ class GenerateRequest(BaseModel):
41
  input_text: str = ""
42
  task_type: str
43
  temperature: float = 1.0
44
- max_new_tokens: int = 200
45
  stream: bool = True
46
  top_p: float = 1.0
47
  top_k: int = 50
@@ -130,7 +130,7 @@ async def generate(request: GenerateRequest):
130
  input_text = request.input_text
131
  task_type = request.task_type
132
  temperature = request.temperature
133
- max_new_tokens = request.max_new_tokens
134
  stream = request.stream
135
  top_p = request.top_p
136
  top_k = request.top_k
@@ -147,7 +147,7 @@ async def generate(request: GenerateRequest):
147
  if "text-to-text" == task_type:
148
  generation_config = GenerationConfig(
149
  temperature=temperature,
150
- max_new_tokens=max_new_tokens,
151
  top_p=top_p,
152
  top_k=top_k,
153
  repetition_penalty=repetition_penalty,
@@ -158,7 +158,7 @@ async def generate(request: GenerateRequest):
158
  return StreamingResponse(
159
  stream_text(model, tokenizer, input_text,
160
  generation_config, stop_sequences,
161
- device),
162
  media_type="text/plain"
163
  )
164
  else:
@@ -172,7 +172,7 @@ async def generate(request: GenerateRequest):
172
 
173
  async def stream_text(model, tokenizer, input_text,
174
  generation_config, stop_sequences,
175
- device, max_length=2048):
176
  encoded_input = tokenizer(
177
  input_text, return_tensors="pt",
178
  truncation=True, max_length=max_length
@@ -186,6 +186,7 @@ async def stream_text(model, tokenizer, input_text,
186
  generation_config.max_new_tokens = min(
187
  remaining_tokens, generation_config.max_new_tokens
188
  )
 
189
 
190
  def find_stop(output_text, stop_sequences):
191
  for seq in stop_sequences:
@@ -222,23 +223,28 @@ async def stream_text(model, tokenizer, input_text,
222
 
223
  if stop_index != -1:
224
  final_output = output_text[:stop_index]
 
 
225
 
226
- for char in final_output:
227
- yield json.dumps({"text": char, "is_end": False}) + "\n"
228
-
229
 
230
  yield json.dumps({"text": "", "is_end": True}) + "\n"
231
  break
232
 
233
  else:
 
 
234
 
235
- for char in new_text:
236
- yield json.dumps({"text": char, "is_end": False}) + "\n"
237
 
238
  if len(output_text) >= generation_config.max_new_tokens:
239
- for char in output_text:
240
- yield json.dumps({"text": char, "is_end": False}) + "\n"
241
 
 
 
242
  yield json.dumps({"text": "", "is_end": True}) + "\n"
243
  break
244
 
 
41
  input_text: str = ""
42
  task_type: str
43
  temperature: float = 1.0
44
+ max_new_tokens: int = 200 # this will be limited to 10
45
  stream: bool = True
46
  top_p: float = 1.0
47
  top_k: int = 50
 
130
  input_text = request.input_text
131
  task_type = request.task_type
132
  temperature = request.temperature
133
+ max_new_tokens = request.max_new_tokens #This value will be used to constraint the output
134
  stream = request.stream
135
  top_p = request.top_p
136
  top_k = request.top_k
 
147
  if "text-to-text" == task_type:
148
  generation_config = GenerationConfig(
149
  temperature=temperature,
150
+ max_new_tokens=min(max_new_tokens,10), # Constrain max_new_tokens to 10
151
  top_p=top_p,
152
  top_k=top_k,
153
  repetition_penalty=repetition_penalty,
 
158
  return StreamingResponse(
159
  stream_text(model, tokenizer, input_text,
160
  generation_config, stop_sequences,
161
+ device, max_length=10),
162
  media_type="text/plain"
163
  )
164
  else:
 
172
 
173
  async def stream_text(model, tokenizer, input_text,
174
  generation_config, stop_sequences,
175
+ device, max_length):
176
  encoded_input = tokenizer(
177
  input_text, return_tensors="pt",
178
  truncation=True, max_length=max_length
 
186
  generation_config.max_new_tokens = min(
187
  remaining_tokens, generation_config.max_new_tokens
188
  )
189
+
190
 
191
  def find_stop(output_text, stop_sequences):
192
  for seq in stop_sequences:
 
223
 
224
  if stop_index != -1:
225
  final_output = output_text[:stop_index]
226
+ chunked_output = [final_output[i:i+10]
227
+ for i in range(0, len(final_output), 10)]
228
 
229
+ for chunk in chunked_output:
230
+ yield json.dumps({"text": chunk, "is_end": False}) + "\n"
 
231
 
232
  yield json.dumps({"text": "", "is_end": True}) + "\n"
233
  break
234
 
235
  else:
236
+ chunked_output = [new_text[i:i+10]
237
+ for i in range(0, len(new_text), 10)]
238
 
239
+ for chunk in chunked_output:
240
+ yield json.dumps({"text": chunk, "is_end": False}) + "\n"
241
 
242
  if len(output_text) >= generation_config.max_new_tokens:
243
+ chunked_output = [output_text[i:i+10]
244
+ for i in range(0, len(output_text), 10)]
245
 
246
+ for chunk in chunked_output:
247
+ yield json.dumps({"text": chunk, "is_end": False}) + "\n"
248
  yield json.dumps({"text": "", "is_end": True}) + "\n"
249
  break
250