|
from fastapi import FastAPI |
|
import torch |
|
from langchain.chains.llm import LLMChain |
|
from langchain.llms import VLLM |
|
from langchain.cache import GPTCache |
|
from transformers import pipeline |
|
import uvicorn |
|
import threading |
|
import time |
|
import nltk |
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import psutil |
|
import os |
|
import gc |
|
import logging |
|
from PIL import Image |
|
from transformers import DALLEncoder, DALLDecoder |
|
import uuid |
|
from tqdm import tqdm |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
nltk.download('punkt') |
|
nltk.download('stopwords') |
|
|
|
app = FastAPI() |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
print("Dispositivo:", device) |
|
|
|
modelos = { |
|
"gpt2-medium": VLLM(model="gpt2-medium"), |
|
"qwen2.5-0.5b": VLLM(model="Qwen/Qwen2.5-0.5B-Instruct"), |
|
"llamaxd": VLLM(model="Hjgugugjhuhjggg/llama-3.2-1B-spinquant-hf"), |
|
"t5-base": VLLM(model="t5-base"), |
|
"bert-base-uncased": VLLM(model="bert-base-uncased"), |
|
"musicgen-small": VLLM(model="musicgen-small"), |
|
"dall-e-mini": VLLM(model="dall-e-mini"), |
|
"xlnet-base-uncased": VLLM(model="xlnet-base-uncased"), |
|
"distilbert-base-uncased": VLLM(model="distilbert-base-uncased"), |
|
"albert-base-v2": VLLM(model="albert-base-v2"), |
|
"roberta-base": VLLM(model="roberta-base"), |
|
} |
|
|
|
print("Cargando modelos...") |
|
for nombre, modelo in tqdm(modelos.items()): |
|
modelos[nombre] = modelo(to=device) |
|
print(f"Modelo {nombre} cargado") |
|
|
|
Crear instancias de caché para cada modelo |
|
caches = { |
|
nombre: GPTCache(modelo, max_size=1000) for nombre, modelo in modelos.items() |
|
} |
|
|
|
print("Creando instancias de caché...") |
|
for nombre, caché in tqdm(caches.items()): |
|
print(f"Caché para modelo {nombre} creada") |
|
|
|
cadenas = { |
|
nombre: LLMChain(modelo, caché) for nombre, modelo, caché in zip(modelos.keys(), modelos.values(), caches.values()) |
|
} |
|
|
|
print("Creando instancias de cadenas de modelo...") |
|
for nombre, cadena in tqdm(cadenas.items()): |
|
print(f"Cadena de modelo {nombre} creada") |
|
|
|
summarizer = pipeline("summarization", device=device) |
|
|
|
print("Cargando modelo de resumen de texto...") |
|
|
|
vectorizer = TfidfVectorizer() |
|
|
|
print("Cargando modelo de vectorizador TF-IDF...") |
|
|
|
dalle_encoder = DALLEncoder(model_id="dall-e-mini") |
|
dalle_decoder = DALLDecoder(model_id="dall-e-mini") |
|
|
|
print("Cargando modelo DALL-E...") |
|
|
|
def keep_alive(): |
|
while True: |
|
|
|
for cadena in cadenas.values(): |
|
try: |
|
cadena.ask("¿Cuál es el sentido de la vida?") |
|
except Exception as e: |
|
logging.error(f"Error en modelo {cadena}: {e}") |
|
cadenas.pop(cadena) |
|
time.sleep(300) |
|
|
|
def liberar_recursos(): |
|
while True: |
|
|
|
memoria_ram = psutil.virtual_memory().available / (1024.0 ** 3) |
|
|
|
|
|
espacio_disco = psutil.disk_usage('/').free / (1024.0 ** 3) |
|
|
|
|
|
if memoria_ram < 5 or espacio_disco < 5: |
|
|
|
gc.collect() |
|
|
|
|
|
for proc in psutil.process_iter(['pid', 'name']): |
|
if proc.info['name'] == 'python': |
|
os.kill(proc.info['pid'], 9) |
|
|
|
time.sleep(60) |
|
|
|
threading.Thread(target=keep_alive, daemon=True).start() |
|
threading.Thread(target=liberar_recursos, daemon=True).start() |
|
|
|
print("Iniciando hilos...") |
|
@app.post("/pregunta") |
|
async def pregunta(pregunta: str, modelo: str): |
|
print(f"Pregunta recibida: {pregunta}, Modelo: {modelo}") |
|
try: |
|
|
|
respuesta = cadenas[modelo].ask(pregunta) |
|
print(f"Respuesta obtenida: {respuesta}") |
|
|
|
|
|
if len(respuesta.split()) > 2048: |
|
|
|
mensajes = [] |
|
palabras = respuesta.split() |
|
mensaje_actual = "" |
|
for palabra in tqdm(palabras): |
|
if len(mensaje_actual.split()) + len(palabra.split()) > 2048: |
|
mensajes.append(mensaje_actual) |
|
mensaje_actual = palabra |
|
else: |
|
mensaje_actual += " " + palabra |
|
mensajes.append(mensaje_actual) |
|
|
|
|
|
return {"respuestas": mensajes} |
|
else: |
|
|
|
resumen = summarizer(respuesta, max_length=50, min_length=5, do_sample=False) |
|
print(f"Resumen obtenido: {resumen[0]['summary_text']}") |
|
|
|
|
|
pregunta_vec = vectorizer.fit_transform([pregunta]) |
|
respuesta_vec = vectorizer.transform([respuesta]) |
|
similitud = cosine_similarity(pregunta_vec, respuesta_vec) |
|
print(f"Similitud calculada: {similitud[0][0]}") |
|
|
|
return { |
|
"respuesta": respuesta, |
|
"resumen": resumen[0]["summary_text"], |
|
"similitud": similitud[0][0] |
|
} |
|
except Exception as e: |
|
logging.error(f"Error en modelo {modelo}: {e}") |
|
return {"error": f"Modelo {modelo} no disponible"} |
|
|
|
@app.post("/resumen") |
|
async def resumen(texto: str): |
|
print(f"Texto recibido: {texto}") |
|
try: |
|
|
|
resumen = summarizer(texto, max_length=50, min_length=5, do_sample=False) |
|
print(f"Resumen obtenido: {resumen[0]['summary_text']}") |
|
|
|
return {"resumen": resumen[0]["summary_text"]} |
|
except Exception as e: |
|
logging.error(f"Error en resumen: {e}") |
|
return {"error": "Error en resumen"} |
|
|
|
@app.post("/similitud") |
|
async def similitud(texto1: str, texto2: str): |
|
print(f"Textos recibidos: {texto1}, {texto2}") |
|
try: |
|
|
|
texto1_vec = vectorizer.fit_transform([texto1]) |
|
texto2_vec = vectorizer.transform([texto2]) |
|
similitud = cosine_similarity(texto1_vec, texto2_vec) |
|
print(f"Similitud calculada: {similitud[0][0]}") |
|
|
|
return {"similitud": similitud[0][0]} |
|
except Exception as e: |
|
logging.error(f"Error en similitud: {e}") |
|
return {"error": "Error en similitud"} |
|
|
|
@app.post("/imagen") |
|
async def imagen(texto: str): |
|
print(f"Texto recibido: {texto}") |
|
try: |
|
|
|
imagen = dalle_decoder.generate_images(texto, num_images=1) |
|
print(f"Imagen generada") |
|
|
|
|
|
nombre_archivo = f"{uuid.uuid4()}.png" |
|
print(f"Nombre de archivo: {nombre_archivo}") |
|
|
|
|
|
imagen.save(nombre_archivo) |
|
print(f"Imagen guardada en {nombre_archivo}") |
|
|
|
return {"imagen": nombre_archivo} |
|
except Exception as e: |
|
logging.error(f"Error en imagen: {e}") |
|
return {"error": "Error en imagen"} |
|
|
|
@app.get("/modelos") |
|
async def modelos(): |
|
print("Modelos solicitados") |
|
return {"modelos": list(cadenas.keys())} |
|
|
|
@app.get("/estado") |
|
async def estado(): |
|
print("Estado solicitado") |
|
return {"estado": "activo"} |
|
|
|
if __name__ == "__main__": |
|
print("Iniciando API...") |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|