Spaces:
Sleeping
Sleeping
import os | |
import boto3 | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel, Field | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
import safetensors.torch | |
from fastapi.responses import StreamingResponse | |
from dotenv import load_dotenv | |
import requests | |
import torch | |
import uvicorn | |
import re | |
from tqdm import tqdm | |
# Cargar las variables de entorno desde el archivo .env | |
load_dotenv() | |
# Configuraci贸n AWS y Hugging Face | |
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID") | |
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY") | |
AWS_REGION = os.getenv("AWS_REGION") | |
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME") | |
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") | |
# Cliente de Amazon S3 | |
s3_client = boto3.client( | |
's3', | |
aws_access_key_id=AWS_ACCESS_KEY_ID, | |
aws_secret_access_key=AWS_SECRET_ACCESS_KEY, | |
region_name=AWS_REGION | |
) | |
app = FastAPI() | |
# Modelos Pydantic | |
class DownloadModelRequest(BaseModel): | |
model_name: str = Field(..., example="model_directory_name") | |
pipeline_task: str = Field(..., example="text-generation") | |
input_text: str = Field(..., example="Introduce your input text here.") | |
# Clase para interacci贸n con S3 | |
class S3DirectStream: | |
def __init__(self, bucket_name): | |
self.s3_client = boto3.client( | |
's3', | |
aws_access_key_id=AWS_ACCESS_KEY_ID, | |
aws_secret_access_key=AWS_SECRET_ACCESS_KEY, | |
region_name=AWS_REGION | |
) | |
self.bucket_name = bucket_name | |
def stream_from_s3(self, key): | |
try: | |
print(f"[INFO] Descargando archivo {key} desde S3...") | |
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key) | |
return response['Body'] | |
except self.s3_client.exceptions.NoSuchKey: | |
raise HTTPException(status_code=404, detail=f"El archivo {key} no existe en el bucket S3.") | |
except Exception as e: | |
print(f"[ERROR] Error al descargar {key}: {str(e)}") | |
raise HTTPException(status_code=500, detail="Error al descargar archivo desde S3.") | |
def load_model_and_tokenizer(self, model_prefix): | |
try: | |
print(f"[INFO] Cargando modelo y tokenizer desde S3 para {model_prefix}...") | |
model_stream = self.stream_from_s3(f"{model_prefix}/model.safetensors") | |
config_stream = self.stream_from_s3(f"{model_prefix}/config.json") | |
tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json") | |
# Cargar configuraci贸n del modelo | |
config_data = config_stream.read().decode("utf-8") | |
# Cargar modelo | |
model = AutoModelForCausalLM.from_config(config_data) | |
model.load_state_dict(safetensors.torch.load_stream(model_stream)) | |
# Cargar tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream) | |
print("[INFO] Modelo y tokenizer cargados con 茅xito.") | |
return model, tokenizer | |
except Exception as e: | |
print(f"[ERROR] Error al cargar modelo/tokenizer desde S3: {e}") | |
raise HTTPException(status_code=500, detail="Error al cargar modelo/tokenizer.") | |
# Endpoint para predicciones | |
async def predict(model_request: DownloadModelRequest): | |
try: | |
print(f"[INFO] Procesando solicitud para el modelo {model_request.model_name}...") | |
streamer = S3DirectStream(S3_BUCKET_NAME) | |
model, tokenizer = streamer.load_model_and_tokenizer(model_request.model_name) | |
if model_request.pipeline_task not in [ | |
"text-generation", "sentiment-analysis", "translation", | |
"fill-mask", "question-answering", "text-to-speech", | |
"text-to-image", "text-to-video" | |
]: | |
raise HTTPException(status_code=400, detail="Pipeline task no soportado.") | |
nlp_pipeline = pipeline(model_request.pipeline_task, model=model, tokenizer=tokenizer, max_length=2046) | |
outputs = nlp_pipeline(model_request.input_text) | |
# Responder seg煤n la tarea | |
if model_request.pipeline_task in ["text-to-speech", "text-to-image", "text-to-video"]: | |
media_type_map = { | |
"text-to-speech": "audio/wav", | |
"text-to-image": "image/png", | |
"text-to-video": "video/mp4" | |
} | |
s3_key = f"{model_request.model_name}/generated_output" | |
return StreamingResponse(streamer.stream_from_s3(s3_key), media_type=media_type_map[model_request.pipeline_task]) | |
return {"input_text": model_request.input_text, "output": outputs} | |
except Exception as e: | |
print(f"[ERROR] Error al procesar la solicitud: {e}") | |
raise HTTPException(status_code=500, detail=f"Error interno: {e}") | |
# Punto de entrada principal | |
if __name__ == "__main__": | |
print("[INFO] Iniciando el servidor FastAPI...") | |
uvicorn.run(app, host="0.0.0.0", port=9000) | |