from flask import Flask, request, jsonify, render_template_string
from vllm import LLM, SamplingParams
from langchain_community.cache import GPTCache
import torch
app = Flask(__name__)
# Verificar si hay una GPU disponible, si no usar la CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Inicializar los modelos con el dispositivo adecuado (GPU o CPU)
modelos = {
"facebook/opt-125m": LLM(model="facebook/opt-125m", device=device),
"llama-3.2-1B": LLM(model="Hjgugugjhuhjggg/llama-3.2-1B-spinquant-hf", device=device),
"gpt2": LLM(model="gpt2", device=device)
}
# Configuración de caché para los modelos
caches = {
nombre: GPTCache(modelo, max_size=1000)
for nombre, modelo in modelos.items()
}
# Parámetros de muestreo para la generación de texto
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Código HTML para la documentación de la API
html_code_docs = """
Documentación de la API
API de Generación de Texto
Endpoints
-
Generar texto
Método: POST
Ruta: /generate
Parámetros:
- prompts: Lista de prompts para generar texto
- modelo: Nombre del modelo a utilizar
Ejemplo:
curl -X POST -H "Content-Type: application/json" -d '{"prompts": ["Hola, cómo estás?"], "modelo": "facebook/opt-125m"}' http://localhost:5000/generate
-
Obtener lista de modelos
Método: GET
Ruta: /modelos
Ejemplo:
curl -X GET http://localhost:5000/modelos
-
Chatbot
Método: POST
Ruta: /chatbot
Parámetros:
- mensaje: Mensaje para el chatbot
- modelo: Nombre del modelo a utilizar
Ejemplo:
curl -X POST -H "Content-Type: application/json" -d '{"mensaje": "Hola, cómo estás?", "modelo": "facebook/opt-125m"}' http://localhost:5000/chatbot
"""
# Código HTML para la interfaz del chatbot
html_code_chatbot = """
Chatbot
Chatbot
"""
@app.route('/generate', methods=['POST'])
def generate():
data = request.get_json()
prompts = data.get('prompts', [])
modelo_seleccionado = data.get('modelo', "facebook/opt-125m")
if modelo_seleccionado not in modelos:
return jsonify({"error": "Modelo no encontrado"}), 404
outputs = caches[modelo_seleccionado].generate(prompts, sampling_params)
results = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
results.append({
'prompt': prompt,
'generated_text': generated_text
})
return jsonify(results)
@app.route('/modelos', methods=['GET'])
def get_modelos():
return jsonify({"modelos": list(modelos.keys())})
@app.route('/docs', methods=['GET'])
def docs():
return render_template_string(html_code_docs)
@app.route('/chatbot', methods=['POST'])
def chatbot():
data = request.get_json()
mensaje = data.get('mensaje', '')
modelo_seleccionado = data.get('modelo', "facebook/opt-125m")
if modelo_seleccionado not in modelos:
return jsonify({"error": "Modelo no encontrado"}), 404
outputs = caches[modelo_seleccionado].generate([mensaje], sampling_params)
respuesta = outputs[0].outputs[0].text
return jsonify({"respuesta": respuesta})
@app.route('/chat', methods=['GET'])
def chat():
return render_template_string(html_code_chatbot)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)