Hjgugugjhuhjggg commited on
Commit
1064fad
·
verified ·
1 Parent(s): 2c59376

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -85
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from huggingface_hub import HfApi
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel, field_validator
4
  import requests
@@ -82,109 +82,104 @@ class S3DirectStream:
82
  logger.error(f"Could not determine revision for {model_prefix}")
83
  raise ValueError(f"Could not determine revision for {model_prefix}")
84
 
85
- if self.file_exists_in_s3(f"{model_prefix}/config.json") and \
86
- any(self.file_exists_in_s3(f"{model_prefix}/{file}") for file in self._get_model_files(model_prefix, revision)):
87
- logger.info(f"Model {model_prefix} found in S3. Loading...")
88
- return self.load_model_from_existing_s3(model_prefix)
 
 
 
 
 
 
 
89
 
90
- logger.info(f"Model {model_prefix} not found in S3. Downloading and uploading...")
91
- self.download_and_upload_to_s3(model_prefix, revision)
92
- logger.info(f"Downloaded and uploaded {model_prefix}. Loading from S3...")
93
- return self.load_model_from_stream(model_prefix)
94
  except HTTPException as e:
95
  logger.error(f"Error loading model: {e}")
96
- return None
 
 
 
97
 
98
- def load_model_from_existing_s3(self, model_prefix):
99
- logger.info(f"Loading config for {model_prefix} from S3...")
100
- config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
101
- config_dict = json.load(config_stream)
102
- config = AutoConfig.from_pretrained(model_prefix, **config_dict)
103
- logger.info(f"Config loaded for {model_prefix}.")
104
-
105
- model_files = self._get_model_files(model_prefix)
106
- if not model_files:
107
- logger.error(f"No model files found for {model_prefix} in S3")
108
- raise EnvironmentError(f"No model files found for {model_prefix} in S3")
109
-
110
- state_dict = {}
111
- for model_file in model_files:
112
- model_path = os.path.join(model_prefix, model_file)
113
- logger.info(f"Loading model file: {model_path}")
114
- model_stream = self.stream_from_s3(model_path)
115
- try:
116
- if model_path.endswith(".safetensors"):
117
- shard_state = safetensors.torch.load_stream(model_stream)
118
- elif model_path.endswith(".bin"):
119
- shard_state = torch.load(model_stream, map_location="cpu")
120
- else:
121
- logger.error(f"Unsupported model file type: {model_path}")
122
- raise ValueError(f"Unsupported model file type: {model_path}")
123
 
124
- state_dict.update(shard_state)
125
- except Exception as e:
126
- logger.exception(f"Error loading model file {model_path}: {e}")
127
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- model = AutoModelForCausalLM.from_config(config)
130
- model.load_state_dict(state_dict)
131
- return model
132
 
133
  def load_tokenizer_from_stream(self, model_prefix):
134
  try:
135
  logger.info(f"Loading tokenizer for {model_prefix}...")
136
- if self.file_exists_in_s3(f"{model_prefix}/tokenizer.json"):
137
- logger.info(f"Tokenizer for {model_prefix} found in S3. Loading...")
138
- return self.load_tokenizer_from_existing_s3(model_prefix, config)
139
-
140
- logger.info(f"Tokenizer for {model_prefix} not found in S3. Downloading and uploading...")
141
- self.download_and_upload_to_s3(model_prefix)
142
- logger.info(f"Downloaded and uploaded tokenizer for {model_prefix}. Loading from S3...")
143
- return self.load_tokenizer_from_stream(model_prefix)
 
 
144
  except HTTPException as e:
145
  logger.error(f"Error loading tokenizer: {e}")
146
  return None
 
 
 
147
 
148
- def load_tokenizer_from_existing_s3(self, model_prefix, config):
149
- logger.info(f"Loading tokenizer from S3 for {model_prefix}...")
150
- tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")
151
- tokenizer = AutoTokenizer.from_pretrained(None, config=config)
152
- logger.info(f"Tokenizer loaded for {model_prefix}.")
153
- return tokenizer
154
-
155
- def download_and_upload_to_s3(self, model_prefix, revision="main"):
156
- logger.info(f"Downloading and uploading model files for {model_prefix} to S3...")
157
- config_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/config.json"
158
- self.download_and_upload_to_s3_url(config_url, f"{model_prefix}/config.json")
159
-
160
- model_files = self._get_model_files(model_prefix, revision)
161
- for model_file in model_files:
162
- url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/{model_file}"
163
- s3_key = f"{model_prefix}/{model_file}"
164
- self.download_and_upload_to_s3_url(url, s3_key)
165
- logger.info(f"Downloaded and uploaded {s3_key}")
166
-
167
- tokenizer_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/tokenizer.json"
168
- self.download_and_upload_to_s3_url(tokenizer_url, f"{model_prefix}/tokenizer.json")
169
-
170
- logger.info(f"Finished downloading and uploading model files for {model_prefix}.")
171
 
172
 
173
  def _get_model_files(self, model_prefix, revision="main"):
174
- index_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/"
175
  try:
176
- index_response = requests.get(index_url)
177
- index_response.raise_for_status()
178
- logger.info(f"Hugging Face API Response: Status Code = {index_response.status_code}, Headers = {index_response.headers}")
179
- index_content = index_response.text
180
- logger.info(f"Index content: {index_content}")
181
- model_files = [f for f in index_content.split('\n') if f.endswith(('.bin', '.safetensors'))]
182
  return model_files
183
- except requests.exceptions.RequestException as e:
184
- logger.error(f"Error retrieving model index: {e}")
185
- raise HTTPException(status_code=500, detail=f"Error retrieving model files from Hugging Face") from e
186
- except (IndexError, ValueError) as e:
187
- logger.error(f"Error parsing model file names from Hugging Face: {e}")
188
  raise HTTPException(status_code=500, detail=f"Error retrieving model files from Hugging Face") from e
189
 
190
  def download_and_upload_to_s3_url(self, url, s3_key):
 
1
+ from huggingface_hub import HfApi, hf_hub_download
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel, field_validator
4
  import requests
 
82
  logger.error(f"Could not determine revision for {model_prefix}")
83
  raise ValueError(f"Could not determine revision for {model_prefix}")
84
 
85
+ config = self._load_config(model_prefix, revision)
86
+ if config is None:
87
+ logger.error(f"Failed to load config for {model_prefix}")
88
+ raise ValueError(f"Failed to load config for {model_prefix}")
89
+
90
+ model = self._load_model(model_prefix, config, revision)
91
+ if model is None:
92
+ logger.error(f"Failed to load model {model_prefix}")
93
+ raise ValueError(f"Failed to load model {model_prefix}")
94
+
95
+ return model
96
 
 
 
 
 
97
  except HTTPException as e:
98
  logger.error(f"Error loading model: {e}")
99
+ raise
100
+ except Exception as e:
101
+ logger.exception(f"Unexpected error loading model: {e}")
102
+ raise HTTPException(status_code=500, detail=f"An unexpected error occurred while loading the model.")
103
 
104
+ def _load_config(self, model_prefix, revision):
105
+ try:
106
+ logger.info(f"Downloading config for {model_prefix} (revision {revision})...")
107
+ config_path = hf_hub_download(repo_id=model_prefix, filename="config.json", revision=revision)
108
+ with open(config_path, "r", encoding="utf-8") as f:
109
+ config_dict = json.load(f)
110
+ return AutoConfig.from_pretrained(model_prefix, **config_dict)
111
+ except Exception as e:
112
+ logger.error(f"Error loading config: {e}")
113
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
+ def _load_model(self, model_prefix, config, revision):
116
+ try:
117
+ logger.info(f"Downloading model files for {model_prefix} (revision {revision})...")
118
+ model_files = self._get_model_files(model_prefix, revision)
119
+ if not model_files:
120
+ logger.error(f"No model files found for {model_prefix}")
121
+ return None
122
+
123
+ state_dict = {}
124
+ for model_file in model_files:
125
+ logger.info(f"Downloading model file: {model_file}")
126
+ file_path = hf_hub_download(repo_id=model_prefix, filename=model_file, revision=revision)
127
+ with open(file_path, "rb") as f:
128
+ if model_file.endswith(".safetensors"):
129
+ shard_state = safetensors.torch.load_file(file_path)
130
+ elif model_file.endswith(".bin"):
131
+ shard_state = torch.load(f, map_location="cpu")
132
+ else:
133
+ logger.error(f"Unsupported model file type: {model_file}")
134
+ raise ValueError(f"Unsupported model file type: {model_file}")
135
+ state_dict.update(shard_state)
136
+
137
+ model = AutoModelForCausalLM.from_config(config)
138
+ model.load_state_dict(state_dict)
139
+ return model
140
 
141
+ except Exception as e:
142
+ logger.exception(f"Error loading model: {e}")
143
+ return None
144
 
145
  def load_tokenizer_from_stream(self, model_prefix):
146
  try:
147
  logger.info(f"Loading tokenizer for {model_prefix}...")
148
+ revision = self._get_latest_revision(model_prefix)
149
+ if revision is None:
150
+ logger.error(f"Could not determine revision for {model_prefix}")
151
+ raise ValueError(f"Could not determine revision for {model_prefix}")
152
+
153
+ tokenizer = self._load_tokenizer(model_prefix, revision)
154
+ if tokenizer is None:
155
+ logger.error(f"Failed to load tokenizer for {model_prefix}")
156
+ raise ValueError(f"Failed to load tokenizer for {model_prefix}")
157
+ return tokenizer
158
  except HTTPException as e:
159
  logger.error(f"Error loading tokenizer: {e}")
160
  return None
161
+ except Exception as e:
162
+ logger.exception(f"Unexpected error loading tokenizer: {e}")
163
+ raise HTTPException(status_code=500, detail=f"An unexpected error occurred while loading the tokenizer.")
164
 
165
+ def _load_tokenizer(self, model_prefix, revision):
166
+ try:
167
+ logger.info(f"Downloading tokenizer for {model_prefix} (revision {revision})...")
168
+ tokenizer_path = hf_hub_download(repo_id=model_prefix, filename="tokenizer.json", revision=revision)
169
+ return AutoTokenizer.from_pretrained(tokenizer_path)
170
+ except Exception as e:
171
+ logger.error(f"Error loading tokenizer: {e}")
172
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
 
175
  def _get_model_files(self, model_prefix, revision="main"):
 
176
  try:
177
+ api = HfApi()
178
+ model_files = api.list_repo_files(model_prefix, revision=revision)
179
+ model_files = [file.rfilename for file in model_files if file.rfilename.endswith(('.bin', '.safetensors'))]
 
 
 
180
  return model_files
181
+ except Exception as e:
182
+ logger.error(f"Error retrieving model files from Hugging Face: {e}")
 
 
 
183
  raise HTTPException(status_code=500, detail=f"Error retrieving model files from Hugging Face") from e
184
 
185
  def download_and_upload_to_s3_url(self, url, s3_key):