Hjgugugjhuhjggg commited on
Commit
d44fda2
verified
1 Parent(s): d200477

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -57
app.py CHANGED
@@ -1,28 +1,29 @@
1
  import os
2
  import boto3
 
3
  from fastapi import FastAPI, HTTPException
4
- from pydantic import BaseModel, Field
5
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
6
  import safetensors.torch
 
7
  from fastapi.responses import StreamingResponse
8
- from dotenv import load_dotenv
 
9
  import requests
10
- import torch
11
  import uvicorn
12
  import re
13
- from tqdm import tqdm
14
-
15
  # Cargar las variables de entorno desde el archivo .env
 
16
  load_dotenv()
17
 
18
- # Configuraci贸n AWS y Hugging Face
19
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
20
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
21
  AWS_REGION = os.getenv("AWS_REGION")
22
- S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
23
- HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
24
 
25
- # Cliente de Amazon S3
26
  s3_client = boto3.client(
27
  's3',
28
  aws_access_key_id=AWS_ACCESS_KEY_ID,
@@ -32,13 +33,12 @@ s3_client = boto3.client(
32
 
33
  app = FastAPI()
34
 
35
- # Modelos Pydantic
36
  class DownloadModelRequest(BaseModel):
37
- model_name: str = Field(..., example="model_directory_name")
38
- pipeline_task: str = Field(..., example="text-generation")
39
- input_text: str = Field(..., example="Introduce your input text here.")
40
 
41
- # Clase para interacci贸n con S3
42
  class S3DirectStream:
43
  def __init__(self, bucket_name):
44
  self.s3_client = boto3.client(
@@ -53,72 +53,139 @@ class S3DirectStream:
53
  try:
54
  print(f"[INFO] Descargando archivo {key} desde S3...")
55
  response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
56
- return response['Body']
57
  except self.s3_client.exceptions.NoSuchKey:
58
  raise HTTPException(status_code=404, detail=f"El archivo {key} no existe en el bucket S3.")
59
  except Exception as e:
60
  print(f"[ERROR] Error al descargar {key}: {str(e)}")
61
- raise HTTPException(status_code=500, detail="Error al descargar archivo desde S3.")
62
 
63
- def load_model_and_tokenizer(self, model_prefix):
64
  try:
65
- print(f"[INFO] Cargando modelo y tokenizer desde S3 para {model_prefix}...")
66
- model_stream = self.stream_from_s3(f"{model_prefix}/model.safetensors")
67
- config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
68
- tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")
 
 
 
 
 
 
 
69
 
70
- # Cargar configuraci贸n del modelo
 
 
 
 
 
 
71
  config_data = config_stream.read().decode("utf-8")
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- # Cargar modelo
74
- model = AutoModelForCausalLM.from_config(config_data)
75
- model.load_state_dict(safetensors.torch.load_stream(model_stream))
76
 
77
- # Cargar tokenizer
 
 
 
78
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- print("[INFO] Modelo y tokenizer cargados con 茅xito.")
81
- return model, tokenizer
 
 
 
 
 
 
82
  except Exception as e:
83
- print(f"[ERROR] Error al cargar modelo/tokenizer desde S3: {e}")
84
- raise HTTPException(status_code=500, detail="Error al cargar modelo/tokenizer.")
85
 
86
- # Endpoint para predicciones
87
  @app.post("/predict/")
88
  async def predict(model_request: DownloadModelRequest):
89
  try:
90
- print(f"[INFO] Procesando solicitud para el modelo {model_request.model_name}...")
91
  streamer = S3DirectStream(S3_BUCKET_NAME)
92
- model, tokenizer = streamer.load_model_and_tokenizer(model_request.model_name)
 
93
 
94
- if model_request.pipeline_task not in [
95
- "text-generation", "sentiment-analysis", "translation",
96
- "fill-mask", "question-answering", "text-to-speech",
97
- "text-to-image", "text-to-video"
98
- ]:
99
- raise HTTPException(status_code=400, detail="Pipeline task no soportado.")
100
 
101
- nlp_pipeline = pipeline(model_request.pipeline_task, model=model, tokenizer=tokenizer, max_length=2046)
102
 
103
- outputs = nlp_pipeline(model_request.input_text)
 
 
104
 
105
- # Responder seg煤n la tarea
106
- if model_request.pipeline_task in ["text-to-speech", "text-to-image", "text-to-video"]:
107
- media_type_map = {
108
- "text-to-speech": "audio/wav",
109
- "text-to-image": "image/png",
110
- "text-to-video": "video/mp4"
111
- }
112
- s3_key = f"{model_request.model_name}/generated_output"
113
- return StreamingResponse(streamer.stream_from_s3(s3_key), media_type=media_type_map[model_request.pipeline_task])
 
114
 
115
- return {"input_text": model_request.input_text, "output": outputs}
116
 
117
  except Exception as e:
118
- print(f"[ERROR] Error al procesar la solicitud: {e}")
119
- raise HTTPException(status_code=500, detail=f"Error interno: {e}")
120
 
121
- # Punto de entrada principal
122
  if __name__ == "__main__":
123
- print("[INFO] Iniciando el servidor FastAPI...")
124
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
  import boto3
3
+ import torch
4
  from fastapi import FastAPI, HTTPException
5
+ from pydantic import BaseModel
 
6
  import safetensors.torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from fastapi.responses import StreamingResponse
9
+ import io
10
+ from tqdm import tqdm
11
  import requests
 
12
  import uvicorn
13
  import re
14
+ import sys
 
15
  # Cargar las variables de entorno desde el archivo .env
16
+ from dotenv import load_dotenv
17
  load_dotenv()
18
 
19
+ # Cargar las credenciales de AWS desde las variables de entorno
20
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
21
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
22
  AWS_REGION = os.getenv("AWS_REGION")
23
+ S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME") # Nombre del bucket de S3
24
+ HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") # Token de Hugging Face
25
 
26
+ # Cliente S3 de Amazon
27
  s3_client = boto3.client(
28
  's3',
29
  aws_access_key_id=AWS_ACCESS_KEY_ID,
 
33
 
34
  app = FastAPI()
35
 
36
+ # Pydantic Model para el cuerpo de la solicitud del endpoint /predict/
37
  class DownloadModelRequest(BaseModel):
38
+ model_name: str
39
+ pipeline_task: str
40
+ input_text: str
41
 
 
42
  class S3DirectStream:
43
  def __init__(self, bucket_name):
44
  self.s3_client = boto3.client(
 
53
  try:
54
  print(f"[INFO] Descargando archivo {key} desde S3...")
55
  response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
56
+ return response['Body'] # Devolver el cuerpo directamente para el StreamingResponse
57
  except self.s3_client.exceptions.NoSuchKey:
58
  raise HTTPException(status_code=404, detail=f"El archivo {key} no existe en el bucket S3.")
59
  except Exception as e:
60
  print(f"[ERROR] Error al descargar {key}: {str(e)}")
61
+ raise HTTPException(status_code=500, detail=f"Error al descargar archivo {key} desde S3.")
62
 
63
+ def file_exists_in_s3(self, key):
64
  try:
65
+ self.s3_client.head_object(Bucket=self.bucket_name, Key=key)
66
+ return True
67
+ except self.s3_client.exceptions.ClientError:
68
+ return False
69
+
70
+ def load_model_from_stream(self, model_prefix):
71
+ try:
72
+ print(f"[INFO] Cargando el modelo {model_prefix} desde S3...")
73
+ model_files = self.get_model_file_parts(model_prefix)
74
+ if not model_files:
75
+ model_files = [f"{model_prefix}/model"] # Uso de modelo base
76
 
77
+ # Leer y cargar todos los archivos del modelo
78
+ model_streams = []
79
+ for model_file in tqdm(model_files, desc="Cargando archivos del modelo", unit="archivo"):
80
+ model_streams.append(self.stream_from_s3(model_file))
81
+
82
+ # Verificar si el archivo es un safetensor o un archivo binario
83
+ config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
84
  config_data = config_stream.read().decode("utf-8")
85
+
86
+ # Cargar el modelo dependiendo de si es safetensor o binario
87
+ if model_files[0].endswith("model.safetensors"):
88
+ print("[INFO] Cargando el modelo como safetensor...")
89
+ model = AutoModelForCausalLM.from_config(config_data)
90
+ model.load_state_dict(safetensors.torch.load_stream(model_streams[0])) # Cargar el modelo utilizando safetensors
91
+ else:
92
+ print("[INFO] Cargando el modelo como archivo binario de PyTorch...")
93
+ model = AutoModelForCausalLM.from_config(config_data)
94
+ model.load_state_dict(torch.load(model_streams[0], map_location="cpu")) # Cargar el modelo utilizando pytorch
95
+
96
+ print("[INFO] Modelo cargado con 茅xito.")
97
+ return model
98
 
99
+ except Exception as e:
100
+ print(f"[ERROR] Error al cargar el modelo desde S3: {e}")
101
+ raise HTTPException(status_code=500, detail="Error al cargar el modelo desde S3.")
102
 
103
+ def load_tokenizer_from_stream(self, model_prefix):
104
+ try:
105
+ print(f"[INFO] Cargando el tokenizer {model_prefix} desde S3...")
106
+ tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")
107
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)
108
+ return tokenizer
109
+ except Exception as e:
110
+ print(f"[ERROR] Error al cargar el tokenizer desde S3: {e}")
111
+ raise HTTPException(status_code=500, detail="Error al cargar el tokenizer desde S3.")
112
+
113
+ def get_model_file_parts(self, model_prefix):
114
+ print(f"[INFO] Listando archivos del modelo en S3 con prefijo {model_prefix}...")
115
+ files = self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=model_prefix)
116
+ model_files = []
117
+ for obj in tqdm(files.get('Contents', []), desc="Verificando archivos", unit="archivo"):
118
+ key = obj['Key']
119
+ if re.match(rf"{model_prefix}/model(-\d+-of-\d+)?", key) or key.endswith("model.safetensors"):
120
+ model_files.append(key)
121
+ if not model_files:
122
+ print(f"[WARNING] No se encontraron archivos coincidentes con el patr贸n para el modelo {model_prefix}.")
123
+ return model_files
124
+
125
+ def download_and_upload_to_s3_url(self, url: str, s3_key: str):
126
+ try:
127
+ print(f"[INFO] Descargando archivo desde {url}...")
128
+ response = requests.get(url)
129
+ if response.status_code == 200:
130
+ print(f"[INFO] Subiendo archivo a S3 con key {s3_key}...")
131
+ self.s3_client.put_object(Bucket=self.bucket_name, Key=s3_key, Body=response.content)
132
+ # Eliminar el archivo local despu茅s de la carga exitosa
133
+ self.delete_local_file(s3_key)
134
+ else:
135
+ print(f"[ERROR] Error al descargar el archivo desde {url}, c贸digo de estado {response.status_code}.")
136
+ raise HTTPException(status_code=500, detail=f"Error al descargar el archivo desde {url}")
137
+ except Exception as e:
138
+ print(f"[ERROR] Error al procesar la URL {url}: {str(e)}")
139
+ raise HTTPException(status_code=500, detail=f"Error al procesar la URL {url}")
140
 
141
+ def delete_local_file(self, file_path: str):
142
+ try:
143
+ print(f"[INFO] Eliminando archivo local {file_path}...")
144
+ if os.path.exists(file_path):
145
+ os.remove(file_path)
146
+ print(f"[INFO] Archivo local {file_path} eliminado correctamente.")
147
+ else:
148
+ print(f"[WARNING] El archivo local {file_path} no existe.")
149
  except Exception as e:
150
+ print(f"[ERROR] Error al eliminar el archivo local {file_path}: {str(e)}")
151
+
152
 
 
153
  @app.post("/predict/")
154
  async def predict(model_request: DownloadModelRequest):
155
  try:
156
+ print(f"[INFO] Recibiendo solicitud para predecir con el modelo {model_request.model_name}...")
157
  streamer = S3DirectStream(S3_BUCKET_NAME)
158
+ model = streamer.load_model_from_stream(model_request.model_name)
159
+ tokenizer = streamer.load_tokenizer_from_stream(model_request.model_name)
160
 
161
+ task = model_request.pipeline_task
162
+ if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "text-to-speech", "text-to-image", "text-to-audio", "text-to-video"]:
163
+ raise HTTPException(status_code=400, detail="Pipeline task no soportado")
 
 
 
164
 
165
+ nlp_pipeline = pipeline(task, model=model, tokenizer=tokenizer, max_length=2046)
166
 
167
+ input_text = model_request.input_text
168
+ print(f"[INFO] Ejecutando tarea {task} con el texto de entrada...")
169
+ outputs = nlp_pipeline(input_text)
170
 
171
+ # Eliminaci贸n de archivo local despu茅s de subir a S3
172
+ if task == "text-to-speech":
173
+ s3_key = f"{model_request.model_name}/generated_audio.wav"
174
+ return StreamingResponse(streamer.stream_from_s3(s3_key), media_type="audio/wav")
175
+ elif task == "text-to-image":
176
+ s3_key = f"{model_request.model_name}/generated_image.png"
177
+ return StreamingResponse(streamer.stream_from_s3(s3_key), media_type="image/png")
178
+ elif task == "text-to-video":
179
+ s3_key = f"{model_request.model_name}/generated_video.mp4"
180
+ return StreamingResponse(streamer.stream_from_s3(s3_key), media_type="video/mp4")
181
 
182
+ return {"input_text": input_text, "output": outputs}
183
 
184
  except Exception as e:
185
+ print(f"[ERROR] Error en la predicci贸n: {str(e)}")
186
+ raise HTTPException(status_code=500, detail="Error al realizar la predicci贸n.")
187
 
188
+ # Iniciar servidor de predicciones
189
  if __name__ == "__main__":
190
+ print("Iniciando servidor de predicciones en localhost:8000")
191
+ uvicorn.run(app, host="0.0.0.0", port=8000)