Hjgugugjhuhjggg commited on
Commit
c8741b0
·
verified ·
1 Parent(s): e77c20c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -26
app.py CHANGED
@@ -42,7 +42,7 @@ class GenerateRequest(BaseModel):
42
  input_text: str = ""
43
  task_type: str
44
  temperature: float = 1.0
45
- max_new_tokens: int = 200 # this will be limited to 10
46
  stream: bool = True
47
  top_p: float = 1.0
48
  top_k: int = 50
@@ -146,7 +146,7 @@ async def generate(request: GenerateRequest):
146
  input_text = request.input_text
147
  task_type = request.task_type
148
  temperature = request.temperature
149
- max_new_tokens = request.max_new_tokens # This value will be used to constraint the output
150
  stream = request.stream
151
  top_p = request.top_p
152
  top_k = request.top_k
@@ -162,7 +162,7 @@ async def generate(request: GenerateRequest):
162
  if "text-to-text" == task_type:
163
  generation_config = GenerationConfig(
164
  temperature=temperature,
165
- max_new_tokens=min(max_new_tokens, 10), # Constrain max_new_tokens to 10
166
  top_p=top_p,
167
  top_k=top_k,
168
  repetition_penalty=repetition_penalty,
@@ -173,7 +173,7 @@ async def generate(request: GenerateRequest):
173
  return StreamingResponse(
174
  stream_text(model, tokenizer, input_text,
175
  generation_config, stop_sequences,
176
- device, max_length=10),
177
  media_type="text/plain"
178
  )
179
  else:
@@ -187,21 +187,14 @@ async def generate(request: GenerateRequest):
187
 
188
  async def stream_text(model, tokenizer, input_text,
189
  generation_config, stop_sequences,
190
- device, max_length):
 
 
191
  encoded_input = tokenizer(
192
  input_text, return_tensors="pt",
193
  truncation=True, max_length=max_length
194
  ).to(device)
195
- input_length = encoded_input["input_ids"].shape[1]
196
- remaining_tokens = max_length - input_length
197
-
198
- if remaining_tokens <= 0:
199
- yield ""
200
 
201
- generation_config.max_new_tokens = min(
202
- remaining_tokens, generation_config.max_new_tokens
203
- )
204
-
205
 
206
  def find_stop(output_text, stop_sequences):
207
  for seq in stop_sequences:
@@ -210,10 +203,21 @@ async def stream_text(model, tokenizer, input_text,
210
  return last_index + len(seq)
211
 
212
  return -1
213
-
214
  output_text = ""
215
-
216
  while True:
 
 
 
 
 
 
 
 
 
 
 
 
217
  outputs = model.generate(
218
  **encoded_input,
219
  do_sample=generation_config.do_sample,
@@ -231,7 +235,7 @@ async def stream_text(model, tokenizer, input_text,
231
  outputs.sequences[0][len(encoded_input["input_ids"][0]):],
232
  skip_special_tokens=True
233
  )
234
-
235
  output_text += new_text
236
 
237
  stop_index = find_stop(output_text, stop_sequences)
@@ -244,26 +248,30 @@ async def stream_text(model, tokenizer, input_text,
244
  yield json.dumps({"text": text, "is_end": False}) + "\n"
245
  yield json.dumps({"text": "", "is_end": True}) + "\n"
246
  break
247
-
248
  else:
249
  for chunk in [new_text[i:i+10] for i in range(0, len(new_text), 10)]:
250
  for text in chunk.split():
251
  yield json.dumps({"text": text, "is_end": False}) + "\n"
252
 
 
 
 
 
 
 
 
 
 
253
 
254
- if len(output_text) >= generation_config.max_new_tokens:
255
-
256
  for chunk in [output_text[i:i+10] for i in range(0, len(output_text), 10)]:
257
  for text in chunk.split():
258
  yield json.dumps({"text": text, "is_end": False}) + "\n"
259
-
260
  yield json.dumps({"text": "", "is_end": True}) + "\n"
261
  break
262
-
263
- encoded_input = tokenizer(
264
- output_text, return_tensors="pt",
265
- truncation=True, max_length=max_length
266
- ).to(device)
267
 
268
  @app.post("/generate-image")
269
  async def generate_image(request: GenerateRequest):
 
42
  input_text: str = ""
43
  task_type: str
44
  temperature: float = 1.0
45
+ max_new_tokens: int = 10
46
  stream: bool = True
47
  top_p: float = 1.0
48
  top_k: int = 50
 
146
  input_text = request.input_text
147
  task_type = request.task_type
148
  temperature = request.temperature
149
+ max_new_tokens = request.max_new_tokens
150
  stream = request.stream
151
  top_p = request.top_p
152
  top_k = request.top_k
 
162
  if "text-to-text" == task_type:
163
  generation_config = GenerationConfig(
164
  temperature=temperature,
165
+ max_new_tokens=max_new_tokens,
166
  top_p=top_p,
167
  top_k=top_k,
168
  repetition_penalty=repetition_penalty,
 
173
  return StreamingResponse(
174
  stream_text(model, tokenizer, input_text,
175
  generation_config, stop_sequences,
176
+ device),
177
  media_type="text/plain"
178
  )
179
  else:
 
187
 
188
  async def stream_text(model, tokenizer, input_text,
189
  generation_config, stop_sequences,
190
+ device):
191
+ max_length=10 #Define the max length to cut the text and generate another response
192
+
193
  encoded_input = tokenizer(
194
  input_text, return_tensors="pt",
195
  truncation=True, max_length=max_length
196
  ).to(device)
 
 
 
 
 
197
 
 
 
 
 
198
 
199
  def find_stop(output_text, stop_sequences):
200
  for seq in stop_sequences:
 
203
  return last_index + len(seq)
204
 
205
  return -1
206
+
207
  output_text = ""
 
208
  while True:
209
+
210
+ input_length = encoded_input["input_ids"].shape[1]
211
+ remaining_tokens = max_length - input_length
212
+
213
+ if remaining_tokens <=0:
214
+ yield json.dumps({"text": "", "is_end": True}) + "\n"
215
+ break
216
+
217
+ generation_config.max_new_tokens = min(
218
+ remaining_tokens, generation_config.max_new_tokens
219
+ )
220
+
221
  outputs = model.generate(
222
  **encoded_input,
223
  do_sample=generation_config.do_sample,
 
235
  outputs.sequences[0][len(encoded_input["input_ids"][0]):],
236
  skip_special_tokens=True
237
  )
238
+
239
  output_text += new_text
240
 
241
  stop_index = find_stop(output_text, stop_sequences)
 
248
  yield json.dumps({"text": text, "is_end": False}) + "\n"
249
  yield json.dumps({"text": "", "is_end": True}) + "\n"
250
  break
 
251
  else:
252
  for chunk in [new_text[i:i+10] for i in range(0, len(new_text), 10)]:
253
  for text in chunk.split():
254
  yield json.dumps({"text": text, "is_end": False}) + "\n"
255
 
256
+
257
+ if len(output_text) >= max_length:
258
+
259
+ encoded_input = tokenizer(
260
+ output_text, return_tensors="pt",
261
+ truncation=True, max_length=max_length
262
+ ).to(device)
263
+
264
+ output_text = ""
265
 
266
+ elif len(output_text) < max_length and len(new_text) == 0:
267
+
268
  for chunk in [output_text[i:i+10] for i in range(0, len(output_text), 10)]:
269
  for text in chunk.split():
270
  yield json.dumps({"text": text, "is_end": False}) + "\n"
271
+
272
  yield json.dumps({"text": "", "is_end": True}) + "\n"
273
  break
274
+
 
 
 
 
275
 
276
  @app.post("/generate-image")
277
  async def generate_image(request: GenerateRequest):