Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -1,96 +1,225 @@
|
|
|
|
|
|
1 |
import os
|
2 |
-
import
|
3 |
-
import uvicorn
|
4 |
-
from fastapi import FastAPI, Request
|
5 |
-
from fastapi.responses import JSONResponse
|
6 |
-
from langchain.llms import VLLM
|
7 |
-
from gptcache import Cache
|
8 |
-
from gptcache.manager.factory import manager_factory
|
9 |
-
from gptcache.processor.pre import get_prompt
|
10 |
-
from langchain_community.callbacks.manager import get_openai_callback
|
11 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
12 |
-
from sentence_transformers import SentenceTransformer
|
13 |
import torch
|
14 |
-
import
|
15 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
app = FastAPI()
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import psutil
|
3 |
import os
|
4 |
+
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import torch
|
6 |
+
from fastapi import FastAPI
|
7 |
+
from vllm import VLLM
|
8 |
+
from chatgptcache import cache
|
9 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
10 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
11 |
+
import nltk
|
12 |
+
from nltk.tokenize import sent_tokenize, word_tokenize
|
13 |
+
from nltk.corpus import stopwords
|
14 |
+
from collections import Counter
|
15 |
+
import asyncio
|
16 |
+
import torch.nn.utils.prune as prune
|
17 |
+
from concurrent.futures import ThreadPoolExecutor
|
18 |
+
|
19 |
+
nltk.download('punkt')
|
20 |
+
nltk.download('stopwords')
|
21 |
|
22 |
app = FastAPI()
|
23 |
|
24 |
+
# Definir los modelos (ser谩n cargados m谩s tarde)
|
25 |
+
model_1 = None
|
26 |
+
model_2 = None
|
27 |
+
model_3 = None
|
28 |
+
model_4 = None
|
29 |
+
|
30 |
+
cache_1 = cache.SimpleCache()
|
31 |
+
cache_2 = cache.SimpleCache()
|
32 |
+
cache_3 = cache.SimpleCache()
|
33 |
+
cache_4 = cache.SimpleCache()
|
34 |
+
|
35 |
+
previous_responses_1 = []
|
36 |
+
previous_responses_2 = []
|
37 |
+
previous_responses_3 = []
|
38 |
+
previous_responses_4 = []
|
39 |
+
|
40 |
+
MAX_TOKENS = 2048 # M谩ximo de tokens para entrada y salida del modelo
|
41 |
+
|
42 |
+
# Usar ThreadPoolExecutor para ejecuci贸n en paralelo
|
43 |
+
executor = ThreadPoolExecutor(max_workers=4)
|
44 |
+
|
45 |
+
# Configuraci贸n del dispositivo (CPU)
|
46 |
+
device = torch.device("cpu")
|
47 |
+
|
48 |
+
def get_best_response(new_response, previous_responses):
|
49 |
+
if not previous_responses:
|
50 |
+
return new_response
|
51 |
+
vectorizer = TfidfVectorizer().fit_transform(previous_responses + [new_response])
|
52 |
+
cosine_sim = cosine_similarity(vectorizer[-1], vectorizer[:-1])
|
53 |
+
max_sim_index = cosine_sim.argmax()
|
54 |
+
max_sim_score = cosine_sim[0][max_sim_index]
|
55 |
+
if max_sim_score > 0.7:
|
56 |
+
return previous_responses[max_sim_index]
|
57 |
+
return new_response
|
58 |
+
|
59 |
+
def summarize_text(text):
|
60 |
+
sentences = sent_tokenize(text)
|
61 |
+
stop_words = set(stopwords.words("english"))
|
62 |
+
word_frequencies = Counter()
|
63 |
+
for sentence in sentences:
|
64 |
+
words = word_tokenize(sentence.lower())
|
65 |
+
words = [word for word in words if word.isalpha() and word not in stop_words]
|
66 |
+
word_frequencies.update(words)
|
67 |
+
most_common_words = word_frequencies.most_common(50)
|
68 |
+
most_common_words = {word: freq for word, freq in most_common_words}
|
69 |
+
ranked_sentences = []
|
70 |
+
for sentence in sentences:
|
71 |
+
score = sum(most_common_words.get(word, 0) for word in word_tokenize(sentence.lower()))
|
72 |
+
ranked_sentences.append((score, sentence))
|
73 |
+
ranked_sentences.sort(reverse=True, key=lambda x: x[0])
|
74 |
+
summary = ' '.join([sentence for _, sentence in ranked_sentences[:3]])
|
75 |
+
return summary
|
76 |
+
|
77 |
+
def clear_memory():
|
78 |
+
gc.collect()
|
79 |
+
process = psutil.Process(os.getpid())
|
80 |
+
memory_usage = psutil.virtual_memory().percent
|
81 |
+
if memory_usage > 90:
|
82 |
+
global model_1, model_2, model_3, model_4
|
83 |
+
model_1 = None
|
84 |
+
model_2 = None
|
85 |
+
model_3 = None
|
86 |
+
model_4 = None
|
87 |
+
gc.collect()
|
88 |
+
|
89 |
+
def apply_pruning(model):
|
90 |
+
for name, module in model.named_modules():
|
91 |
+
if isinstance(module, torch.nn.Linear):
|
92 |
+
prune.random_unstructured(module, name="weight", amount=0.2)
|
93 |
+
prune.remove(module, name="weight") # Opcional: Eliminar la m谩scara de poda para conservar los pesos podados
|
94 |
+
return model
|
95 |
+
|
96 |
+
def split_input(input_text, max_tokens):
|
97 |
+
tokens = input_text.split() # Dividir entrada en palabras (tokens)
|
98 |
+
chunks = []
|
99 |
+
chunk = []
|
100 |
+
total_tokens = 0
|
101 |
+
|
102 |
+
for word in tokens:
|
103 |
+
word_length = len(word.split()) # Estimar la longitud de los tokens
|
104 |
+
if total_tokens + word_length > max_tokens:
|
105 |
+
chunks.append(" ".join(chunk))
|
106 |
+
chunk = [word]
|
107 |
+
total_tokens = word_length
|
108 |
+
else:
|
109 |
+
chunk.append(word)
|
110 |
+
total_tokens += word_length
|
111 |
+
|
112 |
+
if chunk:
|
113 |
+
chunks.append(" ".join(chunk)) # Agregar el 煤ltimo fragmento
|
114 |
+
|
115 |
+
return chunks
|
116 |
+
|
117 |
+
def split_output(output_text, max_tokens):
|
118 |
+
tokens = output_text.split() # Dividir salida en palabras (tokens)
|
119 |
+
chunks = []
|
120 |
+
chunk = []
|
121 |
+
total_tokens = 0
|
122 |
+
|
123 |
+
for word in tokens:
|
124 |
+
word_length = len(word.split()) # Estimar la longitud de los tokens
|
125 |
+
if total_tokens + word_length > max_tokens:
|
126 |
+
chunks.append(" ".join(chunk))
|
127 |
+
chunk = [word]
|
128 |
+
total_tokens = word_length
|
129 |
+
else:
|
130 |
+
chunk.append(word)
|
131 |
+
total_tokens += word_length
|
132 |
+
|
133 |
+
if chunk:
|
134 |
+
chunks.append(" ".join(chunk)) # Agregar el 煤ltimo fragmento
|
135 |
+
|
136 |
+
return chunks
|
137 |
+
|
138 |
+
async def load_model_async(model_name: str):
|
139 |
+
max_model_len = MAX_TOKENS # Establecer la longitud m谩xima del modelo (tokens)
|
140 |
+
if model_name == "model_1":
|
141 |
+
return VLLM("Hjgugugjhuhjggg/llama-3.2-1B-spinquant-hf", device=device, max_model_len=max_model_len)
|
142 |
+
elif model_name == "model_2":
|
143 |
+
return VLLM("meta-llama/Llama-3.2-1B", device=device, max_model_len=max_model_len)
|
144 |
+
elif model_name == "model_3":
|
145 |
+
return VLLM("Qwen2.5-3B-Instruct", device=device, max_model_len=max_model_len)
|
146 |
+
elif model_name == "model_4":
|
147 |
+
return VLLM("gpt2", device=device, max_model_len=max_model_len)
|
148 |
+
return None
|
149 |
+
|
150 |
+
async def load_models():
|
151 |
+
global model_1, model_2, model_3, model_4
|
152 |
+
tasks = [
|
153 |
+
load_model_async("model_1"),
|
154 |
+
load_model_async("model_2"),
|
155 |
+
load_model_async("model_3"),
|
156 |
+
load_model_async("model_4"),
|
157 |
+
]
|
158 |
+
results = await asyncio.gather(*tasks)
|
159 |
+
model_1, model_2, model_3, model_4 = results
|
160 |
+
model_1 = apply_pruning(model_1)
|
161 |
+
model_2 = apply_pruning(model_2)
|
162 |
+
model_3 = apply_pruning(model_3)
|
163 |
+
model_4 = apply_pruning(model_4)
|
164 |
+
print("Modelos cargados y podados exitosamente.")
|
165 |
+
|
166 |
+
async def optimize_models_periodically():
|
167 |
+
while True:
|
168 |
+
await load_models() # Cargar y optimizar modelos autom谩ticamente
|
169 |
+
await asyncio.sleep(3600) # Optimizar modelos cada hora (ajustar intervalo seg煤n sea necesario)
|
170 |
+
|
171 |
+
@app.on_event("startup")
|
172 |
+
async def startup():
|
173 |
+
await load_models()
|
174 |
+
app.add_event_handler("startup", monitor_memory)
|
175 |
+
app.add_event_handler("startup", optimize_models_periodically)
|
176 |
+
|
177 |
+
async def monitor_memory():
|
178 |
+
while True:
|
179 |
+
clear_memory()
|
180 |
+
await asyncio.sleep(60)
|
181 |
+
|
182 |
+
@app.get("/generate")
|
183 |
+
async def generate_response(model_name: str, input_text: str):
|
184 |
+
def generate_for_model(model, input_text, cache, previous_responses):
|
185 |
+
cached_output = cache.get(input_text)
|
186 |
+
if cached_output:
|
187 |
+
return cached_output
|
188 |
+
|
189 |
+
input_chunks = split_input(input_text, MAX_TOKENS)
|
190 |
+
output_text = ""
|
191 |
+
prev_output = ""
|
192 |
+
|
193 |
+
for chunk in input_chunks:
|
194 |
+
prompt = prev_output + chunk
|
195 |
+
output_text += model.generate(prompt)
|
196 |
+
prev_output = output_text.split()[-50:]
|
197 |
+
|
198 |
+
output_chunks = split_output(output_text, MAX_TOKENS)
|
199 |
+
best_response = get_best_response(output_chunks[0], previous_responses)
|
200 |
+
cache.put(input_text, best_response)
|
201 |
+
previous_responses.append(best_response)
|
202 |
+
return best_response
|
203 |
+
|
204 |
+
result = await asyncio.get_event_loop().run_in_executor(
|
205 |
+
executor,
|
206 |
+
generate_for_model,
|
207 |
+
model_1 if model_name == "model1" else model_2 if model_name == "model2" else model_3 if model_name == "model3" else model_4,
|
208 |
+
input_text,
|
209 |
+
cache_1 if model_name == "model1" else cache_2 if model_name == "model2" else cache_3 if model_name == "model3" else cache_4,
|
210 |
+
previous_responses_1 if model_name == "model1" else previous_responses_2 if model_name == "model2" else previous_responses_3 if model_name == "model3" else previous_responses_4
|
211 |
+
)
|
212 |
+
return {f"{model_name}_output": result}
|
213 |
+
|
214 |
+
@app.get("/unified_summary")
|
215 |
+
async def unified_summary(input_text: str):
|
216 |
+
output1 = await generate_response(model_name="model1", input_text=input_text)
|
217 |
+
output2 = await generate_response(model_name="model2", input_text=input_text)
|
218 |
+
output3 = await generate_response(model_name="model3", input_text=input_text)
|
219 |
+
output4 = await generate_response(model_name="model4", input_text=input_text)
|
220 |
+
combined_response = output1.get("model1_output", "") + " " + \
|
221 |
+
output2.get("model2_output", "") + " " + \
|
222 |
+
output3.get("model3_output", "") + " " + \
|
223 |
+
output4.get("model4_output", "")
|
224 |
+
summarized_response = summarize_text(combined_response)
|
225 |
+
return {"summary": summarized_response}
|