Hjgugugjhuhjggg commited on
Commit
44af224
·
verified ·
1 Parent(s): 42861e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -29
app.py CHANGED
@@ -79,7 +79,7 @@ class S3DirectStream:
79
  try:
80
  logger.info(f"Loading model {model_prefix} (revision {revision})...")
81
  if self.file_exists_in_s3(f"{model_prefix}/config.json") and \
82
- (self.file_exists_in_s3(f"{model_prefix}/pytorch_model.bin") or self.file_exists_in_s3(f"{model_prefix}/model.safetensors")):
83
  logger.info(f"Model {model_prefix} found in S3. Loading...")
84
  return self.load_model_from_existing_s3(model_prefix)
85
 
@@ -98,22 +98,32 @@ class S3DirectStream:
98
  config = AutoConfig.from_pretrained(model_prefix, **config_dict)
99
  logger.info(f"Config loaded for {model_prefix}.")
100
 
101
- if self.file_exists_in_s3(f"{model_prefix}/model.safetensors"):
102
- logger.info(f"Loading safetensors model for {model_prefix} from S3...")
103
- model_stream = self.stream_from_s3(f"{model_prefix}/model.safetensors")
104
- model = AutoModelForCausalLM.from_config(config)
105
- model.load_state_dict(safetensors.torch.load_stream(model_stream))
106
- logger.info(f"Safetensors model loaded for {model_prefix}.")
107
- elif self.file_exists_in_s3(f"{model_prefix}/pytorch_model.bin"):
108
- logger.info(f"Loading PyTorch model for {model_prefix} from S3...")
109
- model_stream = self.stream_from_s3(f"{model_prefix}/pytorch_model.bin")
110
- model = AutoModelForCausalLM.from_config(config)
111
- state_dict = torch.load(model_stream, map_location="cpu")
112
- model.load_state_dict(state_dict)
113
- logger.info(f"PyTorch model loaded for {model_prefix}.")
114
- else:
115
- logger.error(f"No model file found for {model_prefix} in S3")
116
- raise EnvironmentError(f"No model file found for {model_prefix} in S3")
 
 
 
 
 
 
 
 
 
 
117
  return model
118
 
119
  def load_tokenizer_from_stream(self, model_prefix):
@@ -139,18 +149,37 @@ class S3DirectStream:
139
  return tokenizer
140
 
141
  def download_and_upload_to_s3(self, model_prefix, revision="main"):
142
- model_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/pytorch_model.bin"
143
- safetensors_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/model.safetensors"
144
- tokenizer_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/tokenizer.json"
145
  config_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/config.json"
 
146
 
147
- logger.info(f"Downloading and uploading model files for {model_prefix} to S3...")
148
- self.download_and_upload_to_s3_url(model_url, f"{model_prefix}/pytorch_model.bin")
149
- self.download_and_upload_to_s3_url(safetensors_url, f"{model_prefix}/model.safetensors")
 
 
 
 
 
150
  self.download_and_upload_to_s3_url(tokenizer_url, f"{model_prefix}/tokenizer.json")
151
- self.download_and_upload_to_s3_url(config_url, f"{model_prefix}/config.json")
152
  logger.info(f"Finished downloading and uploading model files for {model_prefix}.")
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def download_and_upload_to_s3_url(self, url, s3_key):
155
  logger.info(f"Downloading from {url}...")
156
  with requests.get(url, stream=True) as response:
@@ -174,24 +203,22 @@ class S3DirectStream:
174
  async def predict(model_request: DownloadModelRequest):
175
  try:
176
  logger.info(f"Received request: Model={model_request.model_id}, Task={model_request.pipeline_task}, Input={model_request.input_text}")
177
- model_id = model_request.model_id # Fixed: Use model_id, not model_name
178
  revision = model_request.revision
179
  task = model_request.pipeline_task
180
  input_text = model_request.input_text
181
 
182
  streamer = S3DirectStream(S3_BUCKET_NAME)
183
  logger.info("Loading model and tokenizer...")
184
- model = streamer.load_model_from_stream(model_id, revision) # Use model_id
185
 
186
  if model is None:
187
  logger.error(f"Failed to load model {model_id}")
188
  raise HTTPException(status_code=500, detail=f"Failed to load model {model_id}")
189
 
190
- tokenizer = streamer.load_tokenizer_from_stream(model_id) # Use model_id
191
-
192
  logger.info("Model and tokenizer loaded.")
193
 
194
-
195
  if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "summarization", "zero-shot-classification"]:
196
  raise HTTPException(status_code=400, detail="Unsupported pipeline task")
197
 
 
79
  try:
80
  logger.info(f"Loading model {model_prefix} (revision {revision})...")
81
  if self.file_exists_in_s3(f"{model_prefix}/config.json") and \
82
+ any(self.file_exists_in_s3(f"{model_prefix}/{file}") for file in self._get_model_files(model_prefix, revision)):
83
  logger.info(f"Model {model_prefix} found in S3. Loading...")
84
  return self.load_model_from_existing_s3(model_prefix)
85
 
 
98
  config = AutoConfig.from_pretrained(model_prefix, **config_dict)
99
  logger.info(f"Config loaded for {model_prefix}.")
100
 
101
+ model_files = self._get_model_files(model_prefix)
102
+ if not model_files:
103
+ logger.error(f"No model files found for {model_prefix} in S3")
104
+ raise EnvironmentError(f"No model files found for {model_prefix} in S3")
105
+
106
+ state_dict = {}
107
+ for model_file in model_files:
108
+ model_path = os.path.join(model_prefix, model_file)
109
+ logger.info(f"Loading model file: {model_path}")
110
+ model_stream = self.stream_from_s3(model_path)
111
+ try:
112
+ if model_path.endswith(".safetensors"):
113
+ shard_state = safetensors.torch.load_stream(model_stream)
114
+ elif model_path.endswith(".bin"):
115
+ shard_state = torch.load(model_stream, map_location="cpu")
116
+ else:
117
+ logger.error(f"Unsupported model file type: {model_path}")
118
+ raise ValueError(f"Unsupported model file type: {model_path}")
119
+
120
+ state_dict.update(shard_state)
121
+ except Exception as e:
122
+ logger.exception(f"Error loading model file {model_path}: {e}")
123
+ raise
124
+
125
+ model = AutoModelForCausalLM.from_config(config)
126
+ model.load_state_dict(state_dict)
127
  return model
128
 
129
  def load_tokenizer_from_stream(self, model_prefix):
 
149
  return tokenizer
150
 
151
  def download_and_upload_to_s3(self, model_prefix, revision="main"):
152
+ logger.info(f"Downloading and uploading model files for {model_prefix} to S3...")
 
 
153
  config_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/config.json"
154
+ self.download_and_upload_to_s3_url(config_url, f"{model_prefix}/config.json")
155
 
156
+ model_files = self._get_model_files(model_prefix, revision)
157
+ for model_file in model_files:
158
+ url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/{model_file}"
159
+ s3_key = f"{model_prefix}/{model_file}"
160
+ self.download_and_upload_to_s3_url(url, s3_key)
161
+ logger.info(f"Downloaded and uploaded {s3_key}")
162
+
163
+ tokenizer_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/tokenizer.json"
164
  self.download_and_upload_to_s3_url(tokenizer_url, f"{model_prefix}/tokenizer.json")
165
+
166
  logger.info(f"Finished downloading and uploading model files for {model_prefix}.")
167
 
168
+ def _get_model_files(self, model_prefix, revision="main"):
169
+ index_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/"
170
+ try:
171
+ index_response = requests.get(index_url)
172
+ index_response.raise_for_status()
173
+ index_content = index_response.text
174
+ model_files = [f for f in index_content.split('\n') if f.endswith(('.bin', '.safetensors'))]
175
+ return model_files
176
+ except requests.exceptions.RequestException as e:
177
+ logger.error(f"Error retrieving model index: {e}")
178
+ raise HTTPException(status_code=500, detail=f"Error retrieving model files from Hugging Face") from e
179
+ except (IndexError, ValueError) as e:
180
+ logger.error(f"Error parsing model file names from Hugging Face: {e}")
181
+ raise HTTPException(status_code=500, detail=f"Error retrieving model files from Hugging Face") from e
182
+
183
  def download_and_upload_to_s3_url(self, url, s3_key):
184
  logger.info(f"Downloading from {url}...")
185
  with requests.get(url, stream=True) as response:
 
203
  async def predict(model_request: DownloadModelRequest):
204
  try:
205
  logger.info(f"Received request: Model={model_request.model_id}, Task={model_request.pipeline_task}, Input={model_request.input_text}")
206
+ model_id = model_request.model_id
207
  revision = model_request.revision
208
  task = model_request.pipeline_task
209
  input_text = model_request.input_text
210
 
211
  streamer = S3DirectStream(S3_BUCKET_NAME)
212
  logger.info("Loading model and tokenizer...")
213
+ model = streamer.load_model_from_stream(model_id, revision)
214
 
215
  if model is None:
216
  logger.error(f"Failed to load model {model_id}")
217
  raise HTTPException(status_code=500, detail=f"Failed to load model {model_id}")
218
 
219
+ tokenizer = streamer.load_tokenizer_from_stream(model_id)
 
220
  logger.info("Model and tokenizer loaded.")
221
 
 
222
  if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "summarization", "zero-shot-classification"]:
223
  raise HTTPException(status_code=400, detail="Unsupported pipeline task")
224