Spaces:
Sleeping
Sleeping
Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -79,7 +79,7 @@ class S3DirectStream:
|
|
79 |
try:
|
80 |
logger.info(f"Loading model {model_prefix} (revision {revision})...")
|
81 |
if self.file_exists_in_s3(f"{model_prefix}/config.json") and \
|
82 |
-
(self.file_exists_in_s3(f"{model_prefix}/
|
83 |
logger.info(f"Model {model_prefix} found in S3. Loading...")
|
84 |
return self.load_model_from_existing_s3(model_prefix)
|
85 |
|
@@ -98,22 +98,32 @@ class S3DirectStream:
|
|
98 |
config = AutoConfig.from_pretrained(model_prefix, **config_dict)
|
99 |
logger.info(f"Config loaded for {model_prefix}.")
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
model
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
return model
|
118 |
|
119 |
def load_tokenizer_from_stream(self, model_prefix):
|
@@ -139,18 +149,37 @@ class S3DirectStream:
|
|
139 |
return tokenizer
|
140 |
|
141 |
def download_and_upload_to_s3(self, model_prefix, revision="main"):
|
142 |
-
|
143 |
-
safetensors_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/model.safetensors"
|
144 |
-
tokenizer_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/tokenizer.json"
|
145 |
config_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/config.json"
|
|
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
150 |
self.download_and_upload_to_s3_url(tokenizer_url, f"{model_prefix}/tokenizer.json")
|
151 |
-
|
152 |
logger.info(f"Finished downloading and uploading model files for {model_prefix}.")
|
153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
def download_and_upload_to_s3_url(self, url, s3_key):
|
155 |
logger.info(f"Downloading from {url}...")
|
156 |
with requests.get(url, stream=True) as response:
|
@@ -174,24 +203,22 @@ class S3DirectStream:
|
|
174 |
async def predict(model_request: DownloadModelRequest):
|
175 |
try:
|
176 |
logger.info(f"Received request: Model={model_request.model_id}, Task={model_request.pipeline_task}, Input={model_request.input_text}")
|
177 |
-
model_id = model_request.model_id
|
178 |
revision = model_request.revision
|
179 |
task = model_request.pipeline_task
|
180 |
input_text = model_request.input_text
|
181 |
|
182 |
streamer = S3DirectStream(S3_BUCKET_NAME)
|
183 |
logger.info("Loading model and tokenizer...")
|
184 |
-
model = streamer.load_model_from_stream(model_id, revision)
|
185 |
|
186 |
if model is None:
|
187 |
logger.error(f"Failed to load model {model_id}")
|
188 |
raise HTTPException(status_code=500, detail=f"Failed to load model {model_id}")
|
189 |
|
190 |
-
tokenizer = streamer.load_tokenizer_from_stream(model_id)
|
191 |
-
|
192 |
logger.info("Model and tokenizer loaded.")
|
193 |
|
194 |
-
|
195 |
if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "summarization", "zero-shot-classification"]:
|
196 |
raise HTTPException(status_code=400, detail="Unsupported pipeline task")
|
197 |
|
|
|
79 |
try:
|
80 |
logger.info(f"Loading model {model_prefix} (revision {revision})...")
|
81 |
if self.file_exists_in_s3(f"{model_prefix}/config.json") and \
|
82 |
+
any(self.file_exists_in_s3(f"{model_prefix}/{file}") for file in self._get_model_files(model_prefix, revision)):
|
83 |
logger.info(f"Model {model_prefix} found in S3. Loading...")
|
84 |
return self.load_model_from_existing_s3(model_prefix)
|
85 |
|
|
|
98 |
config = AutoConfig.from_pretrained(model_prefix, **config_dict)
|
99 |
logger.info(f"Config loaded for {model_prefix}.")
|
100 |
|
101 |
+
model_files = self._get_model_files(model_prefix)
|
102 |
+
if not model_files:
|
103 |
+
logger.error(f"No model files found for {model_prefix} in S3")
|
104 |
+
raise EnvironmentError(f"No model files found for {model_prefix} in S3")
|
105 |
+
|
106 |
+
state_dict = {}
|
107 |
+
for model_file in model_files:
|
108 |
+
model_path = os.path.join(model_prefix, model_file)
|
109 |
+
logger.info(f"Loading model file: {model_path}")
|
110 |
+
model_stream = self.stream_from_s3(model_path)
|
111 |
+
try:
|
112 |
+
if model_path.endswith(".safetensors"):
|
113 |
+
shard_state = safetensors.torch.load_stream(model_stream)
|
114 |
+
elif model_path.endswith(".bin"):
|
115 |
+
shard_state = torch.load(model_stream, map_location="cpu")
|
116 |
+
else:
|
117 |
+
logger.error(f"Unsupported model file type: {model_path}")
|
118 |
+
raise ValueError(f"Unsupported model file type: {model_path}")
|
119 |
+
|
120 |
+
state_dict.update(shard_state)
|
121 |
+
except Exception as e:
|
122 |
+
logger.exception(f"Error loading model file {model_path}: {e}")
|
123 |
+
raise
|
124 |
+
|
125 |
+
model = AutoModelForCausalLM.from_config(config)
|
126 |
+
model.load_state_dict(state_dict)
|
127 |
return model
|
128 |
|
129 |
def load_tokenizer_from_stream(self, model_prefix):
|
|
|
149 |
return tokenizer
|
150 |
|
151 |
def download_and_upload_to_s3(self, model_prefix, revision="main"):
|
152 |
+
logger.info(f"Downloading and uploading model files for {model_prefix} to S3...")
|
|
|
|
|
153 |
config_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/config.json"
|
154 |
+
self.download_and_upload_to_s3_url(config_url, f"{model_prefix}/config.json")
|
155 |
|
156 |
+
model_files = self._get_model_files(model_prefix, revision)
|
157 |
+
for model_file in model_files:
|
158 |
+
url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/{model_file}"
|
159 |
+
s3_key = f"{model_prefix}/{model_file}"
|
160 |
+
self.download_and_upload_to_s3_url(url, s3_key)
|
161 |
+
logger.info(f"Downloaded and uploaded {s3_key}")
|
162 |
+
|
163 |
+
tokenizer_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/tokenizer.json"
|
164 |
self.download_and_upload_to_s3_url(tokenizer_url, f"{model_prefix}/tokenizer.json")
|
165 |
+
|
166 |
logger.info(f"Finished downloading and uploading model files for {model_prefix}.")
|
167 |
|
168 |
+
def _get_model_files(self, model_prefix, revision="main"):
|
169 |
+
index_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/"
|
170 |
+
try:
|
171 |
+
index_response = requests.get(index_url)
|
172 |
+
index_response.raise_for_status()
|
173 |
+
index_content = index_response.text
|
174 |
+
model_files = [f for f in index_content.split('\n') if f.endswith(('.bin', '.safetensors'))]
|
175 |
+
return model_files
|
176 |
+
except requests.exceptions.RequestException as e:
|
177 |
+
logger.error(f"Error retrieving model index: {e}")
|
178 |
+
raise HTTPException(status_code=500, detail=f"Error retrieving model files from Hugging Face") from e
|
179 |
+
except (IndexError, ValueError) as e:
|
180 |
+
logger.error(f"Error parsing model file names from Hugging Face: {e}")
|
181 |
+
raise HTTPException(status_code=500, detail=f"Error retrieving model files from Hugging Face") from e
|
182 |
+
|
183 |
def download_and_upload_to_s3_url(self, url, s3_key):
|
184 |
logger.info(f"Downloading from {url}...")
|
185 |
with requests.get(url, stream=True) as response:
|
|
|
203 |
async def predict(model_request: DownloadModelRequest):
|
204 |
try:
|
205 |
logger.info(f"Received request: Model={model_request.model_id}, Task={model_request.pipeline_task}, Input={model_request.input_text}")
|
206 |
+
model_id = model_request.model_id
|
207 |
revision = model_request.revision
|
208 |
task = model_request.pipeline_task
|
209 |
input_text = model_request.input_text
|
210 |
|
211 |
streamer = S3DirectStream(S3_BUCKET_NAME)
|
212 |
logger.info("Loading model and tokenizer...")
|
213 |
+
model = streamer.load_model_from_stream(model_id, revision)
|
214 |
|
215 |
if model is None:
|
216 |
logger.error(f"Failed to load model {model_id}")
|
217 |
raise HTTPException(status_code=500, detail=f"Failed to load model {model_id}")
|
218 |
|
219 |
+
tokenizer = streamer.load_tokenizer_from_stream(model_id)
|
|
|
220 |
logger.info("Model and tokenizer loaded.")
|
221 |
|
|
|
222 |
if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "summarization", "zero-shot-classification"]:
|
223 |
raise HTTPException(status_code=400, detail="Unsupported pipeline task")
|
224 |
|