Hjgugugjhuhjggg commited on
Commit
a5d4be8
·
verified ·
1 Parent(s): 282a362

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -38
app.py CHANGED
@@ -130,32 +130,36 @@ async def generate(request: GenerateRequest):
130
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
131
 
132
  async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay):
133
- encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True).to(device)
134
 
135
  def stop_criteria(input_ids, scores):
136
- decoded_output = tokenizer.decode(int(input_ids[0][-1]), skip_special_tokens=True)
137
- return decoded_output in stop_sequences
 
 
 
138
 
139
  stopping_criteria = StoppingCriteriaList([stop_criteria])
140
 
141
  token_buffer = []
142
- outputs = model.generate(
143
- **encoded_input,
144
- do_sample=generation_config.do_sample,
145
- max_new_tokens=generation_config.max_new_tokens,
146
- temperature=generation_config.temperature,
147
- top_p=generation_config.top_p,
148
- top_k=generation_config.top_k,
149
- repetition_penalty=generation_config.repetition_penalty,
150
- num_return_sequences=generation_config.num_return_sequences,
151
- stopping_criteria=stopping_criteria,
152
- output_scores=True,
153
- return_dict_in_generate=True,
154
- streamer=None # Ensure streamer is None for manual token processing
155
- )
156
-
157
- for output in outputs.sequences:
158
- for token_id in output:
 
159
  token = tokenizer.decode(token_id, skip_special_tokens=True)
160
  token_buffer.append(token)
161
  if len(token_buffer) >= 10:
@@ -167,26 +171,13 @@ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequ
167
  yield "".join(token_buffer)
168
  token_buffer = []
169
 
170
- if stop_sequences and any(stop in tokenizer.decode(output, skip_special_tokens=True) for stop in stop_sequences):
171
- return
172
 
173
- encoded_input = tokenizer.build_inputs_with_special_tokens(output)
174
- encoded_input = {'input_ids': torch.tensor([encoded_input]).to(device)}
175
 
176
- outputs = model.generate(
177
- **encoded_input,
178
- do_sample=generation_config.do_sample,
179
- max_new_tokens=generation_config.max_new_tokens,
180
- temperature=generation_config.temperature,
181
- top_p=generation_config.top_p,
182
- top_k=generation_config.top_k,
183
- repetition_penalty=generation_config.repetition_penalty,
184
- num_return_sequences=generation_config.num_return_sequences,
185
- stopping_criteria=stopping_criteria,
186
- output_scores=True,
187
- return_dict_in_generate=True,
188
- streamer=None
189
- )
190
 
191
  @app.post("/generate-image")
192
  async def generate_image(request: GenerateRequest):
 
130
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
131
 
132
  async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay):
133
+ encoded_input = tokenizer(input_text, return_tensors="pt").to(device)
134
 
135
  def stop_criteria(input_ids, scores):
136
+ decoded_output = tokenizer.decode(input_ids[0], skip_special_tokens=True)
137
+ for stop in stop_sequences:
138
+ if decoded_output.endswith(stop):
139
+ return True
140
+ return False
141
 
142
  stopping_criteria = StoppingCriteriaList([stop_criteria])
143
 
144
  token_buffer = []
145
+ output_ids = encoded_input.input_ids
146
+ while True:
147
+ outputs = model.generate(
148
+ output_ids,
149
+ do_sample=generation_config.do_sample,
150
+ max_new_tokens=generation_config.max_new_tokens,
151
+ temperature=generation_config.temperature,
152
+ top_p=generation_config.top_p,
153
+ top_k=generation_config.top_k,
154
+ repetition_penalty=generation_config.repetition_penalty,
155
+ num_return_sequences=generation_config.num_return_sequences,
156
+ stopping_criteria=stopping_criteria,
157
+ output_scores=True,
158
+ return_dict_in_generate=True
159
+ )
160
+ new_token_ids = outputs.sequences[0][encoded_input.input_ids.shape[-1]:]
161
+
162
+ for token_id in new_token_ids:
163
  token = tokenizer.decode(token_id, skip_special_tokens=True)
164
  token_buffer.append(token)
165
  if len(token_buffer) >= 10:
 
171
  yield "".join(token_buffer)
172
  token_buffer = []
173
 
174
+ if stop_criteria(outputs.sequences, None):
175
+ break
176
 
177
+ if len(new_token_ids) < generation_config.max_new_tokens:
178
+ break
179
 
180
+ output_ids = outputs.sequences
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  @app.post("/generate-image")
183
  async def generate_image(request: GenerateRequest):