Hjgugugjhuhjggg commited on
Commit
1949d3a
verified
1 Parent(s): 2b23f87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -120
app.py CHANGED
@@ -1,23 +1,20 @@
1
  import os
2
  import json
3
- import logging
4
- import boto3
5
  from fastapi import FastAPI, HTTPException
6
- from fastapi.responses import JSONResponse
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
8
- import asyncio
9
- import concurrent.futures
10
-
11
- logging.basicConfig(level=logging.INFO)
12
- logger = logging.getLogger(__name__)
13
 
 
14
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
15
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
16
  AWS_REGION = os.getenv("AWS_REGION")
17
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
 
18
 
19
- MAX_TOKENS = 1024 # Limite de tokens por fragmento
20
-
21
  s3_client = boto3.client(
22
  's3',
23
  aws_access_key_id=AWS_ACCESS_KEY_ID,
@@ -27,16 +24,10 @@ s3_client = boto3.client(
27
 
28
  app = FastAPI()
29
 
30
- PIPELINE_MAP = {
31
- "text-generation": "text-generation",
32
- "sentiment-analysis": "sentiment-analysis",
33
- "translation": "translation",
34
- "fill-mask": "fill-mask",
35
- "question-answering": "question-answering",
36
- "text-to-speech": "text-to-speech",
37
- "text-to-video": "text-to-video",
38
- "text-to-image": "text-to-image"
39
- }
40
 
41
  class S3DirectStream:
42
  def __init__(self, bucket_name):
@@ -48,140 +39,130 @@ class S3DirectStream:
48
  )
49
  self.bucket_name = bucket_name
50
 
51
- async def stream_from_s3(self, key):
52
- loop = asyncio.get_event_loop()
53
- return await loop.run_in_executor(None, self._stream_from_s3, key)
54
-
55
- def _stream_from_s3(self, key):
56
  try:
57
- response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
58
- return response['Body']
59
- except self.s3_client.exceptions.NoSuchKey:
60
- raise HTTPException(status_code=404, detail=f"El archivo {key} no existe en el bucket S3.")
61
- except Exception as e:
62
- raise HTTPException(status_code=500, detail=f"Error al descargar {key} desde S3: {str(e)}")
63
 
64
- async def get_model_file_parts(self, model_name):
65
- loop = asyncio.get_event_loop()
66
- return await loop.run_in_executor(None, self._get_model_file_parts, model_name)
 
 
 
 
 
 
 
 
 
67
 
68
- def _get_model_file_parts(self, model_name):
69
- try:
70
- model_prefix = model_name.lower()
71
- files = self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=model_prefix)
72
- model_files = [obj['Key'] for obj in files.get('Contents', []) if model_prefix in obj['Key']]
73
- return model_files
74
  except Exception as e:
75
- raise HTTPException(status_code=500, detail=f"Error al obtener archivos del modelo {model_name} desde S3: {e}")
 
76
 
77
- async def load_model_from_s3(self, model_name):
78
  try:
79
- profile, model = model_name.split("/", 1) if "/" in model_name else ("", model_name)
 
 
 
80
 
81
- model_prefix = f"{profile}/{model}".lower()
82
- model_files = await self.get_model_file_parts(model_prefix)
 
 
83
 
84
  if not model_files:
85
- raise HTTPException(status_code=404, detail=f"Archivos del modelo {model_name} no encontrados en S3.")
86
-
87
- config_stream = await self.stream_from_s3(f"{model_prefix}/config.json")
88
- config_data = config_stream.read()
89
 
90
- if not config_data:
91
- raise HTTPException(status_code=500, detail=f"El archivo de configuraci贸n {model_prefix}/config.json est谩 vac铆o.")
 
 
92
 
93
- config_text = config_data.decode("utf-8")
94
- config_json = json.loads(config_text)
95
 
96
- model = AutoModelForCausalLM.from_pretrained(f"s3://{self.bucket_name}/{model_prefix}", config=config_json, from_tf=False)
 
97
  return model
98
 
99
  except HTTPException as e:
100
  raise e
101
  except Exception as e:
102
- raise HTTPException(status_code=500, detail=f"Error al cargar el modelo desde S3: {e}")
 
103
 
104
  async def load_tokenizer_from_s3(self, model_name):
105
  try:
106
- profile, model = model_name.split("/", 1) if "/" in model_name else ("", model_name)
107
-
108
- tokenizer_stream = await self.stream_from_s3(f"{profile}/{model}/tokenizer.json")
109
- tokenizer_data = tokenizer_stream.read().decode("utf-8")
110
 
111
- tokenizer = AutoTokenizer.from_pretrained(f"{profile}/{model}")
 
 
 
112
  return tokenizer
113
  except Exception as e:
114
- raise HTTPException(status_code=500, detail=f"Error al cargar el tokenizer desde S3: {e}")
 
115
 
116
- async def create_s3_folders(self, s3_key):
117
  try:
118
- folder_keys = s3_key.split('/')
119
- for i in range(1, len(folder_keys)):
120
- folder_key = '/'.join(folder_keys[:i]) + '/'
121
- if not await self.file_exists_in_s3(folder_key):
122
- self.s3_client.put_object(Bucket=self.bucket_name, Key=folder_key, Body='')
123
-
124
  except Exception as e:
125
- raise HTTPException(status_code=500, detail=f"Error al crear carpetas en S3: {e}")
126
 
127
- async def file_exists_in_s3(self, s3_key):
128
  try:
129
- self.s3_client.head_object(Bucket=self.bucket_name, Key=s3_key)
130
- return True
131
- except self.s3_client.exceptions.ClientError:
132
- return False
 
 
133
 
134
- def split_text_by_tokens(text, tokenizer, max_tokens=MAX_TOKENS):
135
- tokens = tokenizer.encode(text)
136
- chunks = []
137
- for i in range(0, len(tokens), max_tokens):
138
- chunk = tokens[i:i+max_tokens]
139
- chunks.append(tokenizer.decode(chunk))
140
- return chunks
141
-
142
- def continue_generation(input_text, model, tokenizer, max_tokens=MAX_TOKENS):
143
- generated_text = ""
144
- while len(input_text) > 0:
145
- tokens = tokenizer.encode(input_text)
146
- input_text = tokenizer.decode(tokens[:max_tokens])
147
- output = model.generate(input_ids=tokenizer.encode(input_text, return_tensors="pt").input_ids)
148
- generated_text += tokenizer.decode(output[0], skip_special_tokens=True)
149
- input_text = input_text[len(input_text):] # Si la entrada se agot贸, ya no hay m谩s que procesar
150
- return generated_text
151
-
152
- @app.post("/predict/")
153
- async def predict(model_request: dict):
154
  try:
155
- model_name = model_request.get("model_name")
156
- task = model_request.get("pipeline_task")
157
- input_text = model_request.get("input_text")
158
-
159
- if not model_name or not task or not input_text:
160
- raise HTTPException(status_code=400, detail="Faltan par谩metros en la solicitud.")
161
-
162
- streamer = S3DirectStream(S3_BUCKET_NAME)
163
- model = await streamer.load_model_from_s3(model_name)
164
- tokenizer = await streamer.load_tokenizer_from_s3(model_name)
165
-
166
- if task not in PIPELINE_MAP:
167
- raise HTTPException(status_code=400, detail="Pipeline task no soportado")
168
-
169
- nlp_pipeline = pipeline(PIPELINE_MAP[task], model=model, tokenizer=tokenizer)
170
-
171
- result = await asyncio.to_thread(nlp_pipeline, input_text)
172
-
173
- chunks = split_text_by_tokens(result, tokenizer)
 
 
 
 
 
 
 
 
 
174
 
175
- if len(chunks) > 1:
176
- full_result = ""
177
- for chunk in chunks:
178
- full_result += continue_generation(chunk, model, tokenizer)
179
- return JSONResponse(content={"result": full_result})
180
  else:
181
- return JSONResponse(content={"result": result})
182
 
 
 
183
  except Exception as e:
184
- raise HTTPException(status_code=500, detail=f"Error al realizar la predicci贸n: {e}")
185
 
186
  if __name__ == "__main__":
187
  import uvicorn
 
1
  import os
2
  import json
 
 
3
  from fastapi import FastAPI, HTTPException
4
+ from pydantic import BaseModel
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
+ import boto3
7
+ import logging
8
+ from huggingface_hub import hf_hub_download
 
 
9
 
10
+ # Configuraciones de AWS y Hugging Face
11
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
12
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
13
  AWS_REGION = os.getenv("AWS_REGION")
14
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
15
+ HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
16
 
17
+ # Cliente de S3
 
18
  s3_client = boto3.client(
19
  's3',
20
  aws_access_key_id=AWS_ACCESS_KEY_ID,
 
24
 
25
  app = FastAPI()
26
 
27
+ class GenerateRequest(BaseModel):
28
+ model_name: str
29
+ input_text: str
30
+ task_type: str
 
 
 
 
 
 
31
 
32
  class S3DirectStream:
33
  def __init__(self, bucket_name):
 
39
  )
40
  self.bucket_name = bucket_name
41
 
42
+ async def download_and_upload_to_s3(self, model_name):
 
 
 
 
43
  try:
44
+ model_name = model_name.replace("/", "-").lower()
 
 
 
 
 
45
 
46
+ # Descargar el archivo config.json desde Hugging Face
47
+ config_file = hf_hub_download(repo_id=model_name, filename="config.json", token=HUGGINGFACE_HUB_TOKEN)
48
+ tokenizer_file = hf_hub_download(repo_id=model_name, filename="tokenizer.json", token=HUGGINGFACE_HUB_TOKEN)
49
+
50
+ # Verificar si el archivo ya existe en S3
51
+ if not await self.file_exists_in_s3(f"{model_name}/config.json"):
52
+ with open(config_file, "rb") as file:
53
+ self.s3_client.put_object(Bucket=self.bucket_name, Key=f"{model_name}/config.json", Body=file)
54
+
55
+ if not await self.file_exists_in_s3(f"{model_name}/tokenizer.json"):
56
+ with open(tokenizer_file, "rb") as file:
57
+ self.s3_client.put_object(Bucket=self.bucket_name, Key=f"{model_name}/tokenizer.json", Body=file)
58
 
 
 
 
 
 
 
59
  except Exception as e:
60
+ logging.error(f"Error al cargar el modelo desde Hugging Face a S3: {e}")
61
+ raise HTTPException(status_code=500, detail=f"Error al cargar el modelo: {str(e)}")
62
 
63
+ async def file_exists_in_s3(self, s3_key):
64
  try:
65
+ self.s3_client.head_object(Bucket=self.bucket_name, Key=s3_key)
66
+ return True
67
+ except self.s3_client.exceptions.ClientError:
68
+ return False
69
 
70
+ async def load_model_from_s3(self, model_name):
71
+ try:
72
+ model_name = model_name.replace("/", "-").lower()
73
+ model_files = await self.get_model_file_parts(model_name)
74
 
75
  if not model_files:
76
+ await self.download_and_upload_to_s3(model_name)
 
 
 
77
 
78
+ # Cargar configuraci贸n del modelo desde S3
79
+ config_data = await self.stream_from_s3(f"{model_name}/config.json")
80
+ if isinstance(config_data, bytes):
81
+ config_data = config_data.decode("utf-8")
82
 
83
+ config_json = json.loads(config_data)
 
84
 
85
+ # Cargar el modelo
86
+ model = AutoModelForCausalLM.from_pretrained(f"s3://{self.bucket_name}/{model_name}", config=config_json)
87
  return model
88
 
89
  except HTTPException as e:
90
  raise e
91
  except Exception as e:
92
+ logging.error(f"Error al cargar el modelo desde S3: {e}")
93
+ raise HTTPException(status_code=500, detail=f"Error al cargar el modelo desde S3: {str(e)}")
94
 
95
  async def load_tokenizer_from_s3(self, model_name):
96
  try:
97
+ model_name = model_name.replace("/", "-").lower()
98
+ tokenizer_data = await self.stream_from_s3(f"{model_name}/tokenizer.json")
 
 
99
 
100
+ if isinstance(tokenizer_data, bytes):
101
+ tokenizer_data = tokenizer_data.decode("utf-8")
102
+
103
+ tokenizer = AutoTokenizer.from_pretrained(f"s3://{self.bucket_name}/{model_name}")
104
  return tokenizer
105
  except Exception as e:
106
+ logging.error(f"Error al cargar el tokenizer desde S3: {e}")
107
+ raise HTTPException(status_code=500, detail=f"Error al cargar el tokenizer desde S3: {str(e)}")
108
 
109
+ async def stream_from_s3(self, key):
110
  try:
111
+ response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
112
+ return response['Body'].read()
113
+ except self.s3_client.exceptions.NoSuchKey:
114
+ raise HTTPException(status_code=404, detail=f"El archivo {key} no existe en el bucket S3.")
 
 
115
  except Exception as e:
116
+ raise HTTPException(status_code=500, detail=f"Error al descargar {key} desde S3: {str(e)}")
117
 
118
+ async def get_model_file_parts(self, model_name):
119
  try:
120
+ model_name = model_name.replace("/", "-").lower()
121
+ files = self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=model_name)
122
+ model_files = [obj['Key'] for obj in files.get('Contents', []) if model_name in obj['Key']]
123
+ return model_files
124
+ except Exception as e:
125
+ raise HTTPException(status_code=500, detail=f"Error al obtener archivos del modelo {model_name} desde S3: {str(e)}")
126
 
127
+ @app.post("/generate")
128
+ async def generate(request: GenerateRequest):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  try:
130
+ model_name = request.model_name
131
+ input_text = request.input_text
132
+ task_type = request.task_type
133
+
134
+ s3_direct_stream = S3DirectStream(S3_BUCKET_NAME)
135
+
136
+ model = await s3_direct_stream.load_model_from_s3(model_name)
137
+ tokenizer = await s3_direct_stream.load_tokenizer_from_s3(model_name)
138
+
139
+ if task_type == "text-to-text":
140
+ generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
141
+ result = generator(input_text, max_length=1024, num_return_sequences=1)
142
+ return {"result": result[0]["generated_text"]}
143
+
144
+ elif task_type == "text-to-image":
145
+ generator = pipeline("text-to-image", model=model, tokenizer=tokenizer, device=0)
146
+ image = generator(input_text)
147
+ return {"result": image}
148
+
149
+ elif task_type == "text-to-speech":
150
+ generator = pipeline("text-to-speech", model=model, tokenizer=tokenizer, device=0)
151
+ audio = generator(input_text)
152
+ return {"result": audio}
153
+
154
+ elif task_type == "text-to-video":
155
+ generator = pipeline("text-to-video", model=model, tokenizer=tokenizer, device=0)
156
+ video = generator(input_text)
157
+ return {"result": video}
158
 
 
 
 
 
 
159
  else:
160
+ raise HTTPException(status_code=400, detail="Tipo de tarea no soportada")
161
 
162
+ except HTTPException as e:
163
+ raise e
164
  except Exception as e:
165
+ raise HTTPException(status_code=500, detail=str(e))
166
 
167
  if __name__ == "__main__":
168
  import uvicorn