Hjgugugjhuhjggg commited on
Commit
6e7eb77
·
verified ·
1 Parent(s): 9de7b93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -18
app.py CHANGED
@@ -6,7 +6,7 @@ from pydantic import BaseModel, field_validator
6
  from transformers import (
7
  AutoConfig,
8
  pipeline,
9
- AutoModelForSeq2SeqLM, # Changed AutoModelForCausalLM to AutoModelForSeq2SeqLM
10
  AutoTokenizer,
11
  GenerationConfig,
12
  StoppingCriteriaList
@@ -69,7 +69,7 @@ class S3ModelLoader:
69
  s3_uri = self._get_s3_uri(model_name)
70
  try:
71
  config = AutoConfig.from_pretrained(s3_uri, local_files_only=True)
72
- model = AutoModelForSeq2SeqLM.from_pretrained(s3_uri, config=config, local_files_only=True) # Changed AutoModelForCausalLM
73
  tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=True)
74
 
75
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
@@ -80,7 +80,7 @@ class S3ModelLoader:
80
  try:
81
  config = AutoConfig.from_pretrained(model_name)
82
  tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
83
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, config=config) # Changed AutoModelForCausalLM
84
 
85
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
86
  tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
@@ -135,6 +135,7 @@ async def generate(request: GenerateRequest):
135
  raise HTTPException(status_code=500,
136
  detail=f"Internal server error: {str(e)}")
137
 
 
138
  async def stream_text(model, tokenizer, input_text,
139
  generation_config, stop_sequences,
140
  device, chunk_delay, max_length=2048):
@@ -159,38 +160,58 @@ async def stream_text(model, tokenizer, input_text,
159
  return last_index + len(seq)
160
 
161
  return -1
162
-
163
 
164
  output_text = ""
165
 
166
  while True:
167
  outputs = model.generate(
168
- **encoded_input,
169
- do_sample=generation_config.do_sample,
170
- max_new_tokens=generation_config.max_new_tokens,
171
- temperature=generation_config.temperature,
172
- top_p=generation_config.top_p,
173
- top_k=generation_config.top_k,
174
- repetition_penalty=generation_config.repetition_penalty,
175
- num_return_sequences=generation_config.num_return_sequences,
176
- output_scores=True,
177
- return_dict_in_generate=True,
178
  )
179
 
180
  new_text = tokenizer.decode(outputs.sequences[0][len(encoded_input["input_ids"][0]):], skip_special_tokens=True)
181
 
182
  output_text += new_text
183
 
184
- yield new_text
185
- await asyncio.sleep(chunk_delay)
186
-
187
 
188
  stop_index = find_stop(output_text, stop_sequences)
 
189
  if stop_index != -1:
190
- yield output_text[:stop_index]
 
 
 
 
 
 
 
 
 
191
  break
 
 
 
 
 
 
 
192
 
193
  if len(output_text) >= generation_config.max_new_tokens:
 
 
 
 
 
 
194
  break
195
 
196
  encoded_input = tokenizer(output_text,
 
6
  from transformers import (
7
  AutoConfig,
8
  pipeline,
9
+ AutoModelForSeq2SeqLM,
10
  AutoTokenizer,
11
  GenerationConfig,
12
  StoppingCriteriaList
 
69
  s3_uri = self._get_s3_uri(model_name)
70
  try:
71
  config = AutoConfig.from_pretrained(s3_uri, local_files_only=True)
72
+ model = AutoModelForSeq2SeqLM.from_pretrained(s3_uri, config=config, local_files_only=True)
73
  tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=True)
74
 
75
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
 
80
  try:
81
  config = AutoConfig.from_pretrained(model_name)
82
  tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
83
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, config=config)
84
 
85
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
86
  tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
 
135
  raise HTTPException(status_code=500,
136
  detail=f"Internal server error: {str(e)}")
137
 
138
+
139
  async def stream_text(model, tokenizer, input_text,
140
  generation_config, stop_sequences,
141
  device, chunk_delay, max_length=2048):
 
160
  return last_index + len(seq)
161
 
162
  return -1
 
163
 
164
  output_text = ""
165
 
166
  while True:
167
  outputs = model.generate(
168
+ **encoded_input,
169
+ do_sample=generation_config.do_sample,
170
+ max_new_tokens=generation_config.max_new_tokens,
171
+ temperature=generation_config.temperature,
172
+ top_p=generation_config.top_p,
173
+ top_k=generation_config.top_k,
174
+ repetition_penalty=generation_config.repetition_penalty,
175
+ num_return_sequences=generation_config.num_return_sequences,
176
+ output_scores=True,
177
+ return_dict_in_generate=True,
178
  )
179
 
180
  new_text = tokenizer.decode(outputs.sequences[0][len(encoded_input["input_ids"][0]):], skip_special_tokens=True)
181
 
182
  output_text += new_text
183
 
184
+
 
 
185
 
186
  stop_index = find_stop(output_text, stop_sequences)
187
+
188
  if stop_index != -1:
189
+ final_output = output_text[:stop_index]
190
+
191
+
192
+
193
+ chunked_output = [final_output[i:i+10] for i in range(0, len(final_output), 10)]
194
+
195
+ for chunk in chunked_output:
196
+ yield chunk
197
+ await asyncio.sleep(chunk_delay)
198
+
199
  break
200
+
201
+ else:
202
+ chunked_output = [new_text[i:i+10] for i in range(0, len(new_text), 10)]
203
+ for chunk in chunked_output:
204
+ yield chunk
205
+ await asyncio.sleep(chunk_delay)
206
+
207
 
208
  if len(output_text) >= generation_config.max_new_tokens:
209
+
210
+ chunked_output = [output_text[i:i+10] for i in range(0, len(output_text), 10)]
211
+
212
+ for chunk in chunked_output:
213
+ yield chunk
214
+ await asyncio.sleep(chunk_delay)
215
  break
216
 
217
  encoded_input = tokenizer(output_text,