Spaces:
Sleeping
Sleeping
Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,7 @@ import uvicorn
|
|
12 |
from dotenv import load_dotenv
|
13 |
import re
|
14 |
from tqdm import tqdm
|
|
|
15 |
|
16 |
# Cargar las variables de entorno desde el archivo .env
|
17 |
load_dotenv()
|
@@ -72,8 +73,12 @@ class S3DirectStream:
|
|
72 |
print(f"[INFO] Cargando el modelo {model_name} desde S3...")
|
73 |
model_prefix = model_name.lower()
|
74 |
model_files = self.get_model_file_parts(model_prefix)
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
77 |
self.download_and_upload_from_huggingface(model_name)
|
78 |
model_files = self.get_model_file_parts(model_prefix)
|
79 |
|
@@ -85,6 +90,7 @@ class S3DirectStream:
|
|
85 |
config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
|
86 |
config_data = config_stream.read().decode("utf-8")
|
87 |
|
|
|
88 |
if model_files[0].endswith("model.safetensors"):
|
89 |
print("[INFO] Cargando el modelo como safetensor...")
|
90 |
model = AutoModelForCausalLM.from_config(config_data)
|
@@ -126,34 +132,30 @@ class S3DirectStream:
|
|
126 |
def download_and_upload_from_huggingface(self, model_name):
|
127 |
try:
|
128 |
print(f"[INFO] Descargando {model_name} desde Hugging Face...")
|
129 |
-
|
130 |
files_to_download = [
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
]
|
135 |
-
|
136 |
for file in files_to_download:
|
137 |
-
|
138 |
-
|
139 |
-
|
|
|
140 |
except Exception as e:
|
141 |
print(f"[ERROR] Error al descargar y subir modelo desde Hugging Face: {e}")
|
142 |
raise HTTPException(status_code=500, detail="Error al descargar y subir modelo desde Hugging Face.")
|
143 |
|
144 |
-
def
|
145 |
try:
|
146 |
-
print(f"[INFO]
|
147 |
-
|
148 |
-
|
149 |
-
print(f"[INFO] Subiendo archivo a S3 con key {s3_key}...")
|
150 |
-
self.s3_client.put_object(Bucket=self.bucket_name, Key=s3_key, Body=response.content)
|
151 |
-
else:
|
152 |
-
print(f"[ERROR] Error al descargar el archivo desde {url}, c贸digo de estado {response.status_code}.")
|
153 |
-
raise HTTPException(status_code=500, detail=f"Error al descargar el archivo desde {url}")
|
154 |
except Exception as e:
|
155 |
-
print(f"[ERROR] Error al
|
156 |
-
raise HTTPException(status_code=500, detail=
|
157 |
|
158 |
@app.post("/predict/")
|
159 |
async def predict(model_request: DownloadModelRequest):
|
@@ -184,13 +186,11 @@ async def predict(model_request: DownloadModelRequest):
|
|
184 |
s3_key = f"{model_request.model_name}/generated_video.mp4"
|
185 |
return StreamingResponse(streamer.stream_from_s3(s3_key), media_type="video/mp4")
|
186 |
|
187 |
-
return {"
|
188 |
-
|
189 |
except Exception as e:
|
190 |
-
print(f"[ERROR] Error en
|
191 |
raise HTTPException(status_code=500, detail="Error al realizar la predicci贸n.")
|
192 |
|
193 |
-
# Iniciar servidor de predicciones
|
194 |
if __name__ == "__main__":
|
195 |
print("Iniciando servidor de predicciones en localhost:8000")
|
196 |
-
uvicorn.run(app, host="0.0.0.0", port=
|
|
|
12 |
from dotenv import load_dotenv
|
13 |
import re
|
14 |
from tqdm import tqdm
|
15 |
+
from huggingface_hub import hf_hub_download
|
16 |
|
17 |
# Cargar las variables de entorno desde el archivo .env
|
18 |
load_dotenv()
|
|
|
73 |
print(f"[INFO] Cargando el modelo {model_name} desde S3...")
|
74 |
model_prefix = model_name.lower()
|
75 |
model_files = self.get_model_file_parts(model_prefix)
|
76 |
+
|
77 |
+
# Verificar si el modelo est谩 en S3
|
78 |
+
if model_files:
|
79 |
+
print(f"[INFO] Modelo encontrado en S3, cargando directamente...")
|
80 |
+
else:
|
81 |
+
print(f"[INFO] Modelo no encontrado en S3, descargando desde Hugging Face...")
|
82 |
self.download_and_upload_from_huggingface(model_name)
|
83 |
model_files = self.get_model_file_parts(model_prefix)
|
84 |
|
|
|
90 |
config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
|
91 |
config_data = config_stream.read().decode("utf-8")
|
92 |
|
93 |
+
# Cargar el modelo dependiendo del tipo de archivo (torch o safetensors)
|
94 |
if model_files[0].endswith("model.safetensors"):
|
95 |
print("[INFO] Cargando el modelo como safetensor...")
|
96 |
model = AutoModelForCausalLM.from_config(config_data)
|
|
|
132 |
def download_and_upload_from_huggingface(self, model_name):
|
133 |
try:
|
134 |
print(f"[INFO] Descargando {model_name} desde Hugging Face...")
|
135 |
+
# Descargar los archivos necesarios de Hugging Face usando huggingface_hub
|
136 |
files_to_download = [
|
137 |
+
"pytorch_model.bin",
|
138 |
+
"config.json",
|
139 |
+
"tokenizer.json",
|
140 |
]
|
141 |
+
|
142 |
for file in files_to_download:
|
143 |
+
# Descargar cada archivo desde Hugging Face y subirlo a S3
|
144 |
+
file_path = hf_hub_download(repo_id=model_name, filename=file, use_auth_token=HUGGINGFACE_TOKEN)
|
145 |
+
s3_key = f"{model_name}/{file}"
|
146 |
+
self.upload_file_to_s3(file_path, s3_key)
|
147 |
except Exception as e:
|
148 |
print(f"[ERROR] Error al descargar y subir modelo desde Hugging Face: {e}")
|
149 |
raise HTTPException(status_code=500, detail="Error al descargar y subir modelo desde Hugging Face.")
|
150 |
|
151 |
+
def upload_file_to_s3(self, file_path, s3_key):
|
152 |
try:
|
153 |
+
print(f"[INFO] Subiendo archivo {file_path} a S3 con key {s3_key}...")
|
154 |
+
with open(file_path, 'rb') as data:
|
155 |
+
self.s3_client.put_object(Bucket=self.bucket_name, Key=s3_key, Body=data)
|
|
|
|
|
|
|
|
|
|
|
156 |
except Exception as e:
|
157 |
+
print(f"[ERROR] Error al subir archivo a S3: {e}")
|
158 |
+
raise HTTPException(status_code=500, detail="Error al subir archivo a S3.")
|
159 |
|
160 |
@app.post("/predict/")
|
161 |
async def predict(model_request: DownloadModelRequest):
|
|
|
186 |
s3_key = f"{model_request.model_name}/generated_video.mp4"
|
187 |
return StreamingResponse(streamer.stream_from_s3(s3_key), media_type="video/mp4")
|
188 |
|
189 |
+
return {"output": outputs}
|
|
|
190 |
except Exception as e:
|
191 |
+
print(f"[ERROR] Error en el proceso de predicci贸n: {str(e)}")
|
192 |
raise HTTPException(status_code=500, detail="Error al realizar la predicci贸n.")
|
193 |
|
|
|
194 |
if __name__ == "__main__":
|
195 |
print("Iniciando servidor de predicciones en localhost:8000")
|
196 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|