Hjgugugjhuhjggg
commited on
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI
|
2 |
+
import torch
|
3 |
+
from langchain.chains.llm import LLMChain
|
4 |
+
from langchain.llms import VLLM
|
5 |
+
from langchain.cache import GPTCache
|
6 |
+
from transformers import pipeline
|
7 |
+
import uvicorn
|
8 |
+
import threading
|
9 |
+
import time
|
10 |
+
import nltk
|
11 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
12 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
13 |
+
import psutil
|
14 |
+
import os
|
15 |
+
import gc
|
16 |
+
import logging
|
17 |
+
from PIL import Image
|
18 |
+
from transformers import DALLEncoder, DALLDecoder
|
19 |
+
import uuid
|
20 |
+
from tqdm import tqdm
|
21 |
+
|
22 |
+
logging.basicConfig(level=logging.INFO)
|
23 |
+
|
24 |
+
nltk.download('punkt')
|
25 |
+
nltk.download('stopwords')
|
26 |
+
|
27 |
+
app = FastAPI()
|
28 |
+
|
29 |
+
Configurar dispositivo
|
30 |
+
if torch.cuda.is_available():
|
31 |
+
device = torch.device("cuda")
|
32 |
+
else:
|
33 |
+
device = torch.device("cpu")
|
34 |
+
|
35 |
+
print("Dispositivo:", device)
|
36 |
+
|
37 |
+
Cargar modelos con VLLM
|
38 |
+
modelos = {
|
39 |
+
"gpt2-medium": VLLM(model="gpt2-medium"),
|
40 |
+
"qwen2.5-0.5b": VLLM(model="qwen2.5-0.5b"),
|
41 |
+
"t5-base": VLLM(model="t5-base"),
|
42 |
+
"bert-base-uncased": VLLM(model="bert-base-uncased"),
|
43 |
+
"musicgen-small": VLLM(model="musicgen-small"),
|
44 |
+
"dall-e-mini": VLLM(model="dall-e-mini"),
|
45 |
+
"xlnet-base-uncased": VLLM(model="xlnet-base-uncased"),
|
46 |
+
"distilbert-base-uncased": VLLM(model="distilbert-base-uncased"),
|
47 |
+
"albert-base-v2": VLLM(model="albert-base-v2"),
|
48 |
+
"roberta-base": VLLM(model="roberta-base"),
|
49 |
+
}
|
50 |
+
|
51 |
+
print("Cargando modelos...")
|
52 |
+
for nombre, modelo in tqdm(modelos.items()):
|
53 |
+
modelos[nombre] = modelo(to=device)
|
54 |
+
print(f"Modelo {nombre} cargado")
|
55 |
+
|
56 |
+
Crear instancias de caché para cada modelo
|
57 |
+
caches = {
|
58 |
+
nombre: GPTCache(modelo, max_size=1000) for nombre, modelo in modelos.items()
|
59 |
+
}
|
60 |
+
|
61 |
+
print("Creando instancias de caché...")
|
62 |
+
for nombre, caché in tqdm(caches.items()):
|
63 |
+
print(f"Caché para modelo {nombre} creada")
|
64 |
+
|
65 |
+
Crear instancias de cadenas de modelo para cada modelo
|
66 |
+
cadenas = {
|
67 |
+
nombre: LLMChain(modelo, caché) for nombre, modelo, caché in zip(modelos.keys(), modelos.values(), caches.values())
|
68 |
+
}
|
69 |
+
|
70 |
+
print("Creando instancias de cadenas de modelo...")
|
71 |
+
for nombre, cadena in tqdm(cadenas.items()):
|
72 |
+
print(f"Cadena de modelo {nombre} creada")
|
73 |
+
|
74 |
+
Cargar modelo de resumen de texto
|
75 |
+
summarizer = pipeline("summarization", device=device)
|
76 |
+
|
77 |
+
print("Cargando modelo de resumen de texto...")
|
78 |
+
|
79 |
+
Cargar modelo de vectorizador TF-IDF
|
80 |
+
vectorizer = TfidfVectorizer()
|
81 |
+
|
82 |
+
print("Cargando modelo de vectorizador TF-IDF...")
|
83 |
+
|
84 |
+
Cargar modelo DALL-E
|
85 |
+
dalle_encoder = DALLEncoder(model_id="dall-e-mini")
|
86 |
+
dalle_decoder = DALLDecoder(model_id="dall-e-mini")
|
87 |
+
|
88 |
+
print("Cargando modelo DALL-E...")
|
89 |
+
|
90 |
+
def keep_alive():
|
91 |
+
while True:
|
92 |
+
# Realizar una petición a cada modelo cada 5 minutos
|
93 |
+
for cadena in cadenas.values():
|
94 |
+
try:
|
95 |
+
cadena.ask("¿Cuál es el sentido de la vida?")
|
96 |
+
except Exception as e:
|
97 |
+
logging.error(f"Error en modelo {cadena}: {e}")
|
98 |
+
cadenas.pop(cadena)
|
99 |
+
time.sleep(300)
|
100 |
+
|
101 |
+
def liberar_recursos():
|
102 |
+
while True:
|
103 |
+
# Obtener memoria RAM disponible
|
104 |
+
memoria_ram = psutil.virtual_memory().available / (1024.0 ** 3)
|
105 |
+
|
106 |
+
# Obtener espacio en disco disponible
|
107 |
+
espacio_disco = psutil.disk_usage('/').free / (1024.0 ** 3)
|
108 |
+
|
109 |
+
# Verificar si la memoria RAM o espacio en disco es menor al 5%
|
110 |
+
if memoria_ram < 5 or espacio_disco < 5:
|
111 |
+
# Liberar memoria RAM
|
112 |
+
gc.collect()
|
113 |
+
|
114 |
+
# Cerrar procesos innecesarios
|
115 |
+
for proc in psutil.process_iter(['pid', 'name']):
|
116 |
+
if proc.info['name'] == 'python':
|
117 |
+
os.kill(proc.info['pid'], 9)
|
118 |
+
|
119 |
+
time.sleep(60)
|
120 |
+
|
121 |
+
Iniciar hilos
|
122 |
+
threading.Thread(target=keep_alive, daemon=True).start()
|
123 |
+
threading.Thread(target=liberar_recursos, daemon=True).start()
|
124 |
+
|
125 |
+
print("Iniciando hilos...")
|
126 |
+
@app.post("/pregunta")
|
127 |
+
async def pregunta(pregunta: str, modelo: str):
|
128 |
+
print(f"Pregunta recibida: {pregunta}, Modelo: {modelo}")
|
129 |
+
try:
|
130 |
+
# Obtener respuesta del modelo seleccionado
|
131 |
+
respuesta = cadenas[modelo].ask(pregunta)
|
132 |
+
print(f"Respuesta obtenida: {respuesta}")
|
133 |
+
|
134 |
+
# Verificar si la respuesta excede el límite de tokens máximos
|
135 |
+
if len(respuesta.split()) > 2048:
|
136 |
+
# Dividir la respuesta en varios mensajes
|
137 |
+
mensajes = []
|
138 |
+
palabras = respuesta.split()
|
139 |
+
mensaje_actual = ""
|
140 |
+
for palabra in tqdm(palabras):
|
141 |
+
if len(mensaje_actual.split()) + len(palabra.split()) > 2048:
|
142 |
+
mensajes.append(mensaje_actual)
|
143 |
+
mensaje_actual = palabra
|
144 |
+
else:
|
145 |
+
mensaje_actual += " " + palabra
|
146 |
+
mensajes.append(mensaje_actual)
|
147 |
+
|
148 |
+
# Retornar los mensajes divididos
|
149 |
+
return {"respuestas": mensajes}
|
150 |
+
else:
|
151 |
+
# Obtener resumen de respuesta
|
152 |
+
resumen = summarizer(respuesta, max_length=50, min_length=5, do_sample=False)
|
153 |
+
print(f"Resumen obtenido: {resumen[0]['summary_text']}")
|
154 |
+
|
155 |
+
# Calcular similitud entre pregunta y respuesta
|
156 |
+
pregunta_vec = vectorizer.fit_transform([pregunta])
|
157 |
+
respuesta_vec = vectorizer.transform([respuesta])
|
158 |
+
similitud = cosine_similarity(pregunta_vec, respuesta_vec)
|
159 |
+
print(f"Similitud calculada: {similitud[0][0]}")
|
160 |
+
|
161 |
+
return {
|
162 |
+
"respuesta": respuesta,
|
163 |
+
"resumen": resumen[0]["summary_text"],
|
164 |
+
"similitud": similitud[0][0]
|
165 |
+
}
|
166 |
+
except Exception as e:
|
167 |
+
logging.error(f"Error en modelo {modelo}: {e}")
|
168 |
+
return {"error": f"Modelo {modelo} no disponible"}
|
169 |
+
|
170 |
+
@app.post("/resumen")
|
171 |
+
async def resumen(texto: str):
|
172 |
+
print(f"Texto recibido: {texto}")
|
173 |
+
try:
|
174 |
+
# Obtener resumen de texto
|
175 |
+
resumen = summarizer(texto, max_length=50, min_length=5, do_sample=False)
|
176 |
+
print(f"Resumen obtenido: {resumen[0]['summary_text']}")
|
177 |
+
|
178 |
+
return {"resumen": resumen[0]["summary_text"]}
|
179 |
+
except Exception as e:
|
180 |
+
logging.error(f"Error en resumen: {e}")
|
181 |
+
return {"error": "Error en resumen"}
|
182 |
+
|
183 |
+
@app.post("/similitud")
|
184 |
+
async def similitud(texto1: str, texto2: str):
|
185 |
+
print(f"Textos recibidos: {texto1}, {texto2}")
|
186 |
+
try:
|
187 |
+
# Calcular similitud entre dos textos
|
188 |
+
texto1_vec = vectorizer.fit_transform([texto1])
|
189 |
+
texto2_vec = vectorizer.transform([texto2])
|
190 |
+
similitud = cosine_similarity(texto1_vec, texto2_vec)
|
191 |
+
print(f"Similitud calculada: {similitud[0][0]}")
|
192 |
+
|
193 |
+
return {"similitud": similitud[0][0]}
|
194 |
+
except Exception as e:
|
195 |
+
logging.error(f"Error en similitud: {e}")
|
196 |
+
return {"error": "Error en similitud"}
|
197 |
+
|
198 |
+
@app.post("/imagen")
|
199 |
+
async def imagen(texto: str):
|
200 |
+
print(f"Texto recibido: {texto}")
|
201 |
+
try:
|
202 |
+
# Obtener imagen a partir del texto
|
203 |
+
imagen = dalle_decoder.generate_images(texto, num_images=1)
|
204 |
+
print(f"Imagen generada")
|
205 |
+
|
206 |
+
# Generar nombre aleatorio para el archivo
|
207 |
+
nombre_archivo = f"{uuid.uuid4()}.png"
|
208 |
+
print(f"Nombre de archivo: {nombre_archivo}")
|
209 |
+
|
210 |
+
# Guardar imagen en archivo
|
211 |
+
imagen.save(nombre_archivo)
|
212 |
+
print(f"Imagen guardada en {nombre_archivo}")
|
213 |
+
|
214 |
+
return {"imagen": nombre_archivo}
|
215 |
+
except Exception as e:
|
216 |
+
logging.error(f"Error en imagen: {e}")
|
217 |
+
return {"error": "Error en imagen"}
|
218 |
+
|
219 |
+
@app.get("/modelos")
|
220 |
+
async def modelos():
|
221 |
+
print("Modelos solicitados")
|
222 |
+
return {"modelos": list(cadenas.keys())}
|
223 |
+
|
224 |
+
@app.get("/estado")
|
225 |
+
async def estado():
|
226 |
+
print("Estado solicitado")
|
227 |
+
return {"estado": "activo"}
|
228 |
+
|
229 |
+
if __name__ == "__main__":
|
230 |
+
print("Iniciando API...")
|
231 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|