Spaces:
Sleeping
Sleeping
Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ from pydantic import BaseModel, field_validator
|
|
6 |
from transformers import (
|
7 |
AutoConfig,
|
8 |
pipeline,
|
9 |
-
AutoModelForSeq2SeqLM,
|
10 |
AutoTokenizer,
|
11 |
GenerationConfig,
|
12 |
StoppingCriteriaList
|
@@ -69,7 +69,7 @@ class S3ModelLoader:
|
|
69 |
s3_uri = self._get_s3_uri(model_name)
|
70 |
try:
|
71 |
config = AutoConfig.from_pretrained(s3_uri, local_files_only=True)
|
72 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(s3_uri, config=config, local_files_only=True)
|
73 |
tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=True)
|
74 |
|
75 |
if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
|
@@ -80,7 +80,7 @@ class S3ModelLoader:
|
|
80 |
try:
|
81 |
config = AutoConfig.from_pretrained(model_name)
|
82 |
tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
|
83 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, config=config)
|
84 |
|
85 |
if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
|
86 |
tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
|
@@ -135,6 +135,7 @@ async def generate(request: GenerateRequest):
|
|
135 |
raise HTTPException(status_code=500,
|
136 |
detail=f"Internal server error: {str(e)}")
|
137 |
|
|
|
138 |
async def stream_text(model, tokenizer, input_text,
|
139 |
generation_config, stop_sequences,
|
140 |
device, chunk_delay, max_length=2048):
|
@@ -159,38 +160,58 @@ async def stream_text(model, tokenizer, input_text,
|
|
159 |
return last_index + len(seq)
|
160 |
|
161 |
return -1
|
162 |
-
|
163 |
|
164 |
output_text = ""
|
165 |
|
166 |
while True:
|
167 |
outputs = model.generate(
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
)
|
179 |
|
180 |
new_text = tokenizer.decode(outputs.sequences[0][len(encoded_input["input_ids"][0]):], skip_special_tokens=True)
|
181 |
|
182 |
output_text += new_text
|
183 |
|
184 |
-
|
185 |
-
await asyncio.sleep(chunk_delay)
|
186 |
-
|
187 |
|
188 |
stop_index = find_stop(output_text, stop_sequences)
|
|
|
189 |
if stop_index != -1:
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
if len(output_text) >= generation_config.max_new_tokens:
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
break
|
195 |
|
196 |
encoded_input = tokenizer(output_text,
|
|
|
6 |
from transformers import (
|
7 |
AutoConfig,
|
8 |
pipeline,
|
9 |
+
AutoModelForSeq2SeqLM,
|
10 |
AutoTokenizer,
|
11 |
GenerationConfig,
|
12 |
StoppingCriteriaList
|
|
|
69 |
s3_uri = self._get_s3_uri(model_name)
|
70 |
try:
|
71 |
config = AutoConfig.from_pretrained(s3_uri, local_files_only=True)
|
72 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(s3_uri, config=config, local_files_only=True)
|
73 |
tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=True)
|
74 |
|
75 |
if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
|
|
|
80 |
try:
|
81 |
config = AutoConfig.from_pretrained(model_name)
|
82 |
tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
|
83 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, config=config)
|
84 |
|
85 |
if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
|
86 |
tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
|
|
|
135 |
raise HTTPException(status_code=500,
|
136 |
detail=f"Internal server error: {str(e)}")
|
137 |
|
138 |
+
|
139 |
async def stream_text(model, tokenizer, input_text,
|
140 |
generation_config, stop_sequences,
|
141 |
device, chunk_delay, max_length=2048):
|
|
|
160 |
return last_index + len(seq)
|
161 |
|
162 |
return -1
|
|
|
163 |
|
164 |
output_text = ""
|
165 |
|
166 |
while True:
|
167 |
outputs = model.generate(
|
168 |
+
**encoded_input,
|
169 |
+
do_sample=generation_config.do_sample,
|
170 |
+
max_new_tokens=generation_config.max_new_tokens,
|
171 |
+
temperature=generation_config.temperature,
|
172 |
+
top_p=generation_config.top_p,
|
173 |
+
top_k=generation_config.top_k,
|
174 |
+
repetition_penalty=generation_config.repetition_penalty,
|
175 |
+
num_return_sequences=generation_config.num_return_sequences,
|
176 |
+
output_scores=True,
|
177 |
+
return_dict_in_generate=True,
|
178 |
)
|
179 |
|
180 |
new_text = tokenizer.decode(outputs.sequences[0][len(encoded_input["input_ids"][0]):], skip_special_tokens=True)
|
181 |
|
182 |
output_text += new_text
|
183 |
|
184 |
+
|
|
|
|
|
185 |
|
186 |
stop_index = find_stop(output_text, stop_sequences)
|
187 |
+
|
188 |
if stop_index != -1:
|
189 |
+
final_output = output_text[:stop_index]
|
190 |
+
|
191 |
+
|
192 |
+
|
193 |
+
chunked_output = [final_output[i:i+10] for i in range(0, len(final_output), 10)]
|
194 |
+
|
195 |
+
for chunk in chunked_output:
|
196 |
+
yield chunk
|
197 |
+
await asyncio.sleep(chunk_delay)
|
198 |
+
|
199 |
break
|
200 |
+
|
201 |
+
else:
|
202 |
+
chunked_output = [new_text[i:i+10] for i in range(0, len(new_text), 10)]
|
203 |
+
for chunk in chunked_output:
|
204 |
+
yield chunk
|
205 |
+
await asyncio.sleep(chunk_delay)
|
206 |
+
|
207 |
|
208 |
if len(output_text) >= generation_config.max_new_tokens:
|
209 |
+
|
210 |
+
chunked_output = [output_text[i:i+10] for i in range(0, len(output_text), 10)]
|
211 |
+
|
212 |
+
for chunk in chunked_output:
|
213 |
+
yield chunk
|
214 |
+
await asyncio.sleep(chunk_delay)
|
215 |
break
|
216 |
|
217 |
encoded_input = tokenizer(output_text,
|