Spaces:
Paused
Paused
Daniel Marques
commited on
Commit
·
598b96c
1
Parent(s):
f1368ae
feat: add websocket
Browse files- main.py +40 -2
- prompt_template_utils.py +3 -2
- websocket/socketManager.py +14 -1
main.py
CHANGED
@@ -213,7 +213,26 @@ async def create_upload_file(file: UploadFile):
|
|
213 |
|
214 |
@api_app.websocket("/ws/{user_id}")
|
215 |
async def websocket_endpoint_student(websocket: WebSocket, user_id: str):
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
message = {
|
219 |
"message": f"Student {user_id} connected"
|
@@ -243,7 +262,26 @@ async def websocket_endpoint_student(websocket: WebSocket, user_id: str):
|
|
243 |
|
244 |
@api_app.websocket("/ws/{room_id}/{user_id}")
|
245 |
async def websocket_endpoint_room(websocket: WebSocket, room_id: str, user_id: str):
|
246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
|
248 |
message = {
|
249 |
"message": f"Student {user_id} connected to the classroom"
|
|
|
213 |
|
214 |
@api_app.websocket("/ws/{user_id}")
|
215 |
async def websocket_endpoint_student(websocket: WebSocket, user_id: str):
|
216 |
+
DB = Chroma(
|
217 |
+
persist_directory=PERSIST_DIRECTORY,
|
218 |
+
embedding_function=EMBEDDINGS,
|
219 |
+
client_settings=CHROMA_SETTINGS,
|
220 |
+
)
|
221 |
+
|
222 |
+
RETRIEVER = DB.as_retriever()
|
223 |
+
|
224 |
+
newInstanceQA = RetrievalQA.from_chain_type(
|
225 |
+
llm=LLM,
|
226 |
+
chain_type="stuff",
|
227 |
+
retriever=RETRIEVER,
|
228 |
+
return_source_documents=SHOW_SOURCES,
|
229 |
+
chain_type_kwargs={
|
230 |
+
"prompt": prompt,
|
231 |
+
"memory": memory
|
232 |
+
},
|
233 |
+
)
|
234 |
+
|
235 |
+
QA = socket_manager.get_instance_qa(user_id, newInstanceQA)
|
236 |
|
237 |
message = {
|
238 |
"message": f"Student {user_id} connected"
|
|
|
262 |
|
263 |
@api_app.websocket("/ws/{room_id}/{user_id}")
|
264 |
async def websocket_endpoint_room(websocket: WebSocket, room_id: str, user_id: str):
|
265 |
+
DB = Chroma(
|
266 |
+
persist_directory=PERSIST_DIRECTORY,
|
267 |
+
embedding_function=EMBEDDINGS,
|
268 |
+
client_settings=CHROMA_SETTINGS,
|
269 |
+
)
|
270 |
+
|
271 |
+
RETRIEVER = DB.as_retriever()
|
272 |
+
|
273 |
+
newInstanceQA = RetrievalQA.from_chain_type(
|
274 |
+
llm=LLM,
|
275 |
+
chain_type="stuff",
|
276 |
+
retriever=RETRIEVER,
|
277 |
+
return_source_documents=SHOW_SOURCES,
|
278 |
+
chain_type_kwargs={
|
279 |
+
"prompt": prompt,
|
280 |
+
"memory": memory
|
281 |
+
},
|
282 |
+
)
|
283 |
+
|
284 |
+
QA = socket_manager.get_instance_qa(room_id, newInstanceQA)
|
285 |
|
286 |
message = {
|
287 |
"message": f"Student {user_id} connected to the classroom"
|
prompt_template_utils.py
CHANGED
@@ -15,8 +15,9 @@ from langchain.prompts import PromptTemplate
|
|
15 |
|
16 |
# system_prompt = """You are a helpful assistant, and you will use the context and documents provided in the training to answer users' questions. Please read the context provided carefully before responding to questions and follow a step-by-step thought process. If you cannot answer a user's question based on the provided context, please inform the user. Do not use any other information to answer the user. Provide a detailed response based on the content of locally trained documents."""
|
17 |
|
18 |
-
system_prompt = """It's a useful assistant
|
19 |
-
Read the context provided before answering the questions and think step by step.
|
|
|
20 |
|
21 |
def get_prompt_template(system_prompt=system_prompt, promptTemplate_type=None, history=False):
|
22 |
if promptTemplate_type == "llama":
|
|
|
15 |
|
16 |
# system_prompt = """You are a helpful assistant, and you will use the context and documents provided in the training to answer users' questions. Please read the context provided carefully before responding to questions and follow a step-by-step thought process. If you cannot answer a user's question based on the provided context, please inform the user. Do not use any other information to answer the user. Provide a detailed response based on the content of locally trained documents."""
|
17 |
|
18 |
+
system_prompt = """It's a useful assistant that will use the context and documents provided in the training to answer users' questions.
|
19 |
+
Read the context provided before answering the questions and think step by step. Your answer cannot be more than 10 sentences long.
|
20 |
+
If you can't answer, just say "I don't know" and don't try to work out an answer to respond to the user."""
|
21 |
|
22 |
def get_prompt_template(system_prompt=system_prompt, promptTemplate_type=None, history=False):
|
23 |
if promptTemplate_type == "llama":
|
websocket/socketManager.py
CHANGED
@@ -1,9 +1,13 @@
|
|
|
|
1 |
import asyncio
|
2 |
import redis.asyncio as aioredis
|
3 |
import json
|
4 |
from fastapi import WebSocket
|
5 |
|
6 |
|
|
|
|
|
|
|
7 |
class RedisPubSubManager:
|
8 |
"""
|
9 |
Initializes the RedisPubSubManager.
|
@@ -80,6 +84,7 @@ class WebSocketManager:
|
|
80 |
pubsub_client (RedisPubSubManager): An instance of the RedisPubSubManager class for pub-sub functionality.
|
81 |
"""
|
82 |
self.rooms: dict = {}
|
|
|
83 |
self.pubsub_client = RedisPubSubManager()
|
84 |
|
85 |
async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
|
@@ -96,7 +101,6 @@ class WebSocketManager:
|
|
96 |
self.rooms[room_id].append(websocket)
|
97 |
else:
|
98 |
self.rooms[room_id] = [websocket]
|
99 |
-
|
100 |
await self.pubsub_client.connect()
|
101 |
pubsub_subscriber = await self.pubsub_client.subscribe(room_id)
|
102 |
asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
|
@@ -120,6 +124,7 @@ class WebSocketManager:
|
|
120 |
websocket (WebSocket): WebSocket connection object.
|
121 |
"""
|
122 |
self.rooms[room_id].remove(websocket)
|
|
|
123 |
|
124 |
if len(self.rooms[room_id]) == 0:
|
125 |
del self.rooms[room_id]
|
@@ -140,3 +145,11 @@ class WebSocketManager:
|
|
140 |
for socket in all_sockets:
|
141 |
data = message['data'].decode('utf-8')
|
142 |
await socket.send_text(data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
import asyncio
|
3 |
import redis.asyncio as aioredis
|
4 |
import json
|
5 |
from fastapi import WebSocket
|
6 |
|
7 |
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
class RedisPubSubManager:
|
12 |
"""
|
13 |
Initializes the RedisPubSubManager.
|
|
|
84 |
pubsub_client (RedisPubSubManager): An instance of the RedisPubSubManager class for pub-sub functionality.
|
85 |
"""
|
86 |
self.rooms: dict = {}
|
87 |
+
self.qa: dict = {}
|
88 |
self.pubsub_client = RedisPubSubManager()
|
89 |
|
90 |
async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
|
|
|
101 |
self.rooms[room_id].append(websocket)
|
102 |
else:
|
103 |
self.rooms[room_id] = [websocket]
|
|
|
104 |
await self.pubsub_client.connect()
|
105 |
pubsub_subscriber = await self.pubsub_client.subscribe(room_id)
|
106 |
asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
|
|
|
124 |
websocket (WebSocket): WebSocket connection object.
|
125 |
"""
|
126 |
self.rooms[room_id].remove(websocket)
|
127 |
+
self.qa.pop(room_id, None)
|
128 |
|
129 |
if len(self.rooms[room_id]) == 0:
|
130 |
del self.rooms[room_id]
|
|
|
145 |
for socket in all_sockets:
|
146 |
data = message['data'].decode('utf-8')
|
147 |
await socket.send_text(data)
|
148 |
+
|
149 |
+
async def get_instance_qa(self, room_id: str, QA: Any):
|
150 |
+
if room_id in self.qa:
|
151 |
+
return self.qa[room_id]
|
152 |
+
|
153 |
+
self.qa[room_id] = QA
|
154 |
+
return self.qa[room_id]
|
155 |
+
|