Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, Request
|
2 |
+
from pydantic import BaseModel, Field
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
+
import torch
|
5 |
+
from typing import Optional, List
|
6 |
+
import asyncio
|
7 |
+
from fastapi.responses import StreamingResponse, HTMLResponse
|
8 |
+
import uvicorn
|
9 |
+
import psutil
|
10 |
+
|
11 |
+
app = FastAPI()
|
12 |
+
|
13 |
+
dispositivo = torch.device("cpu")
|
14 |
+
CPU_LIMIT = 30.0
|
15 |
+
RAM_LIMIT = 30.0
|
16 |
+
|
17 |
+
html_code = """
|
18 |
+
<!DOCTYPE html>
|
19 |
+
<html>
|
20 |
+
<head>
|
21 |
+
<title>Chatbot</title>
|
22 |
+
<style>
|
23 |
+
body { font-family: Arial, sans-serif; margin: 50px; }
|
24 |
+
#chat { border: 1px solid #ccc; padding: 10px; height: 300px; overflow-y: scroll; }
|
25 |
+
#input { width: 80%; padding: 10px; }
|
26 |
+
#send { padding: 10px; }
|
27 |
+
</style>
|
28 |
+
</head>
|
29 |
+
<body>
|
30 |
+
<h1>Chatbot</h1>
|
31 |
+
<div id="chat"></div>
|
32 |
+
<input type="text" id="input" placeholder="Escribe tu mensaje...">
|
33 |
+
<button id="send">Enviar</button>
|
34 |
+
|
35 |
+
<script>
|
36 |
+
const sendButton = document.getElementById('send');
|
37 |
+
const inputBox = document.getElementById('input');
|
38 |
+
const chatBox = document.getElementById('chat');
|
39 |
+
let history = [];
|
40 |
+
|
41 |
+
sendButton.addEventListener('click', () => {
|
42 |
+
const message = inputBox.value;
|
43 |
+
if (message.trim() === '') return;
|
44 |
+
history.push(`T煤: ${message}`);
|
45 |
+
chatBox.innerHTML += `<div><strong>T煤:</strong> ${message}</div>`;
|
46 |
+
inputBox.value = '';
|
47 |
+
fetch('/generar', {
|
48 |
+
method: 'POST',
|
49 |
+
headers: {
|
50 |
+
'Content-Type': 'application/json'
|
51 |
+
},
|
52 |
+
body: JSON.stringify({
|
53 |
+
texto: message,
|
54 |
+
history: history
|
55 |
+
})
|
56 |
+
})
|
57 |
+
.then(response => {
|
58 |
+
if (!response.body) {
|
59 |
+
throw new Error('No soporta streaming');
|
60 |
+
}
|
61 |
+
const reader = response.body.getReader();
|
62 |
+
const decoder = new TextDecoder();
|
63 |
+
let botMessage = '';
|
64 |
+
function read() {
|
65 |
+
reader.read().then(({ done, value }) => {
|
66 |
+
if (done) {
|
67 |
+
history.push(`Bot: ${botMessage}`);
|
68 |
+
chatBox.innerHTML += `<div><strong>Bot:</strong> ${botMessage}</div>`;
|
69 |
+
chatBox.scrollTop = chatBox.scrollHeight;
|
70 |
+
return;
|
71 |
+
}
|
72 |
+
const chunk = decoder.decode(value, { stream: true });
|
73 |
+
botMessage += chunk;
|
74 |
+
chatBox.innerHTML += `<div><strong>Bot:</strong> ${botMessage}</div>`;
|
75 |
+
chatBox.scrollTop = chatBox.scrollHeight;
|
76 |
+
read();
|
77 |
+
}).catch(error => {
|
78 |
+
chatBox.innerHTML += `<div><strong>Bot:</strong> Error: ${error}</div>`;
|
79 |
+
chatBox.scrollTop = chatBox.scrollHeight;
|
80 |
+
});
|
81 |
+
}
|
82 |
+
read();
|
83 |
+
})
|
84 |
+
.catch(error => {
|
85 |
+
chatBox.innerHTML += `<div><strong>Bot:</strong> Error: ${error}</div>`;
|
86 |
+
chatBox.scrollTop = chatBox.scrollHeight;
|
87 |
+
});
|
88 |
+
});
|
89 |
+
</script>
|
90 |
+
</body>
|
91 |
+
</html>
|
92 |
+
"""
|
93 |
+
|
94 |
+
class Entrada(BaseModel):
|
95 |
+
texto: str = Field(..., example="Hola, 驴c贸mo est谩s?")
|
96 |
+
history: Optional[List[str]] = Field(default_factory=list)
|
97 |
+
top_p: Optional[float] = Field(0.95, ge=0.0, le=1.0)
|
98 |
+
top_k: Optional[int] = Field(50, ge=0)
|
99 |
+
temperature: Optional[float] = Field(1.0, gt=0.0)
|
100 |
+
max_length: Optional[int] = Field(100, ge=10, le=1000)
|
101 |
+
chunk_size: Optional[int] = Field(10, ge=1)
|
102 |
+
|
103 |
+
@app.middleware("http")
|
104 |
+
async def limitar_recursos(request: Request, call_next):
|
105 |
+
cpu = psutil.cpu_percent(interval=0.1)
|
106 |
+
ram = psutil.virtual_memory().percent
|
107 |
+
if cpu > CPU_LIMIT or ram > RAM_LIMIT:
|
108 |
+
raise HTTPException(status_code=503, detail="Servidor sobrecargado. Intenta de nuevo m谩s tarde.")
|
109 |
+
response = await call_next(request)
|
110 |
+
return response
|
111 |
+
|
112 |
+
@app.on_event("startup")
|
113 |
+
def cargar_modelo():
|
114 |
+
global tokenizador, modelo, eos_token, pad_token
|
115 |
+
tokenizador = AutoTokenizer.from_pretrained("Yhhxhfh/dgdggd")
|
116 |
+
modelo = AutoModelForCausalLM.from_pretrained(
|
117 |
+
"Yhhxhfh/dgdggd",
|
118 |
+
torch_dtype=torch.float32,
|
119 |
+
device_map="cpu"
|
120 |
+
)
|
121 |
+
modelo.eval()
|
122 |
+
eos_token = tokenizador.eos_token
|
123 |
+
pad_token = tokenizador.pad_token
|
124 |
+
|
125 |
+
async def generar_stream(prompt, top_p, top_k, temperature, max_length, chunk_size):
|
126 |
+
input_ids = tokenizador.encode(prompt, return_tensors="pt").to(dispositivo)
|
127 |
+
outputs = modelo.generate(
|
128 |
+
input_ids,
|
129 |
+
max_length=input_ids.shape[1] + max_length,
|
130 |
+
do_sample=True,
|
131 |
+
top_p=top_p,
|
132 |
+
top_k=top_k,
|
133 |
+
temperature=temperature,
|
134 |
+
no_repeat_ngram_size=2,
|
135 |
+
eos_token_id=tokenizador.eos_token_id if tokenizador.eos_token_id is not None else -1
|
136 |
+
)
|
137 |
+
generated_ids = outputs[0][input_ids.shape[1]:]
|
138 |
+
generated_text = tokenizador.decode(generated_ids, skip_special_tokens=True)
|
139 |
+
for i in range(0, len(generated_text), chunk_size):
|
140 |
+
yield generated_text[i:i+chunk_size]
|
141 |
+
await asyncio.sleep(0)
|
142 |
+
|
143 |
+
@app.post("/generar")
|
144 |
+
async def generar_texto(entrada: Entrada):
|
145 |
+
try:
|
146 |
+
prompt = "\n".join(entrada.history + [f"T煤: {entrada.texto}", "Bot:"])
|
147 |
+
async def stream():
|
148 |
+
async for chunk in generar_stream(
|
149 |
+
prompt,
|
150 |
+
entrada.top_p,
|
151 |
+
entrada.top_k,
|
152 |
+
entrada.temperature,
|
153 |
+
entrada.max_length,
|
154 |
+
entrada.chunk_size
|
155 |
+
):
|
156 |
+
yield chunk
|
157 |
+
return StreamingResponse(stream(), media_type="text/plain")
|
158 |
+
except Exception as e:
|
159 |
+
raise HTTPException(status_code=500, detail=str(e))
|
160 |
+
|
161 |
+
@app.get("/", response_class=HTMLResponse)
|
162 |
+
async def get_home():
|
163 |
+
return html_code
|
164 |
+
|
165 |
+
if __name__ == "__main__":
|
166 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|