Hjgugugjhuhjggg commited on
Commit
a91fbbb
verified
1 Parent(s): ec0fbfd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -124
app.py CHANGED
@@ -1,30 +1,28 @@
1
  import os
2
- import shutil
3
  import boto3
4
  from fastapi import FastAPI, HTTPException
5
- from pydantic import BaseModel
6
- import requests
7
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
8
  import safetensors.torch
9
  from fastapi.responses import StreamingResponse
10
- import io
11
- from tqdm import tqdm
12
- import re
13
  import torch
14
  import uvicorn
 
 
15
 
16
  # Cargar las variables de entorno desde el archivo .env
17
- from dotenv import load_dotenv
18
  load_dotenv()
19
 
20
- # Cargar las credenciales de AWS desde las variables de entorno
21
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
22
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
23
  AWS_REGION = os.getenv("AWS_REGION")
24
- S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME") # Nombre del bucket de S3
25
- HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") # Token de Hugging Face
26
 
27
- # Cliente S3 de Amazon
28
  s3_client = boto3.client(
29
  's3',
30
  aws_access_key_id=AWS_ACCESS_KEY_ID,
@@ -34,12 +32,13 @@ s3_client = boto3.client(
34
 
35
  app = FastAPI()
36
 
37
- # Pydantic Model para el cuerpo de la solicitud del endpoint /predict/
38
  class DownloadModelRequest(BaseModel):
39
- model_name: str
40
- pipeline_task: str
41
- input_text: str
42
 
 
43
  class S3DirectStream:
44
  def __init__(self, bucket_name):
45
  self.s3_client = boto3.client(
@@ -54,139 +53,72 @@ class S3DirectStream:
54
  try:
55
  print(f"[INFO] Descargando archivo {key} desde S3...")
56
  response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
57
- return response['Body'] # Devolver el cuerpo directamente para el StreamingResponse
58
  except self.s3_client.exceptions.NoSuchKey:
59
  raise HTTPException(status_code=404, detail=f"El archivo {key} no existe en el bucket S3.")
60
  except Exception as e:
61
  print(f"[ERROR] Error al descargar {key}: {str(e)}")
62
- raise HTTPException(status_code=500, detail=f"Error al descargar archivo {key} desde S3.")
63
-
64
- def file_exists_in_s3(self, key):
65
- try:
66
- self.s3_client.head_object(Bucket=self.bucket_name, Key=key)
67
- return True
68
- except self.s3_client.exceptions.ClientError:
69
- return False
70
 
71
- def load_model_from_stream(self, model_prefix):
72
  try:
73
- print(f"[INFO] Cargando el modelo {model_prefix} desde S3...")
74
- model_files = self.get_model_file_parts(model_prefix)
75
- if not model_files:
76
- model_files = [f"{model_prefix}/model"] # Uso de modelo base
77
-
78
- # Leer y cargar todos los archivos del modelo
79
- model_streams = []
80
- for model_file in tqdm(model_files, desc="Cargando archivos del modelo", unit="archivo"):
81
- model_streams.append(self.stream_from_s3(model_file))
82
-
83
- # Verificar si el archivo es un safetensor o un archivo binario
84
  config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
 
 
 
85
  config_data = config_stream.read().decode("utf-8")
86
-
87
- # Cargar el modelo dependiendo de si es safetensor o binario
88
- if model_files[0].endswith("model.safetensors"):
89
- print("[INFO] Cargando el modelo como safetensor...")
90
- model = AutoModelForCausalLM.from_config(config_data)
91
- model.load_state_dict(safetensors.torch.load_stream(model_streams[0])) # Cargar el modelo utilizando safetensors
92
- else:
93
- print("[INFO] Cargando el modelo como archivo binario de PyTorch...")
94
- model = AutoModelForCausalLM.from_config(config_data)
95
- model.load_state_dict(torch.load(model_streams[0], map_location="cpu")) # Cargar el modelo utilizando pytorch
96
-
97
- print("[INFO] Modelo cargado con 茅xito.")
98
- return model
99
 
100
- except Exception as e:
101
- print(f"[ERROR] Error al cargar el modelo desde S3: {e}")
102
- raise HTTPException(status_code=500, detail="Error al cargar el modelo desde S3.")
103
 
104
- def load_tokenizer_from_stream(self, model_prefix):
105
- try:
106
- print(f"[INFO] Cargando el tokenizer {model_prefix} desde S3...")
107
- tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")
108
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)
109
- return tokenizer
110
- except Exception as e:
111
- print(f"[ERROR] Error al cargar el tokenizer desde S3: {e}")
112
- raise HTTPException(status_code=500, detail="Error al cargar el tokenizer desde S3.")
113
-
114
- def get_model_file_parts(self, model_prefix):
115
- print(f"[INFO] Listando archivos del modelo en S3 con prefijo {model_prefix}...")
116
- files = self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=model_prefix)
117
- model_files = []
118
- for obj in tqdm(files.get('Contents', []), desc="Verificando archivos", unit="archivo"):
119
- key = obj['Key']
120
- if re.match(rf"{model_prefix}/model(-\d+-of-\d+)?", key) or key.endswith("model.safetensors"):
121
- model_files.append(key)
122
- if not model_files:
123
- print(f"[WARNING] No se encontraron archivos coincidentes con el patr贸n para el modelo {model_prefix}.")
124
- return model_files
125
-
126
- def download_and_upload_to_s3_url(self, url: str, s3_key: str):
127
- try:
128
- print(f"[INFO] Descargando archivo desde {url}...")
129
- response = requests.get(url)
130
- if response.status_code == 200:
131
- print(f"[INFO] Subiendo archivo a S3 con key {s3_key}...")
132
- self.s3_client.put_object(Bucket=self.bucket_name, Key=s3_key, Body=response.content)
133
- # Eliminar el archivo local despu茅s de la carga exitosa
134
- self.delete_local_file(s3_key)
135
- else:
136
- print(f"[ERROR] Error al descargar el archivo desde {url}, c贸digo de estado {response.status_code}.")
137
- raise HTTPException(status_code=500, detail=f"Error al descargar el archivo desde {url}")
138
- except Exception as e:
139
- print(f"[ERROR] Error al procesar la URL {url}: {str(e)}")
140
- raise HTTPException(status_code=500, detail=f"Error al procesar la URL {url}")
141
 
142
- def delete_local_file(self, file_path: str):
143
- try:
144
- print(f"[INFO] Eliminando archivo local {file_path}...")
145
- if os.path.exists(file_path):
146
- os.remove(file_path)
147
- print(f"[INFO] Archivo local {file_path} eliminado correctamente.")
148
- else:
149
- print(f"[WARNING] El archivo local {file_path} no existe.")
150
  except Exception as e:
151
- print(f"[ERROR] Error al eliminar el archivo local {file_path}: {str(e)}")
152
-
153
 
 
154
  @app.post("/predict/")
155
  async def predict(model_request: DownloadModelRequest):
156
  try:
157
- print(f"[INFO] Recibiendo solicitud para predecir con el modelo {model_request.model_name}...")
158
  streamer = S3DirectStream(S3_BUCKET_NAME)
159
- model = streamer.load_model_from_stream(model_request.model_name)
160
- tokenizer = streamer.load_tokenizer_from_stream(model_request.model_name)
161
 
162
- task = model_request.pipeline_task
163
- if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "text-to-speech", "text-to-image", "text-to-audio", "text-to-video"]:
164
- raise HTTPException(status_code=400, detail="Pipeline task no soportado")
 
 
 
165
 
166
- nlp_pipeline = pipeline(task, model=model, tokenizer=tokenizer, max_length=2046)
167
 
168
- input_text = model_request.input_text
169
- print(f"[INFO] Ejecutando tarea {task} con el texto de entrada...")
170
- outputs = nlp_pipeline(input_text)
171
 
172
- # Eliminaci贸n de archivo local despu茅s de subir a S3
173
- if task == "text-to-speech":
174
- s3_key = f"{model_request.model_name}/generated_audio.wav"
175
- return StreamingResponse(streamer.stream_from_s3(s3_key), media_type="audio/wav")
176
- elif task == "text-to-image":
177
- s3_key = f"{model_request.model_name}/generated_image.png"
178
- return StreamingResponse(streamer.stream_from_s3(s3_key), media_type="image/png")
179
- elif task == "text-to-video":
180
- s3_key = f"{model_request.model_name}/generated_video.mp4"
181
- return StreamingResponse(streamer.stream_from_s3(s3_key), media_type="video/mp4")
182
 
183
- return {"input_text": input_text, "output": outputs}
184
 
185
  except Exception as e:
186
- print(f"[ERROR] Error al procesar la solicitud de predicci贸n: {str(e)}")
187
- raise HTTPException(status_code=500, detail=f"Error interno: {str(e)}")
188
-
189
 
 
190
  if __name__ == "__main__":
191
- print("Iniciando el servidor FastAPI...")
192
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
 
2
  import boto3
3
  from fastapi import FastAPI, HTTPException
4
+ from pydantic import BaseModel, Field
 
5
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
6
  import safetensors.torch
7
  from fastapi.responses import StreamingResponse
8
+ from dotenv import load_dotenv
9
+ import requests
 
10
  import torch
11
  import uvicorn
12
+ import re
13
+ from tqdm import tqdm
14
 
15
  # Cargar las variables de entorno desde el archivo .env
 
16
  load_dotenv()
17
 
18
+ # Configuraci贸n AWS y Hugging Face
19
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
20
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
21
  AWS_REGION = os.getenv("AWS_REGION")
22
+ S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
23
+ HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
24
 
25
+ # Cliente de Amazon S3
26
  s3_client = boto3.client(
27
  's3',
28
  aws_access_key_id=AWS_ACCESS_KEY_ID,
 
32
 
33
  app = FastAPI()
34
 
35
+ # Modelos Pydantic
36
  class DownloadModelRequest(BaseModel):
37
+ model_name: str = Field(..., example="model_directory_name")
38
+ pipeline_task: str = Field(..., example="text-generation")
39
+ input_text: str = Field(..., example="Introduce your input text here.")
40
 
41
+ # Clase para interacci贸n con S3
42
  class S3DirectStream:
43
  def __init__(self, bucket_name):
44
  self.s3_client = boto3.client(
 
53
  try:
54
  print(f"[INFO] Descargando archivo {key} desde S3...")
55
  response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
56
+ return response['Body']
57
  except self.s3_client.exceptions.NoSuchKey:
58
  raise HTTPException(status_code=404, detail=f"El archivo {key} no existe en el bucket S3.")
59
  except Exception as e:
60
  print(f"[ERROR] Error al descargar {key}: {str(e)}")
61
+ raise HTTPException(status_code=500, detail="Error al descargar archivo desde S3.")
 
 
 
 
 
 
 
62
 
63
+ def load_model_and_tokenizer(self, model_prefix):
64
  try:
65
+ print(f"[INFO] Cargando modelo y tokenizer desde S3 para {model_prefix}...")
66
+ model_stream = self.stream_from_s3(f"{model_prefix}/model.safetensors")
 
 
 
 
 
 
 
 
 
67
  config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
68
+ tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")
69
+
70
+ # Cargar configuraci贸n del modelo
71
  config_data = config_stream.read().decode("utf-8")
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ # Cargar modelo
74
+ model = AutoModelForCausalLM.from_config(config_data)
75
+ model.load_state_dict(safetensors.torch.load_stream(model_stream))
76
 
77
+ # Cargar tokenizer
 
 
 
78
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ print("[INFO] Modelo y tokenizer cargados con 茅xito.")
81
+ return model, tokenizer
 
 
 
 
 
 
82
  except Exception as e:
83
+ print(f"[ERROR] Error al cargar modelo/tokenizer desde S3: {e}")
84
+ raise HTTPException(status_code=500, detail="Error al cargar modelo/tokenizer.")
85
 
86
+ # Endpoint para predicciones
87
  @app.post("/predict/")
88
  async def predict(model_request: DownloadModelRequest):
89
  try:
90
+ print(f"[INFO] Procesando solicitud para el modelo {model_request.model_name}...")
91
  streamer = S3DirectStream(S3_BUCKET_NAME)
92
+ model, tokenizer = streamer.load_model_and_tokenizer(model_request.model_name)
 
93
 
94
+ if model_request.pipeline_task not in [
95
+ "text-generation", "sentiment-analysis", "translation",
96
+ "fill-mask", "question-answering", "text-to-speech",
97
+ "text-to-image", "text-to-video"
98
+ ]:
99
+ raise HTTPException(status_code=400, detail="Pipeline task no soportado.")
100
 
101
+ nlp_pipeline = pipeline(model_request.pipeline_task, model=model, tokenizer=tokenizer, max_length=2046)
102
 
103
+ outputs = nlp_pipeline(model_request.input_text)
 
 
104
 
105
+ # Responder seg煤n la tarea
106
+ if model_request.pipeline_task in ["text-to-speech", "text-to-image", "text-to-video"]:
107
+ media_type_map = {
108
+ "text-to-speech": "audio/wav",
109
+ "text-to-image": "image/png",
110
+ "text-to-video": "video/mp4"
111
+ }
112
+ s3_key = f"{model_request.model_name}/generated_output"
113
+ return StreamingResponse(streamer.stream_from_s3(s3_key), media_type=media_type_map[model_request.pipeline_task])
 
114
 
115
+ return {"input_text": model_request.input_text, "output": outputs}
116
 
117
  except Exception as e:
118
+ print(f"[ERROR] Error al procesar la solicitud: {e}")
119
+ raise HTTPException(status_code=500, detail=f"Error interno: {e}")
 
120
 
121
+ # Punto de entrada principal
122
  if __name__ == "__main__":
123
+ print("[INFO] Iniciando el servidor FastAPI...")
124
+ uvicorn.run(app, host="0.0.0.0", port=9000)