Hjgugugjhuhjggg commited on
Commit
b4532e1
verified
1 Parent(s): c1d4983

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -152
app.py CHANGED
@@ -1,20 +1,13 @@
1
  import os
2
- import json
 
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
  from huggingface_hub import hf_hub_download
7
- import boto3
8
- import logging
9
  import asyncio
10
 
11
- logger = logging.getLogger(__name__)
12
- logger.setLevel(logging.INFO)
13
- console_handler = logging.StreamHandler()
14
- formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
15
- console_handler.setFormatter(formatter)
16
- logger.addHandler(console_handler)
17
-
18
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
19
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
20
  AWS_REGION = os.getenv("AWS_REGION")
@@ -23,6 +16,7 @@ HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
23
 
24
  MAX_TOKENS = 1024
25
 
 
26
  s3_client = boto3.client(
27
  's3',
28
  aws_access_key_id=AWS_ACCESS_KEY_ID,
@@ -30,194 +24,135 @@ s3_client = boto3.client(
30
  region_name=AWS_REGION
31
  )
32
 
 
33
  app = FastAPI()
34
 
 
35
  class GenerateRequest(BaseModel):
36
  model_name: str
37
  input_text: str
38
- task_type: str # Added task type to handle different tasks (e.g., text-to-image, text-to-speech)
39
 
40
- class S3DirectStream:
 
41
  def __init__(self, bucket_name):
42
- self.s3_client = boto3.client(
43
- 's3',
44
- aws_access_key_id=AWS_ACCESS_KEY_ID,
45
- aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
46
- region_name=AWS_REGION
47
- )
48
  self.bucket_name = bucket_name
 
49
 
50
- async def stream_from_s3(self, key):
 
51
  loop = asyncio.get_event_loop()
52
- return await loop.run_in_executor(None, self._stream_from_s3, key)
53
 
54
- def _stream_from_s3(self, key):
55
  try:
56
  response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
57
- return response['Body'].read() # This is a bytes object
58
  except self.s3_client.exceptions.NoSuchKey:
59
- raise HTTPException(status_code=404, detail=f"El archivo {key} no existe en el bucket S3.")
60
  except Exception as e:
61
- raise HTTPException(status_code=500, detail=f"Error al descargar {key} desde S3: {str(e)}")
62
 
63
- async def get_model_file_parts(self, model_name):
 
64
  loop = asyncio.get_event_loop()
65
- return await loop.run_in_executor(None, self._get_model_file_parts, model_name)
66
 
67
- def _get_model_file_parts(self, model_name):
68
  try:
69
- model_name = model_name.replace("/", "-").lower()
70
- files = self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=model_name)
71
- model_files = [obj['Key'] for obj in files.get('Contents', []) if model_name in obj['Key']]
72
- return model_files
73
  except Exception as e:
74
- raise HTTPException(status_code=500, detail=f"Error al obtener archivos del modelo {model_name} desde S3: {e}")
75
-
76
- async def load_model_from_s3(self, model_name):
77
- try:
78
- model_name = model_name.replace("/", "-").lower()
79
- model_files = await self.get_model_file_parts(model_name)
80
-
81
- if not model_files:
82
- await self.download_and_upload_to_s3(model_name)
83
-
84
- config_data = await self.stream_from_s3(f"{model_name}/config.json")
85
- if not config_data:
86
- raise HTTPException(status_code=500, detail=f"El archivo de configuraci贸n {model_name}/config.json est谩 vac铆o o no se pudo leer.")
87
-
88
- # Ensure config_data is a string or bytes-like object
89
- if isinstance(config_data, bytes):
90
- config_data = config_data.decode("utf-8") # Decodificar los bytes a string si es necesario
91
-
92
- config_json = json.loads(config_data) # Ahora podemos usar json.loads sin problema
93
-
94
- model = AutoModelForCausalLM.from_pretrained(f"s3://{self.bucket_name}/{model_name}", config=config_json, from_tf=False)
95
- return model
96
-
97
- except HTTPException as e:
98
- raise e
99
- except Exception as e:
100
- raise HTTPException(status_code=500, detail=f"Error al cargar el modelo desde S3: {e}")
101
-
102
- async def load_tokenizer_from_s3(self, model_name):
103
- try:
104
- model_name = model_name.replace("/", "-").lower()
105
- tokenizer_data = await self.stream_from_s3(f"{model_name}/tokenizer.json")
106
-
107
- # Ensure tokenizer_data is a string or bytes-like object
108
- if isinstance(tokenizer_data, bytes):
109
- tokenizer_data = tokenizer_data.decode("utf-8") # Decodificar los bytes a string si es necesario
110
-
111
- tokenizer = AutoTokenizer.from_pretrained(f"s3://{self.bucket_name}/{model_name}")
112
- return tokenizer
113
- except Exception as e:
114
- raise HTTPException(status_code=500, detail=f"Error al cargar el tokenizer desde S3: {e}")
115
-
116
- async def create_s3_folders(self, s3_key):
117
- try:
118
- folder_keys = s3_key.split('-')
119
- for i in range(1, len(folder_keys)):
120
- folder_key = '-'.join(folder_keys[:i]) + '/'
121
- if not await self.file_exists_in_s3(folder_key):
122
- logger.info(f"Creando carpeta en S3: {folder_key}")
123
- self.s3_client.put_object(Bucket=self.bucket_name, Key=folder_key, Body='')
124
 
125
- except Exception as e:
126
- raise HTTPException(status_code=500, detail=f"Error al crear carpetas en S3: {e}")
 
 
127
 
128
- async def file_exists_in_s3(self, s3_key):
129
  try:
130
- self.s3_client.head_object(Bucket=self.bucket_name, Key=s3_key)
131
  return True
132
  except self.s3_client.exceptions.ClientError:
133
  return False
 
 
134
 
135
- async def download_and_upload_to_s3(self, model_name, force_download=False):
136
- try:
137
- if force_download:
138
- logger.info(f"Forzando la descarga del modelo {model_name} y la carga a S3.")
139
-
140
- model_name = model_name.replace("/", "-").lower()
141
-
142
- if not await self.file_exists_in_s3(f"{model_name}/config.json") or not await self.file_exists_in_s3(f"{model_name}/tokenizer.json"):
143
- config_file = hf_hub_download(repo_id=model_name, filename="config.json", token=HUGGINGFACE_HUB_TOKEN, force_download=force_download)
144
- tokenizer_file = hf_hub_download(repo_id=model_name, filename="tokenizer.json", token=HUGGINGFACE_HUB_TOKEN, force_download=force_download)
145
-
146
- await self.create_s3_folders(f"{model_name}/")
147
-
148
- if not await self.file_exists_in_s3(f"{model_name}/config.json"):
149
- with open(config_file, "rb") as file:
150
- self.s3_client.put_object(Bucket=self.bucket_name, Key=f"{model_name}/config.json", Body=file)
151
 
152
- if not await self.file_exists_in_s3(f"{model_name}/tokenizer.json"):
153
- with open(tokenizer_file, "rb") as file:
154
- self.s3_client.put_object(Bucket=self.bucket_name, Key=f"{model_name}/tokenizer.json", Body=file)
155
- else:
156
- logger.info(f"Los archivos del modelo {model_name} ya existen en S3. No es necesario descargarlos de nuevo.")
157
 
158
- except Exception as e:
159
- raise HTTPException(status_code=500, detail=f"Error al descargar o cargar archivos desde Hugging Face a S3: {e}")
 
 
 
 
 
 
160
 
161
- async def resume_download(self, model_name):
162
- try:
163
- logger.info(f"Reanudando la descarga del modelo {model_name} desde Hugging Face.")
164
- config_file = hf_hub_download(repo_id=model_name, filename="config.json", token=HUGGINGFACE_HUB_TOKEN, resume_download=True)
165
- tokenizer_file = hf_hub_download(repo_id=model_name, filename="tokenizer.json", token=HUGGINGFACE_HUB_TOKEN, resume_download=True)
166
 
167
- if not await self.file_exists_in_s3(f"{model_name}/config.json"):
168
- with open(config_file, "rb") as file:
169
- self.s3_client.put_object(Bucket=self.bucket_name, Key=f"{model_name}/config.json", Body=file)
170
 
171
- if not await self.file_exists_in_s3(f"{model_name}/tokenizer.json"):
172
- with open(tokenizer_file, "rb") as file:
173
- self.s3_client.put_object(Bucket=self.bucket_name, Key=f"{model_name}/tokenizer.json", Body=file)
174
 
175
- except Exception as e:
176
- raise HTTPException(status_code=500, detail=f"Error al reanudar la descarga del modelo: {e}")
177
 
178
  @app.post("/generate")
179
  async def generate(request: GenerateRequest):
180
  try:
181
- model_name = request.model_name
182
- input_text = request.input_text
183
- task_type = request.task_type
184
-
185
- # Create an instance of S3DirectStream
186
- s3_direct_stream = S3DirectStream(S3_BUCKET_NAME)
187
-
188
- # Load model and tokenizer
189
- model = await s3_direct_stream.load_model_from_s3(model_name)
190
- tokenizer = await s3_direct_stream.load_tokenizer_from_s3(model_name)
191
-
192
- # Generate based on task type
193
- if task_type == "text-to-text":
194
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
195
- result = generator(input_text, max_length=MAX_TOKENS, num_return_sequences=1)
 
196
  return {"result": result[0]["generated_text"]}
197
 
198
- elif task_type == "text-to-image":
199
- generator = pipeline("text-to-image", model=model, tokenizer=tokenizer, device=0)
200
- image = generator(input_text)
201
- return {"result": image}
202
-
203
- elif task_type == "text-to-speech":
204
- generator = pipeline("text-to-speech", model=model, tokenizer=tokenizer, device=0)
205
- audio = generator(input_text)
206
- return {"result": audio}
207
 
208
- elif task_type == "text-to-video":
209
- generator = pipeline("text-to-video", model=model, tokenizer=tokenizer, device=0)
210
- video = generator(input_text)
211
- return {"result": video}
212
 
213
- else:
214
- raise HTTPException(status_code=400, detail="Tipo de tarea no soportada")
 
 
215
 
216
  except HTTPException as e:
217
  raise e
218
  except Exception as e:
219
- raise HTTPException(status_code=500, detail=str(e))
220
 
221
  if __name__ == "__main__":
222
  import uvicorn
223
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
1
  import os
2
+ import logging
3
+ import boto3
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
  from huggingface_hub import hf_hub_download
 
 
8
  import asyncio
9
 
10
+ # Configuraci贸n de variables
 
 
 
 
 
 
11
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
12
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
13
  AWS_REGION = os.getenv("AWS_REGION")
 
16
 
17
  MAX_TOKENS = 1024
18
 
19
+ # Configuraci贸n de cliente S3
20
  s3_client = boto3.client(
21
  's3',
22
  aws_access_key_id=AWS_ACCESS_KEY_ID,
 
24
  region_name=AWS_REGION
25
  )
26
 
27
+ # Inicializaci贸n de la app FastAPI
28
  app = FastAPI()
29
 
30
+ # Estructura de solicitudes
31
  class GenerateRequest(BaseModel):
32
  model_name: str
33
  input_text: str
34
+ task_type: str
35
 
36
+ # Clase para manejo de S3
37
+ class S3Manager:
38
  def __init__(self, bucket_name):
 
 
 
 
 
 
39
  self.bucket_name = bucket_name
40
+ self.s3_client = s3_client
41
 
42
+ async def get_file(self, key: str):
43
+ """Descarga un archivo desde S3."""
44
  loop = asyncio.get_event_loop()
45
+ return await loop.run_in_executor(None, self._get_file_sync, key)
46
 
47
+ def _get_file_sync(self, key: str):
48
  try:
49
  response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
50
+ return response['Body'].read()
51
  except self.s3_client.exceptions.NoSuchKey:
52
+ raise HTTPException(status_code=404, detail=f"Archivo {key} no encontrado en S3.")
53
  except Exception as e:
54
+ raise HTTPException(status_code=500, detail=f"Error al obtener el archivo {key} de S3: {str(e)}")
55
 
56
+ async def upload_file(self, file_path: str, key: str):
57
+ """Sube un archivo a S3."""
58
  loop = asyncio.get_event_loop()
59
+ return await loop.run_in_executor(None, self._upload_file_sync, file_path, key)
60
 
61
+ def _upload_file_sync(self, file_path: str, key: str):
62
  try:
63
+ with open(file_path, "rb") as file:
64
+ self.s3_client.put_object(Bucket=self.bucket_name, Key=key, Body=file)
 
 
65
  except Exception as e:
66
+ raise HTTPException(status_code=500, detail=f"Error al subir {key} a S3: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ async def file_exists(self, key: str):
69
+ """Verifica si un archivo existe en S3."""
70
+ loop = asyncio.get_event_loop()
71
+ return await loop.run_in_executor(None, self._file_exists_sync, key)
72
 
73
+ def _file_exists_sync(self, key: str):
74
  try:
75
+ self.s3_client.head_object(Bucket=self.bucket_name, Key=key)
76
  return True
77
  except self.s3_client.exceptions.ClientError:
78
  return False
79
+ except Exception as e:
80
+ raise HTTPException(status_code=500, detail=f"Error al verificar existencia de {key}: {str(e)}")
81
 
82
+ async def download_model_files(self, model_name: str):
83
+ """Descarga los archivos del modelo desde Hugging Face y los sube a S3 si no est谩n presentes."""
84
+ model_name_s3 = model_name.replace("/", "-").lower()
85
+ files = ["pytorch_model.bin", "tokenizer.json", "config.json"]
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ for file in files:
88
+ if not await self.file_exists(f"{model_name_s3}/{file}"):
89
+ local_file = hf_hub_download(repo_id=model_name, filename=file, token=HUGGINGFACE_HUB_TOKEN)
90
+ await self.upload_file(local_file, f"{model_name_s3}/{file}")
 
91
 
92
+ async def load_model_from_s3(self, model_name: str):
93
+ """Carga el modelo desde S3."""
94
+ model_name_s3 = model_name.replace("/", "-").lower()
95
+ files = {
96
+ "model": f"{model_name_s3}/pytorch_model.bin",
97
+ "tokenizer": f"{model_name_s3}/tokenizer.json",
98
+ "config": f"{model_name_s3}/config.json",
99
+ }
100
 
101
+ for key, path in files.items():
102
+ if not await self.file_exists(path):
103
+ raise HTTPException(status_code=404, detail=f"Archivo {path} no encontrado en S3.")
 
 
104
 
105
+ model_bytes = await self.get_file(files["model"])
106
+ tokenizer_bytes = await self.get_file(files["tokenizer"])
107
+ config_bytes = await self.get_file(files["config"])
108
 
109
+ model = AutoModelForCausalLM.from_pretrained(model_bytes, config=config_bytes)
110
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_bytes)
 
111
 
112
+ return model, tokenizer
 
113
 
114
  @app.post("/generate")
115
  async def generate(request: GenerateRequest):
116
  try:
117
+ # Validaciones iniciales
118
+ if not request.model_name or not request.input_text or not request.task_type:
119
+ raise HTTPException(status_code=400, detail="Todos los campos son obligatorios.")
120
+
121
+ if request.task_type not in ["text-to-text", "text-to-image", "text-to-speech", "text-to-video"]:
122
+ raise HTTPException(status_code=400, detail="Tipo de tarea no soportado.")
123
+
124
+ # Descarga y carga del modelo
125
+ s3_manager = S3Manager(S3_BUCKET_NAME)
126
+ await s3_manager.download_model_files(request.model_name)
127
+ model, tokenizer = await s3_manager.load_model_from_s3(request.model_name)
128
+
129
+ # Generaci贸n seg煤n el tipo de tarea
130
+ if request.task_type == "text-to-text":
131
+ generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
132
+ result = generator(request.input_text, max_length=MAX_TOKENS, num_return_sequences=1)
133
  return {"result": result[0]["generated_text"]}
134
 
135
+ elif request.task_type == "text-to-image":
136
+ generator = pipeline("text-to-image", model=model, tokenizer=tokenizer)
137
+ image = generator(request.input_text)
138
+ return {"image": image}
 
 
 
 
 
139
 
140
+ elif request.task_type == "text-to-speech":
141
+ generator = pipeline("text-to-speech", model=model, tokenizer=tokenizer)
142
+ audio = generator(request.input_text)
143
+ return {"audio": audio}
144
 
145
+ elif request.task_type == "text-to-video":
146
+ generator = pipeline("text-to-video", model=model, tokenizer=tokenizer)
147
+ video = generator(request.input_text)
148
+ return {"video": video}
149
 
150
  except HTTPException as e:
151
  raise e
152
  except Exception as e:
153
+ raise HTTPException(status_code=500, detail=f"Error en la generaci贸n: {str(e)}")
154
 
155
  if __name__ == "__main__":
156
  import uvicorn
157
  uvicorn.run(app, host="0.0.0.0", port=7860)
158
+