Hjgugugjhuhjggg commited on
Commit
00a3421
verified
1 Parent(s): 3f42f59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -37
app.py CHANGED
@@ -4,23 +4,23 @@ import torch
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
  import safetensors.torch
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from fastapi.responses import StreamingResponse
9
  import io
10
- from tqdm import tqdm
11
  import requests
12
  import uvicorn
 
13
  import re
14
- import sys
 
15
  # Cargar las variables de entorno desde el archivo .env
16
- from dotenv import load_dotenv
17
  load_dotenv()
18
 
19
- # Cargar las credenciales de AWS desde las variables de entorno
20
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
21
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
22
  AWS_REGION = os.getenv("AWS_REGION")
23
- S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME") # Nombre del bucket de S3
24
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") # Token de Hugging Face
25
 
26
  # Cliente S3 de Amazon
@@ -67,31 +67,32 @@ class S3DirectStream:
67
  except self.s3_client.exceptions.ClientError:
68
  return False
69
 
70
- def load_model_from_stream(self, model_prefix):
71
  try:
72
- print(f"[INFO] Cargando el modelo {model_prefix} desde S3...")
 
73
  model_files = self.get_model_file_parts(model_prefix)
74
  if not model_files:
75
- model_files = [f"{model_prefix}/model"] # Uso de modelo base
76
-
 
 
77
  # Leer y cargar todos los archivos del modelo
78
  model_streams = []
79
  for model_file in tqdm(model_files, desc="Cargando archivos del modelo", unit="archivo"):
80
  model_streams.append(self.stream_from_s3(model_file))
81
 
82
- # Verificar si el archivo es un safetensor o un archivo binario
83
  config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
84
  config_data = config_stream.read().decode("utf-8")
85
 
86
- # Cargar el modelo dependiendo de si es safetensor o binario
87
  if model_files[0].endswith("model.safetensors"):
88
  print("[INFO] Cargando el modelo como safetensor...")
89
  model = AutoModelForCausalLM.from_config(config_data)
90
- model.load_state_dict(safetensors.torch.load_stream(model_streams[0])) # Cargar el modelo utilizando safetensors
91
  else:
92
  print("[INFO] Cargando el modelo como archivo binario de PyTorch...")
93
  model = AutoModelForCausalLM.from_config(config_data)
94
- model.load_state_dict(torch.load(model_streams[0], map_location="cpu")) # Cargar el modelo utilizando pytorch
95
 
96
  print("[INFO] Modelo cargado con 茅xito.")
97
  return model
@@ -100,28 +101,46 @@ class S3DirectStream:
100
  print(f"[ERROR] Error al cargar el modelo desde S3: {e}")
101
  raise HTTPException(status_code=500, detail="Error al cargar el modelo desde S3.")
102
 
103
- def load_tokenizer_from_stream(self, model_prefix):
104
  try:
105
- print(f"[INFO] Cargando el tokenizer {model_prefix} desde S3...")
106
- tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")
107
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)
108
  return tokenizer
109
  except Exception as e:
110
  print(f"[ERROR] Error al cargar el tokenizer desde S3: {e}")
111
  raise HTTPException(status_code=500, detail="Error al cargar el tokenizer desde S3.")
112
 
113
- def get_model_file_parts(self, model_prefix):
114
- print(f"[INFO] Listando archivos del modelo en S3 con prefijo {model_prefix}...")
115
- files = self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=model_prefix)
116
  model_files = []
117
  for obj in tqdm(files.get('Contents', []), desc="Verificando archivos", unit="archivo"):
118
  key = obj['Key']
119
- if re.match(rf"{model_prefix}/model(-\d+-of-\d+)?", key) or key.endswith("model.safetensors"):
120
  model_files.append(key)
121
  if not model_files:
122
- print(f"[WARNING] No se encontraron archivos coincidentes con el patr贸n para el modelo {model_prefix}.")
123
  return model_files
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def download_and_upload_to_s3_url(self, url: str, s3_key: str):
126
  try:
127
  print(f"[INFO] Descargando archivo desde {url}...")
@@ -129,8 +148,6 @@ class S3DirectStream:
129
  if response.status_code == 200:
130
  print(f"[INFO] Subiendo archivo a S3 con key {s3_key}...")
131
  self.s3_client.put_object(Bucket=self.bucket_name, Key=s3_key, Body=response.content)
132
- # Eliminar el archivo local despu茅s de la carga exitosa
133
- self.delete_local_file(s3_key)
134
  else:
135
  print(f"[ERROR] Error al descargar el archivo desde {url}, c贸digo de estado {response.status_code}.")
136
  raise HTTPException(status_code=500, detail=f"Error al descargar el archivo desde {url}")
@@ -138,18 +155,6 @@ class S3DirectStream:
138
  print(f"[ERROR] Error al procesar la URL {url}: {str(e)}")
139
  raise HTTPException(status_code=500, detail=f"Error al procesar la URL {url}")
140
 
141
- def delete_local_file(self, file_path: str):
142
- try:
143
- print(f"[INFO] Eliminando archivo local {file_path}...")
144
- if os.path.exists(file_path):
145
- os.remove(file_path)
146
- print(f"[INFO] Archivo local {file_path} eliminado correctamente.")
147
- else:
148
- print(f"[WARNING] El archivo local {file_path} no existe.")
149
- except Exception as e:
150
- print(f"[ERROR] Error al eliminar el archivo local {file_path}: {str(e)}")
151
-
152
-
153
  @app.post("/predict/")
154
  async def predict(model_request: DownloadModelRequest):
155
  try:
@@ -188,4 +193,4 @@ async def predict(model_request: DownloadModelRequest):
188
  # Iniciar servidor de predicciones
189
  if __name__ == "__main__":
190
  print("Iniciando servidor de predicciones en localhost:8000")
191
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
  import safetensors.torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
8
  from fastapi.responses import StreamingResponse
9
  import io
 
10
  import requests
11
  import uvicorn
12
+ from dotenv import load_dotenv
13
  import re
14
+ from tqdm import tqdm
15
+
16
  # Cargar las variables de entorno desde el archivo .env
 
17
  load_dotenv()
18
 
19
+ # Configuraci贸n de AWS y Hugging Face
20
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
21
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
22
  AWS_REGION = os.getenv("AWS_REGION")
23
+ S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME") # Nombre del bucket S3
24
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") # Token de Hugging Face
25
 
26
  # Cliente S3 de Amazon
 
67
  except self.s3_client.exceptions.ClientError:
68
  return False
69
 
70
+ def load_model_from_stream(self, model_name):
71
  try:
72
+ print(f"[INFO] Cargando el modelo {model_name} desde S3...")
73
+ model_prefix = model_name.lower()
74
  model_files = self.get_model_file_parts(model_prefix)
75
  if not model_files:
76
+ print(f"[INFO] Modelo no encontrado en S3, descargando desde Hugging Face a S3...")
77
+ self.download_and_upload_from_huggingface(model_name)
78
+ model_files = self.get_model_file_parts(model_prefix)
79
+
80
  # Leer y cargar todos los archivos del modelo
81
  model_streams = []
82
  for model_file in tqdm(model_files, desc="Cargando archivos del modelo", unit="archivo"):
83
  model_streams.append(self.stream_from_s3(model_file))
84
 
 
85
  config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
86
  config_data = config_stream.read().decode("utf-8")
87
 
 
88
  if model_files[0].endswith("model.safetensors"):
89
  print("[INFO] Cargando el modelo como safetensor...")
90
  model = AutoModelForCausalLM.from_config(config_data)
91
+ model.load_state_dict(safetensors.torch.load_stream(model_streams[0]))
92
  else:
93
  print("[INFO] Cargando el modelo como archivo binario de PyTorch...")
94
  model = AutoModelForCausalLM.from_config(config_data)
95
+ model.load_state_dict(torch.load(model_streams[0], map_location="cpu"))
96
 
97
  print("[INFO] Modelo cargado con 茅xito.")
98
  return model
 
101
  print(f"[ERROR] Error al cargar el modelo desde S3: {e}")
102
  raise HTTPException(status_code=500, detail="Error al cargar el modelo desde S3.")
103
 
104
+ def load_tokenizer_from_stream(self, model_name):
105
  try:
106
+ print(f"[INFO] Cargando el tokenizer {model_name} desde S3...")
107
+ tokenizer_stream = self.stream_from_s3(f"{model_name}/tokenizer.json")
108
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)
109
  return tokenizer
110
  except Exception as e:
111
  print(f"[ERROR] Error al cargar el tokenizer desde S3: {e}")
112
  raise HTTPException(status_code=500, detail="Error al cargar el tokenizer desde S3.")
113
 
114
+ def get_model_file_parts(self, model_name):
115
+ print(f"[INFO] Listando archivos del modelo en S3 con prefijo {model_name}...")
116
+ files = self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=model_name)
117
  model_files = []
118
  for obj in tqdm(files.get('Contents', []), desc="Verificando archivos", unit="archivo"):
119
  key = obj['Key']
120
+ if re.match(rf"{model_name}/model(-\d+-of-\d+)?", key) or key.endswith("model.safetensors"):
121
  model_files.append(key)
122
  if not model_files:
123
+ print(f"[WARNING] No se encontraron archivos para el modelo {model_name}.")
124
  return model_files
125
 
126
+ def download_and_upload_from_huggingface(self, model_name):
127
+ try:
128
+ print(f"[INFO] Descargando {model_name} desde Hugging Face...")
129
+ model_url = f"https://huggingface.co/{model_name}/resolve/main/"
130
+ files_to_download = [
131
+ f"{model_name}/pytorch_model.bin",
132
+ f"{model_name}/config.json",
133
+ f"{model_name}/tokenizer.json",
134
+ ]
135
+
136
+ for file in files_to_download:
137
+ file_url = model_url + file
138
+ s3_key = file
139
+ self.download_and_upload_to_s3_url(file_url, s3_key)
140
+ except Exception as e:
141
+ print(f"[ERROR] Error al descargar y subir modelo desde Hugging Face: {e}")
142
+ raise HTTPException(status_code=500, detail="Error al descargar y subir modelo desde Hugging Face.")
143
+
144
  def download_and_upload_to_s3_url(self, url: str, s3_key: str):
145
  try:
146
  print(f"[INFO] Descargando archivo desde {url}...")
 
148
  if response.status_code == 200:
149
  print(f"[INFO] Subiendo archivo a S3 con key {s3_key}...")
150
  self.s3_client.put_object(Bucket=self.bucket_name, Key=s3_key, Body=response.content)
 
 
151
  else:
152
  print(f"[ERROR] Error al descargar el archivo desde {url}, c贸digo de estado {response.status_code}.")
153
  raise HTTPException(status_code=500, detail=f"Error al descargar el archivo desde {url}")
 
155
  print(f"[ERROR] Error al procesar la URL {url}: {str(e)}")
156
  raise HTTPException(status_code=500, detail=f"Error al procesar la URL {url}")
157
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  @app.post("/predict/")
159
  async def predict(model_request: DownloadModelRequest):
160
  try:
 
193
  # Iniciar servidor de predicciones
194
  if __name__ == "__main__":
195
  print("Iniciando servidor de predicciones en localhost:8000")
196
+ uvicorn.run(app, host="0.0.0.0", port=8000)