Hjgugugjhuhjggg commited on
Commit
f5b9942
·
verified ·
1 Parent(s): 8becaf9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -6
app.py CHANGED
@@ -12,6 +12,7 @@ import torch
12
  from fastapi.responses import StreamingResponse
13
  from tqdm import tqdm
14
  import logging
 
15
 
16
  load_dotenv()
17
 
@@ -49,6 +50,7 @@ class S3DirectStream:
49
  )
50
  self.bucket_name = bucket_name
51
 
 
52
  def stream_from_s3(self, key):
53
  try:
54
  logger.info(f"Downloading {key} from S3...")
@@ -71,6 +73,7 @@ class S3DirectStream:
71
  def load_model_from_stream(self, model_prefix, revision):
72
  try:
73
  logger.info(f"Loading model {model_prefix} (revision {revision})...")
 
74
  if self.file_exists_in_s3(f"{model_prefix}/config.json") and \
75
  (self.file_exists_in_s3(f"{model_prefix}/pytorch_model.bin") or self.file_exists_in_s3(f"{model_prefix}/model.safetensors")):
76
  logger.info(f"Model {model_prefix} found in S3. Loading...")
@@ -84,18 +87,22 @@ class S3DirectStream:
84
  logger.error(f"Error loading model: {e}")
85
  return None
86
 
 
87
  def load_model_from_existing_s3(self, model_prefix):
88
  logger.info(f"Loading config for {model_prefix} from S3...")
89
  config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
90
- config = AutoConfig.from_pretrained(config_stream)
 
91
  logger.info(f"Config loaded for {model_prefix}.")
92
 
 
93
  if self.file_exists_in_s3(f"{model_prefix}/model.safetensors"):
94
  logger.info(f"Loading safetensors model for {model_prefix} from S3...")
95
  model_stream = self.stream_from_s3(f"{model_prefix}/model.safetensors")
96
  model = AutoModelForCausalLM.from_config(config)
97
  model.load_state_dict(safetensors.torch.load_stream(model_stream))
98
  logger.info(f"Safetensors model loaded for {model_prefix}.")
 
99
  elif self.file_exists_in_s3(f"{model_prefix}/pytorch_model.bin"):
100
  logger.info(f"Loading PyTorch model for {model_prefix} from S3...")
101
  model_stream = self.stream_from_s3(f"{model_prefix}/pytorch_model.bin")
@@ -106,15 +113,18 @@ class S3DirectStream:
106
  else:
107
  logger.error(f"No model file found for {model_prefix} in S3")
108
  raise EnvironmentError(f"No model file found for {model_prefix} in S3")
 
109
  return model
110
 
111
  def load_tokenizer_from_stream(self, model_prefix):
112
  try:
113
  logger.info(f"Loading tokenizer for {model_prefix}...")
 
114
  if self.file_exists_in_s3(f"{model_prefix}/tokenizer.json"):
115
  logger.info(f"Tokenizer for {model_prefix} found in S3. Loading...")
116
  return self.load_tokenizer_from_existing_s3(model_prefix)
117
 
 
118
  logger.info(f"Tokenizer for {model_prefix} not found in S3. Downloading and uploading...")
119
  self.download_and_upload_to_s3(model_prefix)
120
  logger.info(f"Downloaded and uploaded tokenizer for {model_prefix}. Loading from S3...")
@@ -123,20 +133,24 @@ class S3DirectStream:
123
  logger.error(f"Error loading tokenizer: {e}")
124
  return None
125
 
 
126
  def load_tokenizer_from_existing_s3(self, model_prefix):
127
  logger.info(f"Loading tokenizer from S3 for {model_prefix}...")
128
  tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")
129
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)
 
130
  logger.info(f"Tokenizer loaded for {model_prefix}.")
131
  return tokenizer
132
 
133
 
 
134
  def download_and_upload_to_s3(self, model_prefix, revision="main"):
135
  model_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/pytorch_model.bin"
136
  safetensors_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/model.safetensors"
137
  tokenizer_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/tokenizer.json"
138
  config_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/config.json"
139
 
 
140
  logger.info(f"Downloading and uploading model files for {model_prefix} to S3...")
141
  self.download_and_upload_to_s3_url(model_url, f"{model_prefix}/pytorch_model.bin")
142
  self.download_and_upload_to_s3_url(safetensors_url, f"{model_prefix}/model.safetensors")
@@ -146,15 +160,18 @@ class S3DirectStream:
146
 
147
  def download_and_upload_to_s3_url(self, url, s3_key):
148
  logger.info(f"Downloading from {url}...")
 
149
  with requests.get(url, stream=True) as response:
150
  if response.status_code == 200:
151
  total_size_in_bytes= int(response.headers.get('content-length', 0))
152
  block_size = 1024
153
  progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
154
  logger.info(f"Uploading to S3: {s3_key}...")
 
155
  self.s3_client.upload_fileobj(response.raw, self.bucket_name, s3_key, Callback=lambda bytes_transferred: progress_bar.update(bytes_transferred))
156
  progress_bar.close()
157
  logger.info(f"Uploaded {s3_key} to S3 successfully.")
 
158
  elif response.status_code == 404:
159
  logger.error(f"File not found at {url}")
160
  raise HTTPException(status_code=404, detail=f"Error downloading file from {url}. File not found.")
@@ -167,26 +184,31 @@ class S3DirectStream:
167
  async def predict(model_request: DownloadModelRequest):
168
  try:
169
  logger.info(f"Received request: Model={model_request.model_name}, Task={model_request.pipeline_task}, Input={model_request.input_text}")
 
170
  model_name = model_request.model_name
171
  revision = model_request.revision
172
 
173
  streamer = S3DirectStream(S3_BUCKET_NAME)
174
- logger.info(f"Loading model and tokenizer...")
175
  model = streamer.load_model_from_stream(model_name, revision)
 
176
  tokenizer = streamer.load_tokenizer_from_stream(model_name)
177
- logger.info(f"Model and tokenizer loaded.")
 
178
 
179
  task = model_request.pipeline_task
180
  if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "summarization", "zero-shot-classification"]:
181
  raise HTTPException(status_code=400, detail="Unsupported pipeline task")
182
 
 
183
  if task == "text-generation":
184
- logger.info(f"Starting text generation...")
185
  text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
186
  inputs = tokenizer(model_request.input_text, return_tensors="pt").to(model.device)
187
  generation_kwargs = dict(inputs, streamer=text_streamer)
188
  model.generate(**generation_kwargs)
189
- logger.info(f"Text generation finished.")
 
190
  return StreamingResponse(iter([tokenizer.decode(token) for token in text_streamer]), media_type="text/event-stream")
191
 
192
  else:
@@ -196,10 +218,13 @@ async def predict(model_request: DownloadModelRequest):
196
  logger.info(f"Pipeline task {task} finished.")
197
  return {"result": outputs}
198
 
 
 
199
  except Exception as e:
200
  logger.exception(f"Error processing request: {e}")
201
  raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
202
 
203
 
 
204
  if __name__ == "__main__":
205
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
12
  from fastapi.responses import StreamingResponse
13
  from tqdm import tqdm
14
  import logging
15
+ import json
16
 
17
  load_dotenv()
18
 
 
50
  )
51
  self.bucket_name = bucket_name
52
 
53
+
54
  def stream_from_s3(self, key):
55
  try:
56
  logger.info(f"Downloading {key} from S3...")
 
73
  def load_model_from_stream(self, model_prefix, revision):
74
  try:
75
  logger.info(f"Loading model {model_prefix} (revision {revision})...")
76
+
77
  if self.file_exists_in_s3(f"{model_prefix}/config.json") and \
78
  (self.file_exists_in_s3(f"{model_prefix}/pytorch_model.bin") or self.file_exists_in_s3(f"{model_prefix}/model.safetensors")):
79
  logger.info(f"Model {model_prefix} found in S3. Loading...")
 
87
  logger.error(f"Error loading model: {e}")
88
  return None
89
 
90
+
91
  def load_model_from_existing_s3(self, model_prefix):
92
  logger.info(f"Loading config for {model_prefix} from S3...")
93
  config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
94
+ config_dict = json.load(config_stream)
95
+ config = AutoConfig.from_pretrained(config_dict)
96
  logger.info(f"Config loaded for {model_prefix}.")
97
 
98
+
99
  if self.file_exists_in_s3(f"{model_prefix}/model.safetensors"):
100
  logger.info(f"Loading safetensors model for {model_prefix} from S3...")
101
  model_stream = self.stream_from_s3(f"{model_prefix}/model.safetensors")
102
  model = AutoModelForCausalLM.from_config(config)
103
  model.load_state_dict(safetensors.torch.load_stream(model_stream))
104
  logger.info(f"Safetensors model loaded for {model_prefix}.")
105
+
106
  elif self.file_exists_in_s3(f"{model_prefix}/pytorch_model.bin"):
107
  logger.info(f"Loading PyTorch model for {model_prefix} from S3...")
108
  model_stream = self.stream_from_s3(f"{model_prefix}/pytorch_model.bin")
 
113
  else:
114
  logger.error(f"No model file found for {model_prefix} in S3")
115
  raise EnvironmentError(f"No model file found for {model_prefix} in S3")
116
+
117
  return model
118
 
119
  def load_tokenizer_from_stream(self, model_prefix):
120
  try:
121
  logger.info(f"Loading tokenizer for {model_prefix}...")
122
+
123
  if self.file_exists_in_s3(f"{model_prefix}/tokenizer.json"):
124
  logger.info(f"Tokenizer for {model_prefix} found in S3. Loading...")
125
  return self.load_tokenizer_from_existing_s3(model_prefix)
126
 
127
+
128
  logger.info(f"Tokenizer for {model_prefix} not found in S3. Downloading and uploading...")
129
  self.download_and_upload_to_s3(model_prefix)
130
  logger.info(f"Downloaded and uploaded tokenizer for {model_prefix}. Loading from S3...")
 
133
  logger.error(f"Error loading tokenizer: {e}")
134
  return None
135
 
136
+
137
  def load_tokenizer_from_existing_s3(self, model_prefix):
138
  logger.info(f"Loading tokenizer from S3 for {model_prefix}...")
139
  tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")
140
+ tokenizer_config = json.load(tokenizer_stream) # Corrected this line
141
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=None, config=tokenizer_config) # Corrected this line
142
  logger.info(f"Tokenizer loaded for {model_prefix}.")
143
  return tokenizer
144
 
145
 
146
+
147
  def download_and_upload_to_s3(self, model_prefix, revision="main"):
148
  model_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/pytorch_model.bin"
149
  safetensors_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/model.safetensors"
150
  tokenizer_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/tokenizer.json"
151
  config_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/config.json"
152
 
153
+
154
  logger.info(f"Downloading and uploading model files for {model_prefix} to S3...")
155
  self.download_and_upload_to_s3_url(model_url, f"{model_prefix}/pytorch_model.bin")
156
  self.download_and_upload_to_s3_url(safetensors_url, f"{model_prefix}/model.safetensors")
 
160
 
161
  def download_and_upload_to_s3_url(self, url, s3_key):
162
  logger.info(f"Downloading from {url}...")
163
+
164
  with requests.get(url, stream=True) as response:
165
  if response.status_code == 200:
166
  total_size_in_bytes= int(response.headers.get('content-length', 0))
167
  block_size = 1024
168
  progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
169
  logger.info(f"Uploading to S3: {s3_key}...")
170
+
171
  self.s3_client.upload_fileobj(response.raw, self.bucket_name, s3_key, Callback=lambda bytes_transferred: progress_bar.update(bytes_transferred))
172
  progress_bar.close()
173
  logger.info(f"Uploaded {s3_key} to S3 successfully.")
174
+
175
  elif response.status_code == 404:
176
  logger.error(f"File not found at {url}")
177
  raise HTTPException(status_code=404, detail=f"Error downloading file from {url}. File not found.")
 
184
  async def predict(model_request: DownloadModelRequest):
185
  try:
186
  logger.info(f"Received request: Model={model_request.model_name}, Task={model_request.pipeline_task}, Input={model_request.input_text}")
187
+
188
  model_name = model_request.model_name
189
  revision = model_request.revision
190
 
191
  streamer = S3DirectStream(S3_BUCKET_NAME)
192
+ logger.info("Loading model and tokenizer...")
193
  model = streamer.load_model_from_stream(model_name, revision)
194
+
195
  tokenizer = streamer.load_tokenizer_from_stream(model_name)
196
+ logger.info("Model and tokenizer loaded.")
197
+
198
 
199
  task = model_request.pipeline_task
200
  if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "summarization", "zero-shot-classification"]:
201
  raise HTTPException(status_code=400, detail="Unsupported pipeline task")
202
 
203
+
204
  if task == "text-generation":
205
+ logger.info("Starting text generation...")
206
  text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
207
  inputs = tokenizer(model_request.input_text, return_tensors="pt").to(model.device)
208
  generation_kwargs = dict(inputs, streamer=text_streamer)
209
  model.generate(**generation_kwargs)
210
+ logger.info("Text generation finished.")
211
+
212
  return StreamingResponse(iter([tokenizer.decode(token) for token in text_streamer]), media_type="text/event-stream")
213
 
214
  else:
 
218
  logger.info(f"Pipeline task {task} finished.")
219
  return {"result": outputs}
220
 
221
+
222
+
223
  except Exception as e:
224
  logger.exception(f"Error processing request: {e}")
225
  raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
226
 
227
 
228
+
229
  if __name__ == "__main__":
230
  uvicorn.run(app, host="0.0.0.0", port=7860)