Hhhgg / app.py
Hjgugugjhuhjggg's picture
Update app.py
a6c0d65 verified
raw
history blame
8.26 kB
import gc
import psutil
import os
import time
import torch
from fastapi import FastAPI
from vllm import VLLM
from chatgptcache import cache
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize
from nltk.corpus import stopwords
from collections import Counter
import asyncio
import torch.nn.utils.prune as prune
from concurrent.futures import ThreadPoolExecutor
nltk.download('punkt')
nltk.download('stopwords')
app = FastAPI()
# Definir los modelos (serán cargados más tarde)
model_1 = None
model_2 = None
model_3 = None
model_4 = None
cache_1 = cache.SimpleCache()
cache_2 = cache.SimpleCache()
cache_3 = cache.SimpleCache()
cache_4 = cache.SimpleCache()
previous_responses_1 = []
previous_responses_2 = []
previous_responses_3 = []
previous_responses_4 = []
MAX_TOKENS = 2048 # Máximo de tokens para entrada y salida del modelo
# Usar ThreadPoolExecutor para ejecución en paralelo
executor = ThreadPoolExecutor(max_workers=4)
# Configuración del dispositivo (CPU)
device = torch.device("cpu")
def get_best_response(new_response, previous_responses):
if not previous_responses:
return new_response
vectorizer = TfidfVectorizer().fit_transform(previous_responses + [new_response])
cosine_sim = cosine_similarity(vectorizer[-1], vectorizer[:-1])
max_sim_index = cosine_sim.argmax()
max_sim_score = cosine_sim[0][max_sim_index]
if max_sim_score > 0.7:
return previous_responses[max_sim_index]
return new_response
def summarize_text(text):
sentences = sent_tokenize(text)
stop_words = set(stopwords.words("english"))
word_frequencies = Counter()
for sentence in sentences:
words = word_tokenize(sentence.lower())
words = [word for word in words if word.isalpha() and word not in stop_words]
word_frequencies.update(words)
most_common_words = word_frequencies.most_common(50)
most_common_words = {word: freq for word, freq in most_common_words}
ranked_sentences = []
for sentence in sentences:
score = sum(most_common_words.get(word, 0) for word in word_tokenize(sentence.lower()))
ranked_sentences.append((score, sentence))
ranked_sentences.sort(reverse=True, key=lambda x: x[0])
summary = ' '.join([sentence for _, sentence in ranked_sentences[:3]])
return summary
def clear_memory():
gc.collect()
process = psutil.Process(os.getpid())
memory_usage = psutil.virtual_memory().percent
if memory_usage > 90:
global model_1, model_2, model_3, model_4
model_1 = None
model_2 = None
model_3 = None
model_4 = None
gc.collect()
def apply_pruning(model):
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
prune.random_unstructured(module, name="weight", amount=0.2)
prune.remove(module, name="weight") # Opcional: Eliminar la máscara de poda para conservar los pesos podados
return model
def split_input(input_text, max_tokens):
tokens = input_text.split() # Dividir entrada en palabras (tokens)
chunks = []
chunk = []
total_tokens = 0
for word in tokens:
word_length = len(word.split()) # Estimar la longitud de los tokens
if total_tokens + word_length > max_tokens:
chunks.append(" ".join(chunk))
chunk = [word]
total_tokens = word_length
else:
chunk.append(word)
total_tokens += word_length
if chunk:
chunks.append(" ".join(chunk)) # Agregar el último fragmento
return chunks
def split_output(output_text, max_tokens):
tokens = output_text.split() # Dividir salida en palabras (tokens)
chunks = []
chunk = []
total_tokens = 0
for word in tokens:
word_length = len(word.split()) # Estimar la longitud de los tokens
if total_tokens + word_length > max_tokens:
chunks.append(" ".join(chunk))
chunk = [word]
total_tokens = word_length
else:
chunk.append(word)
total_tokens += word_length
if chunk:
chunks.append(" ".join(chunk)) # Agregar el último fragmento
return chunks
async def load_model_async(model_name: str):
max_model_len = MAX_TOKENS # Establecer la longitud máxima del modelo (tokens)
if model_name == "model_1":
return VLLM("Hjgugugjhuhjggg/llama-3.2-1B-spinquant-hf", device=device, max_model_len=max_model_len)
elif model_name == "model_2":
return VLLM("meta-llama/Llama-3.2-1B", device=device, max_model_len=max_model_len)
elif model_name == "model_3":
return VLLM("Qwen2.5-3B-Instruct", device=device, max_model_len=max_model_len)
elif model_name == "model_4":
return VLLM("gpt2", device=device, max_model_len=max_model_len)
return None
async def load_models():
global model_1, model_2, model_3, model_4
tasks = [
load_model_async("model_1"),
load_model_async("model_2"),
load_model_async("model_3"),
load_model_async("model_4"),
]
results = await asyncio.gather(*tasks)
model_1, model_2, model_3, model_4 = results
model_1 = apply_pruning(model_1)
model_2 = apply_pruning(model_2)
model_3 = apply_pruning(model_3)
model_4 = apply_pruning(model_4)
print("Modelos cargados y podados exitosamente.")
async def optimize_models_periodically():
while True:
await load_models() # Cargar y optimizar modelos automáticamente
await asyncio.sleep(3600) # Optimizar modelos cada hora (ajustar intervalo según sea necesario)
@app.on_event("startup")
async def startup():
await load_models()
app.add_event_handler("startup", monitor_memory)
app.add_event_handler("startup", optimize_models_periodically)
async def monitor_memory():
while True:
clear_memory()
await asyncio.sleep(60)
@app.get("/generate")
async def generate_response(model_name: str, input_text: str):
def generate_for_model(model, input_text, cache, previous_responses):
cached_output = cache.get(input_text)
if cached_output:
return cached_output
input_chunks = split_input(input_text, MAX_TOKENS)
output_text = ""
prev_output = ""
for chunk in input_chunks:
prompt = prev_output + chunk
output_text += model.generate(prompt)
prev_output = output_text.split()[-50:]
output_chunks = split_output(output_text, MAX_TOKENS)
best_response = get_best_response(output_chunks[0], previous_responses)
cache.put(input_text, best_response)
previous_responses.append(best_response)
return best_response
result = await asyncio.get_event_loop().run_in_executor(
executor,
generate_for_model,
model_1 if model_name == "model1" else model_2 if model_name == "model2" else model_3 if model_name == "model3" else model_4,
input_text,
cache_1 if model_name == "model1" else cache_2 if model_name == "model2" else cache_3 if model_name == "model3" else cache_4,
previous_responses_1 if model_name == "model1" else previous_responses_2 if model_name == "model2" else previous_responses_3 if model_name == "model3" else previous_responses_4
)
return {f"{model_name}_output": result}
@app.get("/unified_summary")
async def unified_summary(input_text: str):
output1 = await generate_response(model_name="model1", input_text=input_text)
output2 = await generate_response(model_name="model2", input_text=input_text)
output3 = await generate_response(model_name="model3", input_text=input_text)
output4 = await generate_response(model_name="model4", input_text=input_text)
combined_response = output1.get("model1_output", "") + " " + \
output2.get("model2_output", "") + " " + \
output3.get("model3_output", "") + " " + \
output4.get("model4_output", "")
summarized_response = summarize_text(combined_response)
return {"summary": summarized_response}