Hjgugugjhuhjggg commited on
Commit
3c51859
·
verified ·
1 Parent(s): 7ece340

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -6
app.py CHANGED
@@ -1,14 +1,18 @@
1
  import os
2
  from fastapi import FastAPI, HTTPException, Depends
3
  from fastapi.responses import JSONResponse
4
- from pydantic import BaseModel, field_validator
5
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList, pipeline
6
  import boto3
7
  import uvicorn
8
  import soundfile as sf
9
  import imageio
10
- from typing import Dict, Optional
11
  import torch # Import torch
 
 
 
 
12
 
13
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
14
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
@@ -41,7 +45,7 @@ class GenerateRequest(BaseModel):
41
  repetition_penalty: float = 1.1
42
  num_return_sequences: int = 1
43
  do_sample: bool = True
44
- stop_sequences: list[str] = []
45
  no_repeat_ngram_size: int = 2
46
  continuation_id: Optional[str] = None
47
 
@@ -84,6 +88,7 @@ class S3ModelLoader:
84
  tokenizer.pad_token_id = tokenizer.eos_token_id
85
  return model, tokenizer
86
  except Exception as e:
 
87
  raise HTTPException(status_code=500, detail=f"Error loading model from S3: {e}")
88
 
89
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
@@ -94,6 +99,7 @@ async def get_model_and_tokenizer(model_name: str):
94
  try:
95
  return await model_loader.load_model_and_tokenizer(model_name)
96
  except Exception as e:
 
97
  raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
98
 
99
  @app.post("/generate")
@@ -142,6 +148,7 @@ async def generate(request: GenerateRequest, model_resources: tuple = Depends(ge
142
  except HTTPException as http_err:
143
  raise http_err
144
  except Exception as e:
 
145
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
146
 
147
  def generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences):
@@ -152,6 +159,7 @@ def generate_text_internal(model, tokenizer, input_text, generation_config, stop
152
 
153
  class CustomStoppingCriteria(StoppingCriteria): # Inherit directly from StoppingCriteria
154
  def __init__(self, stop_sequences, tokenizer):
 
155
  self.stop_sequences = stop_sequences
156
  self.tokenizer = tokenizer
157
 
@@ -162,7 +170,8 @@ def generate_text_internal(model, tokenizer, input_text, generation_config, stop
162
  return True
163
  return False
164
 
165
- stopping_criteria.append(CustomStoppingCriteria(stop_sequences, tokenizer))
 
166
 
167
  outputs = model.generate(
168
  encoded_input.input_ids,
@@ -179,7 +188,8 @@ async def load_pipeline_from_s3(task, model_name):
179
  try:
180
  return pipeline(task, model=s3_uri, token=HUGGINGFACE_HUB_TOKEN) # Include token if needed
181
  except Exception as e:
182
- raise HTTPException(status_code=500, detail=f"Error loading {task} model from S3: {e}")
 
183
 
184
  @app.post("/generate-image")
185
  async def generate_image(request: GenerateRequest):
@@ -198,6 +208,7 @@ async def generate_image(request: GenerateRequest):
198
  except HTTPException as http_err:
199
  raise http_err
200
  except Exception as e:
 
201
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
202
 
203
  @app.post("/generate-text-to-speech")
@@ -217,6 +228,7 @@ async def generate_text_to_speech(request: GenerateRequest):
217
  except HTTPException as http_err:
218
  raise http_err
219
  except Exception as e:
 
220
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
221
 
222
  @app.post("/generate-video")
@@ -236,7 +248,14 @@ async def generate_video(request: GenerateRequest):
236
  except HTTPException as http_err:
237
  raise http_err
238
  except Exception as e:
 
239
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
240
 
 
 
 
 
 
 
241
  if __name__ == "__main__":
242
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
  from fastapi import FastAPI, HTTPException, Depends
3
  from fastapi.responses import JSONResponse
4
+ from pydantic import BaseModel, field_validator, ValidationError
5
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList, pipeline, StoppingCriteria
6
  import boto3
7
  import uvicorn
8
  import soundfile as sf
9
  import imageio
10
+ from typing import Dict, Optional, List
11
  import torch # Import torch
12
+ import logging
13
+
14
+ # Configure logging
15
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
16
 
17
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
18
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
 
45
  repetition_penalty: float = 1.1
46
  num_return_sequences: int = 1
47
  do_sample: bool = True
48
+ stop_sequences: List[str] = []
49
  no_repeat_ngram_size: int = 2
50
  continuation_id: Optional[str] = None
51
 
 
88
  tokenizer.pad_token_id = tokenizer.eos_token_id
89
  return model, tokenizer
90
  except Exception as e:
91
+ logging.error(f"Error loading model from S3: {e}")
92
  raise HTTPException(status_code=500, detail=f"Error loading model from S3: {e}")
93
 
94
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
 
99
  try:
100
  return await model_loader.load_model_and_tokenizer(model_name)
101
  except Exception as e:
102
+ logging.error(f"Error loading model: {e}")
103
  raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
104
 
105
  @app.post("/generate")
 
148
  except HTTPException as http_err:
149
  raise http_err
150
  except Exception as e:
151
+ logging.error(f"Internal server error: {str(e)}")
152
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
153
 
154
  def generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences):
 
159
 
160
  class CustomStoppingCriteria(StoppingCriteria): # Inherit directly from StoppingCriteria
161
  def __init__(self, stop_sequences, tokenizer):
162
+ super().__init__() # call parent constructor
163
  self.stop_sequences = stop_sequences
164
  self.tokenizer = tokenizer
165
 
 
170
  return True
171
  return False
172
 
173
+ if stop_sequences: # Only add if stop_sequences is not empty
174
+ stopping_criteria.append(CustomStoppingCriteria(stop_sequences, tokenizer))
175
 
176
  outputs = model.generate(
177
  encoded_input.input_ids,
 
188
  try:
189
  return pipeline(task, model=s3_uri, token=HUGGINGFACE_HUB_TOKEN) # Include token if needed
190
  except Exception as e:
191
+ logging.error(f"Error loading {task} model from S3: {e}")
192
+ raise HTTPException(status_code=500, detail=f"Error loading {task} model from S3: {e}")
193
 
194
  @app.post("/generate-image")
195
  async def generate_image(request: GenerateRequest):
 
208
  except HTTPException as http_err:
209
  raise http_err
210
  except Exception as e:
211
+ logging.error(f"Internal server error: {str(e)}")
212
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
213
 
214
  @app.post("/generate-text-to-speech")
 
228
  except HTTPException as http_err:
229
  raise http_err
230
  except Exception as e:
231
+ logging.error(f"Internal server error: {str(e)}")
232
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
233
 
234
  @app.post("/generate-video")
 
248
  except HTTPException as http_err:
249
  raise http_err
250
  except Exception as e:
251
+ logging.error(f"Internal server error: {str(e)}")
252
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
253
 
254
+ # Adding exception handling for Pydantic validation
255
+ @app.exception_handler(ValidationError)
256
+ async def validation_exception_handler(request, exc):
257
+ logging.error(f"Validation Error: {exc}")
258
+ return JSONResponse({"detail": exc.errors()}, status_code=422)
259
+
260
  if __name__ == "__main__":
261
  uvicorn.run(app, host="0.0.0.0", port=7860)