Hjgugugjhuhjggg commited on
Commit
116d7b7
·
verified ·
1 Parent(s): 8bcaa35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -15
app.py CHANGED
@@ -130,7 +130,10 @@ async def generate(request: GenerateRequest):
130
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
131
 
132
  async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay):
133
- encoded_input = tokenizer(input_text, return_tensors="pt").to(device)
 
 
 
134
 
135
  def stop_criteria(input_ids, scores):
136
  decoded_output = tokenizer.decode(input_ids[0], skip_special_tokens=True)
@@ -144,20 +147,25 @@ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequ
144
  token_buffer = []
145
  output_ids = encoded_input.input_ids
146
  while True:
147
- outputs = model.generate(
148
- output_ids,
149
- do_sample=generation_config.do_sample,
150
- max_new_tokens=generation_config.max_new_tokens,
151
- temperature=generation_config.temperature,
152
- top_p=generation_config.top_p,
153
- top_k=generation_config.top_k,
154
- repetition_penalty=generation_config.repetition_penalty,
155
- num_return_sequences=generation_config.num_return_sequences,
156
- stopping_criteria=stopping_criteria,
157
- output_scores=True,
158
- return_dict_in_generate=True,
159
- pad_token_id=tokenizer.pad_token_id
160
- )
 
 
 
 
 
161
  new_token_ids = outputs.sequences[0][encoded_input.input_ids.shape[-1]:]
162
 
163
  for token_id in new_token_ids:
 
130
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
131
 
132
  async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay):
133
+ # Get the maximum model input length
134
+ max_model_length = model.config.max_position_embeddings
135
+
136
+ encoded_input = tokenizer(input_text, return_tensors="pt", max_length=max_model_length, truncation=True).to(device)
137
 
138
  def stop_criteria(input_ids, scores):
139
  decoded_output = tokenizer.decode(input_ids[0], skip_special_tokens=True)
 
147
  token_buffer = []
148
  output_ids = encoded_input.input_ids
149
  while True:
150
+ try:
151
+ outputs = model.generate(
152
+ output_ids,
153
+ do_sample=generation_config.do_sample,
154
+ max_new_tokens=generation_config.max_new_tokens,
155
+ temperature=generation_config.temperature,
156
+ top_p=generation_config.top_p,
157
+ top_k=generation_config.top_k,
158
+ repetition_penalty=generation_config.repetition_penalty,
159
+ num_return_sequences=generation_config.num_return_sequences,
160
+ stopping_criteria=stopping_criteria,
161
+ output_scores=True,
162
+ return_dict_in_generate=True,
163
+ pad_token_id=tokenizer.pad_token_id
164
+ )
165
+ except IndexError as e:
166
+ print(f"IndexError during generation: {e}")
167
+ break # Exit the loop if there's an index error
168
+
169
  new_token_ids = outputs.sequences[0][encoded_input.input_ids.shape[-1]:]
170
 
171
  for token_id in new_token_ids: