Hhblvjgvg commited on
Commit
9c15b84
verified
1 Parent(s): 3fd0ab6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -0
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Request
2
+ from pydantic import BaseModel, Field
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import torch
5
+ from typing import Optional, List
6
+ import asyncio
7
+ from fastapi.responses import StreamingResponse, HTMLResponse
8
+ import uvicorn
9
+ import psutil
10
+
11
+ app = FastAPI()
12
+
13
+ dispositivo = torch.device("cpu")
14
+ CPU_LIMIT = 30.0
15
+ RAM_LIMIT = 30.0
16
+
17
+ html_code = """
18
+ <!DOCTYPE html>
19
+ <html>
20
+ <head>
21
+ <title>Chatbot</title>
22
+ <style>
23
+ body { font-family: Arial, sans-serif; margin: 50px; }
24
+ #chat { border: 1px solid #ccc; padding: 10px; height: 300px; overflow-y: scroll; }
25
+ #input { width: 80%; padding: 10px; }
26
+ #send { padding: 10px; }
27
+ </style>
28
+ </head>
29
+ <body>
30
+ <h1>Chatbot</h1>
31
+ <div id="chat"></div>
32
+ <input type="text" id="input" placeholder="Escribe tu mensaje...">
33
+ <button id="send">Enviar</button>
34
+
35
+ <script>
36
+ const sendButton = document.getElementById('send');
37
+ const inputBox = document.getElementById('input');
38
+ const chatBox = document.getElementById('chat');
39
+ let history = [];
40
+
41
+ sendButton.addEventListener('click', () => {
42
+ const message = inputBox.value;
43
+ if (message.trim() === '') return;
44
+ history.push(`T煤: ${message}`);
45
+ chatBox.innerHTML += `<div><strong>T煤:</strong> ${message}</div>`;
46
+ inputBox.value = '';
47
+ fetch('/generar', {
48
+ method: 'POST',
49
+ headers: {
50
+ 'Content-Type': 'application/json'
51
+ },
52
+ body: JSON.stringify({
53
+ texto: message,
54
+ history: history
55
+ })
56
+ })
57
+ .then(response => {
58
+ if (!response.body) {
59
+ throw new Error('No soporta streaming');
60
+ }
61
+ const reader = response.body.getReader();
62
+ const decoder = new TextDecoder();
63
+ let botMessage = '';
64
+ function read() {
65
+ reader.read().then(({ done, value }) => {
66
+ if (done) {
67
+ history.push(`Bot: ${botMessage}`);
68
+ chatBox.innerHTML += `<div><strong>Bot:</strong> ${botMessage}</div>`;
69
+ chatBox.scrollTop = chatBox.scrollHeight;
70
+ return;
71
+ }
72
+ const chunk = decoder.decode(value, { stream: true });
73
+ botMessage += chunk;
74
+ chatBox.innerHTML += `<div><strong>Bot:</strong> ${botMessage}</div>`;
75
+ chatBox.scrollTop = chatBox.scrollHeight;
76
+ read();
77
+ }).catch(error => {
78
+ chatBox.innerHTML += `<div><strong>Bot:</strong> Error: ${error}</div>`;
79
+ chatBox.scrollTop = chatBox.scrollHeight;
80
+ });
81
+ }
82
+ read();
83
+ })
84
+ .catch(error => {
85
+ chatBox.innerHTML += `<div><strong>Bot:</strong> Error: ${error}</div>`;
86
+ chatBox.scrollTop = chatBox.scrollHeight;
87
+ });
88
+ });
89
+ </script>
90
+ </body>
91
+ </html>
92
+ """
93
+
94
+ class Entrada(BaseModel):
95
+ texto: str = Field(..., example="Hola, 驴c贸mo est谩s?")
96
+ history: Optional[List[str]] = Field(default_factory=list)
97
+ top_p: Optional[float] = Field(0.95, ge=0.0, le=1.0)
98
+ top_k: Optional[int] = Field(50, ge=0)
99
+ temperature: Optional[float] = Field(1.0, gt=0.0)
100
+ max_length: Optional[int] = Field(100, ge=10, le=1000)
101
+ chunk_size: Optional[int] = Field(10, ge=1)
102
+
103
+ @app.middleware("http")
104
+ async def limitar_recursos(request: Request, call_next):
105
+ cpu = psutil.cpu_percent(interval=0.1)
106
+ ram = psutil.virtual_memory().percent
107
+ if cpu > CPU_LIMIT or ram > RAM_LIMIT:
108
+ raise HTTPException(status_code=503, detail="Servidor sobrecargado. Intenta de nuevo m谩s tarde.")
109
+ response = await call_next(request)
110
+ return response
111
+
112
+ @app.on_event("startup")
113
+ def cargar_modelo():
114
+ global tokenizador, modelo, eos_token, pad_token
115
+ tokenizador = AutoTokenizer.from_pretrained("Yhhxhfh/dgdggd")
116
+ modelo = AutoModelForCausalLM.from_pretrained(
117
+ "Yhhxhfh/dgdggd",
118
+ torch_dtype=torch.float32,
119
+ device_map="cpu"
120
+ )
121
+ modelo.eval()
122
+ eos_token = tokenizador.eos_token
123
+ pad_token = tokenizador.pad_token
124
+
125
+ async def generar_stream(prompt, top_p, top_k, temperature, max_length, chunk_size):
126
+ input_ids = tokenizador.encode(prompt, return_tensors="pt").to(dispositivo)
127
+ outputs = modelo.generate(
128
+ input_ids,
129
+ max_length=input_ids.shape[1] + max_length,
130
+ do_sample=True,
131
+ top_p=top_p,
132
+ top_k=top_k,
133
+ temperature=temperature,
134
+ no_repeat_ngram_size=2,
135
+ eos_token_id=tokenizador.eos_token_id if tokenizador.eos_token_id is not None else -1
136
+ )
137
+ generated_ids = outputs[0][input_ids.shape[1]:]
138
+ generated_text = tokenizador.decode(generated_ids, skip_special_tokens=True)
139
+ for i in range(0, len(generated_text), chunk_size):
140
+ yield generated_text[i:i+chunk_size]
141
+ await asyncio.sleep(0)
142
+
143
+ @app.post("/generar")
144
+ async def generar_texto(entrada: Entrada):
145
+ try:
146
+ prompt = "\n".join(entrada.history + [f"T煤: {entrada.texto}", "Bot:"])
147
+ async def stream():
148
+ async for chunk in generar_stream(
149
+ prompt,
150
+ entrada.top_p,
151
+ entrada.top_k,
152
+ entrada.temperature,
153
+ entrada.max_length,
154
+ entrada.chunk_size
155
+ ):
156
+ yield chunk
157
+ return StreamingResponse(stream(), media_type="text/plain")
158
+ except Exception as e:
159
+ raise HTTPException(status_code=500, detail=str(e))
160
+
161
+ @app.get("/", response_class=HTMLResponse)
162
+ async def get_home():
163
+ return html_code
164
+
165
+ if __name__ == "__main__":
166
+ uvicorn.run(app, host="0.0.0.0", port=7860)