aws_test / app.py
Hjgugugjhuhjggg's picture
Update app.py
bb6ace2 verified
raw
history blame
7.64 kB
import os
import json
import logging
import boto3
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from huggingface_hub import hf_hub_download
from tqdm import tqdm
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
AWS_REGION = os.getenv("AWS_REGION")
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
s3_client = boto3.client(
's3',
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
region_name=AWS_REGION
)
app = FastAPI()
PIPELINE_MAP = {
"text-generation": "text-generation",
"sentiment-analysis": "sentiment-analysis",
"translation": "translation",
"fill-mask": "fill-mask",
"question-answering": "question-answering",
"text-to-speech": "text-to-speech",
"text-to-video": "text-to-video",
"text-to-image": "text-to-image"
}
class S3DirectStream:
def __init__(self, bucket_name):
self.s3_client = boto3.client(
's3',
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
region_name=AWS_REGION
)
self.bucket_name = bucket_name
def stream_from_s3(self, key):
try:
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
return response['Body']
except self.s3_client.exceptions.NoSuchKey:
raise HTTPException(status_code=404, detail=f"El archivo {key} no existe en el bucket S3.")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error al descargar {key} desde S3: {str(e)}")
def get_model_file_parts(self, model_name):
try:
model_prefix = model_name.lower()
files = self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=model_prefix)
model_files = [obj['Key'] for obj in files.get('Contents', []) if model_prefix in obj['Key']]
return model_files
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error al obtener archivos del modelo {model_name} desde S3: {e}")
def load_model_from_s3(self, model_name):
try:
profile, model = model_name.split("/", 1) if "/" in model_name else ("", model_name)
model_prefix = f"{profile}/{model}".lower()
model_files = self.get_model_file_parts(model_prefix)
if not model_files:
self.download_and_upload_from_huggingface(model_name)
model_files = self.get_model_file_parts(model_prefix)
if not model_files:
raise HTTPException(status_code=404, detail=f"Archivos del modelo {model_name} no encontrados en S3.")
config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
config_data = config_stream.read()
if not config_data:
raise HTTPException(status_code=500, detail=f"El archivo de configuraci贸n {model_prefix}/config.json est谩 vac铆o.")
config_text = config_data.decode("utf-8")
config_json = json.loads(config_text)
model = AutoModelForCausalLM.from_pretrained(f"s3://{self.bucket_name}/{model_prefix}", config=config_json, from_tf=False)
return model
except HTTPException as e:
raise e
except Exception as e:
try:
logger.error(f"Error al cargar el modelo desde S3, intentando desde Hugging Face: {e}")
model = AutoModelForCausalLM.from_pretrained(model_name)
return model
except Exception as hf_error:
raise HTTPException(status_code=500, detail=f"Error al cargar el modelo desde Hugging Face: {hf_error}")
def load_tokenizer_from_s3(self, model_name):
try:
profile, model = model_name.split("/", 1) if "/" in model_name else ("", model_name)
tokenizer_stream = self.stream_from_s3(f"{profile}/{model}/tokenizer.json")
tokenizer_data = tokenizer_stream.read().decode("utf-8")
tokenizer = AutoTokenizer.from_pretrained(f"{profile}/{model}")
return tokenizer
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error al cargar el tokenizer desde S3: {e}")
def download_and_upload_from_huggingface(self, model_name):
try:
files_to_download = hf_hub_download(repo_id=model_name, use_auth_token=HUGGINGFACE_TOKEN, local_dir=model_name)
for file in tqdm(files_to_download, desc="Subiendo archivos a S3"):
file_name = os.path.basename(file)
profile, model = model_name.split("/", 1) if "/" in model_name else ("", model_name)
s3_key = f"{profile}/{model}/{file_name}"
if not self.file_exists_in_s3(s3_key):
self.upload_file_to_s3(file, s3_key)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error al descargar y subir modelo desde Hugging Face: {e}")
def upload_file_to_s3(self, file_path, s3_key):
try:
self.create_s3_folders(s3_key)
s3_client.put_object(Bucket=self.bucket_name, Key=s3_key, Body=open(file_path, 'rb'))
os.remove(file_path)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error al subir archivo a S3: {e}")
def create_s3_folders(self, s3_key):
try:
folder_keys = s3_key.split('/')
for i in range(1, len(folder_keys)):
folder_key = '/'.join(folder_keys[:i]) + '/'
if not self.file_exists_in_s3(folder_key):
self.s3_client.put_object(Bucket=self.bucket_name, Key=folder_key, Body='')
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error al crear carpetas en S3: {e}")
def file_exists_in_s3(self, s3_key):
try:
self.s3_client.head_object(Bucket=self.bucket_name, Key=s3_key)
return True
except self.s3_client.exceptions.ClientError:
return False
@app.post("/predict/")
async def predict(model_request: dict):
try:
model_name = model_request.get("model_name")
task = model_request.get("pipeline_task")
input_text = model_request.get("input_text")
if not model_name or not task or not input_text:
raise HTTPException(status_code=400, detail="Faltan par谩metros en la solicitud.")
streamer = S3DirectStream(S3_BUCKET_NAME)
model = streamer.load_model_from_s3(model_name)
tokenizer = streamer.load_tokenizer_from_s3(model_name)
if task not in PIPELINE_MAP:
raise HTTPException(status_code=400, detail="Pipeline task no soportado")
nlp_pipeline = pipeline(PIPELINE_MAP[task], model=model, tokenizer=tokenizer)
result = nlp_pipeline(input_text)
if isinstance(result, dict) and 'file' in result:
return JSONResponse(content={"file": result['file']})
else:
return JSONResponse(content={"result": result})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error al realizar la predicci贸n: {e}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)