File size: 4,941 Bytes
410390c
 
 
a91fbbb
410390c
 
 
a91fbbb
 
410390c
 
a91fbbb
 
410390c
 
 
 
a91fbbb
410390c
 
 
a91fbbb
 
410390c
a91fbbb
410390c
 
 
 
 
 
 
 
 
a91fbbb
410390c
a91fbbb
 
 
410390c
a91fbbb
410390c
 
 
 
 
 
 
 
 
 
 
 
 
 
a91fbbb
410390c
 
 
 
a91fbbb
410390c
a91fbbb
410390c
a91fbbb
 
410390c
a91fbbb
 
 
410390c
 
a91fbbb
 
 
410390c
a91fbbb
410390c
 
a91fbbb
 
410390c
a91fbbb
 
410390c
a91fbbb
410390c
 
 
a91fbbb
410390c
a91fbbb
410390c
a91fbbb
 
 
 
 
 
410390c
a91fbbb
410390c
a91fbbb
410390c
a91fbbb
 
 
 
 
 
 
 
 
410390c
a91fbbb
410390c
 
a91fbbb
 
410390c
a91fbbb
410390c
a91fbbb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import boto3
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import safetensors.torch
from fastapi.responses import StreamingResponse
from dotenv import load_dotenv
import requests
import torch
import uvicorn
import re
from tqdm import tqdm

# Cargar las variables de entorno desde el archivo .env
load_dotenv()

# Configuraci贸n AWS y Hugging Face
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
AWS_REGION = os.getenv("AWS_REGION")
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")

# Cliente de Amazon S3
s3_client = boto3.client(
    's3',
    aws_access_key_id=AWS_ACCESS_KEY_ID,
    aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
    region_name=AWS_REGION
)

app = FastAPI()

# Modelos Pydantic
class DownloadModelRequest(BaseModel):
    model_name: str = Field(..., example="model_directory_name")
    pipeline_task: str = Field(..., example="text-generation")
    input_text: str = Field(..., example="Introduce your input text here.")

# Clase para interacci贸n con S3
class S3DirectStream:
    def __init__(self, bucket_name):
        self.s3_client = boto3.client(
            's3',
            aws_access_key_id=AWS_ACCESS_KEY_ID,
            aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
            region_name=AWS_REGION
        )
        self.bucket_name = bucket_name

    def stream_from_s3(self, key):
        try:
            print(f"[INFO] Descargando archivo {key} desde S3...")
            response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
            return response['Body']
        except self.s3_client.exceptions.NoSuchKey:
            raise HTTPException(status_code=404, detail=f"El archivo {key} no existe en el bucket S3.")
        except Exception as e:
            print(f"[ERROR] Error al descargar {key}: {str(e)}")
            raise HTTPException(status_code=500, detail="Error al descargar archivo desde S3.")

    def load_model_and_tokenizer(self, model_prefix):
        try:
            print(f"[INFO] Cargando modelo y tokenizer desde S3 para {model_prefix}...")
            model_stream = self.stream_from_s3(f"{model_prefix}/model.safetensors")
            config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
            tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")

            # Cargar configuraci贸n del modelo
            config_data = config_stream.read().decode("utf-8")

            # Cargar modelo
            model = AutoModelForCausalLM.from_config(config_data)
            model.load_state_dict(safetensors.torch.load_stream(model_stream))

            # Cargar tokenizer
            tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)

            print("[INFO] Modelo y tokenizer cargados con 茅xito.")
            return model, tokenizer
        except Exception as e:
            print(f"[ERROR] Error al cargar modelo/tokenizer desde S3: {e}")
            raise HTTPException(status_code=500, detail="Error al cargar modelo/tokenizer.")

# Endpoint para predicciones
@app.post("/predict/")
async def predict(model_request: DownloadModelRequest):
    try:
        print(f"[INFO] Procesando solicitud para el modelo {model_request.model_name}...")
        streamer = S3DirectStream(S3_BUCKET_NAME)
        model, tokenizer = streamer.load_model_and_tokenizer(model_request.model_name)

        if model_request.pipeline_task not in [
            "text-generation", "sentiment-analysis", "translation", 
            "fill-mask", "question-answering", "text-to-speech", 
            "text-to-image", "text-to-video"
        ]:
            raise HTTPException(status_code=400, detail="Pipeline task no soportado.")

        nlp_pipeline = pipeline(model_request.pipeline_task, model=model, tokenizer=tokenizer, max_length=2046)

        outputs = nlp_pipeline(model_request.input_text)

        # Responder seg煤n la tarea
        if model_request.pipeline_task in ["text-to-speech", "text-to-image", "text-to-video"]:
            media_type_map = {
                "text-to-speech": "audio/wav",
                "text-to-image": "image/png",
                "text-to-video": "video/mp4"
            }
            s3_key = f"{model_request.model_name}/generated_output"
            return StreamingResponse(streamer.stream_from_s3(s3_key), media_type=media_type_map[model_request.pipeline_task])

        return {"input_text": model_request.input_text, "output": outputs}

    except Exception as e:
        print(f"[ERROR] Error al procesar la solicitud: {e}")
        raise HTTPException(status_code=500, detail=f"Error interno: {e}")

# Punto de entrada principal
if __name__ == "__main__":
    print("[INFO] Iniciando el servidor FastAPI...")
    uvicorn.run(app, host="0.0.0.0", port=9000)