Hjgugugjhuhjggg commited on
Commit
8becaf9
·
verified ·
1 Parent(s): 972e5ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -19
app.py CHANGED
@@ -10,9 +10,14 @@ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoConf
10
  import safetensors.torch
11
  import torch
12
  from fastapi.responses import StreamingResponse
 
 
13
 
14
  load_dotenv()
15
 
 
 
 
16
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
17
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
18
  AWS_REGION = os.getenv("AWS_REGION")
@@ -46,60 +51,83 @@ class S3DirectStream:
46
 
47
  def stream_from_s3(self, key):
48
  try:
 
49
  response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
 
50
  return response['Body']
51
  except self.s3_client.exceptions.NoSuchKey:
 
52
  raise HTTPException(status_code=404, detail=f"File {key} not found in S3")
53
 
54
  def file_exists_in_s3(self, key):
55
  try:
56
  self.s3_client.head_object(Bucket=self.bucket_name, Key=key)
 
57
  return True
58
  except self.s3_client.exceptions.ClientError:
 
59
  return False
60
 
61
  def load_model_from_stream(self, model_prefix, revision):
62
  try:
 
63
  if self.file_exists_in_s3(f"{model_prefix}/config.json") and \
64
  (self.file_exists_in_s3(f"{model_prefix}/pytorch_model.bin") or self.file_exists_in_s3(f"{model_prefix}/model.safetensors")):
 
65
  return self.load_model_from_existing_s3(model_prefix)
66
 
 
67
  self.download_and_upload_to_s3(model_prefix, revision)
 
68
  return self.load_model_from_stream(model_prefix, revision)
69
  except HTTPException as e:
 
70
  return None
71
 
72
  def load_model_from_existing_s3(self, model_prefix):
 
73
  config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
74
- config = AutoConfig.from_pretrained(config_stream) # Directly from stream
 
75
 
76
  if self.file_exists_in_s3(f"{model_prefix}/model.safetensors"):
 
77
  model_stream = self.stream_from_s3(f"{model_prefix}/model.safetensors")
78
  model = AutoModelForCausalLM.from_config(config)
79
  model.load_state_dict(safetensors.torch.load_stream(model_stream))
 
80
  elif self.file_exists_in_s3(f"{model_prefix}/pytorch_model.bin"):
 
81
  model_stream = self.stream_from_s3(f"{model_prefix}/pytorch_model.bin")
82
  model = AutoModelForCausalLM.from_config(config)
83
- state_dict = torch.load(model_stream, map_location="cpu") # Load directly
84
  model.load_state_dict(state_dict)
 
85
  else:
86
- raise EnvironmentError(f"No model file found for {model_prefix} in S3")
 
87
  return model
88
 
89
-
90
-
91
  def load_tokenizer_from_stream(self, model_prefix):
92
  try:
 
93
  if self.file_exists_in_s3(f"{model_prefix}/tokenizer.json"):
 
94
  return self.load_tokenizer_from_existing_s3(model_prefix)
 
 
95
  self.download_and_upload_to_s3(model_prefix)
 
96
  return self.load_tokenizer_from_stream(model_prefix)
97
  except HTTPException as e:
 
98
  return None
99
 
100
  def load_tokenizer_from_existing_s3(self, model_prefix):
 
101
  tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")
102
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream) # Directly from stream
 
103
  return tokenizer
104
 
105
 
@@ -109,56 +137,69 @@ class S3DirectStream:
109
  tokenizer_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/tokenizer.json"
110
  config_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/config.json"
111
 
 
112
  self.download_and_upload_to_s3_url(model_url, f"{model_prefix}/pytorch_model.bin")
113
  self.download_and_upload_to_s3_url(safetensors_url, f"{model_prefix}/model.safetensors")
114
  self.download_and_upload_to_s3_url(tokenizer_url, f"{model_prefix}/tokenizer.json")
115
  self.download_and_upload_to_s3_url(config_url, f"{model_prefix}/config.json")
116
-
117
 
118
  def download_and_upload_to_s3_url(self, url, s3_key):
119
- response = requests.get(url, stream=True)
120
- if response.status_code == 200:
121
- self.s3_client.upload_fileobj(response.raw, self.bucket_name, s3_key) # Direct upload
122
- elif response.status_code == 404:
123
- raise HTTPException(status_code=404, detail=f"Error downloading file from {url}. File not found.")
124
-
125
- else:
126
- raise HTTPException(status_code=500, detail=f"Error downloading file from {url}")
127
-
 
 
 
 
 
 
 
128
 
129
 
130
  @app.post("/predict/")
131
  async def predict(model_request: DownloadModelRequest):
132
  try:
 
133
  model_name = model_request.model_name
134
  revision = model_request.revision
135
 
136
  streamer = S3DirectStream(S3_BUCKET_NAME)
 
137
  model = streamer.load_model_from_stream(model_name, revision)
138
  tokenizer = streamer.load_tokenizer_from_stream(model_name)
 
139
 
140
  task = model_request.pipeline_task
141
  if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "summarization", "zero-shot-classification"]:
142
  raise HTTPException(status_code=400, detail="Unsupported pipeline task")
143
 
144
  if task == "text-generation":
 
145
  text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
146
  inputs = tokenizer(model_request.input_text, return_tensors="pt").to(model.device)
147
  generation_kwargs = dict(inputs, streamer=text_streamer)
148
  model.generate(**generation_kwargs)
 
149
  return StreamingResponse(iter([tokenizer.decode(token) for token in text_streamer]), media_type="text/event-stream")
150
 
151
  else:
 
152
  nlp_pipeline = pipeline(task, model=model, tokenizer=tokenizer, device_map="auto", trust_remote_code=True)
153
  outputs = nlp_pipeline(model_request.input_text)
 
154
  return {"result": outputs}
155
 
156
-
157
  except Exception as e:
158
- print(f"Complete Error: {e}")
159
  raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
160
 
161
 
162
-
163
  if __name__ == "__main__":
164
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
10
  import safetensors.torch
11
  import torch
12
  from fastapi.responses import StreamingResponse
13
+ from tqdm import tqdm
14
+ import logging
15
 
16
  load_dotenv()
17
 
18
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
19
+ logger = logging.getLogger(__name__)
20
+
21
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
22
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
23
  AWS_REGION = os.getenv("AWS_REGION")
 
51
 
52
  def stream_from_s3(self, key):
53
  try:
54
+ logger.info(f"Downloading {key} from S3...")
55
  response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
56
+ logger.info(f"Downloaded {key} from S3 successfully.")
57
  return response['Body']
58
  except self.s3_client.exceptions.NoSuchKey:
59
+ logger.error(f"File {key} not found in S3")
60
  raise HTTPException(status_code=404, detail=f"File {key} not found in S3")
61
 
62
  def file_exists_in_s3(self, key):
63
  try:
64
  self.s3_client.head_object(Bucket=self.bucket_name, Key=key)
65
+ logger.info(f"File {key} exists in S3.")
66
  return True
67
  except self.s3_client.exceptions.ClientError:
68
+ logger.info(f"File {key} does not exist in S3.")
69
  return False
70
 
71
  def load_model_from_stream(self, model_prefix, revision):
72
  try:
73
+ logger.info(f"Loading model {model_prefix} (revision {revision})...")
74
  if self.file_exists_in_s3(f"{model_prefix}/config.json") and \
75
  (self.file_exists_in_s3(f"{model_prefix}/pytorch_model.bin") or self.file_exists_in_s3(f"{model_prefix}/model.safetensors")):
76
+ logger.info(f"Model {model_prefix} found in S3. Loading...")
77
  return self.load_model_from_existing_s3(model_prefix)
78
 
79
+ logger.info(f"Model {model_prefix} not found in S3. Downloading and uploading...")
80
  self.download_and_upload_to_s3(model_prefix, revision)
81
+ logger.info(f"Downloaded and uploaded {model_prefix}. Loading from S3...")
82
  return self.load_model_from_stream(model_prefix, revision)
83
  except HTTPException as e:
84
+ logger.error(f"Error loading model: {e}")
85
  return None
86
 
87
  def load_model_from_existing_s3(self, model_prefix):
88
+ logger.info(f"Loading config for {model_prefix} from S3...")
89
  config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
90
+ config = AutoConfig.from_pretrained(config_stream)
91
+ logger.info(f"Config loaded for {model_prefix}.")
92
 
93
  if self.file_exists_in_s3(f"{model_prefix}/model.safetensors"):
94
+ logger.info(f"Loading safetensors model for {model_prefix} from S3...")
95
  model_stream = self.stream_from_s3(f"{model_prefix}/model.safetensors")
96
  model = AutoModelForCausalLM.from_config(config)
97
  model.load_state_dict(safetensors.torch.load_stream(model_stream))
98
+ logger.info(f"Safetensors model loaded for {model_prefix}.")
99
  elif self.file_exists_in_s3(f"{model_prefix}/pytorch_model.bin"):
100
+ logger.info(f"Loading PyTorch model for {model_prefix} from S3...")
101
  model_stream = self.stream_from_s3(f"{model_prefix}/pytorch_model.bin")
102
  model = AutoModelForCausalLM.from_config(config)
103
+ state_dict = torch.load(model_stream, map_location="cpu")
104
  model.load_state_dict(state_dict)
105
+ logger.info(f"PyTorch model loaded for {model_prefix}.")
106
  else:
107
+ logger.error(f"No model file found for {model_prefix} in S3")
108
+ raise EnvironmentError(f"No model file found for {model_prefix} in S3")
109
  return model
110
 
 
 
111
  def load_tokenizer_from_stream(self, model_prefix):
112
  try:
113
+ logger.info(f"Loading tokenizer for {model_prefix}...")
114
  if self.file_exists_in_s3(f"{model_prefix}/tokenizer.json"):
115
+ logger.info(f"Tokenizer for {model_prefix} found in S3. Loading...")
116
  return self.load_tokenizer_from_existing_s3(model_prefix)
117
+
118
+ logger.info(f"Tokenizer for {model_prefix} not found in S3. Downloading and uploading...")
119
  self.download_and_upload_to_s3(model_prefix)
120
+ logger.info(f"Downloaded and uploaded tokenizer for {model_prefix}. Loading from S3...")
121
  return self.load_tokenizer_from_stream(model_prefix)
122
  except HTTPException as e:
123
+ logger.error(f"Error loading tokenizer: {e}")
124
  return None
125
 
126
  def load_tokenizer_from_existing_s3(self, model_prefix):
127
+ logger.info(f"Loading tokenizer from S3 for {model_prefix}...")
128
  tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")
129
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)
130
+ logger.info(f"Tokenizer loaded for {model_prefix}.")
131
  return tokenizer
132
 
133
 
 
137
  tokenizer_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/tokenizer.json"
138
  config_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/config.json"
139
 
140
+ logger.info(f"Downloading and uploading model files for {model_prefix} to S3...")
141
  self.download_and_upload_to_s3_url(model_url, f"{model_prefix}/pytorch_model.bin")
142
  self.download_and_upload_to_s3_url(safetensors_url, f"{model_prefix}/model.safetensors")
143
  self.download_and_upload_to_s3_url(tokenizer_url, f"{model_prefix}/tokenizer.json")
144
  self.download_and_upload_to_s3_url(config_url, f"{model_prefix}/config.json")
145
+ logger.info(f"Finished downloading and uploading model files for {model_prefix}.")
146
 
147
  def download_and_upload_to_s3_url(self, url, s3_key):
148
+ logger.info(f"Downloading from {url}...")
149
+ with requests.get(url, stream=True) as response:
150
+ if response.status_code == 200:
151
+ total_size_in_bytes= int(response.headers.get('content-length', 0))
152
+ block_size = 1024
153
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
154
+ logger.info(f"Uploading to S3: {s3_key}...")
155
+ self.s3_client.upload_fileobj(response.raw, self.bucket_name, s3_key, Callback=lambda bytes_transferred: progress_bar.update(bytes_transferred))
156
+ progress_bar.close()
157
+ logger.info(f"Uploaded {s3_key} to S3 successfully.")
158
+ elif response.status_code == 404:
159
+ logger.error(f"File not found at {url}")
160
+ raise HTTPException(status_code=404, detail=f"Error downloading file from {url}. File not found.")
161
+ else:
162
+ logger.error(f"Error downloading from {url}: Status code {response.status_code}")
163
+ raise HTTPException(status_code=500, detail=f"Error downloading file from {url}")
164
 
165
 
166
  @app.post("/predict/")
167
  async def predict(model_request: DownloadModelRequest):
168
  try:
169
+ logger.info(f"Received request: Model={model_request.model_name}, Task={model_request.pipeline_task}, Input={model_request.input_text}")
170
  model_name = model_request.model_name
171
  revision = model_request.revision
172
 
173
  streamer = S3DirectStream(S3_BUCKET_NAME)
174
+ logger.info(f"Loading model and tokenizer...")
175
  model = streamer.load_model_from_stream(model_name, revision)
176
  tokenizer = streamer.load_tokenizer_from_stream(model_name)
177
+ logger.info(f"Model and tokenizer loaded.")
178
 
179
  task = model_request.pipeline_task
180
  if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "summarization", "zero-shot-classification"]:
181
  raise HTTPException(status_code=400, detail="Unsupported pipeline task")
182
 
183
  if task == "text-generation":
184
+ logger.info(f"Starting text generation...")
185
  text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
186
  inputs = tokenizer(model_request.input_text, return_tensors="pt").to(model.device)
187
  generation_kwargs = dict(inputs, streamer=text_streamer)
188
  model.generate(**generation_kwargs)
189
+ logger.info(f"Text generation finished.")
190
  return StreamingResponse(iter([tokenizer.decode(token) for token in text_streamer]), media_type="text/event-stream")
191
 
192
  else:
193
+ logger.info(f"Starting pipeline task: {task}...")
194
  nlp_pipeline = pipeline(task, model=model, tokenizer=tokenizer, device_map="auto", trust_remote_code=True)
195
  outputs = nlp_pipeline(model_request.input_text)
196
+ logger.info(f"Pipeline task {task} finished.")
197
  return {"result": outputs}
198
 
 
199
  except Exception as e:
200
+ logger.exception(f"Error processing request: {e}")
201
  raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
202
 
203
 
 
204
  if __name__ == "__main__":
205
  uvicorn.run(app, host="0.0.0.0", port=7860)