from fastapi import FastAPI, HTTPException from pydantic import BaseModel import requests import boto3 from dotenv import load_dotenv import os import uvicorn from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer import torch import safetensors.torch from fastapi.responses import StreamingResponse from tqdm import tqdm import re load_dotenv() AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID") AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY") AWS_REGION = os.getenv("AWS_REGION") S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME") HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") s3_client = boto3.client( 's3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION ) app = FastAPI() class DownloadModelRequest(BaseModel): model_name: str pipeline_task: str input_text: str class S3DirectStream: def __init__(self, bucket_name): self.s3_client = boto3.client( 's3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION ) self.bucket_name = bucket_name def stream_from_s3(self, key): try: response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key) return response['Body'] except self.s3_client.exceptions.NoSuchKey: raise HTTPException(status_code=404, detail=f"El archivo {key} no existe en el bucket S3.") except Exception as e: raise HTTPException(status_code=500, detail=f"Error al descargar de S3: {e}") def file_exists_in_s3(self, key): try: self.s3_client.head_object(Bucket=self.bucket_name, Key=key) return True except self.s3_client.exceptions.ClientError as e: if e.response['Error']['Code'] == '404': return False raise HTTPException(status_code=500, detail=f"Error al verificar archivo en S3: {e}") def load_model_from_stream(self, model_prefix): try: model_files = self.list_model_files(model_prefix) if not model_files: self.download_and_upload_to_s3(model_prefix) return self.load_model_from_stream(model_prefix) config_stream = self.stream_from_s3(f"{model_prefix}/config.json") config_data = config_stream.read().decode("utf-8") model_path = f"{model_prefix}/model.safetensors" if self.file_exists_in_s3(model_path): model_stream = self.stream_from_s3(model_path) model = AutoModelForCausalLM.from_config(config_data) model.load_state_dict(safetensors.torch.load_stream(model_stream)) elif model_files: model = AutoModelForCausalLM.from_config(config_data) state_dict = {} for file_name in model_files: file_stream = self.stream_from_s3(f"{model_prefix}/{file_name}") tmp = torch.load(file_stream, map_location="cpu") state_dict.update(tmp) model.load_state_dict(state_dict) else: raise HTTPException(status_code=500, detail="Modelo no encontrado") return model except HTTPException as e: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Error al cargar el modelo: {e}") def list_model_files(self, model_prefix): try: response = self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=f"{model_prefix}/pytorch_model-") model_files = [] if 'Contents' in response: for obj in response['Contents']: if re.match(r"pytorch_model-\d+-of-\d+", obj['Key'].split('/')[-1]): model_files.append(obj['Key'].split('/')[-1]) return model_files except Exception as e: return None def load_tokenizer_from_stream(self, model_prefix): try: tokenizer_path = f"{model_prefix}/tokenizer.json" if self.file_exists_in_s3(tokenizer_path): tokenizer_stream = self.stream_from_s3(tokenizer_path) tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream) return tokenizer else: self.download_and_upload_to_s3(model_prefix) return self.load_tokenizer_from_stream(model_prefix) except HTTPException as e: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Error al cargar el tokenizer: {e}") def download_and_upload_to_s3(self, model_prefix): urls = { "pytorch_model.bin": f"https://huggingface.co/{model_prefix}/resolve/main/pytorch_model.bin", "model.safetensors": f"https://huggingface.co/{model_prefix}/resolve/main/model.safetensors", "tokenizer.json": f"https://huggingface.co/{model_prefix}/resolve/main/tokenizer.json", "config.json": f"https://huggingface.co/{model_prefix}/resolve/main/config.json" } for filename, url in urls.items(): try: response = requests.get(url, stream=True) response.raise_for_status() self.s3_client.upload_fileobj(response.raw, self.bucket_name, f"{model_prefix}/{filename}") except requests.exceptions.RequestException as e: raise HTTPException(status_code=500, detail=f"Error al descargar {filename}: {e}") except Exception as e: raise HTTPException(status_code=500, detail=f"Error al subir {filename} a S3: {e}") @app.post("/predict/") async def predict(model_request: DownloadModelRequest): try: streamer = S3DirectStream(S3_BUCKET_NAME) model = streamer.load_model_from_stream(model_request.model_name) tokenizer = streamer.load_tokenizer_from_stream(model_request.model_name) task = model_request.pipeline_task if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "text-to-speech", "text-to-image", "text-to-audio", "text-to-video"]: raise HTTPException(status_code=400, detail="Pipeline task no soportado") nlp_pipeline = pipeline(task, model=model, tokenizer=tokenizer) input_text = model_request.input_text outputs = nlp_pipeline(input_text) if task in ["text-generation", "translation", "fill-mask", "sentiment-analysis", "question-answering"]: return {"response": outputs} elif task == "text-to-image": s3_key = f"{model_request.model_name}/generated_image.png" return StreamingResponse(streamer.stream_from_s3(s3_key), media_type="image/png") elif task == "text-to-audio": s3_key = f"{model_request.model_name}/generated_audio.wav" return StreamingResponse(streamer.stream_from_s3(s3_key), media_type="audio/wav") elif task == "text-to-video": s3_key = f"{model_request.model_name}/generated_video.mp4" return StreamingResponse(streamer.stream_from_s3(s3_key), media_type="video/mp4") else: raise HTTPException(status_code=400, detail="Tipo de tarea desconocido") except HTTPException as e: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Error inesperado: {str(e)}") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)