Hjgugugjhuhjggg commited on
Commit
014edf2
verified
1 Parent(s): 5618c19

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +218 -0
app.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+ import boto3
5
+ from fastapi import FastAPI, HTTPException
6
+ from fastapi.responses import JSONResponse
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
8
+ from huggingface_hub import hf_hub_download
9
+ import asyncio
10
+
11
+ # Configuraci贸n del logger
12
+ logger = logging.getLogger(__name__)
13
+ logger.setLevel(logging.INFO)
14
+ console_handler = logging.StreamHandler()
15
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
16
+ console_handler.setFormatter(formatter)
17
+ logger.addHandler(console_handler)
18
+
19
+ AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
20
+ AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
21
+ AWS_REGION = os.getenv("AWS_REGION")
22
+ S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
23
+ HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
24
+
25
+ MAX_TOKENS = 1024
26
+
27
+ s3_client = boto3.client(
28
+ 's3',
29
+ aws_access_key_id=AWS_ACCESS_KEY_ID,
30
+ aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
31
+ region_name=AWS_REGION
32
+ )
33
+
34
+ app = FastAPI()
35
+
36
+ PIPELINE_MAP = {
37
+ "text-generation": "text-generation",
38
+ "sentiment-analysis": "sentiment-analysis",
39
+ "translation": "translation",
40
+ "fill-mask": "fill-mask",
41
+ "question-answering": "question-answering",
42
+ "text-to-speech": "text-to-speech",
43
+ "text-to-video": "text-to-video",
44
+ "text-to-image": "text-to-image"
45
+ }
46
+
47
+ class S3DirectStream:
48
+ def __init__(self, bucket_name):
49
+ self.s3_client = boto3.client(
50
+ 's3',
51
+ aws_access_key_id=AWS_ACCESS_KEY_ID,
52
+ aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
53
+ region_name=AWS_REGION
54
+ )
55
+ self.bucket_name = bucket_name
56
+
57
+ async def stream_from_s3(self, key):
58
+ loop = asyncio.get_event_loop()
59
+ return await loop.run_in_executor(None, self._stream_from_s3, key)
60
+
61
+ def _stream_from_s3(self, key):
62
+ try:
63
+ response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
64
+ return response['Body']
65
+ except self.s3_client.exceptions.NoSuchKey:
66
+ raise HTTPException(status_code=404, detail=f"El archivo {key} no existe en el bucket S3.")
67
+ except Exception as e:
68
+ raise HTTPException(status_code=500, detail=f"Error al descargar {key} desde S3: {str(e)}")
69
+
70
+ async def get_model_file_parts(self, model_name):
71
+ loop = asyncio.get_event_loop()
72
+ return await loop.run_in_executor(None, self._get_model_file_parts, model_name)
73
+
74
+ def _get_model_file_parts(self, model_name):
75
+ try:
76
+ model_prefix = model_name.lower()
77
+ files = self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=model_prefix)
78
+ model_files = [obj['Key'] for obj in files.get('Contents', []) if model_prefix in obj['Key']]
79
+ return model_files
80
+ except Exception as e:
81
+ raise HTTPException(status_code=500, detail=f"Error al obtener archivos del modelo {model_name} desde S3: {e}")
82
+
83
+ async def load_model_from_s3(self, model_name):
84
+ try:
85
+ profile, model = model_name.split("/", 1) if "/" in model_name else ("", model_name)
86
+
87
+ model_prefix = f"{profile}/{model}".lower()
88
+ model_files = await self.get_model_file_parts(model_prefix)
89
+
90
+ if not model_files:
91
+ await self.download_and_upload_to_s3(model_prefix, model)
92
+
93
+ config_stream = await self.stream_from_s3(f"{model_prefix}/config.json")
94
+ config_data = config_stream.read()
95
+
96
+ if not config_data:
97
+ raise HTTPException(status_code=500, detail=f"El archivo de configuraci贸n {model_prefix}/config.json est谩 vac铆o o no se pudo leer.")
98
+
99
+ config_text = config_data.decode("utf-8")
100
+ config_json = json.loads(config_text)
101
+
102
+ model = AutoModelForCausalLM.from_pretrained(f"s3://{self.bucket_name}/{model_prefix}", config=config_json, from_tf=False)
103
+ return model
104
+
105
+ except HTTPException as e:
106
+ raise e
107
+ except Exception as e:
108
+ raise HTTPException(status_code=500, detail=f"Error al cargar el modelo desde S3: {e}")
109
+
110
+ async def load_tokenizer_from_s3(self, model_name):
111
+ try:
112
+ profile, model = model_name.split("/", 1) if "/" in model_name else ("", model_name)
113
+
114
+ tokenizer_stream = await self.stream_from_s3(f"{profile}/{model}/tokenizer.json")
115
+ tokenizer_data = tokenizer_stream.read().decode("utf-8")
116
+
117
+ tokenizer = AutoTokenizer.from_pretrained(f"s3://{self.bucket_name}/{profile}/{model}")
118
+ return tokenizer
119
+ except Exception as e:
120
+ raise HTTPException(status_code=500, detail=f"Error al cargar el tokenizer desde S3: {e}")
121
+
122
+ async def create_s3_folders(self, s3_key):
123
+ try:
124
+ folder_keys = s3_key.split('/')
125
+ for i in range(1, len(folder_keys)):
126
+ folder_key = '/'.join(folder_keys[:i]) + '/'
127
+ if not await self.file_exists_in_s3(folder_key):
128
+ logger.info(f"Creando carpeta en S3: {folder_key}")
129
+ self.s3_client.put_object(Bucket=self.bucket_name, Key=folder_key, Body='')
130
+
131
+ except Exception as e:
132
+ raise HTTPException(status_code=500, detail=f"Error al crear carpetas en S3: {e}")
133
+
134
+ async def file_exists_in_s3(self, s3_key):
135
+ try:
136
+ self.s3_client.head_object(Bucket=self.bucket_name, Key=s3_key)
137
+ return True
138
+ except self.s3_client.exceptions.ClientError:
139
+ return False
140
+
141
+ async def download_and_upload_to_s3(self, model_prefix, model_name):
142
+ try:
143
+ config_file = hf_hub_download(repo_id=model_name, filename="config.json", token=HUGGINGFACE_HUB_TOKEN)
144
+ tokenizer_file = hf_hub_download(repo_id=model_name, filename="tokenizer.json", token=HUGGINGFACE_HUB_TOKEN)
145
+
146
+ if not await self.file_exists_in_s3(f"{model_prefix}/config.json"):
147
+ with open(config_file, "rb") as file:
148
+ self.s3_client.put_object(Bucket=self.bucket_name, Key=f"{model_prefix}/config.json", Body=file)
149
+
150
+ if not await self.file_exists_in_s3(f"{model_prefix}/tokenizer.json"):
151
+ with open(tokenizer_file, "rb") as file:
152
+ self.s3_client.put_object(Bucket=self.bucket_name, Key=f"{model_prefix}/tokenizer.json", Body=file)
153
+
154
+ except Exception as e:
155
+ raise HTTPException(status_code=500, detail=f"Error al descargar o cargar archivos desde Hugging Face a S3: {e}")
156
+
157
+ def split_text_by_tokens(text, tokenizer, max_tokens=MAX_TOKENS):
158
+ tokens = tokenizer.encode(text)
159
+ chunks = []
160
+ for i in range(0, len(tokens), max_tokens):
161
+ chunk = tokens[i:i+max_tokens]
162
+ chunks.append(tokenizer.decode(chunk))
163
+ return chunks
164
+
165
+ def continue_generation(input_text, model, tokenizer, max_tokens=MAX_TOKENS):
166
+ generated_text = ""
167
+ while len(input_text) > 0:
168
+ tokens = tokenizer.encode(input_text)
169
+ input_text = tokenizer.decode(tokens[:max_tokens])
170
+ output = model.generate(input_ids=tokenizer.encode(input_text, return_tensors="pt").input_ids)
171
+ generated_text += tokenizer.decode(output[0], skip_special_tokens=True)
172
+ input_text = input_text[len(input_text):]
173
+ return generated_text
174
+
175
+ @app.post("/predict/")
176
+ async def predict(model_request: dict):
177
+ try:
178
+ model_name = model_request.get("model_name")
179
+ task = model_request.get("pipeline_task")
180
+ input_text = model_request.get("input_text")
181
+
182
+ if not model_name or not task or not input_text:
183
+ raise HTTPException(status_code=400, detail="Faltan par谩metros en la solicitud.")
184
+
185
+ streamer = S3DirectStream(S3_BUCKET_NAME)
186
+
187
+ await streamer.create_s3_folders(model_name)
188
+
189
+ model = await streamer.load_model_from_s3(model_name)
190
+ tokenizer = await streamer.load_tokenizer_from_s3(model_name)
191
+
192
+ if task not in PIPELINE_MAP:
193
+ raise HTTPException(status_code=400, detail="Pipeline task no soportado")
194
+
195
+ nlp_pipeline = pipeline(PIPELINE_MAP[task], model=model, tokenizer=tokenizer)
196
+
197
+ result = await asyncio.to_thread(nlp_pipeline, input_text)
198
+
199
+ if len(result) > MAX_TOKENS:
200
+ chunks = split_text_by_tokens(result, tokenizer)
201
+ full_result = ""
202
+ for chunk in chunks:
203
+ full_result += continue_generation(chunk, model, tokenizer)
204
+ return {"result": full_result}
205
+
206
+ return {"result": result}
207
+
208
+ except HTTPException as e:
209
+ logger.error(f"Error al realizar la predicci贸n: {str(e.detail)}")
210
+ return JSONResponse(status_code=e.status_code, content={"detail": str(e.detail)})
211
+
212
+ except Exception as e:
213
+ logger.error(f"Error inesperado: {str(e)}")
214
+ return JSONResponse(status_code=500, content={"detail": "Error inesperado. Intenta m谩s tarde."})
215
+
216
+ if __name__ == "__main__":
217
+ import uvicorn
218
+ uvicorn.run(app, host="0.0.0.0", port=7860)