Hjgugugjhuhjggg commited on
Commit
410390c
verified
1 Parent(s): 29219a5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -0
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import boto3
4
+ from fastapi import FastAPI, HTTPException
5
+ from pydantic import BaseModel
6
+ import requests
7
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
8
+ import safetensors.torch
9
+ from fastapi.responses import StreamingResponse
10
+ import io
11
+ from tqdm import tqdm
12
+ import re
13
+ import torch
14
+ import uvicorn
15
+
16
+ # Cargar las variables de entorno desde el archivo .env
17
+ from dotenv import load_dotenv
18
+ load_dotenv()
19
+
20
+ # Cargar las credenciales de AWS desde las variables de entorno
21
+ AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
22
+ AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
23
+ AWS_REGION = os.getenv("AWS_REGION")
24
+ S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME") # Nombre del bucket de S3
25
+ HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") # Token de Hugging Face
26
+
27
+ # Cliente S3 de Amazon
28
+ s3_client = boto3.client(
29
+ 's3',
30
+ aws_access_key_id=AWS_ACCESS_KEY_ID,
31
+ aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
32
+ region_name=AWS_REGION
33
+ )
34
+
35
+ app = FastAPI()
36
+
37
+ # Pydantic Model para el cuerpo de la solicitud del endpoint /predict/
38
+ class DownloadModelRequest(BaseModel):
39
+ model_name: str
40
+ pipeline_task: str
41
+ input_text: str
42
+
43
+ class S3DirectStream:
44
+ def __init__(self, bucket_name):
45
+ self.s3_client = boto3.client(
46
+ 's3',
47
+ aws_access_key_id=AWS_ACCESS_KEY_ID,
48
+ aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
49
+ region_name=AWS_REGION
50
+ )
51
+ self.bucket_name = bucket_name
52
+
53
+ def stream_from_s3(self, key):
54
+ try:
55
+ print(f"[INFO] Descargando archivo {key} desde S3...")
56
+ response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
57
+ return response['Body'] # Devolver el cuerpo directamente para el StreamingResponse
58
+ except self.s3_client.exceptions.NoSuchKey:
59
+ raise HTTPException(status_code=404, detail=f"El archivo {key} no existe en el bucket S3.")
60
+ except Exception as e:
61
+ print(f"[ERROR] Error al descargar {key}: {str(e)}")
62
+ raise HTTPException(status_code=500, detail=f"Error al descargar archivo {key} desde S3.")
63
+
64
+ def file_exists_in_s3(self, key):
65
+ try:
66
+ self.s3_client.head_object(Bucket=self.bucket_name, Key=key)
67
+ return True
68
+ except self.s3_client.exceptions.ClientError:
69
+ return False
70
+
71
+ def load_model_from_stream(self, model_prefix):
72
+ try:
73
+ print(f"[INFO] Cargando el modelo {model_prefix} desde S3...")
74
+ model_files = self.get_model_file_parts(model_prefix)
75
+ if not model_files:
76
+ model_files = [f"{model_prefix}/model"] # Uso de modelo base
77
+
78
+ # Leer y cargar todos los archivos del modelo
79
+ model_streams = []
80
+ for model_file in tqdm(model_files, desc="Cargando archivos del modelo", unit="archivo"):
81
+ model_streams.append(self.stream_from_s3(model_file))
82
+
83
+ # Verificar si el archivo es un safetensor o un archivo binario
84
+ config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
85
+ config_data = config_stream.read().decode("utf-8")
86
+
87
+ # Cargar el modelo dependiendo de si es safetensor o binario
88
+ if model_files[0].endswith("model.safetensors"):
89
+ print("[INFO] Cargando el modelo como safetensor...")
90
+ model = AutoModelForCausalLM.from_config(config_data)
91
+ model.load_state_dict(safetensors.torch.load_stream(model_streams[0])) # Cargar el modelo utilizando safetensors
92
+ else:
93
+ print("[INFO] Cargando el modelo como archivo binario de PyTorch...")
94
+ model = AutoModelForCausalLM.from_config(config_data)
95
+ model.load_state_dict(torch.load(model_streams[0], map_location="cpu")) # Cargar el modelo utilizando pytorch
96
+
97
+ print("[INFO] Modelo cargado con 茅xito.")
98
+ return model
99
+
100
+ except Exception as e:
101
+ print(f"[ERROR] Error al cargar el modelo desde S3: {e}")
102
+ raise HTTPException(status_code=500, detail="Error al cargar el modelo desde S3.")
103
+
104
+ def load_tokenizer_from_stream(self, model_prefix):
105
+ try:
106
+ print(f"[INFO] Cargando el tokenizer {model_prefix} desde S3...")
107
+ tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")
108
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)
109
+ return tokenizer
110
+ except Exception as e:
111
+ print(f"[ERROR] Error al cargar el tokenizer desde S3: {e}")
112
+ raise HTTPException(status_code=500, detail="Error al cargar el tokenizer desde S3.")
113
+
114
+ def get_model_file_parts(self, model_prefix):
115
+ print(f"[INFO] Listando archivos del modelo en S3 con prefijo {model_prefix}...")
116
+ files = self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=model_prefix)
117
+ model_files = []
118
+ for obj in tqdm(files.get('Contents', []), desc="Verificando archivos", unit="archivo"):
119
+ key = obj['Key']
120
+ if re.match(rf"{model_prefix}/model(-\d+-of-\d+)?", key) or key.endswith("model.safetensors"):
121
+ model_files.append(key)
122
+ if not model_files:
123
+ print(f"[WARNING] No se encontraron archivos coincidentes con el patr贸n para el modelo {model_prefix}.")
124
+ return model_files
125
+
126
+ def download_and_upload_to_s3_url(self, url: str, s3_key: str):
127
+ try:
128
+ print(f"[INFO] Descargando archivo desde {url}...")
129
+ response = requests.get(url)
130
+ if response.status_code == 200:
131
+ print(f"[INFO] Subiendo archivo a S3 con key {s3_key}...")
132
+ self.s3_client.put_object(Bucket=self.bucket_name, Key=s3_key, Body=response.content)
133
+ # Eliminar el archivo local despu茅s de la carga exitosa
134
+ self.delete_local_file(s3_key)
135
+ else:
136
+ print(f"[ERROR] Error al descargar el archivo desde {url}, c贸digo de estado {response.status_code}.")
137
+ raise HTTPException(status_code=500, detail=f"Error al descargar el archivo desde {url}")
138
+ except Exception as e:
139
+ print(f"[ERROR] Error al procesar la URL {url}: {str(e)}")
140
+ raise HTTPException(status_code=500, detail=f"Error al procesar la URL {url}")
141
+
142
+ def delete_local_file(self, file_path: str):
143
+ try:
144
+ print(f"[INFO] Eliminando archivo local {file_path}...")
145
+ if os.path.exists(file_path):
146
+ os.remove(file_path)
147
+ print(f"[INFO] Archivo local {file_path} eliminado correctamente.")
148
+ else:
149
+ print(f"[WARNING] El archivo local {file_path} no existe.")
150
+ except Exception as e:
151
+ print(f"[ERROR] Error al eliminar el archivo local {file_path}: {str(e)}")
152
+
153
+
154
+ @app.post("/predict/")
155
+ async def predict(model_request: DownloadModelRequest):
156
+ try:
157
+ print(f"[INFO] Recibiendo solicitud para predecir con el modelo {model_request.model_name}...")
158
+ streamer = S3DirectStream(S3_BUCKET_NAME)
159
+ model = streamer.load_model_from_stream(model_request.model_name)
160
+ tokenizer = streamer.load_tokenizer_from_stream(model_request.model_name)
161
+
162
+ task = model_request.pipeline_task
163
+ if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "text-to-speech", "text-to-image", "text-to-audio", "text-to-video"]:
164
+ raise HTTPException(status_code=400, detail="Pipeline task no soportado")
165
+
166
+ nlp_pipeline = pipeline(task, model=model, tokenizer=tokenizer, max_length=2046)
167
+
168
+ input_text = model_request.input_text
169
+ print(f"[INFO] Ejecutando tarea {task} con el texto de entrada...")
170
+ outputs = nlp_pipeline(input_text)
171
+
172
+ # Eliminaci贸n de archivo local despu茅s de subir a S3
173
+ if task == "text-to-speech":
174
+ s3_key = f"{model_request.model_name}/generated_audio.wav"
175
+ return StreamingResponse(streamer.stream_from_s3(s3_key), media_type="audio/wav")
176
+ elif task == "text-to-image":
177
+ s3_key = f"{model_request.model_name}/generated_image.png"
178
+ return StreamingResponse(streamer.stream_from_s3(s3_key), media_type="image/png")
179
+ elif task == "text-to-video":
180
+ s3_key = f"{model_request.model_name}/generated_video.mp4"
181
+ return StreamingResponse(streamer.stream_from_s3(s3_key), media_type="video/mp4")
182
+
183
+ return {"input_text": input_text, "output": outputs}
184
+
185
+ except Exception as e:
186
+ print(f"[ERROR] Error al procesar la solicitud de predicci贸n: {str(e)}")
187
+ raise HTTPException(status_code=500, detail=f"Error interno: {str(e)}")
188
+
189
+
190
+ if __name__ == "__main__":
191
+ print("Iniciando el servidor FastAPI...")
192
+ uvicorn.run(app, host="0.0.0.0", port=9000)