Spaces:
Sleeping
Sleeping
Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,14 +1,18 @@
|
|
1 |
import os
|
2 |
from fastapi import FastAPI, HTTPException, Depends
|
3 |
from fastapi.responses import JSONResponse
|
4 |
-
from pydantic import BaseModel, field_validator
|
5 |
-
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList, pipeline
|
6 |
import boto3
|
7 |
import uvicorn
|
8 |
import soundfile as sf
|
9 |
import imageio
|
10 |
-
from typing import Dict, Optional
|
11 |
import torch # Import torch
|
|
|
|
|
|
|
|
|
12 |
|
13 |
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
|
14 |
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
|
@@ -41,7 +45,7 @@ class GenerateRequest(BaseModel):
|
|
41 |
repetition_penalty: float = 1.1
|
42 |
num_return_sequences: int = 1
|
43 |
do_sample: bool = True
|
44 |
-
stop_sequences:
|
45 |
no_repeat_ngram_size: int = 2
|
46 |
continuation_id: Optional[str] = None
|
47 |
|
@@ -84,6 +88,7 @@ class S3ModelLoader:
|
|
84 |
tokenizer.pad_token_id = tokenizer.eos_token_id
|
85 |
return model, tokenizer
|
86 |
except Exception as e:
|
|
|
87 |
raise HTTPException(status_code=500, detail=f"Error loading model from S3: {e}")
|
88 |
|
89 |
model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
|
@@ -94,6 +99,7 @@ async def get_model_and_tokenizer(model_name: str):
|
|
94 |
try:
|
95 |
return await model_loader.load_model_and_tokenizer(model_name)
|
96 |
except Exception as e:
|
|
|
97 |
raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
|
98 |
|
99 |
@app.post("/generate")
|
@@ -142,6 +148,7 @@ async def generate(request: GenerateRequest, model_resources: tuple = Depends(ge
|
|
142 |
except HTTPException as http_err:
|
143 |
raise http_err
|
144 |
except Exception as e:
|
|
|
145 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
146 |
|
147 |
def generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences):
|
@@ -152,6 +159,7 @@ def generate_text_internal(model, tokenizer, input_text, generation_config, stop
|
|
152 |
|
153 |
class CustomStoppingCriteria(StoppingCriteria): # Inherit directly from StoppingCriteria
|
154 |
def __init__(self, stop_sequences, tokenizer):
|
|
|
155 |
self.stop_sequences = stop_sequences
|
156 |
self.tokenizer = tokenizer
|
157 |
|
@@ -162,7 +170,8 @@ def generate_text_internal(model, tokenizer, input_text, generation_config, stop
|
|
162 |
return True
|
163 |
return False
|
164 |
|
165 |
-
|
|
|
166 |
|
167 |
outputs = model.generate(
|
168 |
encoded_input.input_ids,
|
@@ -179,7 +188,8 @@ async def load_pipeline_from_s3(task, model_name):
|
|
179 |
try:
|
180 |
return pipeline(task, model=s3_uri, token=HUGGINGFACE_HUB_TOKEN) # Include token if needed
|
181 |
except Exception as e:
|
182 |
-
|
|
|
183 |
|
184 |
@app.post("/generate-image")
|
185 |
async def generate_image(request: GenerateRequest):
|
@@ -198,6 +208,7 @@ async def generate_image(request: GenerateRequest):
|
|
198 |
except HTTPException as http_err:
|
199 |
raise http_err
|
200 |
except Exception as e:
|
|
|
201 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
202 |
|
203 |
@app.post("/generate-text-to-speech")
|
@@ -217,6 +228,7 @@ async def generate_text_to_speech(request: GenerateRequest):
|
|
217 |
except HTTPException as http_err:
|
218 |
raise http_err
|
219 |
except Exception as e:
|
|
|
220 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
221 |
|
222 |
@app.post("/generate-video")
|
@@ -236,7 +248,14 @@ async def generate_video(request: GenerateRequest):
|
|
236 |
except HTTPException as http_err:
|
237 |
raise http_err
|
238 |
except Exception as e:
|
|
|
239 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
if __name__ == "__main__":
|
242 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
1 |
import os
|
2 |
from fastapi import FastAPI, HTTPException, Depends
|
3 |
from fastapi.responses import JSONResponse
|
4 |
+
from pydantic import BaseModel, field_validator, ValidationError
|
5 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList, pipeline, StoppingCriteria
|
6 |
import boto3
|
7 |
import uvicorn
|
8 |
import soundfile as sf
|
9 |
import imageio
|
10 |
+
from typing import Dict, Optional, List
|
11 |
import torch # Import torch
|
12 |
+
import logging
|
13 |
+
|
14 |
+
# Configure logging
|
15 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
16 |
|
17 |
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
|
18 |
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
|
|
|
45 |
repetition_penalty: float = 1.1
|
46 |
num_return_sequences: int = 1
|
47 |
do_sample: bool = True
|
48 |
+
stop_sequences: List[str] = []
|
49 |
no_repeat_ngram_size: int = 2
|
50 |
continuation_id: Optional[str] = None
|
51 |
|
|
|
88 |
tokenizer.pad_token_id = tokenizer.eos_token_id
|
89 |
return model, tokenizer
|
90 |
except Exception as e:
|
91 |
+
logging.error(f"Error loading model from S3: {e}")
|
92 |
raise HTTPException(status_code=500, detail=f"Error loading model from S3: {e}")
|
93 |
|
94 |
model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
|
|
|
99 |
try:
|
100 |
return await model_loader.load_model_and_tokenizer(model_name)
|
101 |
except Exception as e:
|
102 |
+
logging.error(f"Error loading model: {e}")
|
103 |
raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
|
104 |
|
105 |
@app.post("/generate")
|
|
|
148 |
except HTTPException as http_err:
|
149 |
raise http_err
|
150 |
except Exception as e:
|
151 |
+
logging.error(f"Internal server error: {str(e)}")
|
152 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
153 |
|
154 |
def generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences):
|
|
|
159 |
|
160 |
class CustomStoppingCriteria(StoppingCriteria): # Inherit directly from StoppingCriteria
|
161 |
def __init__(self, stop_sequences, tokenizer):
|
162 |
+
super().__init__() # call parent constructor
|
163 |
self.stop_sequences = stop_sequences
|
164 |
self.tokenizer = tokenizer
|
165 |
|
|
|
170 |
return True
|
171 |
return False
|
172 |
|
173 |
+
if stop_sequences: # Only add if stop_sequences is not empty
|
174 |
+
stopping_criteria.append(CustomStoppingCriteria(stop_sequences, tokenizer))
|
175 |
|
176 |
outputs = model.generate(
|
177 |
encoded_input.input_ids,
|
|
|
188 |
try:
|
189 |
return pipeline(task, model=s3_uri, token=HUGGINGFACE_HUB_TOKEN) # Include token if needed
|
190 |
except Exception as e:
|
191 |
+
logging.error(f"Error loading {task} model from S3: {e}")
|
192 |
+
raise HTTPException(status_code=500, detail=f"Error loading {task} model from S3: {e}")
|
193 |
|
194 |
@app.post("/generate-image")
|
195 |
async def generate_image(request: GenerateRequest):
|
|
|
208 |
except HTTPException as http_err:
|
209 |
raise http_err
|
210 |
except Exception as e:
|
211 |
+
logging.error(f"Internal server error: {str(e)}")
|
212 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
213 |
|
214 |
@app.post("/generate-text-to-speech")
|
|
|
228 |
except HTTPException as http_err:
|
229 |
raise http_err
|
230 |
except Exception as e:
|
231 |
+
logging.error(f"Internal server error: {str(e)}")
|
232 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
233 |
|
234 |
@app.post("/generate-video")
|
|
|
248 |
except HTTPException as http_err:
|
249 |
raise http_err
|
250 |
except Exception as e:
|
251 |
+
logging.error(f"Internal server error: {str(e)}")
|
252 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
253 |
|
254 |
+
# Adding exception handling for Pydantic validation
|
255 |
+
@app.exception_handler(ValidationError)
|
256 |
+
async def validation_exception_handler(request, exc):
|
257 |
+
logging.error(f"Validation Error: {exc}")
|
258 |
+
return JSONResponse({"detail": exc.errors()}, status_code=422)
|
259 |
+
|
260 |
if __name__ == "__main__":
|
261 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|