import os import shutil import boto3 from fastapi import FastAPI, HTTPException from pydantic import BaseModel import requests from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer import safetensors.torch from fastapi.responses import StreamingResponse import io from tqdm import tqdm import re import torch import uvicorn # Cargar las variables de entorno desde el archivo .env from dotenv import load_dotenv load_dotenv() # Cargar las credenciales de AWS desde las variables de entorno AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID") AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY") AWS_REGION = os.getenv("AWS_REGION") S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME") # Nombre del bucket de S3 HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") # Token de Hugging Face # Cliente S3 de Amazon s3_client = boto3.client( 's3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION ) app = FastAPI() # Pydantic Model para el cuerpo de la solicitud del endpoint /predict/ class DownloadModelRequest(BaseModel): model_name: str pipeline_task: str input_text: str class S3DirectStream: def __init__(self, bucket_name): self.s3_client = boto3.client( 's3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION ) self.bucket_name = bucket_name def stream_from_s3(self, key): try: print(f"[INFO] Descargando archivo {key} desde S3...") response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key) return response['Body'] # Devolver el cuerpo directamente para el StreamingResponse except self.s3_client.exceptions.NoSuchKey: raise HTTPException(status_code=404, detail=f"El archivo {key} no existe en el bucket S3.") except Exception as e: print(f"[ERROR] Error al descargar {key}: {str(e)}") raise HTTPException(status_code=500, detail=f"Error al descargar archivo {key} desde S3.") def file_exists_in_s3(self, key): try: self.s3_client.head_object(Bucket=self.bucket_name, Key=key) return True except self.s3_client.exceptions.ClientError: return False def load_model_from_stream(self, model_prefix): try: print(f"[INFO] Cargando el modelo {model_prefix} desde S3...") model_files = self.get_model_file_parts(model_prefix) if not model_files: model_files = [f"{model_prefix}/model"] # Uso de modelo base # Leer y cargar todos los archivos del modelo model_streams = [] for model_file in tqdm(model_files, desc="Cargando archivos del modelo", unit="archivo"): model_streams.append(self.stream_from_s3(model_file)) # Verificar si el archivo es un safetensor o un archivo binario config_stream = self.stream_from_s3(f"{model_prefix}/config.json") config_data = config_stream.read().decode("utf-8") # Cargar el modelo dependiendo de si es safetensor o binario if model_files[0].endswith("model.safetensors"): print("[INFO] Cargando el modelo como safetensor...") model = AutoModelForCausalLM.from_config(config_data) model.load_state_dict(safetensors.torch.load_stream(model_streams[0])) # Cargar el modelo utilizando safetensors else: print("[INFO] Cargando el modelo como archivo binario de PyTorch...") model = AutoModelForCausalLM.from_config(config_data) model.load_state_dict(torch.load(model_streams[0], map_location="cpu")) # Cargar el modelo utilizando pytorch print("[INFO] Modelo cargado con éxito.") return model except Exception as e: print(f"[ERROR] Error al cargar el modelo desde S3: {e}") raise HTTPException(status_code=500, detail="Error al cargar el modelo desde S3.") def load_tokenizer_from_stream(self, model_prefix): try: print(f"[INFO] Cargando el tokenizer {model_prefix} desde S3...") tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json") tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream) return tokenizer except Exception as e: print(f"[ERROR] Error al cargar el tokenizer desde S3: {e}") raise HTTPException(status_code=500, detail="Error al cargar el tokenizer desde S3.") def get_model_file_parts(self, model_prefix): print(f"[INFO] Listando archivos del modelo en S3 con prefijo {model_prefix}...") files = self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=model_prefix) model_files = [] for obj in tqdm(files.get('Contents', []), desc="Verificando archivos", unit="archivo"): key = obj['Key'] if re.match(rf"{model_prefix}/model(-\d+-of-\d+)?", key) or key.endswith("model.safetensors"): model_files.append(key) if not model_files: print(f"[WARNING] No se encontraron archivos coincidentes con el patrón para el modelo {model_prefix}.") return model_files def download_and_upload_to_s3_url(self, url: str, s3_key: str): try: print(f"[INFO] Descargando archivo desde {url}...") response = requests.get(url) if response.status_code == 200: print(f"[INFO] Subiendo archivo a S3 con key {s3_key}...") self.s3_client.put_object(Bucket=self.bucket_name, Key=s3_key, Body=response.content) # Eliminar el archivo local después de la carga exitosa self.delete_local_file(s3_key) else: print(f"[ERROR] Error al descargar el archivo desde {url}, código de estado {response.status_code}.") raise HTTPException(status_code=500, detail=f"Error al descargar el archivo desde {url}") except Exception as e: print(f"[ERROR] Error al procesar la URL {url}: {str(e)}") raise HTTPException(status_code=500, detail=f"Error al procesar la URL {url}") def delete_local_file(self, file_path: str): try: print(f"[INFO] Eliminando archivo local {file_path}...") if os.path.exists(file_path): os.remove(file_path) print(f"[INFO] Archivo local {file_path} eliminado correctamente.") else: print(f"[WARNING] El archivo local {file_path} no existe.") except Exception as e: print(f"[ERROR] Error al eliminar el archivo local {file_path}: {str(e)}") @app.post("/predict/") async def predict(model_request: DownloadModelRequest): try: print(f"[INFO] Recibiendo solicitud para predecir con el modelo {model_request.model_name}...") streamer = S3DirectStream(S3_BUCKET_NAME) model = streamer.load_model_from_stream(model_request.model_name) tokenizer = streamer.load_tokenizer_from_stream(model_request.model_name) task = model_request.pipeline_task if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "text-to-speech", "text-to-image", "text-to-audio", "text-to-video"]: raise HTTPException(status_code=400, detail="Pipeline task no soportado") nlp_pipeline = pipeline(task, model=model, tokenizer=tokenizer, max_length=2046) input_text = model_request.input_text print(f"[INFO] Ejecutando tarea {task} con el texto de entrada...") outputs = nlp_pipeline(input_text) # Eliminación de archivo local después de subir a S3 if task == "text-to-speech": s3_key = f"{model_request.model_name}/generated_audio.wav" return StreamingResponse(streamer.stream_from_s3(s3_key), media_type="audio/wav") elif task == "text-to-image": s3_key = f"{model_request.model_name}/generated_image.png" return StreamingResponse(streamer.stream_from_s3(s3_key), media_type="image/png") elif task == "text-to-video": s3_key = f"{model_request.model_name}/generated_video.mp4" return StreamingResponse(streamer.stream_from_s3(s3_key), media_type="video/mp4") return {"input_text": input_text, "output": outputs} except Exception as e: print(f"[ERROR] Error al procesar la solicitud de predicción: {str(e)}") raise HTTPException(status_code=500, detail=f"Error interno: {str(e)}") if __name__ == "__main__": print("Iniciando el servidor FastAPI...") uvicorn.run(app, host="0.0.0.0", port=7860)