import os import boto3 from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer import safetensors.torch from fastapi.responses import StreamingResponse from dotenv import load_dotenv import requests import torch import uvicorn import re from tqdm import tqdm # Cargar las variables de entorno desde el archivo .env load_dotenv() # Configuración AWS y Hugging Face AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID") AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY") AWS_REGION = os.getenv("AWS_REGION") S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME") HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") # Cliente de Amazon S3 s3_client = boto3.client( 's3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION ) app = FastAPI() # Modelos Pydantic class DownloadModelRequest(BaseModel): model_name: str = Field(..., example="model_directory_name") pipeline_task: str = Field(..., example="text-generation") input_text: str = Field(..., example="Introduce your input text here.") # Clase para interacción con S3 class S3DirectStream: def __init__(self, bucket_name): self.s3_client = boto3.client( 's3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION ) self.bucket_name = bucket_name def stream_from_s3(self, key): try: print(f"[INFO] Descargando archivo {key} desde S3...") response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key) return response['Body'] except self.s3_client.exceptions.NoSuchKey: raise HTTPException(status_code=404, detail=f"El archivo {key} no existe en el bucket S3.") except Exception as e: print(f"[ERROR] Error al descargar {key}: {str(e)}") raise HTTPException(status_code=500, detail="Error al descargar archivo desde S3.") def load_model_and_tokenizer(self, model_prefix): try: print(f"[INFO] Cargando modelo y tokenizer desde S3 para {model_prefix}...") model_stream = self.stream_from_s3(f"{model_prefix}/model.safetensors") config_stream = self.stream_from_s3(f"{model_prefix}/config.json") tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json") # Cargar configuración del modelo config_data = config_stream.read().decode("utf-8") # Cargar modelo model = AutoModelForCausalLM.from_config(config_data) model.load_state_dict(safetensors.torch.load_stream(model_stream)) # Cargar tokenizer tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream) print("[INFO] Modelo y tokenizer cargados con éxito.") return model, tokenizer except Exception as e: print(f"[ERROR] Error al cargar modelo/tokenizer desde S3: {e}") raise HTTPException(status_code=500, detail="Error al cargar modelo/tokenizer.") # Endpoint para predicciones @app.post("/predict/") async def predict(model_request: DownloadModelRequest): try: print(f"[INFO] Procesando solicitud para el modelo {model_request.model_name}...") streamer = S3DirectStream(S3_BUCKET_NAME) model, tokenizer = streamer.load_model_and_tokenizer(model_request.model_name) if model_request.pipeline_task not in [ "text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "text-to-speech", "text-to-image", "text-to-video" ]: raise HTTPException(status_code=400, detail="Pipeline task no soportado.") nlp_pipeline = pipeline(model_request.pipeline_task, model=model, tokenizer=tokenizer, max_length=2046) outputs = nlp_pipeline(model_request.input_text) # Responder según la tarea if model_request.pipeline_task in ["text-to-speech", "text-to-image", "text-to-video"]: media_type_map = { "text-to-speech": "audio/wav", "text-to-image": "image/png", "text-to-video": "video/mp4" } s3_key = f"{model_request.model_name}/generated_output" return StreamingResponse(streamer.stream_from_s3(s3_key), media_type=media_type_map[model_request.pipeline_task]) return {"input_text": model_request.input_text, "output": outputs} except Exception as e: print(f"[ERROR] Error al procesar la solicitud: {e}") raise HTTPException(status_code=500, detail=f"Error interno: {e}") # Punto de entrada principal if __name__ == "__main__": print("[INFO] Iniciando el servidor FastAPI...") uvicorn.run(app, host="0.0.0.0", port=9000)