Hjgugugjhuhjggg commited on
Commit
685ddd1
·
verified ·
1 Parent(s): f5b9942

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -28
app.py CHANGED
@@ -50,7 +50,6 @@ class S3DirectStream:
50
  )
51
  self.bucket_name = bucket_name
52
 
53
-
54
  def stream_from_s3(self, key):
55
  try:
56
  logger.info(f"Downloading {key} from S3...")
@@ -73,7 +72,6 @@ class S3DirectStream:
73
  def load_model_from_stream(self, model_prefix, revision):
74
  try:
75
  logger.info(f"Loading model {model_prefix} (revision {revision})...")
76
-
77
  if self.file_exists_in_s3(f"{model_prefix}/config.json") and \
78
  (self.file_exists_in_s3(f"{model_prefix}/pytorch_model.bin") or self.file_exists_in_s3(f"{model_prefix}/model.safetensors")):
79
  logger.info(f"Model {model_prefix} found in S3. Loading...")
@@ -87,22 +85,19 @@ class S3DirectStream:
87
  logger.error(f"Error loading model: {e}")
88
  return None
89
 
90
-
91
  def load_model_from_existing_s3(self, model_prefix):
92
  logger.info(f"Loading config for {model_prefix} from S3...")
93
  config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
94
  config_dict = json.load(config_stream)
95
- config = AutoConfig.from_pretrained(config_dict)
96
  logger.info(f"Config loaded for {model_prefix}.")
97
 
98
-
99
  if self.file_exists_in_s3(f"{model_prefix}/model.safetensors"):
100
  logger.info(f"Loading safetensors model for {model_prefix} from S3...")
101
  model_stream = self.stream_from_s3(f"{model_prefix}/model.safetensors")
102
  model = AutoModelForCausalLM.from_config(config)
103
  model.load_state_dict(safetensors.torch.load_stream(model_stream))
104
  logger.info(f"Safetensors model loaded for {model_prefix}.")
105
-
106
  elif self.file_exists_in_s3(f"{model_prefix}/pytorch_model.bin"):
107
  logger.info(f"Loading PyTorch model for {model_prefix} from S3...")
108
  model_stream = self.stream_from_s3(f"{model_prefix}/pytorch_model.bin")
@@ -113,17 +108,14 @@ class S3DirectStream:
113
  else:
114
  logger.error(f"No model file found for {model_prefix} in S3")
115
  raise EnvironmentError(f"No model file found for {model_prefix} in S3")
116
-
117
  return model
118
 
119
  def load_tokenizer_from_stream(self, model_prefix):
120
  try:
121
  logger.info(f"Loading tokenizer for {model_prefix}...")
122
-
123
  if self.file_exists_in_s3(f"{model_prefix}/tokenizer.json"):
124
  logger.info(f"Tokenizer for {model_prefix} found in S3. Loading...")
125
- return self.load_tokenizer_from_existing_s3(model_prefix)
126
-
127
 
128
  logger.info(f"Tokenizer for {model_prefix} not found in S3. Downloading and uploading...")
129
  self.download_and_upload_to_s3(model_prefix)
@@ -134,23 +126,21 @@ class S3DirectStream:
134
  return None
135
 
136
 
137
- def load_tokenizer_from_existing_s3(self, model_prefix):
138
  logger.info(f"Loading tokenizer from S3 for {model_prefix}...")
 
139
  tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")
140
- tokenizer_config = json.load(tokenizer_stream) # Corrected this line
141
- tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=None, config=tokenizer_config) # Corrected this line
142
  logger.info(f"Tokenizer loaded for {model_prefix}.")
143
  return tokenizer
144
 
145
 
146
-
147
  def download_and_upload_to_s3(self, model_prefix, revision="main"):
148
  model_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/pytorch_model.bin"
149
  safetensors_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/model.safetensors"
150
  tokenizer_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/tokenizer.json"
151
  config_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/config.json"
152
 
153
-
154
  logger.info(f"Downloading and uploading model files for {model_prefix} to S3...")
155
  self.download_and_upload_to_s3_url(model_url, f"{model_prefix}/pytorch_model.bin")
156
  self.download_and_upload_to_s3_url(safetensors_url, f"{model_prefix}/model.safetensors")
@@ -158,23 +148,22 @@ class S3DirectStream:
158
  self.download_and_upload_to_s3_url(config_url, f"{model_prefix}/config.json")
159
  logger.info(f"Finished downloading and uploading model files for {model_prefix}.")
160
 
 
161
  def download_and_upload_to_s3_url(self, url, s3_key):
162
  logger.info(f"Downloading from {url}...")
163
-
164
  with requests.get(url, stream=True) as response:
165
  if response.status_code == 200:
166
- total_size_in_bytes= int(response.headers.get('content-length', 0))
167
  block_size = 1024
168
  progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
169
  logger.info(f"Uploading to S3: {s3_key}...")
170
-
171
  self.s3_client.upload_fileobj(response.raw, self.bucket_name, s3_key, Callback=lambda bytes_transferred: progress_bar.update(bytes_transferred))
172
  progress_bar.close()
173
  logger.info(f"Uploaded {s3_key} to S3 successfully.")
174
-
175
  elif response.status_code == 404:
176
  logger.error(f"File not found at {url}")
177
  raise HTTPException(status_code=404, detail=f"Error downloading file from {url}. File not found.")
 
178
  else:
179
  logger.error(f"Error downloading from {url}: Status code {response.status_code}")
180
  raise HTTPException(status_code=500, detail=f"Error downloading file from {url}")
@@ -184,23 +173,19 @@ class S3DirectStream:
184
  async def predict(model_request: DownloadModelRequest):
185
  try:
186
  logger.info(f"Received request: Model={model_request.model_name}, Task={model_request.pipeline_task}, Input={model_request.input_text}")
187
-
188
  model_name = model_request.model_name
189
  revision = model_request.revision
190
 
191
  streamer = S3DirectStream(S3_BUCKET_NAME)
192
  logger.info("Loading model and tokenizer...")
193
  model = streamer.load_model_from_stream(model_name, revision)
194
-
195
- tokenizer = streamer.load_tokenizer_from_stream(model_name)
196
  logger.info("Model and tokenizer loaded.")
197
 
198
-
199
  task = model_request.pipeline_task
200
  if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "summarization", "zero-shot-classification"]:
201
  raise HTTPException(status_code=400, detail="Unsupported pipeline task")
202
 
203
-
204
  if task == "text-generation":
205
  logger.info("Starting text generation...")
206
  text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
@@ -208,7 +193,6 @@ async def predict(model_request: DownloadModelRequest):
208
  generation_kwargs = dict(inputs, streamer=text_streamer)
209
  model.generate(**generation_kwargs)
210
  logger.info("Text generation finished.")
211
-
212
  return StreamingResponse(iter([tokenizer.decode(token) for token in text_streamer]), media_type="text/event-stream")
213
 
214
  else:
@@ -218,13 +202,10 @@ async def predict(model_request: DownloadModelRequest):
218
  logger.info(f"Pipeline task {task} finished.")
219
  return {"result": outputs}
220
 
221
-
222
-
223
  except Exception as e:
224
  logger.exception(f"Error processing request: {e}")
225
  raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
226
 
227
 
228
-
229
  if __name__ == "__main__":
230
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
50
  )
51
  self.bucket_name = bucket_name
52
 
 
53
  def stream_from_s3(self, key):
54
  try:
55
  logger.info(f"Downloading {key} from S3...")
 
72
  def load_model_from_stream(self, model_prefix, revision):
73
  try:
74
  logger.info(f"Loading model {model_prefix} (revision {revision})...")
 
75
  if self.file_exists_in_s3(f"{model_prefix}/config.json") and \
76
  (self.file_exists_in_s3(f"{model_prefix}/pytorch_model.bin") or self.file_exists_in_s3(f"{model_prefix}/model.safetensors")):
77
  logger.info(f"Model {model_prefix} found in S3. Loading...")
 
85
  logger.error(f"Error loading model: {e}")
86
  return None
87
 
 
88
  def load_model_from_existing_s3(self, model_prefix):
89
  logger.info(f"Loading config for {model_prefix} from S3...")
90
  config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
91
  config_dict = json.load(config_stream)
92
+ config = AutoConfig.from_dict(config_dict)
93
  logger.info(f"Config loaded for {model_prefix}.")
94
 
 
95
  if self.file_exists_in_s3(f"{model_prefix}/model.safetensors"):
96
  logger.info(f"Loading safetensors model for {model_prefix} from S3...")
97
  model_stream = self.stream_from_s3(f"{model_prefix}/model.safetensors")
98
  model = AutoModelForCausalLM.from_config(config)
99
  model.load_state_dict(safetensors.torch.load_stream(model_stream))
100
  logger.info(f"Safetensors model loaded for {model_prefix}.")
 
101
  elif self.file_exists_in_s3(f"{model_prefix}/pytorch_model.bin"):
102
  logger.info(f"Loading PyTorch model for {model_prefix} from S3...")
103
  model_stream = self.stream_from_s3(f"{model_prefix}/pytorch_model.bin")
 
108
  else:
109
  logger.error(f"No model file found for {model_prefix} in S3")
110
  raise EnvironmentError(f"No model file found for {model_prefix} in S3")
 
111
  return model
112
 
113
  def load_tokenizer_from_stream(self, model_prefix):
114
  try:
115
  logger.info(f"Loading tokenizer for {model_prefix}...")
 
116
  if self.file_exists_in_s3(f"{model_prefix}/tokenizer.json"):
117
  logger.info(f"Tokenizer for {model_prefix} found in S3. Loading...")
118
+ return self.load_tokenizer_from_existing_s3(model_prefix, config) # Pass config
 
119
 
120
  logger.info(f"Tokenizer for {model_prefix} not found in S3. Downloading and uploading...")
121
  self.download_and_upload_to_s3(model_prefix)
 
126
  return None
127
 
128
 
129
+ def load_tokenizer_from_existing_s3(self, model_prefix, config): # Recieve config
130
  logger.info(f"Loading tokenizer from S3 for {model_prefix}...")
131
+
132
  tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")
133
+ tokenizer = AutoTokenizer.from_pretrained(None, config=config) # Use config
 
134
  logger.info(f"Tokenizer loaded for {model_prefix}.")
135
  return tokenizer
136
 
137
 
 
138
  def download_and_upload_to_s3(self, model_prefix, revision="main"):
139
  model_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/pytorch_model.bin"
140
  safetensors_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/model.safetensors"
141
  tokenizer_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/tokenizer.json"
142
  config_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/config.json"
143
 
 
144
  logger.info(f"Downloading and uploading model files for {model_prefix} to S3...")
145
  self.download_and_upload_to_s3_url(model_url, f"{model_prefix}/pytorch_model.bin")
146
  self.download_and_upload_to_s3_url(safetensors_url, f"{model_prefix}/model.safetensors")
 
148
  self.download_and_upload_to_s3_url(config_url, f"{model_prefix}/config.json")
149
  logger.info(f"Finished downloading and uploading model files for {model_prefix}.")
150
 
151
+
152
  def download_and_upload_to_s3_url(self, url, s3_key):
153
  logger.info(f"Downloading from {url}...")
 
154
  with requests.get(url, stream=True) as response:
155
  if response.status_code == 200:
156
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
157
  block_size = 1024
158
  progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
159
  logger.info(f"Uploading to S3: {s3_key}...")
 
160
  self.s3_client.upload_fileobj(response.raw, self.bucket_name, s3_key, Callback=lambda bytes_transferred: progress_bar.update(bytes_transferred))
161
  progress_bar.close()
162
  logger.info(f"Uploaded {s3_key} to S3 successfully.")
 
163
  elif response.status_code == 404:
164
  logger.error(f"File not found at {url}")
165
  raise HTTPException(status_code=404, detail=f"Error downloading file from {url}. File not found.")
166
+
167
  else:
168
  logger.error(f"Error downloading from {url}: Status code {response.status_code}")
169
  raise HTTPException(status_code=500, detail=f"Error downloading file from {url}")
 
173
  async def predict(model_request: DownloadModelRequest):
174
  try:
175
  logger.info(f"Received request: Model={model_request.model_name}, Task={model_request.pipeline_task}, Input={model_request.input_text}")
 
176
  model_name = model_request.model_name
177
  revision = model_request.revision
178
 
179
  streamer = S3DirectStream(S3_BUCKET_NAME)
180
  logger.info("Loading model and tokenizer...")
181
  model = streamer.load_model_from_stream(model_name, revision)
182
+ tokenizer = streamer.load_tokenizer_from_stream(model_name) # Moved after model loading
 
183
  logger.info("Model and tokenizer loaded.")
184
 
 
185
  task = model_request.pipeline_task
186
  if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "summarization", "zero-shot-classification"]:
187
  raise HTTPException(status_code=400, detail="Unsupported pipeline task")
188
 
 
189
  if task == "text-generation":
190
  logger.info("Starting text generation...")
191
  text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
193
  generation_kwargs = dict(inputs, streamer=text_streamer)
194
  model.generate(**generation_kwargs)
195
  logger.info("Text generation finished.")
 
196
  return StreamingResponse(iter([tokenizer.decode(token) for token in text_streamer]), media_type="text/event-stream")
197
 
198
  else:
 
202
  logger.info(f"Pipeline task {task} finished.")
203
  return {"result": outputs}
204
 
 
 
205
  except Exception as e:
206
  logger.exception(f"Error processing request: {e}")
207
  raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
208
 
209
 
 
210
  if __name__ == "__main__":
211
  uvicorn.run(app, host="0.0.0.0", port=7860)