Spaces:
Sleeping
Sleeping
Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -38,7 +38,6 @@ class DownloadModelRequest(BaseModel):
|
|
38 |
model_id: str
|
39 |
pipeline_task: str
|
40 |
input_text: str
|
41 |
-
revision: str = "main"
|
42 |
|
43 |
@field_validator('model_id')
|
44 |
def validate_model_id(cls, value):
|
@@ -75,9 +74,14 @@ class S3DirectStream:
|
|
75 |
logger.info(f"File {key} does not exist in S3.")
|
76 |
return False
|
77 |
|
78 |
-
def load_model_from_stream(self, model_prefix
|
79 |
try:
|
80 |
-
logger.info(f"Loading model {model_prefix}
|
|
|
|
|
|
|
|
|
|
|
81 |
if self.file_exists_in_s3(f"{model_prefix}/config.json") and \
|
82 |
any(self.file_exists_in_s3(f"{model_prefix}/{file}") for file in self._get_model_files(model_prefix, revision)):
|
83 |
logger.info(f"Model {model_prefix} found in S3. Loading...")
|
@@ -86,7 +90,7 @@ class S3DirectStream:
|
|
86 |
logger.info(f"Model {model_prefix} not found in S3. Downloading and uploading...")
|
87 |
self.download_and_upload_to_s3(model_prefix, revision)
|
88 |
logger.info(f"Downloaded and uploaded {model_prefix}. Loading from S3...")
|
89 |
-
return self.load_model_from_stream(model_prefix
|
90 |
except HTTPException as e:
|
91 |
logger.error(f"Error loading model: {e}")
|
92 |
return None
|
@@ -165,12 +169,15 @@ class S3DirectStream:
|
|
165 |
|
166 |
logger.info(f"Finished downloading and uploading model files for {model_prefix}.")
|
167 |
|
|
|
168 |
def _get_model_files(self, model_prefix, revision="main"):
|
169 |
index_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/"
|
170 |
try:
|
171 |
index_response = requests.get(index_url)
|
172 |
index_response.raise_for_status()
|
|
|
173 |
index_content = index_response.text
|
|
|
174 |
model_files = [f for f in index_content.split('\n') if f.endswith(('.bin', '.safetensors'))]
|
175 |
return model_files
|
176 |
except requests.exceptions.RequestException as e:
|
@@ -198,19 +205,27 @@ class S3DirectStream:
|
|
198 |
logger.error(f"Error downloading from {url}: Status code {response.status_code}")
|
199 |
raise HTTPException(status_code=500, detail=f"Error downloading file from {url}")
|
200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
@app.post("/predict/")
|
203 |
async def predict(model_request: DownloadModelRequest):
|
204 |
try:
|
205 |
logger.info(f"Received request: Model={model_request.model_id}, Task={model_request.pipeline_task}, Input={model_request.input_text}")
|
206 |
model_id = model_request.model_id
|
207 |
-
revision = model_request.revision
|
208 |
task = model_request.pipeline_task
|
209 |
input_text = model_request.input_text
|
210 |
|
211 |
streamer = S3DirectStream(S3_BUCKET_NAME)
|
212 |
logger.info("Loading model and tokenizer...")
|
213 |
-
model = streamer.load_model_from_stream(model_id
|
214 |
|
215 |
if model is None:
|
216 |
logger.error(f"Failed to load model {model_id}")
|
|
|
38 |
model_id: str
|
39 |
pipeline_task: str
|
40 |
input_text: str
|
|
|
41 |
|
42 |
@field_validator('model_id')
|
43 |
def validate_model_id(cls, value):
|
|
|
74 |
logger.info(f"File {key} does not exist in S3.")
|
75 |
return False
|
76 |
|
77 |
+
def load_model_from_stream(self, model_prefix):
|
78 |
try:
|
79 |
+
logger.info(f"Loading model {model_prefix}...")
|
80 |
+
revision = self._get_latest_revision(model_prefix)
|
81 |
+
if revision is None:
|
82 |
+
logger.error(f"Could not determine revision for {model_prefix}")
|
83 |
+
raise ValueError(f"Could not determine revision for {model_prefix}")
|
84 |
+
|
85 |
if self.file_exists_in_s3(f"{model_prefix}/config.json") and \
|
86 |
any(self.file_exists_in_s3(f"{model_prefix}/{file}") for file in self._get_model_files(model_prefix, revision)):
|
87 |
logger.info(f"Model {model_prefix} found in S3. Loading...")
|
|
|
90 |
logger.info(f"Model {model_prefix} not found in S3. Downloading and uploading...")
|
91 |
self.download_and_upload_to_s3(model_prefix, revision)
|
92 |
logger.info(f"Downloaded and uploaded {model_prefix}. Loading from S3...")
|
93 |
+
return self.load_model_from_stream(model_prefix)
|
94 |
except HTTPException as e:
|
95 |
logger.error(f"Error loading model: {e}")
|
96 |
return None
|
|
|
169 |
|
170 |
logger.info(f"Finished downloading and uploading model files for {model_prefix}.")
|
171 |
|
172 |
+
|
173 |
def _get_model_files(self, model_prefix, revision="main"):
|
174 |
index_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/"
|
175 |
try:
|
176 |
index_response = requests.get(index_url)
|
177 |
index_response.raise_for_status()
|
178 |
+
logger.info(f"Hugging Face API Response: Status Code = {index_response.status_code}, Headers = {index_response.headers}")
|
179 |
index_content = index_response.text
|
180 |
+
logger.info(f"Index content: {index_content}")
|
181 |
model_files = [f for f in index_content.split('\n') if f.endswith(('.bin', '.safetensors'))]
|
182 |
return model_files
|
183 |
except requests.exceptions.RequestException as e:
|
|
|
205 |
logger.error(f"Error downloading from {url}: Status code {response.status_code}")
|
206 |
raise HTTPException(status_code=500, detail=f"Error downloading file from {url}")
|
207 |
|
208 |
+
def _get_latest_revision(self, model_prefix):
|
209 |
+
try:
|
210 |
+
api = HfApi()
|
211 |
+
model_info = api.model_info(model_prefix)
|
212 |
+
return model_info.default_revision
|
213 |
+
except Exception as e:
|
214 |
+
logger.error(f"Error getting latest revision for {model_prefix}: {e}")
|
215 |
+
return None
|
216 |
+
|
217 |
|
218 |
@app.post("/predict/")
|
219 |
async def predict(model_request: DownloadModelRequest):
|
220 |
try:
|
221 |
logger.info(f"Received request: Model={model_request.model_id}, Task={model_request.pipeline_task}, Input={model_request.input_text}")
|
222 |
model_id = model_request.model_id
|
|
|
223 |
task = model_request.pipeline_task
|
224 |
input_text = model_request.input_text
|
225 |
|
226 |
streamer = S3DirectStream(S3_BUCKET_NAME)
|
227 |
logger.info("Loading model and tokenizer...")
|
228 |
+
model = streamer.load_model_from_stream(model_id)
|
229 |
|
230 |
if model is None:
|
231 |
logger.error(f"Failed to load model {model_id}")
|