Hjgugugjhuhjggg commited on
Commit
3a145aa
·
verified ·
1 Parent(s): ec488c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -180
app.py CHANGED
@@ -1,24 +1,19 @@
1
- from huggingface_hub import HfApi
2
  from fastapi import FastAPI, HTTPException
3
- from pydantic import BaseModel, field_validator
4
  import requests
5
  import boto3
6
  from dotenv import load_dotenv
7
  import os
8
  import uvicorn
9
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoConfig, TextIteratorStreamer
10
- import safetensors.torch
11
  import torch
 
12
  from fastapi.responses import StreamingResponse
13
  from tqdm import tqdm
14
- import logging
15
- import json
16
 
17
  load_dotenv()
18
 
19
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
- logger = logging.getLogger(__name__)
21
-
22
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
23
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
24
  AWS_REGION = os.getenv("AWS_REGION")
@@ -35,16 +30,10 @@ s3_client = boto3.client(
35
  app = FastAPI()
36
 
37
  class DownloadModelRequest(BaseModel):
38
- model_id: str
39
  pipeline_task: str
40
  input_text: str
41
 
42
- @field_validator('model_id')
43
- def validate_model_id(cls, value):
44
- if not value:
45
- raise ValueError("model_id cannot be empty")
46
- return value
47
-
48
  class S3DirectStream:
49
  def __init__(self, bucket_name):
50
  self.s3_client = boto3.client(
@@ -57,204 +46,135 @@ class S3DirectStream:
57
 
58
  def stream_from_s3(self, key):
59
  try:
60
- logger.info(f"Downloading {key} from S3...")
61
  response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
62
- logger.info(f"Downloaded {key} from S3 successfully.")
63
  return response['Body']
64
  except self.s3_client.exceptions.NoSuchKey:
65
- logger.error(f"File {key} not found in S3")
66
- raise HTTPException(status_code=404, detail=f"File {key} not found in S3")
 
67
 
68
  def file_exists_in_s3(self, key):
69
  try:
70
  self.s3_client.head_object(Bucket=self.bucket_name, Key=key)
71
- logger.info(f"File {key} exists in S3.")
72
  return True
73
- except self.s3_client.exceptions.ClientError:
74
- logger.info(f"File {key} does not exist in S3.")
75
- return False
 
76
 
77
  def load_model_from_stream(self, model_prefix):
78
  try:
79
- logger.info(f"Loading model {model_prefix}...")
80
- if self.file_exists_in_s3(f"{model_prefix}/config.json"):
81
- logger.info(f"Model {model_prefix} found in S3. Loading...")
82
- return self.load_model_from_existing_s3(model_prefix)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- logger.info(f"Model {model_prefix} not found in S3. Downloading and uploading...")
85
- self.download_and_upload_to_s3(model_prefix)
86
- logger.info(f"Downloaded and uploaded {model_prefix}. Loading from S3...")
87
- return self.load_model_from_stream(model_prefix)
88
  except HTTPException as e:
89
- logger.error(f"Error loading model: {e}")
90
- return None
91
-
92
- def load_model_from_existing_s3(self, model_prefix):
93
- logger.info(f"Loading config for {model_prefix} from S3...")
94
- config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
95
- config_dict = json.load(config_stream)
96
- config = AutoConfig.from_pretrained(model_prefix, **config_dict)
97
- logger.info(f"Config loaded for {model_prefix}.")
98
 
 
99
  try:
100
- api = HfApi()
101
- model_files = api.list_repo_files(model_prefix)
102
- state_dict = {}
103
- for file_info in model_files:
104
- if file_info.rfilename.endswith(('.bin', '.safetensors')):
105
- file_url = api.download_file(model_prefix, file_info.rfilename)
106
- model_path = os.path.join(model_prefix, file_info.rfilename)
107
- logger.info(f"Downloading model file from {file_url} to {model_path} ...")
108
- with requests.get(file_url, stream=True) as response:
109
- if response.status_code == 200:
110
- try:
111
- model_stream = response.raw
112
- if model_path.endswith(".safetensors"):
113
- shard_state = safetensors.torch.load_stream(model_stream)
114
- elif model_path.endswith(".bin"):
115
- shard_state = torch.load(model_stream, map_location="cpu")
116
- state_dict.update(shard_state)
117
- logger.info(f"Downloaded and loaded model file {model_path}")
118
- except Exception as e:
119
- logger.exception(f"Error loading model file {model_path}: {e}")
120
- raise
121
- else:
122
- logger.error(f"Error downloading {file_url} with status code: {response.status_code}")
123
- raise HTTPException(status_code=500, detail=f"Error downloading model file from Hugging Face")
124
  except Exception as e:
125
- logger.exception(f"Error loading model files for {model_prefix} : {e}")
126
- raise
127
-
128
-
129
- model = AutoModelForCausalLM.from_config(config)
130
- model.load_state_dict(state_dict)
131
- return model
132
 
133
  def load_tokenizer_from_stream(self, model_prefix):
134
  try:
135
- logger.info(f"Loading tokenizer for {model_prefix}...")
136
- if self.file_exists_in_s3(f"{model_prefix}/tokenizer.json"):
137
- logger.info(f"Tokenizer for {model_prefix} found in S3. Loading...")
138
- return self.load_tokenizer_from_existing_s3(model_prefix, config)
139
-
140
- logger.info(f"Tokenizer for {model_prefix} not found in S3. Downloading and uploading...")
141
- self.download_and_upload_to_s3(model_prefix)
142
- logger.info(f"Downloaded and uploaded tokenizer for {model_prefix}. Loading from S3...")
143
- return self.load_tokenizer_from_stream(model_prefix)
144
  except HTTPException as e:
145
- logger.error(f"Error loading tokenizer: {e}")
146
- return None
147
-
148
- def load_tokenizer_from_existing_s3(self, model_prefix, config):
149
- logger.info(f"Loading tokenizer from S3 for {model_prefix}...")
150
- tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")
151
- tokenizer = AutoTokenizer.from_pretrained(None, config=config)
152
- logger.info(f"Tokenizer loaded for {model_prefix}.")
153
- return tokenizer
154
-
155
- def download_and_upload_to_s3(self, model_prefix):
156
- logger.info(f"Downloading and uploading model files for {model_prefix} to S3...")
157
- try:
158
- api = HfApi()
159
- model_files = api.list_repo_files(model_prefix)
160
-
161
- for file_info in model_files:
162
- if file_info.rfilename.endswith(('.bin', '.safetensors', 'config.json', 'tokenizer.json')):
163
- file_url = api.download_file(model_prefix, file_info.rfilename)
164
- s3_key = f"{model_prefix}/{file_info.rfilename}"
165
- try:
166
- self.download_and_upload_to_s3_url(file_url, s3_key)
167
- logger.info(f"Downloaded and uploaded {s3_key}")
168
- except Exception as e:
169
- logger.exception(f"Error downloading/uploading {s3_key}: {e}")
170
-
171
- logger.info(f"Finished downloading and uploading model files for {model_prefix}.")
172
-
173
- except requests.exceptions.RequestException as e:
174
- logger.error(f"Error downloading model files from HuggingFace: {e}")
175
- raise HTTPException(status_code=500, detail=f"Error downloading model files from Hugging Face") from e
176
  except Exception as e:
177
- logger.error(f"An unexpected error occurred: {e}")
178
- raise HTTPException(status_code=500, detail=f"An unexpected error occurred during model download") from e
179
 
180
 
181
- def download_and_upload_to_s3_url(self, url, s3_key):
182
- logger.info(f"Downloading from {url}...")
183
- with requests.get(url, stream=True) as response:
184
- if response.status_code == 200:
185
- total_size_in_bytes = int(response.headers.get('content-length', 0))
186
- block_size = 1024
187
- progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
188
- logger.info(f"Uploading to S3: {s3_key}...")
189
- self.s3_client.upload_fileobj(response.raw, self.bucket_name, s3_key, Callback=lambda bytes_transferred: progress_bar.update(bytes_transferred))
190
- progress_bar.close()
191
- logger.info(f"Uploaded {s3_key} to S3 successfully.")
192
- elif response.status_code == 404:
193
- logger.error(f"File not found at {url}")
194
- raise HTTPException(status_code=404, detail=f"Error downloading file from {url}. File not found.")
195
- else:
196
- logger.error(f"Error downloading from {url}: Status code {response.status_code}")
197
- raise HTTPException(status_code=500, detail=f"Error downloading file from {url}")
198
-
199
- def _get_latest_revision(self, model_prefix):
200
- try:
201
- api = HfApi()
202
- model_info = api.model_info(model_prefix)
203
- if hasattr(model_info, 'revision'):
204
- revision = model_info.revision
205
- if revision:
206
- return revision
207
- else:
208
- logger.warning(f"No revision found for {model_prefix}, using 'main'")
209
- return "main"
210
- else:
211
- logger.warning(f"ModelInfo object for {model_prefix} does not have a 'revision' attribute, using 'main'")
212
- return "main"
213
- except Exception as e:
214
- logger.error(f"Error getting latest revision for {model_prefix}: {e}")
215
- return None
216
 
217
 
218
  @app.post("/predict/")
219
  async def predict(model_request: DownloadModelRequest):
220
  try:
221
- logger.info(f"Received request: Model={model_request.model_id}, Task={model_request.pipeline_task}, Input={model_request.input_text}")
222
- model_id = model_request.model_id
223
- task = model_request.pipeline_task
224
- input_text = model_request.input_text
225
-
226
  streamer = S3DirectStream(S3_BUCKET_NAME)
227
- logger.info("Loading model and tokenizer...")
228
- model = streamer.load_model_from_stream(model_id)
229
-
230
- if model is None:
231
- logger.error(f"Failed to load model {model_id}")
232
- raise HTTPException(status_code=500, detail=f"Failed to load model {model_id}")
233
 
234
- tokenizer = streamer.load_tokenizer_from_stream(model_id)
235
- logger.info("Model and tokenizer loaded.")
236
-
237
- if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "summarization", "zero-shot-classification"]:
238
- raise HTTPException(status_code=400, detail="Unsupported pipeline task")
239
 
240
- if task == "text-generation":
241
- logger.info("Starting text generation...")
242
- text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
243
- inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
244
- generation_kwargs = dict(inputs, streamer=text_streamer)
245
- model.generate(**generation_kwargs)
246
- logger.info("Text generation finished.")
247
- return StreamingResponse(iter([tokenizer.decode(token) for token in text_streamer]), media_type="text/event-stream")
 
 
 
 
 
 
 
248
  else:
249
- logger.info(f"Starting pipeline task: {task}...")
250
- nlp_pipeline = pipeline(task, model=model, tokenizer=tokenizer, device_map="auto", trust_remote_code=True)
251
- outputs = nlp_pipeline(input_text)
252
- logger.info(f"Pipeline task {task} finished.")
253
- return {"result": outputs}
254
 
 
 
255
  except Exception as e:
256
- logger.exception(f"Error processing request: {e}")
257
- raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
258
 
259
  if __name__ == "__main__":
260
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
1
  from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
  import requests
4
  import boto3
5
  from dotenv import load_dotenv
6
  import os
7
  import uvicorn
8
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
 
9
  import torch
10
+ import safetensors.torch
11
  from fastapi.responses import StreamingResponse
12
  from tqdm import tqdm
13
+ import re
 
14
 
15
  load_dotenv()
16
 
 
 
 
17
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
18
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
19
  AWS_REGION = os.getenv("AWS_REGION")
 
30
  app = FastAPI()
31
 
32
  class DownloadModelRequest(BaseModel):
33
+ model_name: str
34
  pipeline_task: str
35
  input_text: str
36
 
 
 
 
 
 
 
37
  class S3DirectStream:
38
  def __init__(self, bucket_name):
39
  self.s3_client = boto3.client(
 
46
 
47
  def stream_from_s3(self, key):
48
  try:
 
49
  response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
 
50
  return response['Body']
51
  except self.s3_client.exceptions.NoSuchKey:
52
+ raise HTTPException(status_code=404, detail=f"El archivo {key} no existe en el bucket S3.")
53
+ except Exception as e:
54
+ raise HTTPException(status_code=500, detail=f"Error al descargar de S3: {e}")
55
 
56
  def file_exists_in_s3(self, key):
57
  try:
58
  self.s3_client.head_object(Bucket=self.bucket_name, Key=key)
 
59
  return True
60
+ except self.s3_client.exceptions.ClientError as e:
61
+ if e.response['Error']['Code'] == '404':
62
+ return False
63
+ raise HTTPException(status_code=500, detail=f"Error al verificar archivo en S3: {e}")
64
 
65
  def load_model_from_stream(self, model_prefix):
66
  try:
67
+ model_files = self.list_model_files(model_prefix)
68
+ if not model_files:
69
+ self.download_and_upload_to_s3(model_prefix)
70
+ return self.load_model_from_stream(model_prefix)
71
+
72
+ config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
73
+ config_data = config_stream.read().decode("utf-8")
74
+
75
+ model_path = f"{model_prefix}/model.safetensors"
76
+ if self.file_exists_in_s3(model_path):
77
+ model_stream = self.stream_from_s3(model_path)
78
+ model = AutoModelForCausalLM.from_config(config_data)
79
+ model.load_state_dict(safetensors.torch.load_stream(model_stream))
80
+ elif model_files:
81
+ model = AutoModelForCausalLM.from_config(config_data)
82
+ state_dict = {}
83
+ for file_name in model_files:
84
+ file_stream = self.stream_from_s3(f"{model_prefix}/{file_name}")
85
+ tmp = torch.load(file_stream, map_location="cpu")
86
+ state_dict.update(tmp)
87
+ model.load_state_dict(state_dict)
88
+ else:
89
+ raise HTTPException(status_code=500, detail="Modelo no encontrado")
90
 
91
+ return model
 
 
 
92
  except HTTPException as e:
93
+ raise
94
+ except Exception as e:
95
+ raise HTTPException(status_code=500, detail=f"Error al cargar el modelo: {e}")
 
 
 
 
 
 
96
 
97
+ def list_model_files(self, model_prefix):
98
  try:
99
+ response = self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=f"{model_prefix}/pytorch_model-")
100
+ model_files = []
101
+ if 'Contents' in response:
102
+ for obj in response['Contents']:
103
+ if re.match(r"pytorch_model-\d+-of-\d+", obj['Key'].split('/')[-1]):
104
+ model_files.append(obj['Key'].split('/')[-1])
105
+ return model_files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  except Exception as e:
107
+ return None
 
 
 
 
 
 
108
 
109
  def load_tokenizer_from_stream(self, model_prefix):
110
  try:
111
+ tokenizer_path = f"{model_prefix}/tokenizer.json"
112
+ if self.file_exists_in_s3(tokenizer_path):
113
+ tokenizer_stream = self.stream_from_s3(tokenizer_path)
114
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)
115
+ return tokenizer
116
+ else:
117
+ self.download_and_upload_to_s3(model_prefix)
118
+ return self.load_tokenizer_from_stream(model_prefix)
 
119
  except HTTPException as e:
120
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  except Exception as e:
122
+ raise HTTPException(status_code=500, detail=f"Error al cargar el tokenizer: {e}")
 
123
 
124
 
125
+ def download_and_upload_to_s3(self, model_prefix):
126
+ urls = {
127
+ "pytorch_model.bin": f"https://huggingface.co/{model_prefix}/resolve/main/pytorch_model.bin",
128
+ "model.safetensors": f"https://huggingface.co/{model_prefix}/resolve/main/model.safetensors",
129
+ "tokenizer.json": f"https://huggingface.co/{model_prefix}/resolve/main/tokenizer.json",
130
+ "config.json": f"https://huggingface.co/{model_prefix}/resolve/main/config.json"
131
+ }
132
+
133
+ for filename, url in urls.items():
134
+ try:
135
+ response = requests.get(url, stream=True)
136
+ response.raise_for_status()
137
+ self.s3_client.upload_fileobj(response.raw, self.bucket_name, f"{model_prefix}/{filename}")
138
+ except requests.exceptions.RequestException as e:
139
+ raise HTTPException(status_code=500, detail=f"Error al descargar {filename}: {e}")
140
+ except Exception as e:
141
+ raise HTTPException(status_code=500, detail=f"Error al subir {filename} a S3: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
 
144
  @app.post("/predict/")
145
  async def predict(model_request: DownloadModelRequest):
146
  try:
 
 
 
 
 
147
  streamer = S3DirectStream(S3_BUCKET_NAME)
148
+ model = streamer.load_model_from_stream(model_request.model_name)
149
+ tokenizer = streamer.load_tokenizer_from_stream(model_request.model_name)
 
 
 
 
150
 
151
+ task = model_request.pipeline_task
152
+ if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "text-to-speech", "text-to-image", "text-to-audio", "text-to-video"]:
153
+ raise HTTPException(status_code=400, detail="Pipeline task no soportado")
 
 
154
 
155
+ nlp_pipeline = pipeline(task, model=model, tokenizer=tokenizer)
156
+ input_text = model_request.input_text
157
+ outputs = nlp_pipeline(input_text)
158
+
159
+ if task in ["text-generation", "translation", "fill-mask", "sentiment-analysis", "question-answering"]:
160
+ return {"response": outputs}
161
+ elif task == "text-to-image":
162
+ s3_key = f"{model_request.model_name}/generated_image.png"
163
+ return StreamingResponse(streamer.stream_from_s3(s3_key), media_type="image/png")
164
+ elif task == "text-to-audio":
165
+ s3_key = f"{model_request.model_name}/generated_audio.wav"
166
+ return StreamingResponse(streamer.stream_from_s3(s3_key), media_type="audio/wav")
167
+ elif task == "text-to-video":
168
+ s3_key = f"{model_request.model_name}/generated_video.mp4"
169
+ return StreamingResponse(streamer.stream_from_s3(s3_key), media_type="video/mp4")
170
  else:
171
+ raise HTTPException(status_code=400, detail="Tipo de tarea desconocido")
 
 
 
 
172
 
173
+ except HTTPException as e:
174
+ raise
175
  except Exception as e:
176
+ raise HTTPException(status_code=500, detail=f"Error inesperado: {str(e)}")
177
+
178
 
179
  if __name__ == "__main__":
180
  uvicorn.run(app, host="0.0.0.0", port=7860)