Hjgugugjhuhjggg commited on
Commit
2e9ad50
·
verified ·
1 Parent(s): 6de156a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -203
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import torch
3
  from fastapi import FastAPI, HTTPException
4
- from fastapi.responses import StreamingResponse
5
  from pydantic import BaseModel, field_validator
6
  from transformers import (
7
  AutoConfig,
@@ -10,17 +10,15 @@ from transformers import (
10
  AutoTokenizer,
11
  GenerationConfig,
12
  StoppingCriteria,
13
- StoppingCriteriaList
14
  )
15
  import boto3
16
  import uvicorn
17
- import asyncio
18
  from transformers import pipeline
19
  import json
20
  from huggingface_hub import login
21
  import base64
22
  from botocore.exceptions import NoCredentialsError
23
- import re
24
 
25
 
26
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
@@ -45,10 +43,10 @@ class GenerateRequest(BaseModel):
45
  input_text: str = ""
46
  task_type: str
47
  temperature: float = 1.0
48
- max_new_tokens: int = 3
49
- stream: bool = True
50
  top_p: float = 1.0
51
- top_k: int = 50
52
  repetition_penalty: float = 1.0
53
  num_return_sequences: int = 1
54
  do_sample: bool = True
@@ -79,7 +77,7 @@ class S3ModelLoader:
79
  return f"s3://{self.bucket_name}/" \
80
  f"{model_name.replace('/', '-')}"
81
 
82
- async def load_model_and_tokenizer(self, model_name):
83
  if model_name in model_cache:
84
  return model_cache[model_name]
85
 
@@ -94,44 +92,33 @@ class S3ModelLoader:
94
  )
95
 
96
  tokenizer = AutoTokenizer.from_pretrained(
97
- s3_uri, config=config, local_files_only=False, padding_side="left"
98
  )
99
-
100
- eos_token_id = tokenizer.eos_token_id
101
- pad_token_id = tokenizer.pad_token_id
102
- eos_token = tokenizer.eos_token
103
- pad_token = tokenizer.pad_token
104
- padding = tokenizer.padding_side
105
-
106
- if eos_token_id is not None and pad_token_id is None:
107
- pad_token_id = config.pad_token_id or eos_token_id
108
- tokenizer.pad_token_id = pad_token_id
109
 
110
- model_cache[model_name] = (model, tokenizer,eos_token_id,
111
- pad_token_id,eos_token,pad_token,padding)
112
- return model, tokenizer,eos_token_id,pad_token_id,eos_token,pad_token,padding
 
 
 
113
  except (EnvironmentError, NoCredentialsError):
114
  try:
115
  config = AutoConfig.from_pretrained(
116
  model_name, token=HUGGINGFACE_HUB_TOKEN
117
  )
118
  tokenizer = AutoTokenizer.from_pretrained(
119
- model_name, config=config, token=HUGGINGFACE_HUB_TOKEN, padding_side="left"
120
  )
121
 
122
  model = AutoModelForCausalLM.from_pretrained(
123
  model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
124
  )
125
-
126
- eos_token_id = tokenizer.eos_token_id
127
- pad_token_id = tokenizer.pad_token_id
128
- eos_token = tokenizer.eos_token
129
- pad_token = tokenizer.pad_token
130
- padding = tokenizer.padding_side
131
 
132
- if eos_token_id is not None and pad_token_id is None:
133
- pad_token_id = config.pad_token_id or eos_token_id
134
- tokenizer.pad_token_id = pad_token_id
 
 
135
 
136
 
137
  model.save_pretrained(s3_uri)
@@ -147,22 +134,10 @@ class S3ModelLoader:
147
  )
148
 
149
  tokenizer = AutoTokenizer.from_pretrained(
150
- s3_uri, config=config, local_files_only=False, padding_side="left"
151
  )
152
-
153
- eos_token_id = tokenizer.eos_token_id
154
- pad_token_id = tokenizer.pad_token_id
155
- eos_token = tokenizer.eos_token
156
- pad_token = tokenizer.pad_token
157
- padding = tokenizer.padding_side
158
-
159
- if eos_token_id is not None and pad_token_id is None:
160
- pad_token_id = config.pad_token_id or eos_token_id
161
- tokenizer.pad_token_id = pad_token_id
162
-
163
- model_cache[model_name] = (model, tokenizer,eos_token_id,
164
- pad_token_id,eos_token,pad_token,padding)
165
- return model, tokenizer,eos_token_id,pad_token_id,eos_token,pad_token,padding
166
  except Exception as e:
167
  raise HTTPException(
168
  status_code=500, detail=f"Error loading model: {e}"
@@ -170,46 +145,15 @@ class S3ModelLoader:
170
 
171
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
172
 
173
- class StopOnSequencesCriteria(StoppingCriteria):
174
- def __init__(self, stop_sequences, tokenizer):
175
- self.stop_sequences = stop_sequences
176
- self.tokenizer = tokenizer
177
-
178
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
179
-
180
- decoded_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
181
-
182
- for seq in self.stop_sequences:
183
- if seq in decoded_text:
184
- return True
185
- return False
186
-
187
- async def generate_stream(model, tokenizer, input_text,
188
- generation_config, stop_sequences,
189
- device, pad_token_id, max_model_length,
190
- max_new_tokens):
191
- async def stream():
192
- past_key_values = None
193
- input_ids = None
194
- async for token,past_key_values_response,input_ids_response, is_end in stream_text(model, tokenizer, input_text,
195
- generation_config, stop_sequences,
196
- device,pad_token_id, max_model_length, max_new_tokens, past_key_values, input_ids):
197
- past_key_values = past_key_values_response
198
- input_ids = input_ids_response
199
- if is_end:
200
- break
201
- yield token
202
- return stream()
203
-
204
  @app.post("/generate")
205
- async def generate(request: GenerateRequest):
206
  try:
207
  model_name = request.model_name
208
  input_text = request.input_text
209
  task_type = request.task_type
210
  temperature = request.temperature
211
  max_new_tokens = request.max_new_tokens
212
- stream = True
213
  top_p = request.top_p
214
  top_k = request.top_k
215
  repetition_penalty = request.repetition_penalty
@@ -217,10 +161,10 @@ async def generate(request: GenerateRequest):
217
  do_sample = request.do_sample
218
  stop_sequences = request.stop_sequences
219
 
220
- model, tokenizer, eos_token_id, pad_token_id, eos_token, pad_token, padding = await model_loader.load_model_and_tokenizer(model_name)
221
  device = "cuda" if torch.cuda.is_available() else "cpu"
222
  model.to(device)
223
-
224
  if "text-to-text" == task_type:
225
  generation_config = GenerationConfig(
226
  temperature=temperature,
@@ -230,29 +174,20 @@ async def generate(request: GenerateRequest):
230
  repetition_penalty=repetition_penalty,
231
  do_sample=do_sample,
232
  num_return_sequences=num_return_sequences,
233
- pad_token_id=pad_token_id if pad_token_id is not None else None
234
- )
235
-
236
- max_model_length = 3
237
- input_text = input_text[:max_model_length]
238
-
239
- streams = [
240
- generate_stream(model, tokenizer, input_text,
241
- generation_config, stop_sequences,
242
- device,pad_token_id, max_model_length, max_new_tokens)
243
- for _ in range(num_return_sequences)
244
- ]
245
-
246
-
247
- async def stream_response():
248
- for stream in asyncio.as_completed(streams):
249
- async for chunk in await stream:
250
- yield chunk
251
- return StreamingResponse(
252
- stream_response(),
253
- media_type="text/plain"
254
  )
255
-
 
 
 
 
 
 
 
 
 
 
 
256
  else:
257
  return HTTPException(status_code=400, detail="Task type not text-to-text")
258
 
@@ -261,116 +196,113 @@ async def generate(request: GenerateRequest):
261
  status_code=500, detail=f"Internal server error: {str(e)}"
262
  )
263
 
 
 
 
 
 
264
 
265
- async def stream_text(model, tokenizer, input_text,
266
- generation_config, stop_sequences,
267
- device,pad_token_id, max_model_length, max_new_tokens,
268
- past_key_values, input_ids):
269
-
270
-
271
- stop_regex = re.compile(r'[\.\?\!\n]+')
272
-
273
- def find_stop(output_text, stop_sequences):
274
- for seq in stop_sequences:
275
- if seq in output_text:
276
- last_index = output_text.rfind(seq)
277
- return last_index + len(seq)
278
-
279
- match = stop_regex.search(output_text)
280
- if match:
281
- return match.end()
282
 
283
- return -1
 
 
 
 
 
 
 
 
 
284
 
 
 
 
 
285
 
 
 
 
286
  output_text = ""
287
- stop_criteria = StoppingCriteriaList([StopOnSequencesCriteria(stop_sequences, tokenizer)])
288
-
289
- if input_ids is None:
290
- encoded_input = tokenizer(
291
- input_text, return_tensors="pt",
292
- truncation=True,
293
- padding = "max_length",
294
- max_length=max_model_length
295
- ).to(device)
296
- input_ids = encoded_input.input_ids
297
- else:
298
- encoded_input = {
299
- "input_ids":input_ids,
300
- "past_key_values": past_key_values
301
- }
302
 
303
  while True:
304
-
305
- outputs = model.generate(
306
- **encoded_input,
307
- do_sample=generation_config.do_sample,
308
- max_new_tokens=generation_config.max_new_tokens,
309
- temperature=generation_config.temperature,
310
- top_p=generation_config.top_p,
311
- top_k=generation_config.top_k,
312
- repetition_penalty=generation_config.repetition_penalty,
313
- num_return_sequences=generation_config.num_return_sequences,
314
- output_scores=True,
315
- return_dict_in_generate=True,
316
- pad_token_id=pad_token_id if pad_token_id is not None else None,
317
- stopping_criteria = stop_criteria,
318
- )
319
-
320
- new_text = tokenizer.decode(
321
- outputs.sequences[0][len(encoded_input["input_ids"][0]):],
322
- skip_special_tokens=True
323
- )
324
-
325
- output_text += new_text
326
-
327
- stop_index = find_stop(output_text, stop_sequences)
328
-
329
- is_end = False
330
- if stop_index != -1 or (hasattr(outputs, "sequences") and outputs.sequences[0][-1] == tokenizer.eos_token_id):
331
- final_output = output_text[:stop_index] if stop_index != -1 else output_text
332
-
333
- for text in final_output.split():
334
- yield json.dumps({"text": text, "is_end": False, "temperature": generation_config.temperature, "top_p": generation_config.top_p, "top_k": generation_config.top_k}) + "\n", \
335
- outputs.past_key_values if hasattr(outputs, "past_key_values") else None , \
336
- outputs.sequences if hasattr(outputs, "sequences") else None, True
337
-
338
- yield json.dumps({"text": "", "is_end": True, "temperature": generation_config.temperature, "top_p": generation_config.top_p, "top_k": generation_config.top_k}) + "\n",\
339
- outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
340
- outputs.sequences if hasattr(outputs, "sequences") else None, True
341
- break
342
- else:
343
-
344
- tokens = new_text.split()
345
 
346
- for i in range(0, len(tokens), max_new_tokens):
347
- chunk = tokens[i:i + max_new_tokens]
348
- chunk_text = " ".join(chunk)
349
- for text in chunk_text.split():
350
- yield json.dumps({"text": text, "is_end": False, "temperature": generation_config.temperature, "top_p": generation_config.top_p, "top_k": generation_config.top_k}) + "\n", \
351
- outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
352
- outputs.sequences if hasattr(outputs, "sequences") else None, False
353
-
354
- if len(new_text) == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
- for text in output_text.split():
357
- yield json.dumps({"text": text, "is_end": False, "temperature": generation_config.temperature, "top_p": generation_config.top_p, "top_k": generation_config.top_k}) + "\n", \
358
- outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
359
- outputs.sequences if hasattr(outputs, "sequences") else None, True
360
- yield json.dumps({"text": "", "is_end": True, "temperature": generation_config.temperature, "top_p": generation_config.top_p, "top_k": generation_config.top_k}) + "\n",\
361
- outputs.past_key_values if hasattr(outputs, "past_key_values") else None, \
362
- outputs.sequences if hasattr(outputs, "sequences") else None, True
363
- break
364
-
365
- past_key_values = outputs.past_key_values if hasattr(outputs, "past_key_values") else None
366
- input_ids = outputs.sequences if hasattr(outputs, "sequences") else None
 
 
 
367
 
368
- output_text = ""
369
-
370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
  @app.post("/generate-image")
373
- async def generate_image(request: GenerateRequest):
374
  try:
375
  validated_body = request
376
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -398,7 +330,7 @@ async def generate_image(request: GenerateRequest):
398
 
399
 
400
  @app.post("/generate-text-to-speech")
401
- async def generate_text_to_speech(request: GenerateRequest):
402
  try:
403
  validated_body = request
404
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -429,7 +361,7 @@ async def generate_text_to_speech(request: GenerateRequest):
429
 
430
 
431
  @app.post("/generate-video")
432
- async def generate_video(request: GenerateRequest):
433
  try:
434
  validated_body = request
435
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
1
  import os
2
  import torch
3
  from fastapi import FastAPI, HTTPException
4
+ from fastapi.responses import StreamingResponse, JSONResponse
5
  from pydantic import BaseModel, field_validator
6
  from transformers import (
7
  AutoConfig,
 
10
  AutoTokenizer,
11
  GenerationConfig,
12
  StoppingCriteria,
13
+ StoppingCriteriaList,
14
  )
15
  import boto3
16
  import uvicorn
 
17
  from transformers import pipeline
18
  import json
19
  from huggingface_hub import login
20
  import base64
21
  from botocore.exceptions import NoCredentialsError
 
22
 
23
 
24
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
 
43
  input_text: str = ""
44
  task_type: str
45
  temperature: float = 1.0
46
+ max_new_tokens: int = 3
47
+ stream: bool = True # Set default stream to True
48
  top_p: float = 1.0
49
+ top_k: int = 50
50
  repetition_penalty: float = 1.0
51
  num_return_sequences: int = 1
52
  do_sample: bool = True
 
77
  return f"s3://{self.bucket_name}/" \
78
  f"{model_name.replace('/', '-')}"
79
 
80
+ def load_model_and_tokenizer(self, model_name):
81
  if model_name in model_cache:
82
  return model_cache[model_name]
83
 
 
92
  )
93
 
94
  tokenizer = AutoTokenizer.from_pretrained(
95
+ s3_uri, config=config, local_files_only=False
96
  )
 
 
 
 
 
 
 
 
 
 
97
 
98
+ if tokenizer.eos_token_id is not None and \
99
+ tokenizer.pad_token_id is None:
100
+ tokenizer.pad_token_id = config.pad_token_id \
101
+ or tokenizer.eos_token_id
102
+ model_cache[model_name] = (model, tokenizer)
103
+ return model, tokenizer
104
  except (EnvironmentError, NoCredentialsError):
105
  try:
106
  config = AutoConfig.from_pretrained(
107
  model_name, token=HUGGINGFACE_HUB_TOKEN
108
  )
109
  tokenizer = AutoTokenizer.from_pretrained(
110
+ model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
111
  )
112
 
113
  model = AutoModelForCausalLM.from_pretrained(
114
  model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
115
  )
 
 
 
 
 
 
116
 
117
+
118
+ if tokenizer.eos_token_id is not None and \
119
+ tokenizer.pad_token_id is None:
120
+ tokenizer.pad_token_id = config.pad_token_id \
121
+ or tokenizer.eos_token_id
122
 
123
 
124
  model.save_pretrained(s3_uri)
 
134
  )
135
 
136
  tokenizer = AutoTokenizer.from_pretrained(
137
+ s3_uri, config=config, local_files_only=False
138
  )
139
+ model_cache[model_name] = (model, tokenizer)
140
+ return model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
141
  except Exception as e:
142
  raise HTTPException(
143
  status_code=500, detail=f"Error loading model: {e}"
 
145
 
146
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  @app.post("/generate")
149
+ def generate(request: GenerateRequest):
150
  try:
151
  model_name = request.model_name
152
  input_text = request.input_text
153
  task_type = request.task_type
154
  temperature = request.temperature
155
  max_new_tokens = request.max_new_tokens
156
+ stream = request.stream
157
  top_p = request.top_p
158
  top_k = request.top_k
159
  repetition_penalty = request.repetition_penalty
 
161
  do_sample = request.do_sample
162
  stop_sequences = request.stop_sequences
163
 
164
+ model, tokenizer = model_loader.load_model_and_tokenizer(model_name)
165
  device = "cuda" if torch.cuda.is_available() else "cpu"
166
  model.to(device)
167
+
168
  if "text-to-text" == task_type:
169
  generation_config = GenerationConfig(
170
  temperature=temperature,
 
174
  repetition_penalty=repetition_penalty,
175
  do_sample=do_sample,
176
  num_return_sequences=num_return_sequences,
177
+ eos_token_id = tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  )
179
+ if stream:
180
+ return StreamingResponse(
181
+ stream_text(model, tokenizer, input_text,
182
+ generation_config, stop_sequences,
183
+ device),
184
+ media_type="text/plain"
185
+ )
186
+ else:
187
+ result = generate_text(model, tokenizer, input_text,
188
+ generation_config, stop_sequences,
189
+ device)
190
+ return JSONResponse({"text": result, "is_end": True})
191
  else:
192
  return HTTPException(status_code=400, detail="Task type not text-to-text")
193
 
 
196
  status_code=500, detail=f"Internal server error: {str(e)}"
197
  )
198
 
199
+ class StopOnSequences(StoppingCriteria):
200
+ def __init__(self, stop_sequences, tokenizer):
201
+ self.stop_sequences = stop_sequences
202
+ self.tokenizer = tokenizer
203
+ self.stop_ids = [tokenizer.encode(seq, add_special_tokens=False) for seq in stop_sequences]
204
 
205
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
+ decoded_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
208
+
209
+ for stop_sequence in self.stop_sequences:
210
+ if stop_sequence in decoded_text:
211
+ return True
212
+ return False
213
+
214
+ def stream_text(model, tokenizer, input_text,
215
+ generation_config, stop_sequences,
216
+ device):
217
 
218
+ encoded_input = tokenizer(
219
+ input_text, return_tensors="pt",
220
+ truncation=True
221
+ ).to(device)
222
 
223
+ stop_criteria = StopOnSequences(stop_sequences, tokenizer)
224
+ stopping_criteria = StoppingCriteriaList([stop_criteria])
225
+
226
  output_text = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  while True:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
+ outputs = model.generate(
231
+ **encoded_input,
232
+ do_sample=generation_config.do_sample,
233
+ max_new_tokens=generation_config.max_new_tokens,
234
+ temperature=generation_config.temperature,
235
+ top_p=generation_config.top_p,
236
+ top_k=generation_config.top_k,
237
+ repetition_penalty=generation_config.repetition_penalty,
238
+ num_return_sequences=generation_config.num_return_sequences,
239
+ output_scores=True,
240
+ return_dict_in_generate=True,
241
+ stopping_criteria=stopping_criteria
242
+ )
243
+
244
+ new_text = tokenizer.decode(
245
+ outputs.sequences[0][len(encoded_input["input_ids"][0]):],
246
+ skip_special_tokens=True
247
+ )
248
+
249
+ if len(new_text) == 0:
250
+ if not stop_criteria(outputs.sequences, None):
251
+ for text in output_text.split():
252
+ yield json.dumps({"text": text, "is_end": False}) + "\n"
253
+ yield json.dumps({"text": "", "is_end": True}) + "\n"
254
+ break
255
 
256
+ output_text += new_text
257
+
258
+ for text in new_text.split():
259
+ yield json.dumps({"text": text, "is_end": False}) + "\n"
260
+
261
+ if stop_criteria(outputs.sequences, None):
262
+ yield json.dumps({"text": "", "is_end": True}) + "\n"
263
+ break
264
+
265
+ encoded_input = tokenizer(
266
+ output_text, return_tensors="pt",
267
+ truncation=True
268
+ ).to(device)
269
+ output_text = ""
270
 
 
 
271
 
272
+ def generate_text(model, tokenizer, input_text,
273
+ generation_config, stop_sequences,
274
+ device):
275
+ encoded_input = tokenizer(
276
+ input_text, return_tensors="pt",
277
+ truncation=True
278
+ ).to(device)
279
+
280
+ stop_criteria = StopOnSequences(stop_sequences, tokenizer)
281
+ stopping_criteria = StoppingCriteriaList([stop_criteria])
282
+
283
+ outputs = model.generate(
284
+ **encoded_input,
285
+ do_sample=generation_config.do_sample,
286
+ max_new_tokens=generation_config.max_new_tokens,
287
+ temperature=generation_config.temperature,
288
+ top_p=generation_config.top_p,
289
+ top_k=generation_config.top_k,
290
+ repetition_penalty=generation_config.repetition_penalty,
291
+ num_return_sequences=generation_config.num_return_sequences,
292
+ output_scores=True,
293
+ return_dict_in_generate=True,
294
+ stopping_criteria=stopping_criteria
295
+ )
296
+
297
+
298
+ generated_text = tokenizer.decode(
299
+ outputs.sequences[0], skip_special_tokens=True
300
+ )
301
+
302
+ return generated_text
303
 
304
  @app.post("/generate-image")
305
+ def generate_image(request: GenerateRequest):
306
  try:
307
  validated_body = request
308
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
330
 
331
 
332
  @app.post("/generate-text-to-speech")
333
+ def generate_text_to_speech(request: GenerateRequest):
334
  try:
335
  validated_body = request
336
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
361
 
362
 
363
  @app.post("/generate-video")
364
+ def generate_video(request: GenerateRequest):
365
  try:
366
  validated_body = request
367
  device = "cuda" if torch.cuda.is_available() else "cpu"