Hjgugugjhuhjggg commited on
Commit
e58c8bb
·
verified ·
1 Parent(s): 44af224

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -6
app.py CHANGED
@@ -38,7 +38,6 @@ class DownloadModelRequest(BaseModel):
38
  model_id: str
39
  pipeline_task: str
40
  input_text: str
41
- revision: str = "main"
42
 
43
  @field_validator('model_id')
44
  def validate_model_id(cls, value):
@@ -75,9 +74,14 @@ class S3DirectStream:
75
  logger.info(f"File {key} does not exist in S3.")
76
  return False
77
 
78
- def load_model_from_stream(self, model_prefix, revision):
79
  try:
80
- logger.info(f"Loading model {model_prefix} (revision {revision})...")
 
 
 
 
 
81
  if self.file_exists_in_s3(f"{model_prefix}/config.json") and \
82
  any(self.file_exists_in_s3(f"{model_prefix}/{file}") for file in self._get_model_files(model_prefix, revision)):
83
  logger.info(f"Model {model_prefix} found in S3. Loading...")
@@ -86,7 +90,7 @@ class S3DirectStream:
86
  logger.info(f"Model {model_prefix} not found in S3. Downloading and uploading...")
87
  self.download_and_upload_to_s3(model_prefix, revision)
88
  logger.info(f"Downloaded and uploaded {model_prefix}. Loading from S3...")
89
- return self.load_model_from_stream(model_prefix, revision)
90
  except HTTPException as e:
91
  logger.error(f"Error loading model: {e}")
92
  return None
@@ -165,12 +169,15 @@ class S3DirectStream:
165
 
166
  logger.info(f"Finished downloading and uploading model files for {model_prefix}.")
167
 
 
168
  def _get_model_files(self, model_prefix, revision="main"):
169
  index_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/"
170
  try:
171
  index_response = requests.get(index_url)
172
  index_response.raise_for_status()
 
173
  index_content = index_response.text
 
174
  model_files = [f for f in index_content.split('\n') if f.endswith(('.bin', '.safetensors'))]
175
  return model_files
176
  except requests.exceptions.RequestException as e:
@@ -198,19 +205,27 @@ class S3DirectStream:
198
  logger.error(f"Error downloading from {url}: Status code {response.status_code}")
199
  raise HTTPException(status_code=500, detail=f"Error downloading file from {url}")
200
 
 
 
 
 
 
 
 
 
 
201
 
202
  @app.post("/predict/")
203
  async def predict(model_request: DownloadModelRequest):
204
  try:
205
  logger.info(f"Received request: Model={model_request.model_id}, Task={model_request.pipeline_task}, Input={model_request.input_text}")
206
  model_id = model_request.model_id
207
- revision = model_request.revision
208
  task = model_request.pipeline_task
209
  input_text = model_request.input_text
210
 
211
  streamer = S3DirectStream(S3_BUCKET_NAME)
212
  logger.info("Loading model and tokenizer...")
213
- model = streamer.load_model_from_stream(model_id, revision)
214
 
215
  if model is None:
216
  logger.error(f"Failed to load model {model_id}")
 
38
  model_id: str
39
  pipeline_task: str
40
  input_text: str
 
41
 
42
  @field_validator('model_id')
43
  def validate_model_id(cls, value):
 
74
  logger.info(f"File {key} does not exist in S3.")
75
  return False
76
 
77
+ def load_model_from_stream(self, model_prefix):
78
  try:
79
+ logger.info(f"Loading model {model_prefix}...")
80
+ revision = self._get_latest_revision(model_prefix)
81
+ if revision is None:
82
+ logger.error(f"Could not determine revision for {model_prefix}")
83
+ raise ValueError(f"Could not determine revision for {model_prefix}")
84
+
85
  if self.file_exists_in_s3(f"{model_prefix}/config.json") and \
86
  any(self.file_exists_in_s3(f"{model_prefix}/{file}") for file in self._get_model_files(model_prefix, revision)):
87
  logger.info(f"Model {model_prefix} found in S3. Loading...")
 
90
  logger.info(f"Model {model_prefix} not found in S3. Downloading and uploading...")
91
  self.download_and_upload_to_s3(model_prefix, revision)
92
  logger.info(f"Downloaded and uploaded {model_prefix}. Loading from S3...")
93
+ return self.load_model_from_stream(model_prefix)
94
  except HTTPException as e:
95
  logger.error(f"Error loading model: {e}")
96
  return None
 
169
 
170
  logger.info(f"Finished downloading and uploading model files for {model_prefix}.")
171
 
172
+
173
  def _get_model_files(self, model_prefix, revision="main"):
174
  index_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/"
175
  try:
176
  index_response = requests.get(index_url)
177
  index_response.raise_for_status()
178
+ logger.info(f"Hugging Face API Response: Status Code = {index_response.status_code}, Headers = {index_response.headers}")
179
  index_content = index_response.text
180
+ logger.info(f"Index content: {index_content}")
181
  model_files = [f for f in index_content.split('\n') if f.endswith(('.bin', '.safetensors'))]
182
  return model_files
183
  except requests.exceptions.RequestException as e:
 
205
  logger.error(f"Error downloading from {url}: Status code {response.status_code}")
206
  raise HTTPException(status_code=500, detail=f"Error downloading file from {url}")
207
 
208
+ def _get_latest_revision(self, model_prefix):
209
+ try:
210
+ api = HfApi()
211
+ model_info = api.model_info(model_prefix)
212
+ return model_info.default_revision
213
+ except Exception as e:
214
+ logger.error(f"Error getting latest revision for {model_prefix}: {e}")
215
+ return None
216
+
217
 
218
  @app.post("/predict/")
219
  async def predict(model_request: DownloadModelRequest):
220
  try:
221
  logger.info(f"Received request: Model={model_request.model_id}, Task={model_request.pipeline_task}, Input={model_request.input_text}")
222
  model_id = model_request.model_id
 
223
  task = model_request.pipeline_task
224
  input_text = model_request.input_text
225
 
226
  streamer = S3DirectStream(S3_BUCKET_NAME)
227
  logger.info("Loading model and tokenizer...")
228
+ model = streamer.load_model_from_stream(model_id)
229
 
230
  if model is None:
231
  logger.error(f"Failed to load model {model_id}")