Spaces:
Sleeping
Sleeping
File size: 6,510 Bytes
410390c 227ec7b 0e63678 410390c 0e63678 00a3421 227ec7b 0e63678 00a3421 0e63678 410390c 0e63678 410390c 0e63678 410390c 0e63678 410390c 0e63678 410390c 0e63678 410390c 0e63678 d44fda2 227ec7b d44fda2 00a3421 d44fda2 299d616 00a3421 299d616 d44fda2 0e63678 d44fda2 410390c d44fda2 0e63678 410390c 227ec7b d44fda2 00a3421 0e63678 d44fda2 0e63678 d44fda2 00a3421 0e63678 26237b6 0e63678 227ec7b 00a3421 0e63678 00a3421 26237b6 d44fda2 26237b6 227ec7b d44fda2 0e63678 227ec7b c32fab0 410390c 0e63678 410390c 0e63678 c32fab0 410390c 0e63678 410390c 0e63678 d44fda2 410390c 0e63678 410390c 0e63678 227ec7b 0e63678 410390c 0e63678 2f5a890 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import os
import json
import logging
import boto3
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from huggingface_hub import hf_hub_download
from tqdm import tqdm
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
AWS_REGION = os.getenv("AWS_REGION")
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
s3_client = boto3.client(
's3',
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
region_name=AWS_REGION
)
app = FastAPI()
PIPELINE_MAP = {
"text-generation": "text-generation",
"sentiment-analysis": "sentiment-analysis",
"translation": "translation",
"fill-mask": "fill-mask",
"question-answering": "question-answering",
"text-to-speech": "text-to-speech",
"text-to-video": "text-to-video",
"text-to-image": "text-to-image"
}
class S3DirectStream:
def __init__(self, bucket_name):
self.s3_client = boto3.client(
's3',
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
region_name=AWS_REGION
)
self.bucket_name = bucket_name
def stream_from_s3(self, key):
try:
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
return response['Body']
except self.s3_client.exceptions.NoSuchKey:
raise HTTPException(status_code=404, detail=f"El archivo {key} no existe en el bucket S3.")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error al descargar {key} desde S3: {str(e)}")
def get_model_file_parts(self, model_name):
try:
model_prefix = model_name.lower()
files = self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=model_prefix)
model_files = [obj['Key'] for obj in files.get('Contents', []) if model_prefix in obj['Key']]
return model_files
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error al obtener archivos del modelo {model_name} desde S3: {e}")
def load_model_from_s3(self, model_name):
try:
model_prefix = model_name.lower()
model_files = self.get_model_file_parts(model_prefix)
if not model_files:
self.download_and_upload_from_huggingface(model_name)
model_files = self.get_model_file_parts(model_prefix)
if not model_files:
raise HTTPException(status_code=404, detail=f"Archivos del modelo {model_name} no encontrados en S3.")
config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
config_data = config_stream.read()
if not config_data:
raise HTTPException(status_code=500, detail=f"El archivo de configuración {model_prefix}/config.json está vacío.")
config_text = config_data.decode("utf-8")
config_json = json.loads(config_text)
model = AutoModelForCausalLM.from_pretrained(f"s3://{self.bucket_name}/{model_prefix}", config=config_json, from_tf=False)
return model
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error al cargar el modelo desde S3: {e}")
def load_tokenizer_from_s3(self, model_name):
try:
tokenizer_stream = self.stream_from_s3(f"{model_name}/tokenizer.json")
tokenizer_data = tokenizer_stream.read().decode("utf-8")
tokenizer = AutoTokenizer.from_pretrained(f"s3://{self.bucket_name}/{model_name}")
return tokenizer
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error al cargar el tokenizer desde S3: {e}")
def download_and_upload_from_huggingface(self, model_name):
try:
files_to_download = hf_hub_download(repo_id=model_name, use_auth_token=HUGGINGFACE_TOKEN, local_dir=model_name)
for file in tqdm(files_to_download, desc="Subiendo archivos a S3"):
file_name = os.path.basename(file)
s3_key = f"{model_name}/{file_name}"
if not self.file_exists_in_s3(s3_key):
self.upload_file_to_s3(file, s3_key)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error al descargar y subir modelo desde Hugging Face: {e}")
def upload_file_to_s3(self, file_path, s3_key):
try:
with open(file_path, 'rb') as data:
self.s3_client.put_object(Bucket=self.bucket_name, Key=s3_key, Body=data)
os.remove(file_path)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error al subir archivo a S3: {e}")
def file_exists_in_s3(self, s3_key):
try:
self.s3_client.head_object(Bucket=self.bucket_name, Key=s3_key)
return True
except self.s3_client.exceptions.ClientError:
return False
@app.post("/predict/")
async def predict(model_request: dict):
try:
model_name = model_request.get("model_name")
task = model_request.get("pipeline_task")
input_text = model_request.get("input_text")
if not model_name or not task or not input_text:
raise HTTPException(status_code=400, detail="Faltan parámetros en la solicitud.")
streamer = S3DirectStream(S3_BUCKET_NAME)
model = streamer.load_model_from_s3(model_name)
tokenizer = streamer.load_tokenizer_from_s3(model_name)
if task not in PIPELINE_MAP:
raise HTTPException(status_code=400, detail="Pipeline task no soportado")
nlp_pipeline = pipeline(PIPELINE_MAP[task], model=model, tokenizer=tokenizer)
result = nlp_pipeline(input_text)
if isinstance(result, dict) and 'file' in result:
return JSONResponse(content={"file": result['file']})
else:
return JSONResponse(content={"result": result})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error al realizar la predicción: {e}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|