from fastapi import FastAPI
from langchain_community.llms import VLLM
from langchain_community.cache import GPTCache
import torch
from langchain.chains.llm import LLMChain
from transformers import pipeline
import uvicorn
import threading
import time
import nltk
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import psutil
import os
import gc
import logging

logging.basicConfig(level=logging.INFO)
nltk.download('punkt')
nltk.download('stopwords')

app = FastAPI()

device = torch.device("cpu")

modelos = {
    "gpt2-medium": VLLM(model="gpt2-medium", device=device),
    "qwen2.5-0.5b": VLLM(model="Qwen/Qwen2.5-0.5B-Instruct", device=device),
    "llamaxd": VLLM(model="Hjgugugjhuhjggg/llama-3.2-1B-spinquant-hf", device=device)
}

caches = {
    nombre: GPTCache(modelo, max_size=1000) 
    for nombre, modelo in modelos.items()
}

cadenas = {
    nombre: LLMChain(modelo, caché) 
    for nombre, modelo, caché in zip(modelos.keys(), modelos.values(), caches.values())
}

summarizer = pipeline("summarization", device=device)
vectorizer = TfidfVectorizer()

def keep_alive():
    while True:
        for cadena in cadenas.values():
            try:
                cadena.ask("¿Cuál es el sentido de la vida?")
            except Exception as e:
                logging.error(f"Error en modelo {cadena}: {e}")
                cadenas.pop(cadena)
        time.sleep(300)

def liberar_recursos():
    while True:
        memoria_ram = psutil.virtual_memory().available / (1024.0 ** 3)
        espacio_disco = psutil.disk_usage('/').free / (1024.0 ** 3)
        if memoria_ram < 5 or espacio_disco < 5:
            gc.collect()
            for proc in psutil.process_iter(['pid', 'name']):
                if proc.info['name'] == 'python':
                    os.kill(proc.info['pid'], 9)
        time.sleep(60)

threading.Thread(target=keep_alive, daemon=True).start()
threading.Thread(target=liberar_recursos, daemon=True).start()

@app.post("/pregunta")
async def pregunta(pregunta: str, modelo: str):
    try:
        respuesta = cadenas[modelo].ask(pregunta)
        if len(respuesta.split()) > 2048:
            mensajes = []
            palabras = respuesta.split()
            mensaje_actual = ""
            for palabra in palabras:
                if len(mensaje_actual.split()) + len(palabra.split()) > 2048:
                    mensajes.append(mensaje_actual)
                    mensaje_actual = palabra
                else:
                    mensaje_actual += " " + palabra
            mensajes.append(mensaje_actual)
            return {"respuestas": mensajes}
        else:
            resumen = summarizer(respuesta, max_length=50, min_length=5, do_sample=False)
            pregunta_vec = vectorizer.fit_transform([pregunta])
            respuesta_vec = vectorizer.transform([respuesta])
            similitud = cosine_similarity(pregunta_vec, respuesta_vec)
            return {
                "respuesta": respuesta, 
                "resumen": resumen[0]["summary_text"], 
                "similitud": similitud[0][0]
            }
    except Exception as e:
        logging.error(f"Error en modelo {modelo}: {e}")
        return {"error": f"Modelo {modelo} no disponible"}

@app.post("/resumen")
async def resumen(texto: str):
    try:
        resumen = summarizer(texto, max_length=50, min_length=5, do_sample=False)
        return {"resumen": resumen[0]["summary_text"]}
    except Exception as e:
        logging.error(f"Error en resumen: {e}")
        return {"error": "Error en resumen"}

@app.post("/similitud")
async def similitud(texto1: str, texto2: str):
    try:
        texto1_vec = vectorizer.fit_transform([texto1])
        texto2_vec = vectorizer.transform([texto2])
        similitud = cosine_similarity(texto1_vec, texto2_vec)
        return {"similitud": similitud[0][0]}
    except Exception as e:
        logging.error(f"Error en similitud: {e}")
        return {"error": "Error en similitud"}

@app.get("/modelos")
async def modelos():
    return {"modelos": list(cadenas.keys())}

@app.get("/estado")
async def estado():
    return {"estado": "activo"}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)