Hjgugugjhuhjggg commited on
Commit
c9c8569
·
verified ·
1 Parent(s): bb9df29

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +231 -0
app.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ import torch
3
+ from langchain.chains.llm import LLMChain
4
+ from langchain.llms import VLLM
5
+ from langchain.cache import GPTCache
6
+ from transformers import pipeline
7
+ import uvicorn
8
+ import threading
9
+ import time
10
+ import nltk
11
+ from sklearn.feature_extraction.text import TfidfVectorizer
12
+ from sklearn.metrics.pairwise import cosine_similarity
13
+ import psutil
14
+ import os
15
+ import gc
16
+ import logging
17
+ from PIL import Image
18
+ from transformers import DALLEncoder, DALLDecoder
19
+ import uuid
20
+ from tqdm import tqdm
21
+
22
+ logging.basicConfig(level=logging.INFO)
23
+
24
+ nltk.download('punkt')
25
+ nltk.download('stopwords')
26
+
27
+ app = FastAPI()
28
+
29
+ Configurar dispositivo
30
+ if torch.cuda.is_available():
31
+ device = torch.device("cuda")
32
+ else:
33
+ device = torch.device("cpu")
34
+
35
+ print("Dispositivo:", device)
36
+
37
+ Cargar modelos con VLLM
38
+ modelos = {
39
+ "gpt2-medium": VLLM(model="gpt2-medium"),
40
+ "qwen2.5-0.5b": VLLM(model="qwen2.5-0.5b"),
41
+ "t5-base": VLLM(model="t5-base"),
42
+ "bert-base-uncased": VLLM(model="bert-base-uncased"),
43
+ "musicgen-small": VLLM(model="musicgen-small"),
44
+ "dall-e-mini": VLLM(model="dall-e-mini"),
45
+ "xlnet-base-uncased": VLLM(model="xlnet-base-uncased"),
46
+ "distilbert-base-uncased": VLLM(model="distilbert-base-uncased"),
47
+ "albert-base-v2": VLLM(model="albert-base-v2"),
48
+ "roberta-base": VLLM(model="roberta-base"),
49
+ }
50
+
51
+ print("Cargando modelos...")
52
+ for nombre, modelo in tqdm(modelos.items()):
53
+ modelos[nombre] = modelo(to=device)
54
+ print(f"Modelo {nombre} cargado")
55
+
56
+ Crear instancias de caché para cada modelo
57
+ caches = {
58
+ nombre: GPTCache(modelo, max_size=1000) for nombre, modelo in modelos.items()
59
+ }
60
+
61
+ print("Creando instancias de caché...")
62
+ for nombre, caché in tqdm(caches.items()):
63
+ print(f"Caché para modelo {nombre} creada")
64
+
65
+ Crear instancias de cadenas de modelo para cada modelo
66
+ cadenas = {
67
+ nombre: LLMChain(modelo, caché) for nombre, modelo, caché in zip(modelos.keys(), modelos.values(), caches.values())
68
+ }
69
+
70
+ print("Creando instancias de cadenas de modelo...")
71
+ for nombre, cadena in tqdm(cadenas.items()):
72
+ print(f"Cadena de modelo {nombre} creada")
73
+
74
+ Cargar modelo de resumen de texto
75
+ summarizer = pipeline("summarization", device=device)
76
+
77
+ print("Cargando modelo de resumen de texto...")
78
+
79
+ Cargar modelo de vectorizador TF-IDF
80
+ vectorizer = TfidfVectorizer()
81
+
82
+ print("Cargando modelo de vectorizador TF-IDF...")
83
+
84
+ Cargar modelo DALL-E
85
+ dalle_encoder = DALLEncoder(model_id="dall-e-mini")
86
+ dalle_decoder = DALLDecoder(model_id="dall-e-mini")
87
+
88
+ print("Cargando modelo DALL-E...")
89
+
90
+ def keep_alive():
91
+ while True:
92
+ # Realizar una petición a cada modelo cada 5 minutos
93
+ for cadena in cadenas.values():
94
+ try:
95
+ cadena.ask("¿Cuál es el sentido de la vida?")
96
+ except Exception as e:
97
+ logging.error(f"Error en modelo {cadena}: {e}")
98
+ cadenas.pop(cadena)
99
+ time.sleep(300)
100
+
101
+ def liberar_recursos():
102
+ while True:
103
+ # Obtener memoria RAM disponible
104
+ memoria_ram = psutil.virtual_memory().available / (1024.0 ** 3)
105
+
106
+ # Obtener espacio en disco disponible
107
+ espacio_disco = psutil.disk_usage('/').free / (1024.0 ** 3)
108
+
109
+ # Verificar si la memoria RAM o espacio en disco es menor al 5%
110
+ if memoria_ram < 5 or espacio_disco < 5:
111
+ # Liberar memoria RAM
112
+ gc.collect()
113
+
114
+ # Cerrar procesos innecesarios
115
+ for proc in psutil.process_iter(['pid', 'name']):
116
+ if proc.info['name'] == 'python':
117
+ os.kill(proc.info['pid'], 9)
118
+
119
+ time.sleep(60)
120
+
121
+ Iniciar hilos
122
+ threading.Thread(target=keep_alive, daemon=True).start()
123
+ threading.Thread(target=liberar_recursos, daemon=True).start()
124
+
125
+ print("Iniciando hilos...")
126
+ @app.post("/pregunta")
127
+ async def pregunta(pregunta: str, modelo: str):
128
+ print(f"Pregunta recibida: {pregunta}, Modelo: {modelo}")
129
+ try:
130
+ # Obtener respuesta del modelo seleccionado
131
+ respuesta = cadenas[modelo].ask(pregunta)
132
+ print(f"Respuesta obtenida: {respuesta}")
133
+
134
+ # Verificar si la respuesta excede el límite de tokens máximos
135
+ if len(respuesta.split()) > 2048:
136
+ # Dividir la respuesta en varios mensajes
137
+ mensajes = []
138
+ palabras = respuesta.split()
139
+ mensaje_actual = ""
140
+ for palabra in tqdm(palabras):
141
+ if len(mensaje_actual.split()) + len(palabra.split()) > 2048:
142
+ mensajes.append(mensaje_actual)
143
+ mensaje_actual = palabra
144
+ else:
145
+ mensaje_actual += " " + palabra
146
+ mensajes.append(mensaje_actual)
147
+
148
+ # Retornar los mensajes divididos
149
+ return {"respuestas": mensajes}
150
+ else:
151
+ # Obtener resumen de respuesta
152
+ resumen = summarizer(respuesta, max_length=50, min_length=5, do_sample=False)
153
+ print(f"Resumen obtenido: {resumen[0]['summary_text']}")
154
+
155
+ # Calcular similitud entre pregunta y respuesta
156
+ pregunta_vec = vectorizer.fit_transform([pregunta])
157
+ respuesta_vec = vectorizer.transform([respuesta])
158
+ similitud = cosine_similarity(pregunta_vec, respuesta_vec)
159
+ print(f"Similitud calculada: {similitud[0][0]}")
160
+
161
+ return {
162
+ "respuesta": respuesta,
163
+ "resumen": resumen[0]["summary_text"],
164
+ "similitud": similitud[0][0]
165
+ }
166
+ except Exception as e:
167
+ logging.error(f"Error en modelo {modelo}: {e}")
168
+ return {"error": f"Modelo {modelo} no disponible"}
169
+
170
+ @app.post("/resumen")
171
+ async def resumen(texto: str):
172
+ print(f"Texto recibido: {texto}")
173
+ try:
174
+ # Obtener resumen de texto
175
+ resumen = summarizer(texto, max_length=50, min_length=5, do_sample=False)
176
+ print(f"Resumen obtenido: {resumen[0]['summary_text']}")
177
+
178
+ return {"resumen": resumen[0]["summary_text"]}
179
+ except Exception as e:
180
+ logging.error(f"Error en resumen: {e}")
181
+ return {"error": "Error en resumen"}
182
+
183
+ @app.post("/similitud")
184
+ async def similitud(texto1: str, texto2: str):
185
+ print(f"Textos recibidos: {texto1}, {texto2}")
186
+ try:
187
+ # Calcular similitud entre dos textos
188
+ texto1_vec = vectorizer.fit_transform([texto1])
189
+ texto2_vec = vectorizer.transform([texto2])
190
+ similitud = cosine_similarity(texto1_vec, texto2_vec)
191
+ print(f"Similitud calculada: {similitud[0][0]}")
192
+
193
+ return {"similitud": similitud[0][0]}
194
+ except Exception as e:
195
+ logging.error(f"Error en similitud: {e}")
196
+ return {"error": "Error en similitud"}
197
+
198
+ @app.post("/imagen")
199
+ async def imagen(texto: str):
200
+ print(f"Texto recibido: {texto}")
201
+ try:
202
+ # Obtener imagen a partir del texto
203
+ imagen = dalle_decoder.generate_images(texto, num_images=1)
204
+ print(f"Imagen generada")
205
+
206
+ # Generar nombre aleatorio para el archivo
207
+ nombre_archivo = f"{uuid.uuid4()}.png"
208
+ print(f"Nombre de archivo: {nombre_archivo}")
209
+
210
+ # Guardar imagen en archivo
211
+ imagen.save(nombre_archivo)
212
+ print(f"Imagen guardada en {nombre_archivo}")
213
+
214
+ return {"imagen": nombre_archivo}
215
+ except Exception as e:
216
+ logging.error(f"Error en imagen: {e}")
217
+ return {"error": "Error en imagen"}
218
+
219
+ @app.get("/modelos")
220
+ async def modelos():
221
+ print("Modelos solicitados")
222
+ return {"modelos": list(cadenas.keys())}
223
+
224
+ @app.get("/estado")
225
+ async def estado():
226
+ print("Estado solicitado")
227
+ return {"estado": "activo"}
228
+
229
+ if __name__ == "__main__":
230
+ print("Iniciando API...")
231
+ uvicorn.run(app, host="0.0.0.0", port=8000)