Hjgugugjhuhjggg commited on
Commit
6de156a
·
verified ·
1 Parent(s): f7e7ec1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -129
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import torch
3
  from fastapi import FastAPI, HTTPException
4
- from fastapi.responses import StreamingResponse, JSONResponse
5
  from pydantic import BaseModel, field_validator
6
  from transformers import (
7
  AutoConfig,
@@ -10,7 +10,7 @@ from transformers import (
10
  AutoTokenizer,
11
  GenerationConfig,
12
  StoppingCriteria,
13
- StoppingCriteriaList,
14
  )
15
  import boto3
16
  import uvicorn
@@ -20,6 +20,7 @@ import json
20
  from huggingface_hub import login
21
  import base64
22
  from botocore.exceptions import NoCredentialsError
 
23
 
24
 
25
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
@@ -44,10 +45,10 @@ class GenerateRequest(BaseModel):
44
  input_text: str = ""
45
  task_type: str
46
  temperature: float = 1.0
47
- max_new_tokens: int = 3
48
- stream: bool = True # Set default stream to True
49
  top_p: float = 1.0
50
- top_k: int = 50
51
  repetition_penalty: float = 1.0
52
  num_return_sequences: int = 1
53
  do_sample: bool = True
@@ -93,33 +94,44 @@ class S3ModelLoader:
93
  )
94
 
95
  tokenizer = AutoTokenizer.from_pretrained(
96
- s3_uri, config=config, local_files_only=False
97
  )
 
 
 
 
 
 
 
 
 
 
98
 
99
- if tokenizer.eos_token_id is not None and \
100
- tokenizer.pad_token_id is None:
101
- tokenizer.pad_token_id = config.pad_token_id \
102
- or tokenizer.eos_token_id
103
- model_cache[model_name] = (model, tokenizer)
104
- return model, tokenizer
105
  except (EnvironmentError, NoCredentialsError):
106
  try:
107
  config = AutoConfig.from_pretrained(
108
  model_name, token=HUGGINGFACE_HUB_TOKEN
109
  )
110
  tokenizer = AutoTokenizer.from_pretrained(
111
- model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
112
  )
113
 
114
  model = AutoModelForCausalLM.from_pretrained(
115
  model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
116
  )
 
 
 
 
 
 
117
 
118
-
119
- if tokenizer.eos_token_id is not None and \
120
- tokenizer.pad_token_id is None:
121
- tokenizer.pad_token_id = config.pad_token_id \
122
- or tokenizer.eos_token_id
123
 
124
 
125
  model.save_pretrained(s3_uri)
@@ -135,10 +147,22 @@ class S3ModelLoader:
135
  )
136
 
137
  tokenizer = AutoTokenizer.from_pretrained(
138
- s3_uri, config=config, local_files_only=False
139
  )
140
- model_cache[model_name] = (model, tokenizer)
141
- return model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
142
  except Exception as e:
143
  raise HTTPException(
144
  status_code=500, detail=f"Error loading model: {e}"
@@ -146,6 +170,37 @@ class S3ModelLoader:
146
 
147
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  @app.post("/generate")
150
  async def generate(request: GenerateRequest):
151
  try:
@@ -154,7 +209,7 @@ async def generate(request: GenerateRequest):
154
  task_type = request.task_type
155
  temperature = request.temperature
156
  max_new_tokens = request.max_new_tokens
157
- stream = request.stream
158
  top_p = request.top_p
159
  top_k = request.top_k
160
  repetition_penalty = request.repetition_penalty
@@ -162,10 +217,10 @@ async def generate(request: GenerateRequest):
162
  do_sample = request.do_sample
163
  stop_sequences = request.stop_sequences
164
 
165
- model, tokenizer = await model_loader.load_model_and_tokenizer(model_name)
166
  device = "cuda" if torch.cuda.is_available() else "cpu"
167
  model.to(device)
168
-
169
  if "text-to-text" == task_type:
170
  generation_config = GenerationConfig(
171
  temperature=temperature,
@@ -175,20 +230,29 @@ async def generate(request: GenerateRequest):
175
  repetition_penalty=repetition_penalty,
176
  do_sample=do_sample,
177
  num_return_sequences=num_return_sequences,
178
- eos_token_id = tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  )
180
- if stream:
181
- return StreamingResponse(
182
- stream_text(model, tokenizer, input_text,
183
- generation_config, stop_sequences,
184
- device),
185
- media_type="text/plain"
186
- )
187
- else:
188
- result = await generate_text(model, tokenizer, input_text,
189
- generation_config, stop_sequences,
190
- device)
191
- return JSONResponse({"text": result, "is_end": True})
192
  else:
193
  return HTTPException(status_code=400, detail="Task type not text-to-text")
194
 
@@ -197,110 +261,112 @@ async def generate(request: GenerateRequest):
197
  status_code=500, detail=f"Internal server error: {str(e)}"
198
  )
199
 
200
- class StopOnSequences(StoppingCriteria):
201
- def __init__(self, stop_sequences, tokenizer):
202
- self.stop_sequences = stop_sequences
203
- self.tokenizer = tokenizer
204
- self.stop_ids = [tokenizer.encode(seq, add_special_tokens=False) for seq in stop_sequences]
205
-
206
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
207
-
208
- decoded_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
209
-
210
- for stop_sequence in self.stop_sequences:
211
- if stop_sequence in decoded_text:
212
- return True
213
- return False
214
 
215
  async def stream_text(model, tokenizer, input_text,
216
  generation_config, stop_sequences,
217
- device):
 
218
 
219
- encoded_input = tokenizer(
220
- input_text, return_tensors="pt",
221
- truncation=True
222
- ).to(device)
223
 
224
- stop_criteria = StopOnSequences(stop_sequences, tokenizer)
225
- stopping_criteria = StoppingCriteriaList([stop_criteria])
226
-
227
- output_text = ""
228
 
229
- while True:
 
 
 
 
230
 
231
- outputs = await asyncio.to_thread(model.generate,
232
- **encoded_input,
233
- do_sample=generation_config.do_sample,
234
- max_new_tokens=generation_config.max_new_tokens,
235
- temperature=generation_config.temperature,
236
- top_p=generation_config.top_p,
237
- top_k=generation_config.top_k,
238
- repetition_penalty=generation_config.repetition_penalty,
239
- num_return_sequences=generation_config.num_return_sequences,
240
- output_scores=True,
241
- return_dict_in_generate=True,
242
- stopping_criteria=stopping_criteria
243
- )
244
-
245
- new_text = tokenizer.decode(
246
- outputs.sequences[0][len(encoded_input["input_ids"][0]):],
247
- skip_special_tokens=True
248
- )
249
-
250
- if len(new_text) == 0:
251
- if not stop_criteria(outputs.sequences, None):
252
- for text in output_text.split():
253
- yield json.dumps({"text": text, "is_end": False}) + "\n"
254
- yield json.dumps({"text": "", "is_end": True}) + "\n"
255
- break
256
-
257
- output_text += new_text
258
-
259
- for text in new_text.split():
260
- yield json.dumps({"text": text, "is_end": False}) + "\n"
261
-
262
- if stop_criteria(outputs.sequences, None):
263
- yield json.dumps({"text": "", "is_end": True}) + "\n"
264
- break
265
 
266
- encoded_input = tokenizer(
267
- output_text, return_tensors="pt",
268
- truncation=True
269
- ).to(device)
270
- output_text = ""
271
-
272
-
273
- async def generate_text(model, tokenizer, input_text,
274
- generation_config, stop_sequences,
275
- device):
276
- encoded_input = tokenizer(
277
- input_text, return_tensors="pt",
278
- truncation=True
279
- ).to(device)
280
 
281
- stop_criteria = StopOnSequences(stop_sequences, tokenizer)
282
- stopping_criteria = StoppingCriteriaList([stop_criteria])
283
-
284
- outputs = await asyncio.to_thread(model.generate,
285
- **encoded_input,
286
- do_sample=generation_config.do_sample,
287
- max_new_tokens=generation_config.max_new_tokens,
288
- temperature=generation_config.temperature,
289
- top_p=generation_config.top_p,
290
- top_k=generation_config.top_k,
291
- repetition_penalty=generation_config.repetition_penalty,
292
- num_return_sequences=generation_config.num_return_sequences,
293
- output_scores=True,
294
- return_dict_in_generate=True,
295
- stopping_criteria=stopping_criteria
296
- )
297
 
 
 
298
 
299
- generated_text = tokenizer.decode(
300
- outputs.sequences[0], skip_special_tokens=True
301
- )
 
 
 
 
 
 
 
 
 
 
302
 
303
- return generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
 
306
  @app.post("/generate-image")
 
1
  import os
2
  import torch
3
  from fastapi import FastAPI, HTTPException
4
+ from fastapi.responses import StreamingResponse
5
  from pydantic import BaseModel, field_validator
6
  from transformers import (
7
  AutoConfig,
 
10
  AutoTokenizer,
11
  GenerationConfig,
12
  StoppingCriteria,
13
+ StoppingCriteriaList
14
  )
15
  import boto3
16
  import uvicorn
 
20
  from huggingface_hub import login
21
  import base64
22
  from botocore.exceptions import NoCredentialsError
23
+ import re
24
 
25
 
26
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
 
45
  input_text: str = ""
46
  task_type: str
47
  temperature: float = 1.0
48
+ max_new_tokens: int = 3
49
+ stream: bool = True
50
  top_p: float = 1.0
51
+ top_k: int = 50
52
  repetition_penalty: float = 1.0
53
  num_return_sequences: int = 1
54
  do_sample: bool = True
 
94
  )
95
 
96
  tokenizer = AutoTokenizer.from_pretrained(
97
+ s3_uri, config=config, local_files_only=False, padding_side="left"
98
  )
99
+
100
+ eos_token_id = tokenizer.eos_token_id
101
+ pad_token_id = tokenizer.pad_token_id
102
+ eos_token = tokenizer.eos_token
103
+ pad_token = tokenizer.pad_token
104
+ padding = tokenizer.padding_side
105
+
106
+ if eos_token_id is not None and pad_token_id is None:
107
+ pad_token_id = config.pad_token_id or eos_token_id
108
+ tokenizer.pad_token_id = pad_token_id
109
 
110
+ model_cache[model_name] = (model, tokenizer,eos_token_id,
111
+ pad_token_id,eos_token,pad_token,padding)
112
+ return model, tokenizer,eos_token_id,pad_token_id,eos_token,pad_token,padding
 
 
 
113
  except (EnvironmentError, NoCredentialsError):
114
  try:
115
  config = AutoConfig.from_pretrained(
116
  model_name, token=HUGGINGFACE_HUB_TOKEN
117
  )
118
  tokenizer = AutoTokenizer.from_pretrained(
119
+ model_name, config=config, token=HUGGINGFACE_HUB_TOKEN, padding_side="left"
120
  )
121
 
122
  model = AutoModelForCausalLM.from_pretrained(
123
  model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
124
  )
125
+
126
+ eos_token_id = tokenizer.eos_token_id
127
+ pad_token_id = tokenizer.pad_token_id
128
+ eos_token = tokenizer.eos_token
129
+ pad_token = tokenizer.pad_token
130
+ padding = tokenizer.padding_side
131
 
132
+ if eos_token_id is not None and pad_token_id is None:
133
+ pad_token_id = config.pad_token_id or eos_token_id
134
+ tokenizer.pad_token_id = pad_token_id
 
 
135
 
136
 
137
  model.save_pretrained(s3_uri)
 
147
  )
148
 
149
  tokenizer = AutoTokenizer.from_pretrained(
150
+ s3_uri, config=config, local_files_only=False, padding_side="left"
151
  )
152
+
153
+ eos_token_id = tokenizer.eos_token_id
154
+ pad_token_id = tokenizer.pad_token_id
155
+ eos_token = tokenizer.eos_token
156
+ pad_token = tokenizer.pad_token
157
+ padding = tokenizer.padding_side
158
+
159
+ if eos_token_id is not None and pad_token_id is None:
160
+ pad_token_id = config.pad_token_id or eos_token_id
161
+ tokenizer.pad_token_id = pad_token_id
162
+
163
+ model_cache[model_name] = (model, tokenizer,eos_token_id,
164
+ pad_token_id,eos_token,pad_token,padding)
165
+ return model, tokenizer,eos_token_id,pad_token_id,eos_token,pad_token,padding
166
  except Exception as e:
167
  raise HTTPException(
168
  status_code=500, detail=f"Error loading model: {e}"
 
170
 
171
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
172
 
173
+ class StopOnSequencesCriteria(StoppingCriteria):
174
+ def __init__(self, stop_sequences, tokenizer):
175
+ self.stop_sequences = stop_sequences
176
+ self.tokenizer = tokenizer
177
+
178
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
179
+
180
+ decoded_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
181
+
182
+ for seq in self.stop_sequences:
183
+ if seq in decoded_text:
184
+ return True
185
+ return False
186
+
187
+ async def generate_stream(model, tokenizer, input_text,
188
+ generation_config, stop_sequences,
189
+ device, pad_token_id, max_model_length,
190
+ max_new_tokens):
191
+ async def stream():
192
+ past_key_values = None
193
+ input_ids = None
194
+ async for token,past_key_values_response,input_ids_response, is_end in stream_text(model, tokenizer, input_text,
195
+ generation_config, stop_sequences,
196
+ device,pad_token_id, max_model_length, max_new_tokens, past_key_values, input_ids):
197
+ past_key_values = past_key_values_response
198
+ input_ids = input_ids_response
199
+ if is_end:
200
+ break
201
+ yield token
202
+ return stream()
203
+
204
  @app.post("/generate")
205
  async def generate(request: GenerateRequest):
206
  try:
 
209
  task_type = request.task_type
210
  temperature = request.temperature
211
  max_new_tokens = request.max_new_tokens
212
+ stream = True
213
  top_p = request.top_p
214
  top_k = request.top_k
215
  repetition_penalty = request.repetition_penalty
 
217
  do_sample = request.do_sample
218
  stop_sequences = request.stop_sequences
219
 
220
+ model, tokenizer, eos_token_id, pad_token_id, eos_token, pad_token, padding = await model_loader.load_model_and_tokenizer(model_name)
221
  device = "cuda" if torch.cuda.is_available() else "cpu"
222
  model.to(device)
223
+
224
  if "text-to-text" == task_type:
225
  generation_config = GenerationConfig(
226
  temperature=temperature,
 
230
  repetition_penalty=repetition_penalty,
231
  do_sample=do_sample,
232
  num_return_sequences=num_return_sequences,
233
+ pad_token_id=pad_token_id if pad_token_id is not None else None
234
+ )
235
+
236
+ max_model_length = 3
237
+ input_text = input_text[:max_model_length]
238
+
239
+ streams = [
240
+ generate_stream(model, tokenizer, input_text,
241
+ generation_config, stop_sequences,
242
+ device,pad_token_id, max_model_length, max_new_tokens)
243
+ for _ in range(num_return_sequences)
244
+ ]
245
+
246
+
247
+ async def stream_response():
248
+ for stream in asyncio.as_completed(streams):
249
+ async for chunk in await stream:
250
+ yield chunk
251
+ return StreamingResponse(
252
+ stream_response(),
253
+ media_type="text/plain"
254
  )
255
+
 
 
 
 
 
 
 
 
 
 
 
256
  else:
257
  return HTTPException(status_code=400, detail="Task type not text-to-text")
258
 
 
261
  status_code=500, detail=f"Internal server error: {str(e)}"
262
  )
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
  async def stream_text(model, tokenizer, input_text,
266
  generation_config, stop_sequences,
267
+ device,pad_token_id, max_model_length, max_new_tokens,
268
+ past_key_values, input_ids):
269
 
 
 
 
 
270
 
271
+ stop_regex = re.compile(r'[\.\?\!\n]+')
 
 
 
272
 
273
+ def find_stop(output_text, stop_sequences):
274
+ for seq in stop_sequences:
275
+ if seq in output_text:
276
+ last_index = output_text.rfind(seq)
277
+ return last_index + len(seq)
278
 
279
+ match = stop_regex.search(output_text)
280
+ if match:
281
+ return match.end()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
+ return -1
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
+ output_text = ""
287
+ stop_criteria = StoppingCriteriaList([StopOnSequencesCriteria(stop_sequences, tokenizer)])
288
 
289
+ if input_ids is None:
290
+ encoded_input = tokenizer(
291
+ input_text, return_tensors="pt",
292
+ truncation=True,
293
+ padding = "max_length",
294
+ max_length=max_model_length
295
+ ).to(device)
296
+ input_ids = encoded_input.input_ids
297
+ else:
298
+ encoded_input = {
299
+ "input_ids":input_ids,
300
+ "past_key_values": past_key_values
301
+ }
302
 
303
+ while True:
304
+
305
+ outputs = model.generate(
306
+ **encoded_input,
307
+ do_sample=generation_config.do_sample,
308
+ max_new_tokens=generation_config.max_new_tokens,
309
+ temperature=generation_config.temperature,
310
+ top_p=generation_config.top_p,
311
+ top_k=generation_config.top_k,
312
+ repetition_penalty=generation_config.repetition_penalty,
313
+ num_return_sequences=generation_config.num_return_sequences,
314
+ output_scores=True,
315
+ return_dict_in_generate=True,
316
+ pad_token_id=pad_token_id if pad_token_id is not None else None,
317
+ stopping_criteria = stop_criteria,
318
+ )
319
+
320
+ new_text = tokenizer.decode(
321
+ outputs.sequences[0][len(encoded_input["input_ids"][0]):],
322
+ skip_special_tokens=True
323
+ )
324
+
325
+ output_text += new_text
326
+
327
+ stop_index = find_stop(output_text, stop_sequences)
328
+
329
+ is_end = False
330
+ if stop_index != -1 or (hasattr(outputs, "sequences") and outputs.sequences[0][-1] == tokenizer.eos_token_id):
331
+ final_output = output_text[:stop_index] if stop_index != -1 else output_text
332
+
333
+ for text in final_output.split():
334
+ yield json.dumps({"text": text, "is_end": False, "temperature": generation_config.temperature, "top_p": generation_config.top_p, "top_k": generation_config.top_k}) + "\n", \
335
+ outputs.past_key_values if hasattr(outputs, "past_key_values") else None , \
336
+ outputs.sequences if hasattr(outputs, "sequences") else None, True
337
+
338
+ yield json.dumps({"text": "", "is_end": True, "temperature": generation_config.temperature, "top_p": generation_config.top_p, "top_k": generation_config.top_k}) + "\n",\
339
+ outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
340
+ outputs.sequences if hasattr(outputs, "sequences") else None, True
341
+ break
342
+ else:
343
+
344
+ tokens = new_text.split()
345
+
346
+ for i in range(0, len(tokens), max_new_tokens):
347
+ chunk = tokens[i:i + max_new_tokens]
348
+ chunk_text = " ".join(chunk)
349
+ for text in chunk_text.split():
350
+ yield json.dumps({"text": text, "is_end": False, "temperature": generation_config.temperature, "top_p": generation_config.top_p, "top_k": generation_config.top_k}) + "\n", \
351
+ outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
352
+ outputs.sequences if hasattr(outputs, "sequences") else None, False
353
+
354
+ if len(new_text) == 0:
355
+
356
+ for text in output_text.split():
357
+ yield json.dumps({"text": text, "is_end": False, "temperature": generation_config.temperature, "top_p": generation_config.top_p, "top_k": generation_config.top_k}) + "\n", \
358
+ outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
359
+ outputs.sequences if hasattr(outputs, "sequences") else None, True
360
+ yield json.dumps({"text": "", "is_end": True, "temperature": generation_config.temperature, "top_p": generation_config.top_p, "top_k": generation_config.top_k}) + "\n",\
361
+ outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
362
+ outputs.sequences if hasattr(outputs, "sequences") else None, True
363
+ break
364
+
365
+ past_key_values = outputs.past_key_values if hasattr(outputs, "past_key_values") else None
366
+ input_ids = outputs.sequences if hasattr(outputs, "sequences") else None
367
+
368
+ output_text = ""
369
+
370
 
371
 
372
  @app.post("/generate-image")