Merge pull request #51 from DL4DS/dev_branch
Browse files- .github/workflows/push_to_hf_space_prototype.yml +14 -13
- .vscode/launch.json +35 -0
- .vscode/tasks.json +13 -0
- Dockerfile.dev +2 -2
- README.md +1 -1
- code/.chainlit/config.toml +19 -16
- code/main.py +448 -159
- code/modules/chat/base.py +13 -0
- code/modules/chat/chat_model_loader.py +25 -8
- code/modules/chat/helpers.py +126 -61
- code/modules/chat/langchain/langchain_rag.py +274 -0
- code/modules/chat/langchain/utils.py +340 -0
- code/modules/chat/llm_tutor.py +128 -176
- code/modules/chat_processor/base.py +0 -12
- code/modules/chat_processor/chat_processor.py +0 -30
- code/modules/chat_processor/literal_ai.py +40 -33
- code/modules/config/config.yml +18 -8
- code/modules/config/constants.py +9 -68
- code/modules/config/prompts.py +97 -0
- code/modules/dataloader/data_loader.py +121 -60
- code/modules/dataloader/helpers.py +22 -2
- code/modules/dataloader/pdf_readers/base.py +14 -0
- code/modules/dataloader/pdf_readers/gpt.py +81 -0
- code/modules/dataloader/pdf_readers/llama.py +92 -0
- code/modules/dataloader/webpage_crawler.py +0 -1
- code/modules/vectorstore/base.py +3 -0
- code/modules/vectorstore/chroma.py +3 -0
- code/modules/vectorstore/colbert.py +73 -0
- code/modules/vectorstore/faiss.py +16 -14
- code/modules/vectorstore/raptor.py +8 -1
- code/modules/vectorstore/store_manager.py +37 -3
- code/modules/vectorstore/vectorstore.py +34 -0
- code/public/test.css +10 -0
- requirements.txt +24 -26
.github/workflows/push_to_hf_space_prototype.yml
CHANGED
@@ -1,20 +1,21 @@
|
|
1 |
name: Push Prototype to HuggingFace
|
2 |
|
3 |
on:
|
4 |
-
|
5 |
-
branches:
|
6 |
-
|
7 |
-
|
|
|
8 |
|
9 |
jobs:
|
10 |
-
|
11 |
runs-on: ubuntu-latest
|
12 |
steps:
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
1 |
name: Push Prototype to HuggingFace
|
2 |
|
3 |
on:
|
4 |
+
push:
|
5 |
+
branches: [dev_branch]
|
6 |
+
|
7 |
+
# run this workflow manuall from the Actions tab
|
8 |
+
workflow_dispatch:
|
9 |
|
10 |
jobs:
|
11 |
+
sync-to-hub:
|
12 |
runs-on: ubuntu-latest
|
13 |
steps:
|
14 |
+
- uses: actions/checkout@v4
|
15 |
+
with:
|
16 |
+
fetch-depth: 0
|
17 |
+
lfs: true
|
18 |
+
- name: Deploy Prototype to HuggingFace
|
19 |
+
env:
|
20 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
21 |
+
run: git push https://trgardos:$HF_TOKEN@huggingface.co/spaces/dl4ds/tutor_dev dev_branch:main
|
.vscode/launch.json
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
// Use IntelliSense to learn about possible attributes.
|
3 |
+
// Hover to view descriptions of existing attributes.
|
4 |
+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
5 |
+
"version": "0.2.0",
|
6 |
+
"configurations": [
|
7 |
+
{
|
8 |
+
"name": "Python Debugger: Chainlit run main.py",
|
9 |
+
"type": "debugpy",
|
10 |
+
"request": "launch",
|
11 |
+
"program": "${workspaceFolder}/.venv/bin/chainlit",
|
12 |
+
"console": "integratedTerminal",
|
13 |
+
"args": ["run", "main.py"],
|
14 |
+
"cwd": "${workspaceFolder}/code",
|
15 |
+
"justMyCode": true
|
16 |
+
},
|
17 |
+
{ "name":"Python Debugger: Module store_manager",
|
18 |
+
"type":"debugpy",
|
19 |
+
"request":"launch",
|
20 |
+
"module":"modules.vectorstore.store_manager",
|
21 |
+
"env": {"PYTHONPATH": "${workspaceFolder}/code"},
|
22 |
+
"cwd": "${workspaceFolder}/code",
|
23 |
+
"justMyCode": true
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"name": "Python Debugger: Module data_loader",
|
27 |
+
"type": "debugpy",
|
28 |
+
"request": "launch",
|
29 |
+
"module": "modules.dataloader.data_loader",
|
30 |
+
"env": {"PYTHONPATH": "${workspaceFolder}/code"},
|
31 |
+
"cwd": "${workspaceFolder}/code",
|
32 |
+
"justMyCode": true
|
33 |
+
}
|
34 |
+
]
|
35 |
+
}
|
.vscode/tasks.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
// See https://go.microsoft.com/fwlink/?LinkId=733558
|
3 |
+
// for the documentation about the tasks.json format
|
4 |
+
"version": "2.0.0",
|
5 |
+
"tasks": [
|
6 |
+
{
|
7 |
+
"label": "echo",
|
8 |
+
"type": "shell",
|
9 |
+
"command": "echo ${workspaceFolder}; ls ${workspaceFolder}/code",
|
10 |
+
"problemMatcher": []
|
11 |
+
}
|
12 |
+
]
|
13 |
+
}
|
Dockerfile.dev
CHANGED
@@ -25,7 +25,7 @@ RUN mkdir /.cache && chmod -R 777 /.cache
|
|
25 |
WORKDIR /code/code
|
26 |
|
27 |
# Expose the port the app runs on
|
28 |
-
EXPOSE
|
29 |
|
30 |
# Default command to run the application
|
31 |
-
CMD ["sh", "-c", "python -m modules.vectorstore.store_manager && chainlit run main.py --host 0.0.0.0 --port
|
|
|
25 |
WORKDIR /code/code
|
26 |
|
27 |
# Expose the port the app runs on
|
28 |
+
EXPOSE 8000
|
29 |
|
30 |
# Default command to run the application
|
31 |
+
CMD ["sh", "-c", "python -m modules.vectorstore.store_manager && chainlit run main.py --host 0.0.0.0 --port 8000"]
|
README.md
CHANGED
@@ -76,7 +76,7 @@ The HuggingFace Space is built using the `Dockerfile` in the repository. To run
|
|
76 |
|
77 |
```bash
|
78 |
docker build --tag dev -f Dockerfile.dev .
|
79 |
-
docker run -it --rm -p
|
80 |
```
|
81 |
|
82 |
## Contributing
|
|
|
76 |
|
77 |
```bash
|
78 |
docker build --tag dev -f Dockerfile.dev .
|
79 |
+
docker run -it --rm -p 8000:8000 dev
|
80 |
```
|
81 |
|
82 |
## Contributing
|
code/.chainlit/config.toml
CHANGED
@@ -23,7 +23,7 @@ allow_origins = ["*"]
|
|
23 |
unsafe_allow_html = false
|
24 |
|
25 |
# Process and display mathematical expressions. This can clash with "$" characters in messages.
|
26 |
-
latex =
|
27 |
|
28 |
# Automatically tag threads with the current chat profile (if a chat profile is used)
|
29 |
auto_tag_thread = true
|
@@ -85,31 +85,34 @@ custom_meta_image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/f/
|
|
85 |
# custom_build = "./public/build"
|
86 |
|
87 |
[UI.theme]
|
88 |
-
default = "
|
89 |
#layout = "wide"
|
90 |
#font_family = "Inter, sans-serif"
|
91 |
# Override default MUI light theme. (Check theme.ts)
|
92 |
[UI.theme.light]
|
93 |
-
background = "#FAFAFA"
|
94 |
-
paper = "#FFFFFF"
|
95 |
|
96 |
[UI.theme.light.primary]
|
97 |
-
main = "#
|
98 |
-
dark = "#
|
99 |
-
light = "#
|
100 |
[UI.theme.light.text]
|
101 |
-
primary = "#212121"
|
102 |
-
secondary = "#616161"
|
|
|
103 |
# Override default MUI dark theme. (Check theme.ts)
|
104 |
[UI.theme.dark]
|
105 |
-
background = "#
|
106 |
-
paper = "#
|
107 |
|
108 |
[UI.theme.dark.primary]
|
109 |
-
main = "#
|
110 |
-
dark = "#
|
111 |
-
light = "#
|
112 |
-
|
|
|
|
|
113 |
|
114 |
[meta]
|
115 |
-
generated_by = "1.1.
|
|
|
23 |
unsafe_allow_html = false
|
24 |
|
25 |
# Process and display mathematical expressions. This can clash with "$" characters in messages.
|
26 |
+
latex = true
|
27 |
|
28 |
# Automatically tag threads with the current chat profile (if a chat profile is used)
|
29 |
auto_tag_thread = true
|
|
|
85 |
# custom_build = "./public/build"
|
86 |
|
87 |
[UI.theme]
|
88 |
+
default = "dark"
|
89 |
#layout = "wide"
|
90 |
#font_family = "Inter, sans-serif"
|
91 |
# Override default MUI light theme. (Check theme.ts)
|
92 |
[UI.theme.light]
|
93 |
+
#background = "#FAFAFA"
|
94 |
+
#paper = "#FFFFFF"
|
95 |
|
96 |
[UI.theme.light.primary]
|
97 |
+
#main = "#F80061"
|
98 |
+
#dark = "#980039"
|
99 |
+
#light = "#FFE7EB"
|
100 |
[UI.theme.light.text]
|
101 |
+
#primary = "#212121"
|
102 |
+
#secondary = "#616161"
|
103 |
+
|
104 |
# Override default MUI dark theme. (Check theme.ts)
|
105 |
[UI.theme.dark]
|
106 |
+
#background = "#FAFAFA"
|
107 |
+
#paper = "#FFFFFF"
|
108 |
|
109 |
[UI.theme.dark.primary]
|
110 |
+
#main = "#F80061"
|
111 |
+
#dark = "#980039"
|
112 |
+
#light = "#FFE7EB"
|
113 |
+
[UI.theme.dark.text]
|
114 |
+
#primary = "#EEEEEE"
|
115 |
+
#secondary = "#BDBDBD"
|
116 |
|
117 |
[meta]
|
118 |
+
generated_by = "1.1.304"
|
code/main.py
CHANGED
@@ -1,176 +1,465 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
from
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
from
|
9 |
-
import yaml
|
10 |
-
import logging
|
11 |
-
from dotenv import load_dotenv
|
12 |
|
|
|
|
|
|
|
|
|
|
|
13 |
from modules.chat.llm_tutor import LLMTutor
|
14 |
-
from modules.
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
)
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
cl.
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
else:
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
-
|
128 |
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
134 |
|
135 |
-
|
136 |
-
chat_processor = ChatProcessor(config, tags=tags)
|
137 |
-
cl.user_session.set("chain", chain)
|
138 |
-
cl.user_session.set("counter", 0)
|
139 |
-
cl.user_session.set("chat_processor", chat_processor)
|
140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
-
|
143 |
-
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
152 |
|
153 |
-
counter = cl.user_session.get("counter")
|
154 |
-
counter += 1
|
155 |
-
cl.user_session.set("counter", counter)
|
156 |
|
157 |
-
|
158 |
-
# await cl.Message(content="Your credits are up!").send()
|
159 |
-
# await on_chat_end() # Call the on_chat_end function to handle the end of the chat
|
160 |
-
# return # Exit the function to stop further processing
|
161 |
-
# else:
|
162 |
|
163 |
-
cb = cl.AsyncLangchainCallbackHandler() # TODO: fix streaming here
|
164 |
-
cb.answer_reached = True
|
165 |
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
|
|
172 |
|
173 |
-
answer_with_sources, source_elements, sources_dict = get_sources(res, answer)
|
174 |
-
processor._process(message.content, answer, sources_dict)
|
175 |
|
176 |
-
|
|
|
1 |
+
import chainlit.data as cl_data
|
2 |
+
import asyncio
|
3 |
+
from modules.config.constants import (
|
4 |
+
LLAMA_PATH,
|
5 |
+
LITERAL_API_KEY_LOGGING,
|
6 |
+
LITERAL_API_URL,
|
7 |
+
)
|
8 |
+
from modules.chat_processor.literal_ai import CustomLiteralDataLayer
|
|
|
|
|
|
|
9 |
|
10 |
+
import json
|
11 |
+
import yaml
|
12 |
+
import os
|
13 |
+
from typing import Any, Dict, no_type_check
|
14 |
+
import chainlit as cl
|
15 |
from modules.chat.llm_tutor import LLMTutor
|
16 |
+
from modules.chat.helpers import (
|
17 |
+
get_sources,
|
18 |
+
get_history_chat_resume,
|
19 |
+
get_history_setup_llm,
|
20 |
+
get_last_config,
|
21 |
+
)
|
22 |
+
import copy
|
23 |
+
from typing import Optional
|
24 |
+
from chainlit.types import ThreadDict
|
25 |
+
import time
|
26 |
+
|
27 |
+
USER_TIMEOUT = 60_000
|
28 |
+
SYSTEM = "System 🖥️"
|
29 |
+
LLM = "LLM 🧠"
|
30 |
+
AGENT = "Agent <>"
|
31 |
+
YOU = "You 😃"
|
32 |
+
ERROR = "Error 🚫"
|
33 |
+
|
34 |
+
with open("modules/config/config.yml", "r") as f:
|
35 |
+
config = yaml.safe_load(f)
|
36 |
+
|
37 |
+
|
38 |
+
async def setup_data_layer():
|
39 |
+
"""
|
40 |
+
Set up the data layer for chat logging.
|
41 |
+
"""
|
42 |
+
if config["chat_logging"]["log_chat"]:
|
43 |
+
data_layer = CustomLiteralDataLayer(
|
44 |
+
api_key=LITERAL_API_KEY_LOGGING, server=LITERAL_API_URL
|
45 |
+
)
|
46 |
+
else:
|
47 |
+
data_layer = None
|
48 |
+
|
49 |
+
return data_layer
|
50 |
+
|
51 |
+
|
52 |
+
class Chatbot:
|
53 |
+
def __init__(self, config):
|
54 |
+
"""
|
55 |
+
Initialize the Chatbot class.
|
56 |
+
"""
|
57 |
+
self.config = config
|
58 |
+
|
59 |
+
async def _load_config(self):
|
60 |
+
"""
|
61 |
+
Load the configuration from a YAML file.
|
62 |
+
"""
|
63 |
+
with open("modules/config/config.yml", "r") as f:
|
64 |
+
return yaml.safe_load(f)
|
65 |
+
|
66 |
+
@no_type_check
|
67 |
+
async def setup_llm(self):
|
68 |
+
"""
|
69 |
+
Set up the LLM with the provided settings. Update the configuration and initialize the LLM tutor.
|
70 |
+
|
71 |
+
#TODO: Clean this up.
|
72 |
+
"""
|
73 |
+
start_time = time.time()
|
74 |
+
|
75 |
+
llm_settings = cl.user_session.get("llm_settings", {})
|
76 |
+
chat_profile, retriever_method, memory_window, llm_style, generate_follow_up, chunking_mode = (
|
77 |
+
llm_settings.get("chat_model"),
|
78 |
+
llm_settings.get("retriever_method"),
|
79 |
+
llm_settings.get("memory_window"),
|
80 |
+
llm_settings.get("llm_style"),
|
81 |
+
llm_settings.get("follow_up_questions"),
|
82 |
+
llm_settings.get("chunking_mode"),
|
83 |
+
)
|
84 |
+
|
85 |
+
chain = cl.user_session.get("chain")
|
86 |
+
memory_list = cl.user_session.get(
|
87 |
+
"memory",
|
88 |
+
(
|
89 |
+
list(chain.store.values())[0].messages
|
90 |
+
if len(chain.store.values()) > 0
|
91 |
+
else []
|
92 |
+
),
|
93 |
+
)
|
94 |
+
conversation_list = get_history_setup_llm(memory_list)
|
95 |
+
|
96 |
+
old_config = copy.deepcopy(self.config)
|
97 |
+
self.config["vectorstore"]["db_option"] = retriever_method
|
98 |
+
self.config["llm_params"]["memory_window"] = memory_window
|
99 |
+
self.config["llm_params"]["llm_style"] = llm_style
|
100 |
+
self.config["llm_params"]["llm_loader"] = chat_profile
|
101 |
+
self.config["llm_params"]["generate_follow_up"] = generate_follow_up
|
102 |
+
self.config["splitter_options"]["chunking_mode"] = chunking_mode
|
103 |
+
|
104 |
+
self.llm_tutor.update_llm(
|
105 |
+
old_config, self.config
|
106 |
+
) # update only llm attributes that are changed
|
107 |
+
self.chain = self.llm_tutor.qa_bot(
|
108 |
+
memory=conversation_list,
|
109 |
+
callbacks=(
|
110 |
+
[cl.LangchainCallbackHandler()]
|
111 |
+
if cl_data._data_layer and self.config["chat_logging"]["callbacks"]
|
112 |
+
else None
|
113 |
+
),
|
114 |
+
)
|
115 |
+
|
116 |
+
tags = [chat_profile, self.config["vectorstore"]["db_option"]]
|
117 |
+
|
118 |
+
cl.user_session.set("chain", self.chain)
|
119 |
+
cl.user_session.set("llm_tutor", self.llm_tutor)
|
120 |
+
|
121 |
+
print("Time taken to setup LLM: ", time.time() - start_time)
|
122 |
+
|
123 |
+
@no_type_check
|
124 |
+
async def update_llm(self, new_settings: Dict[str, Any]):
|
125 |
+
"""
|
126 |
+
Update the LLM settings and reinitialize the LLM with the new settings.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
new_settings (Dict[str, Any]): The new settings to update.
|
130 |
+
"""
|
131 |
+
cl.user_session.set("llm_settings", new_settings)
|
132 |
+
await self.inform_llm_settings()
|
133 |
+
await self.setup_llm()
|
134 |
+
|
135 |
+
async def make_llm_settings_widgets(self, config=None):
|
136 |
+
"""
|
137 |
+
Create and send the widgets for LLM settings configuration.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
config: The configuration to use for setting up the widgets.
|
141 |
+
"""
|
142 |
+
config = config or self.config
|
143 |
+
await cl.ChatSettings(
|
144 |
+
[
|
145 |
+
cl.input_widget.Select(
|
146 |
+
id="chat_model",
|
147 |
+
label="Model Name (Default GPT-3)",
|
148 |
+
values=["local_llm", "gpt-3.5-turbo-1106", "gpt-4", "gpt-4o-mini"],
|
149 |
+
initial_index=[
|
150 |
+
"local_llm",
|
151 |
+
"gpt-3.5-turbo-1106",
|
152 |
+
"gpt-4",
|
153 |
+
"gpt-4o-mini",
|
154 |
+
].index(config["llm_params"]["llm_loader"]),
|
155 |
+
),
|
156 |
+
cl.input_widget.Select(
|
157 |
+
id="retriever_method",
|
158 |
+
label="Retriever (Default FAISS)",
|
159 |
+
values=["FAISS", "Chroma", "RAGatouille", "RAPTOR"],
|
160 |
+
initial_index=["FAISS", "Chroma", "RAGatouille", "RAPTOR"].index(
|
161 |
+
config["vectorstore"]["db_option"]
|
162 |
+
),
|
163 |
+
),
|
164 |
+
cl.input_widget.Slider(
|
165 |
+
id="memory_window",
|
166 |
+
label="Memory Window (Default 3)",
|
167 |
+
initial=3,
|
168 |
+
min=0,
|
169 |
+
max=10,
|
170 |
+
step=1,
|
171 |
+
),
|
172 |
+
cl.input_widget.Switch(
|
173 |
+
id="view_sources", label="View Sources", initial=False
|
174 |
+
),
|
175 |
+
cl.input_widget.Switch(
|
176 |
+
id="stream_response",
|
177 |
+
label="Stream response",
|
178 |
+
initial=config["llm_params"]["stream"],
|
179 |
+
),
|
180 |
+
cl.input_widget.Select(
|
181 |
+
id="chunking_mode",
|
182 |
+
label="Chunking mode",
|
183 |
+
values=['fixed', 'semantic'],
|
184 |
+
initial_index=1,
|
185 |
+
),
|
186 |
+
cl.input_widget.Switch(
|
187 |
+
id="follow_up_questions",
|
188 |
+
label="Generate follow up questions",
|
189 |
+
initial=False,
|
190 |
+
),
|
191 |
+
cl.input_widget.Select(
|
192 |
+
id="llm_style",
|
193 |
+
label="Type of Conversation (Default Normal)",
|
194 |
+
values=["Normal", "ELI5"],
|
195 |
+
initial_index=0,
|
196 |
+
),
|
197 |
+
]
|
198 |
+
).send()
|
199 |
+
|
200 |
+
@no_type_check
|
201 |
+
async def inform_llm_settings(self):
|
202 |
+
"""
|
203 |
+
Inform the user about the updated LLM settings and display them as a message.
|
204 |
+
"""
|
205 |
+
llm_settings: Dict[str, Any] = cl.user_session.get("llm_settings", {})
|
206 |
+
llm_tutor = cl.user_session.get("llm_tutor")
|
207 |
+
settings_dict = {
|
208 |
+
"model": llm_settings.get("chat_model"),
|
209 |
+
"retriever": llm_settings.get("retriever_method"),
|
210 |
+
"memory_window": llm_settings.get("memory_window"),
|
211 |
+
"num_docs_in_db": (
|
212 |
+
len(llm_tutor.vector_db)
|
213 |
+
if llm_tutor and hasattr(llm_tutor, "vector_db")
|
214 |
+
else 0
|
215 |
+
),
|
216 |
+
"view_sources": llm_settings.get("view_sources"),
|
217 |
+
"follow_up_questions": llm_settings.get("follow_up_questions"),
|
218 |
+
}
|
219 |
+
await cl.Message(
|
220 |
+
author=SYSTEM,
|
221 |
+
content="LLM settings have been updated. You can continue with your Query!",
|
222 |
+
elements=[
|
223 |
+
cl.Text(
|
224 |
+
name="settings",
|
225 |
+
display="side",
|
226 |
+
content=json.dumps(settings_dict, indent=4),
|
227 |
+
language="json",
|
228 |
+
),
|
229 |
+
],
|
230 |
+
).send()
|
231 |
+
|
232 |
+
async def set_starters(self):
|
233 |
+
"""
|
234 |
+
Set starter messages for the chatbot.
|
235 |
+
"""
|
236 |
+
# Return Starters only if the chat is new
|
237 |
|
238 |
+
try:
|
239 |
+
thread = cl_data._data_layer.get_thread(
|
240 |
+
cl.context.session.thread_id
|
241 |
+
) # see if the thread has any steps
|
242 |
+
if thread.steps or len(thread.steps) > 0:
|
243 |
+
return None
|
244 |
+
except:
|
245 |
+
return [
|
246 |
+
cl.Starter(
|
247 |
+
label="recording on CNNs?",
|
248 |
+
message="Where can I find the recording for the lecture on Transformers?",
|
249 |
+
icon="/public/adv-screen-recorder-svgrepo-com.svg",
|
250 |
+
),
|
251 |
+
cl.Starter(
|
252 |
+
label="where's the slides?",
|
253 |
+
message="When are the lectures? I can't find the schedule.",
|
254 |
+
icon="/public/alarmy-svgrepo-com.svg",
|
255 |
+
),
|
256 |
+
cl.Starter(
|
257 |
+
label="Due Date?",
|
258 |
+
message="When is the final project due?",
|
259 |
+
icon="/public/calendar-samsung-17-svgrepo-com.svg",
|
260 |
+
),
|
261 |
+
cl.Starter(
|
262 |
+
label="Explain backprop.",
|
263 |
+
message="I didn't understand the math behind backprop, could you explain it?",
|
264 |
+
icon="/public/acastusphoton-svgrepo-com.svg",
|
265 |
+
),
|
266 |
+
]
|
267 |
+
|
268 |
+
def rename(self, orig_author: str):
|
269 |
+
"""
|
270 |
+
Rename the original author to a more user-friendly name.
|
271 |
+
|
272 |
+
Args:
|
273 |
+
orig_author (str): The original author's name.
|
274 |
+
|
275 |
+
Returns:
|
276 |
+
str: The renamed author.
|
277 |
+
"""
|
278 |
+
rename_dict = {"Chatbot": "AI Tutor"}
|
279 |
+
return rename_dict.get(orig_author, orig_author)
|
280 |
+
|
281 |
+
async def start(self, config=None):
|
282 |
+
"""
|
283 |
+
Start the chatbot, initialize settings widgets,
|
284 |
+
and display and load previous conversation if chat logging is enabled.
|
285 |
+
"""
|
286 |
+
|
287 |
+
start_time = time.time()
|
288 |
+
|
289 |
+
self.config = (
|
290 |
+
await self._load_config() if config is None else config
|
291 |
+
) # Reload the configuration on chat resume
|
292 |
+
|
293 |
+
await self.make_llm_settings_widgets(self.config) # Reload the settings widgets
|
294 |
+
|
295 |
+
await self.make_llm_settings_widgets(self.config)
|
296 |
+
user = cl.user_session.get("user")
|
297 |
+
self.user = {
|
298 |
+
"user_id": user.identifier,
|
299 |
+
"session_id": cl.context.session.thread_id,
|
300 |
+
}
|
301 |
+
|
302 |
+
memory = cl.user_session.get("memory", [])
|
303 |
+
|
304 |
+
cl.user_session.set("user", self.user)
|
305 |
+
self.llm_tutor = LLMTutor(self.config, user=self.user)
|
306 |
+
|
307 |
+
self.chain = self.llm_tutor.qa_bot(
|
308 |
+
memory=memory,
|
309 |
+
callbacks=(
|
310 |
+
[cl.LangchainCallbackHandler()]
|
311 |
+
if cl_data._data_layer and self.config["chat_logging"]["callbacks"]
|
312 |
+
else None
|
313 |
+
),
|
314 |
+
)
|
315 |
+
self.question_generator = self.llm_tutor.question_generator
|
316 |
+
cl.user_session.set("llm_tutor", self.llm_tutor)
|
317 |
+
cl.user_session.set("chain", self.chain)
|
318 |
+
|
319 |
+
print("Time taken to start LLM: ", time.time() - start_time)
|
320 |
+
|
321 |
+
async def stream_response(self, response):
|
322 |
+
"""
|
323 |
+
Stream the response from the LLM.
|
324 |
+
|
325 |
+
Args:
|
326 |
+
response: The response from the LLM.
|
327 |
+
"""
|
328 |
+
msg = cl.Message(content="")
|
329 |
+
await msg.send()
|
330 |
+
|
331 |
+
output = {}
|
332 |
+
for chunk in response:
|
333 |
+
if "answer" in chunk:
|
334 |
+
await msg.stream_token(chunk["answer"])
|
335 |
+
|
336 |
+
for key in chunk:
|
337 |
+
if key not in output:
|
338 |
+
output[key] = chunk[key]
|
339 |
+
else:
|
340 |
+
output[key] += chunk[key]
|
341 |
+
return output
|
342 |
+
|
343 |
+
async def main(self, message):
|
344 |
+
"""
|
345 |
+
Process and Display the Conversation.
|
346 |
+
|
347 |
+
Args:
|
348 |
+
message: The incoming chat message.
|
349 |
+
"""
|
350 |
+
|
351 |
+
start_time = time.time()
|
352 |
+
|
353 |
+
chain = cl.user_session.get("chain")
|
354 |
+
|
355 |
+
llm_settings = cl.user_session.get("llm_settings", {})
|
356 |
+
view_sources = llm_settings.get("view_sources", False)
|
357 |
+
stream = llm_settings.get("stream_response", False)
|
358 |
+
steam = False # Fix streaming
|
359 |
+
user_query_dict = {"input": message.content}
|
360 |
+
# Define the base configuration
|
361 |
+
chain_config = {
|
362 |
+
"configurable": {
|
363 |
+
"user_id": self.user["user_id"],
|
364 |
+
"conversation_id": self.user["session_id"],
|
365 |
+
"memory_window": self.config["llm_params"]["memory_window"],
|
366 |
+
}
|
367 |
+
}
|
368 |
+
|
369 |
+
if stream:
|
370 |
+
res = chain.stream(user_query=user_query_dict, config=chain_config)
|
371 |
+
res = await self.stream_response(res)
|
372 |
else:
|
373 |
+
res = await chain.invoke(
|
374 |
+
user_query=user_query_dict,
|
375 |
+
config=chain_config,
|
376 |
+
)
|
377 |
+
|
378 |
+
answer = res.get("answer", res.get("result"))
|
379 |
+
|
380 |
+
answer_with_sources, source_elements, sources_dict = get_sources(
|
381 |
+
res, answer, stream=stream, view_sources=view_sources
|
382 |
+
)
|
383 |
+
answer_with_sources = answer_with_sources.replace("$$", "$")
|
384 |
+
|
385 |
+
print("Time taken to process the message: ", time.time() - start_time)
|
386 |
+
|
387 |
+
actions = []
|
388 |
+
|
389 |
+
if self.config["llm_params"]["generate_follow_up"]:
|
390 |
+
start_time = time.time()
|
391 |
+
list_of_questions = self.question_generator.generate_questions(
|
392 |
+
query=user_query_dict["input"],
|
393 |
+
response=answer,
|
394 |
+
chat_history=res.get("chat_history"),
|
395 |
+
context=res.get("context"),
|
396 |
+
)
|
397 |
|
398 |
+
for question in list_of_questions:
|
399 |
|
400 |
+
actions.append(
|
401 |
+
cl.Action(
|
402 |
+
name="follow up question",
|
403 |
+
value="example_value",
|
404 |
+
description=question,
|
405 |
+
label=question,
|
406 |
+
)
|
407 |
+
)
|
408 |
|
409 |
+
print("Time taken to generate questions: ", time.time() - start_time)
|
|
|
|
|
|
|
|
|
410 |
|
411 |
+
await cl.Message(
|
412 |
+
content=answer_with_sources,
|
413 |
+
elements=source_elements,
|
414 |
+
author=LLM,
|
415 |
+
actions=actions,
|
416 |
+
metadata=self.config,
|
417 |
+
).send()
|
418 |
|
419 |
+
async def on_chat_resume(self, thread: ThreadDict):
|
420 |
+
thread_config = None
|
421 |
+
steps = thread["steps"]
|
422 |
+
k = self.config["llm_params"][
|
423 |
+
"memory_window"
|
424 |
+
] # on resume, alwyas use the default memory window
|
425 |
+
conversation_list = get_history_chat_resume(steps, k, SYSTEM, LLM)
|
426 |
+
thread_config = get_last_config(
|
427 |
+
steps
|
428 |
+
) # TODO: Returns None for now - which causes config to be reloaded with default values
|
429 |
+
cl.user_session.set("memory", conversation_list)
|
430 |
+
await self.start(config=thread_config)
|
431 |
|
432 |
+
@cl.oauth_callback
|
433 |
+
def auth_callback(
|
434 |
+
provider_id: str,
|
435 |
+
token: str,
|
436 |
+
raw_user_data: Dict[str, str],
|
437 |
+
default_user: cl.User,
|
438 |
+
) -> Optional[cl.User]:
|
439 |
+
return default_user
|
440 |
|
441 |
+
async def on_follow_up(self, action: cl.Action):
|
442 |
+
message = await cl.Message(
|
443 |
+
content=action.description,
|
444 |
+
type="user_message",
|
445 |
+
author=self.user["user_id"],
|
446 |
+
).send()
|
447 |
+
await self.main(message)
|
448 |
|
|
|
|
|
|
|
449 |
|
450 |
+
chatbot = Chatbot(config=config)
|
|
|
|
|
|
|
|
|
451 |
|
|
|
|
|
452 |
|
453 |
+
async def start_app():
|
454 |
+
cl_data._data_layer = await setup_data_layer()
|
455 |
+
chatbot.literal_client = cl_data._data_layer.client if cl_data._data_layer else None
|
456 |
+
cl.set_starters(chatbot.set_starters)
|
457 |
+
cl.author_rename(chatbot.rename)
|
458 |
+
cl.on_chat_start(chatbot.start)
|
459 |
+
cl.on_chat_resume(chatbot.on_chat_resume)
|
460 |
+
cl.on_message(chatbot.main)
|
461 |
+
cl.on_settings_update(chatbot.update_llm)
|
462 |
+
cl.action_callback("follow up question")(chatbot.on_follow_up)
|
463 |
|
|
|
|
|
464 |
|
465 |
+
asyncio.run(start_app())
|
code/modules/chat/base.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class BaseRAG:
|
2 |
+
"""
|
3 |
+
Base class for RAG chatbot.
|
4 |
+
"""
|
5 |
+
|
6 |
+
def __init__():
|
7 |
+
pass
|
8 |
+
|
9 |
+
def invoke():
|
10 |
+
"""
|
11 |
+
Invoke the RAG chatbot.
|
12 |
+
"""
|
13 |
+
pass
|
code/modules/chat/chat_model_loader.py
CHANGED
@@ -1,12 +1,15 @@
|
|
1 |
-
from
|
2 |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
3 |
from transformers import AutoTokenizer, TextStreamer
|
4 |
from langchain_community.llms import LlamaCpp
|
5 |
import torch
|
6 |
import transformers
|
7 |
import os
|
|
|
|
|
8 |
from langchain.callbacks.manager import CallbackManager
|
9 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
|
|
10 |
|
11 |
|
12 |
class ChatModelLoader:
|
@@ -14,16 +17,28 @@ class ChatModelLoader:
|
|
14 |
self.config = config
|
15 |
self.huggingface_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
def load_chat_model(self):
|
18 |
-
if self.config["llm_params"]["llm_loader"]
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
22 |
elif self.config["llm_params"]["llm_loader"] == "local_llm":
|
23 |
n_batch = 512 # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
|
24 |
-
model_path = self.
|
|
|
|
|
25 |
llm = LlamaCpp(
|
26 |
-
model_path=
|
27 |
n_batch=n_batch,
|
28 |
n_ctx=2048,
|
29 |
f16_kv=True,
|
@@ -34,5 +49,7 @@ class ChatModelLoader:
|
|
34 |
],
|
35 |
)
|
36 |
else:
|
37 |
-
raise ValueError(
|
|
|
|
|
38 |
return llm
|
|
|
1 |
+
from langchain_openai import ChatOpenAI
|
2 |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
3 |
from transformers import AutoTokenizer, TextStreamer
|
4 |
from langchain_community.llms import LlamaCpp
|
5 |
import torch
|
6 |
import transformers
|
7 |
import os
|
8 |
+
from pathlib import Path
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
from langchain.callbacks.manager import CallbackManager
|
11 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
12 |
+
from modules.config.constants import LLAMA_PATH
|
13 |
|
14 |
|
15 |
class ChatModelLoader:
|
|
|
17 |
self.config = config
|
18 |
self.huggingface_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
19 |
|
20 |
+
def _verify_model_cache(self, model_cache_path):
|
21 |
+
hf_hub_download(
|
22 |
+
repo_id=self.config["llm_params"]["local_llm_params"]["repo_id"],
|
23 |
+
filename=self.config["llm_params"]["local_llm_params"]["filename"],
|
24 |
+
cache_dir=model_cache_path,
|
25 |
+
)
|
26 |
+
return str(list(Path(model_cache_path).glob("*/snapshots/*/*.gguf"))[0])
|
27 |
+
|
28 |
def load_chat_model(self):
|
29 |
+
if self.config["llm_params"]["llm_loader"] in [
|
30 |
+
"gpt-3.5-turbo-1106",
|
31 |
+
"gpt-4",
|
32 |
+
"gpt-4o-mini",
|
33 |
+
]:
|
34 |
+
llm = ChatOpenAI(model_name=self.config["llm_params"]["llm_loader"])
|
35 |
elif self.config["llm_params"]["llm_loader"] == "local_llm":
|
36 |
n_batch = 512 # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
|
37 |
+
model_path = self._verify_model_cache(
|
38 |
+
self.config["llm_params"]["local_llm_params"]["model"]
|
39 |
+
)
|
40 |
llm = LlamaCpp(
|
41 |
+
model_path=LLAMA_PATH,
|
42 |
n_batch=n_batch,
|
43 |
n_ctx=2048,
|
44 |
f16_kv=True,
|
|
|
49 |
],
|
50 |
)
|
51 |
else:
|
52 |
+
raise ValueError(
|
53 |
+
f"Invalid LLM Loader: {self.config['llm_params']['llm_loader']}"
|
54 |
+
)
|
55 |
return llm
|
code/modules/chat/helpers.py
CHANGED
@@ -1,13 +1,12 @@
|
|
1 |
-
from modules.config.
|
2 |
import chainlit as cl
|
3 |
-
from langchain_core.prompts import PromptTemplate
|
4 |
|
5 |
|
6 |
-
def get_sources(res, answer):
|
7 |
source_elements = []
|
8 |
source_dict = {} # Dictionary to store URL elements
|
9 |
|
10 |
-
for idx, source in enumerate(res["
|
11 |
source_metadata = source.metadata
|
12 |
url = source_metadata.get("source", "N/A")
|
13 |
score = source_metadata.get("score", "N/A")
|
@@ -36,69 +35,135 @@ def get_sources(res, answer):
|
|
36 |
else:
|
37 |
source_dict[url_name]["text"] += f"\n\n{source.page_content}"
|
38 |
|
39 |
-
#
|
40 |
-
full_answer = "**Answer:**\n"
|
41 |
-
full_answer += answer
|
42 |
|
43 |
-
#
|
44 |
-
|
45 |
-
|
46 |
-
full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
|
47 |
|
48 |
-
|
49 |
-
full_answer += name
|
50 |
-
source_elements.append(
|
51 |
-
cl.Text(name=name, content=source_data["text"], display="side")
|
52 |
-
)
|
53 |
|
54 |
-
#
|
55 |
-
if
|
56 |
-
|
57 |
-
full_answer +=
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
return full_answer, source_elements, source_dict
|
79 |
|
80 |
|
81 |
-
def get_prompt(config):
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
)
|
93 |
-
|
94 |
-
if
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
# custom_prompt_template = tinyllama_prompt_template
|
100 |
-
prompt = PromptTemplate(
|
101 |
-
template=custom_prompt_template,
|
102 |
-
input_variables=["context", "question"],
|
103 |
)
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.config.prompts import prompts
|
2 |
import chainlit as cl
|
|
|
3 |
|
4 |
|
5 |
+
def get_sources(res, answer, stream=True, view_sources=False):
|
6 |
source_elements = []
|
7 |
source_dict = {} # Dictionary to store URL elements
|
8 |
|
9 |
+
for idx, source in enumerate(res["context"]):
|
10 |
source_metadata = source.metadata
|
11 |
url = source_metadata.get("source", "N/A")
|
12 |
score = source_metadata.get("score", "N/A")
|
|
|
35 |
else:
|
36 |
source_dict[url_name]["text"] += f"\n\n{source.page_content}"
|
37 |
|
38 |
+
full_answer = "" # Not to include the answer again if streaming
|
|
|
|
|
39 |
|
40 |
+
if not stream: # First, display the answer if not streaming
|
41 |
+
full_answer = "**Answer:**\n"
|
42 |
+
full_answer += answer
|
|
|
43 |
|
44 |
+
if view_sources:
|
|
|
|
|
|
|
|
|
45 |
|
46 |
+
# Then, display the sources
|
47 |
+
# check if the answer has sources
|
48 |
+
if len(source_dict) == 0:
|
49 |
+
full_answer += "\n\n**No sources found.**"
|
50 |
+
return full_answer, source_elements, source_dict
|
51 |
+
else:
|
52 |
+
full_answer += "\n\n**Sources:**\n"
|
53 |
+
for idx, (url_name, source_data) in enumerate(source_dict.items()):
|
54 |
+
|
55 |
+
full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
|
56 |
+
|
57 |
+
name = f"Source {idx + 1} Text\n"
|
58 |
+
full_answer += name
|
59 |
+
source_elements.append(
|
60 |
+
cl.Text(name=name, content=source_data["text"], display="side")
|
61 |
+
)
|
62 |
+
|
63 |
+
# Add a PDF element if the source is a PDF file
|
64 |
+
if source_data["url"].lower().endswith(".pdf"):
|
65 |
+
name = f"Source {idx + 1} PDF\n"
|
66 |
+
full_answer += name
|
67 |
+
pdf_url = f"{source_data['url']}#page={source_data['page']+1}"
|
68 |
+
source_elements.append(
|
69 |
+
cl.Pdf(name=name, url=pdf_url, display="side")
|
70 |
+
)
|
71 |
+
|
72 |
+
full_answer += "\n**Metadata:**\n"
|
73 |
+
for idx, (url_name, source_data) in enumerate(source_dict.items()):
|
74 |
+
full_answer += f"\nSource {idx + 1} Metadata:\n"
|
75 |
+
source_elements.append(
|
76 |
+
cl.Text(
|
77 |
+
name=f"Source {idx + 1} Metadata",
|
78 |
+
content=f"Source: {source_data['url']}\n"
|
79 |
+
f"Page: {source_data['page']}\n"
|
80 |
+
f"Type: {source_data['source_type']}\n"
|
81 |
+
f"Date: {source_data['date']}\n"
|
82 |
+
f"TL;DR: {source_data['lecture_tldr']}\n"
|
83 |
+
f"Lecture Recording: {source_data['lecture_recording']}\n"
|
84 |
+
f"Suggested Readings: {source_data['suggested_readings']}\n",
|
85 |
+
display="side",
|
86 |
+
)
|
87 |
+
)
|
88 |
|
89 |
return full_answer, source_elements, source_dict
|
90 |
|
91 |
|
92 |
+
def get_prompt(config, prompt_type):
|
93 |
+
llm_params = config["llm_params"]
|
94 |
+
llm_loader = llm_params["llm_loader"]
|
95 |
+
use_history = llm_params["use_history"]
|
96 |
+
llm_style = llm_params["llm_style"].lower()
|
97 |
+
|
98 |
+
if prompt_type == "qa":
|
99 |
+
if llm_loader == "local_llm":
|
100 |
+
if use_history:
|
101 |
+
return prompts["tiny_llama"]["prompt_with_history"]
|
102 |
+
else:
|
103 |
+
return prompts["tiny_llama"]["prompt_no_history"]
|
104 |
+
else:
|
105 |
+
if use_history:
|
106 |
+
return prompts["openai"]["prompt_with_history"][llm_style]
|
107 |
+
else:
|
108 |
+
return prompts["openai"]["prompt_no_history"]
|
109 |
+
elif prompt_type == "rephrase":
|
110 |
+
return prompts["openai"]["rephrase_prompt"]
|
111 |
+
|
112 |
+
|
113 |
+
def get_history_chat_resume(steps, k, SYSTEM, LLM):
|
114 |
+
conversation_list = []
|
115 |
+
count = 0
|
116 |
+
for step in reversed(steps):
|
117 |
+
if step["name"] not in [SYSTEM]:
|
118 |
+
if step["type"] == "user_message":
|
119 |
+
conversation_list.append(
|
120 |
+
{"type": "user_message", "content": step["output"]}
|
121 |
+
)
|
122 |
+
elif step["type"] == "assistant_message":
|
123 |
+
if step["name"] == LLM:
|
124 |
+
conversation_list.append(
|
125 |
+
{"type": "ai_message", "content": step["output"]}
|
126 |
+
)
|
127 |
+
else:
|
128 |
+
raise ValueError("Invalid message type")
|
129 |
+
count += 1
|
130 |
+
if count >= 2 * k: # 2 * k to account for both user and assistant messages
|
131 |
+
break
|
132 |
+
conversation_list = conversation_list[::-1]
|
133 |
+
return conversation_list
|
134 |
+
|
135 |
+
|
136 |
+
def get_history_setup_llm(memory_list):
|
137 |
+
conversation_list = []
|
138 |
+
for message in memory_list:
|
139 |
+
message_dict = message.to_dict() if hasattr(message, "to_dict") else message
|
140 |
+
|
141 |
+
# Check if the type attribute is present as a key or attribute
|
142 |
+
message_type = (
|
143 |
+
message_dict.get("type", None)
|
144 |
+
if isinstance(message_dict, dict)
|
145 |
+
else getattr(message, "type", None)
|
146 |
)
|
147 |
+
|
148 |
+
# Check if content is present as a key or attribute
|
149 |
+
message_content = (
|
150 |
+
message_dict.get("content", None)
|
151 |
+
if isinstance(message_dict, dict)
|
152 |
+
else getattr(message, "content", None)
|
|
|
|
|
|
|
|
|
153 |
)
|
154 |
+
|
155 |
+
if message_type in ["ai", "ai_message"]:
|
156 |
+
conversation_list.append({"type": "ai_message", "content": message_content})
|
157 |
+
elif message_type in ["human", "user_message"]:
|
158 |
+
conversation_list.append(
|
159 |
+
{"type": "user_message", "content": message_content}
|
160 |
+
)
|
161 |
+
else:
|
162 |
+
raise ValueError("Invalid message type")
|
163 |
+
|
164 |
+
return conversation_list
|
165 |
+
|
166 |
+
|
167 |
+
def get_last_config(steps):
|
168 |
+
# TODO: Implement this function
|
169 |
+
return None
|
code/modules/chat/langchain/langchain_rag.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.prompts import ChatPromptTemplate
|
2 |
+
|
3 |
+
from modules.chat.langchain.utils import *
|
4 |
+
from langchain.memory import ChatMessageHistory
|
5 |
+
from modules.chat.base import BaseRAG
|
6 |
+
from langchain_core.prompts import PromptTemplate
|
7 |
+
from langchain.memory import (
|
8 |
+
ConversationBufferWindowMemory,
|
9 |
+
ConversationSummaryBufferMemory,
|
10 |
+
)
|
11 |
+
|
12 |
+
import chainlit as cl
|
13 |
+
from langchain_community.chat_models import ChatOpenAI
|
14 |
+
|
15 |
+
|
16 |
+
class Langchain_RAG_V1(BaseRAG):
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
llm,
|
21 |
+
memory,
|
22 |
+
retriever,
|
23 |
+
qa_prompt: str,
|
24 |
+
rephrase_prompt: str,
|
25 |
+
config: dict,
|
26 |
+
callbacks=None,
|
27 |
+
):
|
28 |
+
"""
|
29 |
+
Initialize the Langchain_RAG class.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
llm (LanguageModelLike): The language model instance.
|
33 |
+
memory (BaseChatMessageHistory): The chat message history instance.
|
34 |
+
retriever (BaseRetriever): The retriever instance.
|
35 |
+
qa_prompt (str): The QA prompt string.
|
36 |
+
rephrase_prompt (str): The rephrase prompt string.
|
37 |
+
"""
|
38 |
+
self.llm = llm
|
39 |
+
self.config = config
|
40 |
+
# self.memory = self.add_history_from_list(memory)
|
41 |
+
self.memory = ConversationBufferWindowMemory(
|
42 |
+
k=self.config["llm_params"]["memory_window"],
|
43 |
+
memory_key="chat_history",
|
44 |
+
return_messages=True,
|
45 |
+
output_key="answer",
|
46 |
+
max_token_limit=128,
|
47 |
+
)
|
48 |
+
self.retriever = retriever
|
49 |
+
self.qa_prompt = qa_prompt
|
50 |
+
self.rephrase_prompt = rephrase_prompt
|
51 |
+
self.store = {}
|
52 |
+
|
53 |
+
self.qa_prompt = PromptTemplate(
|
54 |
+
template=self.qa_prompt,
|
55 |
+
input_variables=["context", "chat_history", "input"],
|
56 |
+
)
|
57 |
+
|
58 |
+
self.rag_chain = CustomConversationalRetrievalChain.from_llm(
|
59 |
+
llm=llm,
|
60 |
+
chain_type="stuff",
|
61 |
+
retriever=retriever,
|
62 |
+
return_source_documents=True,
|
63 |
+
memory=self.memory,
|
64 |
+
combine_docs_chain_kwargs={"prompt": self.qa_prompt},
|
65 |
+
response_if_no_docs_found="No context found",
|
66 |
+
)
|
67 |
+
|
68 |
+
def add_history_from_list(self, history_list):
|
69 |
+
"""
|
70 |
+
TODO: Add messages from a list to the chat history.
|
71 |
+
"""
|
72 |
+
history = []
|
73 |
+
|
74 |
+
return history
|
75 |
+
|
76 |
+
async def invoke(self, user_query, config):
|
77 |
+
"""
|
78 |
+
Invoke the chain.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
kwargs: The input variables.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
dict: The output variables.
|
85 |
+
"""
|
86 |
+
res = await self.rag_chain.acall(user_query["input"])
|
87 |
+
return res
|
88 |
+
|
89 |
+
|
90 |
+
class QuestionGenerator:
|
91 |
+
"""
|
92 |
+
Generate a question from the LLMs response and users input and past conversations.
|
93 |
+
"""
|
94 |
+
|
95 |
+
def __init__(self):
|
96 |
+
pass
|
97 |
+
|
98 |
+
def generate_questions(self, query, response, chat_history, context):
|
99 |
+
questions = return_questions(query, response, chat_history, context)
|
100 |
+
return questions
|
101 |
+
|
102 |
+
|
103 |
+
class Langchain_RAG_V2(BaseRAG):
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
llm,
|
107 |
+
memory,
|
108 |
+
retriever,
|
109 |
+
qa_prompt: str,
|
110 |
+
rephrase_prompt: str,
|
111 |
+
config: dict,
|
112 |
+
callbacks=None,
|
113 |
+
):
|
114 |
+
"""
|
115 |
+
Initialize the Langchain_RAG class.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
llm (LanguageModelLike): The language model instance.
|
119 |
+
memory (BaseChatMessageHistory): The chat message history instance.
|
120 |
+
retriever (BaseRetriever): The retriever instance.
|
121 |
+
qa_prompt (str): The QA prompt string.
|
122 |
+
rephrase_prompt (str): The rephrase prompt string.
|
123 |
+
"""
|
124 |
+
self.llm = llm
|
125 |
+
self.memory = self.add_history_from_list(memory)
|
126 |
+
self.retriever = retriever
|
127 |
+
self.qa_prompt = qa_prompt
|
128 |
+
self.rephrase_prompt = rephrase_prompt
|
129 |
+
self.store = {}
|
130 |
+
|
131 |
+
# Contextualize question prompt
|
132 |
+
contextualize_q_system_prompt = rephrase_prompt or (
|
133 |
+
"Given a chat history and the latest user question "
|
134 |
+
"which might reference context in the chat history, "
|
135 |
+
"formulate a standalone question which can be understood "
|
136 |
+
"without the chat history. Do NOT answer the question, just "
|
137 |
+
"reformulate it if needed and otherwise return it as is."
|
138 |
+
)
|
139 |
+
self.contextualize_q_prompt = ChatPromptTemplate.from_template(
|
140 |
+
contextualize_q_system_prompt
|
141 |
+
)
|
142 |
+
|
143 |
+
# History-aware retriever
|
144 |
+
self.history_aware_retriever = create_history_aware_retriever(
|
145 |
+
self.llm, self.retriever, self.contextualize_q_prompt
|
146 |
+
)
|
147 |
+
|
148 |
+
# Answer question prompt
|
149 |
+
qa_system_prompt = qa_prompt or (
|
150 |
+
"You are an assistant for question-answering tasks. Use "
|
151 |
+
"the following pieces of retrieved context to answer the "
|
152 |
+
"question. If you don't know the answer, just say that you "
|
153 |
+
"don't know. Use three sentences maximum and keep the answer "
|
154 |
+
"concise."
|
155 |
+
"\n\n"
|
156 |
+
"{context}"
|
157 |
+
)
|
158 |
+
self.qa_prompt_template = ChatPromptTemplate.from_template(qa_system_prompt)
|
159 |
+
|
160 |
+
# Question-answer chain
|
161 |
+
self.question_answer_chain = create_stuff_documents_chain(
|
162 |
+
self.llm, self.qa_prompt_template
|
163 |
+
)
|
164 |
+
|
165 |
+
# Final retrieval chain
|
166 |
+
self.rag_chain = create_retrieval_chain(
|
167 |
+
self.history_aware_retriever, self.question_answer_chain
|
168 |
+
)
|
169 |
+
|
170 |
+
self.rag_chain = CustomRunnableWithHistory(
|
171 |
+
self.rag_chain,
|
172 |
+
get_session_history=self.get_session_history,
|
173 |
+
input_messages_key="input",
|
174 |
+
history_messages_key="chat_history",
|
175 |
+
output_messages_key="answer",
|
176 |
+
history_factory_config=[
|
177 |
+
ConfigurableFieldSpec(
|
178 |
+
id="user_id",
|
179 |
+
annotation=str,
|
180 |
+
name="User ID",
|
181 |
+
description="Unique identifier for the user.",
|
182 |
+
default="",
|
183 |
+
is_shared=True,
|
184 |
+
),
|
185 |
+
ConfigurableFieldSpec(
|
186 |
+
id="conversation_id",
|
187 |
+
annotation=str,
|
188 |
+
name="Conversation ID",
|
189 |
+
description="Unique identifier for the conversation.",
|
190 |
+
default="",
|
191 |
+
is_shared=True,
|
192 |
+
),
|
193 |
+
ConfigurableFieldSpec(
|
194 |
+
id="memory_window",
|
195 |
+
annotation=int,
|
196 |
+
name="Number of Conversations",
|
197 |
+
description="Number of conversations to consider for context.",
|
198 |
+
default=1,
|
199 |
+
is_shared=True,
|
200 |
+
),
|
201 |
+
],
|
202 |
+
)
|
203 |
+
|
204 |
+
if callbacks is not None:
|
205 |
+
self.rag_chain = self.rag_chain.with_config(callbacks=callbacks)
|
206 |
+
|
207 |
+
def get_session_history(
|
208 |
+
self, user_id: str, conversation_id: str, memory_window: int
|
209 |
+
) -> BaseChatMessageHistory:
|
210 |
+
"""
|
211 |
+
Get the session history for a user and conversation.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
user_id (str): The user identifier.
|
215 |
+
conversation_id (str): The conversation identifier.
|
216 |
+
memory_window (int): The number of conversations to consider for context.
|
217 |
+
|
218 |
+
Returns:
|
219 |
+
BaseChatMessageHistory: The chat message history.
|
220 |
+
"""
|
221 |
+
if (user_id, conversation_id) not in self.store:
|
222 |
+
self.store[(user_id, conversation_id)] = InMemoryHistory()
|
223 |
+
self.store[(user_id, conversation_id)].add_messages(
|
224 |
+
self.memory.messages
|
225 |
+
) # add previous messages to the store. Note: the store is in-memory.
|
226 |
+
return self.store[(user_id, conversation_id)]
|
227 |
+
|
228 |
+
async def invoke(self, user_query, config, **kwargs):
|
229 |
+
"""
|
230 |
+
Invoke the chain.
|
231 |
+
|
232 |
+
Args:
|
233 |
+
kwargs: The input variables.
|
234 |
+
|
235 |
+
Returns:
|
236 |
+
dict: The output variables.
|
237 |
+
"""
|
238 |
+
res = await self.rag_chain.ainvoke(user_query, config, **kwargs)
|
239 |
+
res["rephrase_prompt"] = self.rephrase_prompt
|
240 |
+
res["qa_prompt"] = self.qa_prompt
|
241 |
+
return res
|
242 |
+
|
243 |
+
def stream(self, user_query, config):
|
244 |
+
res = self.rag_chain.stream(user_query, config)
|
245 |
+
return res
|
246 |
+
|
247 |
+
def add_history_from_list(self, conversation_list):
|
248 |
+
"""
|
249 |
+
Add messages from a list to the chat history.
|
250 |
+
|
251 |
+
Args:
|
252 |
+
messages (list): The list of messages to add.
|
253 |
+
"""
|
254 |
+
history = ChatMessageHistory()
|
255 |
+
|
256 |
+
for idx, message in enumerate(conversation_list):
|
257 |
+
message_type = (
|
258 |
+
message.get("type", None)
|
259 |
+
if isinstance(message, dict)
|
260 |
+
else getattr(message, "type", None)
|
261 |
+
)
|
262 |
+
|
263 |
+
message_content = (
|
264 |
+
message.get("content", None)
|
265 |
+
if isinstance(message, dict)
|
266 |
+
else getattr(message, "content", None)
|
267 |
+
)
|
268 |
+
|
269 |
+
if message_type in ["human", "user_message"]:
|
270 |
+
history.add_user_message(message_content)
|
271 |
+
elif message_type in ["ai", "ai_message"]:
|
272 |
+
history.add_ai_message(message_content)
|
273 |
+
|
274 |
+
return history
|
code/modules/chat/langchain/utils.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Union, Tuple, Optional
|
2 |
+
from langchain_core.messages import (
|
3 |
+
BaseMessage,
|
4 |
+
AIMessage,
|
5 |
+
FunctionMessage,
|
6 |
+
HumanMessage,
|
7 |
+
)
|
8 |
+
|
9 |
+
from langchain_core.prompts.base import BasePromptTemplate, format_document
|
10 |
+
from langchain_core.prompts.chat import MessagesPlaceholder
|
11 |
+
from langchain_core.output_parsers import StrOutputParser
|
12 |
+
from langchain_core.output_parsers.base import BaseOutputParser
|
13 |
+
from langchain_core.retrievers import BaseRetriever, RetrieverOutput
|
14 |
+
from langchain_core.language_models import LanguageModelLike
|
15 |
+
from langchain_core.runnables import Runnable, RunnableBranch, RunnablePassthrough
|
16 |
+
from langchain_core.runnables.history import RunnableWithMessageHistory
|
17 |
+
from langchain_core.runnables.utils import ConfigurableFieldSpec
|
18 |
+
from langchain_core.chat_history import BaseChatMessageHistory
|
19 |
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
20 |
+
from langchain.chains.combine_documents.base import (
|
21 |
+
DEFAULT_DOCUMENT_PROMPT,
|
22 |
+
DEFAULT_DOCUMENT_SEPARATOR,
|
23 |
+
DOCUMENTS_KEY,
|
24 |
+
BaseCombineDocumentsChain,
|
25 |
+
_validate_prompt,
|
26 |
+
)
|
27 |
+
from langchain.chains.llm import LLMChain
|
28 |
+
from langchain_core.callbacks import Callbacks
|
29 |
+
from langchain_core.documents import Document
|
30 |
+
|
31 |
+
|
32 |
+
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
|
33 |
+
|
34 |
+
from langchain_core.runnables.config import RunnableConfig
|
35 |
+
from langchain_core.messages import BaseMessage
|
36 |
+
|
37 |
+
|
38 |
+
from langchain_core.output_parsers import StrOutputParser
|
39 |
+
from langchain_core.prompts import ChatPromptTemplate
|
40 |
+
from langchain_community.chat_models import ChatOpenAI
|
41 |
+
|
42 |
+
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
|
43 |
+
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
|
44 |
+
|
45 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
46 |
+
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
|
47 |
+
import inspect
|
48 |
+
from langchain.chains.conversational_retrieval.base import _get_chat_history
|
49 |
+
from langchain_core.messages import BaseMessage
|
50 |
+
|
51 |
+
|
52 |
+
class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
53 |
+
|
54 |
+
def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
|
55 |
+
_ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
|
56 |
+
buffer = ""
|
57 |
+
for dialogue_turn in chat_history:
|
58 |
+
if isinstance(dialogue_turn, BaseMessage):
|
59 |
+
role_prefix = _ROLE_MAP.get(
|
60 |
+
dialogue_turn.type, f"{dialogue_turn.type}: "
|
61 |
+
)
|
62 |
+
buffer += f"\n{role_prefix}{dialogue_turn.content}"
|
63 |
+
elif isinstance(dialogue_turn, tuple):
|
64 |
+
human = "Student: " + dialogue_turn[0]
|
65 |
+
ai = "AI Tutor: " + dialogue_turn[1]
|
66 |
+
buffer += "\n" + "\n".join([human, ai])
|
67 |
+
else:
|
68 |
+
raise ValueError(
|
69 |
+
f"Unsupported chat history format: {type(dialogue_turn)}."
|
70 |
+
f" Full chat history: {chat_history} "
|
71 |
+
)
|
72 |
+
return buffer
|
73 |
+
|
74 |
+
async def _acall(
|
75 |
+
self,
|
76 |
+
inputs: Dict[str, Any],
|
77 |
+
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
78 |
+
) -> Dict[str, Any]:
|
79 |
+
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
80 |
+
question = inputs["question"]
|
81 |
+
get_chat_history = self._get_chat_history
|
82 |
+
chat_history_str = get_chat_history(inputs["chat_history"])
|
83 |
+
if chat_history_str:
|
84 |
+
# callbacks = _run_manager.get_child()
|
85 |
+
# new_question = await self.question_generator.arun(
|
86 |
+
# question=question, chat_history=chat_history_str, callbacks=callbacks
|
87 |
+
# )
|
88 |
+
system = (
|
89 |
+
"You are someone that rephrases statements. Rephrase the student's question to add context from their chat history if relevant, ensuring it remains from the student's point of view. "
|
90 |
+
"Incorporate relevant details from the chat history to make the question clearer and more specific."
|
91 |
+
"Do not change the meaning of the original statement, and maintain the student's tone and perspective. "
|
92 |
+
"If the question is conversational and doesn't require context, do not rephrase it. "
|
93 |
+
"Example: If the student previously asked about backpropagation in the context of deep learning and now asks 'what is it', rephrase to 'What is backprogatation.'. "
|
94 |
+
"Example: Do not rephrase if the user is asking something specific like 'cool, suggest a project with transformers to use as my final project'"
|
95 |
+
"Chat history: \n{chat_history_str}\n"
|
96 |
+
"Rephrase the following question only if necessary: '{input}'"
|
97 |
+
)
|
98 |
+
|
99 |
+
prompt = ChatPromptTemplate.from_messages(
|
100 |
+
[
|
101 |
+
("system", system),
|
102 |
+
("human", "{input}, {chat_history_str}"),
|
103 |
+
]
|
104 |
+
)
|
105 |
+
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
106 |
+
step_back = prompt | llm | StrOutputParser()
|
107 |
+
new_question = step_back.invoke(
|
108 |
+
{"input": question, "chat_history_str": chat_history_str}
|
109 |
+
)
|
110 |
+
else:
|
111 |
+
new_question = question
|
112 |
+
accepts_run_manager = (
|
113 |
+
"run_manager" in inspect.signature(self._aget_docs).parameters
|
114 |
+
)
|
115 |
+
if accepts_run_manager:
|
116 |
+
docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager)
|
117 |
+
else:
|
118 |
+
docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
|
119 |
+
|
120 |
+
output: Dict[str, Any] = {}
|
121 |
+
output["original_question"] = question
|
122 |
+
if self.response_if_no_docs_found is not None and len(docs) == 0:
|
123 |
+
output[self.output_key] = self.response_if_no_docs_found
|
124 |
+
else:
|
125 |
+
new_inputs = inputs.copy()
|
126 |
+
if self.rephrase_question:
|
127 |
+
new_inputs["question"] = new_question
|
128 |
+
new_inputs["chat_history"] = chat_history_str
|
129 |
+
|
130 |
+
# Prepare the final prompt with metadata
|
131 |
+
context = "\n\n".join(
|
132 |
+
[
|
133 |
+
f"Context {idx+1}: \n(Document content: {doc.page_content}\nMetadata: (source_file: {doc.metadata['source'] if 'source' in doc.metadata else 'unknown'}))"
|
134 |
+
for idx, doc in enumerate(docs)
|
135 |
+
]
|
136 |
+
)
|
137 |
+
final_prompt = (
|
138 |
+
"You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance."
|
139 |
+
"If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
|
140 |
+
"Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevent."
|
141 |
+
"Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
|
142 |
+
f"Chat History:\n{chat_history_str}\n\n"
|
143 |
+
f"Context:\n{context}\n\n"
|
144 |
+
"Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
|
145 |
+
f"Student: {input}\n"
|
146 |
+
"AI Tutor:"
|
147 |
+
)
|
148 |
+
|
149 |
+
new_inputs["input"] = final_prompt
|
150 |
+
# new_inputs["question"] = final_prompt
|
151 |
+
# output["final_prompt"] = final_prompt
|
152 |
+
|
153 |
+
answer = await self.combine_docs_chain.arun(
|
154 |
+
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
|
155 |
+
)
|
156 |
+
output[self.output_key] = answer
|
157 |
+
|
158 |
+
if self.return_source_documents:
|
159 |
+
output["source_documents"] = docs
|
160 |
+
output["rephrased_question"] = new_question
|
161 |
+
output["context"] = output["source_documents"]
|
162 |
+
return output
|
163 |
+
|
164 |
+
|
165 |
+
class CustomRunnableWithHistory(RunnableWithMessageHistory):
|
166 |
+
|
167 |
+
def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
|
168 |
+
_ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
|
169 |
+
buffer = ""
|
170 |
+
for dialogue_turn in chat_history:
|
171 |
+
if isinstance(dialogue_turn, BaseMessage):
|
172 |
+
role_prefix = _ROLE_MAP.get(
|
173 |
+
dialogue_turn.type, f"{dialogue_turn.type}: "
|
174 |
+
)
|
175 |
+
buffer += f"\n{role_prefix}{dialogue_turn.content}"
|
176 |
+
elif isinstance(dialogue_turn, tuple):
|
177 |
+
human = "Student: " + dialogue_turn[0]
|
178 |
+
ai = "AI Tutor: " + dialogue_turn[1]
|
179 |
+
buffer += "\n" + "\n".join([human, ai])
|
180 |
+
else:
|
181 |
+
raise ValueError(
|
182 |
+
f"Unsupported chat history format: {type(dialogue_turn)}."
|
183 |
+
f" Full chat history: {chat_history} "
|
184 |
+
)
|
185 |
+
return buffer
|
186 |
+
|
187 |
+
async def _aenter_history(
|
188 |
+
self, input: Any, config: RunnableConfig
|
189 |
+
) -> List[BaseMessage]:
|
190 |
+
"""
|
191 |
+
Get the last k conversations from the message history.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
input (Any): The input data.
|
195 |
+
config (RunnableConfig): The runnable configuration.
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
List[BaseMessage]: The last k conversations.
|
199 |
+
"""
|
200 |
+
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
|
201 |
+
messages = (await hist.aget_messages()).copy()
|
202 |
+
if not self.history_messages_key:
|
203 |
+
# return all messages
|
204 |
+
input_val = (
|
205 |
+
input if not self.input_messages_key else input[self.input_messages_key]
|
206 |
+
)
|
207 |
+
messages += self._get_input_messages(input_val)
|
208 |
+
|
209 |
+
# return last k conversations
|
210 |
+
if config["configurable"]["memory_window"] == 0: # if k is 0, return empty list
|
211 |
+
messages = []
|
212 |
+
else:
|
213 |
+
messages = messages[-2 * config["configurable"]["memory_window"] :]
|
214 |
+
|
215 |
+
messages = self._get_chat_history(messages)
|
216 |
+
|
217 |
+
return messages
|
218 |
+
|
219 |
+
|
220 |
+
class InMemoryHistory(BaseChatMessageHistory, BaseModel):
|
221 |
+
"""In-memory implementation of chat message history."""
|
222 |
+
|
223 |
+
messages: List[BaseMessage] = Field(default_factory=list)
|
224 |
+
|
225 |
+
def add_messages(self, messages: List[BaseMessage]) -> None:
|
226 |
+
"""Add a list of messages to the store."""
|
227 |
+
self.messages.extend(messages)
|
228 |
+
|
229 |
+
def clear(self) -> None:
|
230 |
+
"""Clear the message history."""
|
231 |
+
self.messages = []
|
232 |
+
|
233 |
+
def __len__(self) -> int:
|
234 |
+
"""Return the number of messages."""
|
235 |
+
return len(self.messages)
|
236 |
+
|
237 |
+
|
238 |
+
def create_history_aware_retriever(
|
239 |
+
llm: LanguageModelLike,
|
240 |
+
retriever: BaseRetriever,
|
241 |
+
prompt: BasePromptTemplate,
|
242 |
+
) -> Runnable[Dict[str, Any], RetrieverOutput]:
|
243 |
+
"""Create a chain that takes conversation history and returns documents."""
|
244 |
+
if "input" not in prompt.input_variables:
|
245 |
+
raise ValueError(
|
246 |
+
"Expected `input` to be a prompt variable, "
|
247 |
+
f"but got {prompt.input_variables}"
|
248 |
+
)
|
249 |
+
|
250 |
+
retrieve_documents = RunnableBranch(
|
251 |
+
(
|
252 |
+
lambda x: not x["chat_history"],
|
253 |
+
(lambda x: x["input"]) | retriever,
|
254 |
+
),
|
255 |
+
prompt | llm | StrOutputParser() | retriever,
|
256 |
+
).with_config(run_name="chat_retriever_chain")
|
257 |
+
|
258 |
+
return retrieve_documents
|
259 |
+
|
260 |
+
|
261 |
+
def create_stuff_documents_chain(
|
262 |
+
llm: LanguageModelLike,
|
263 |
+
prompt: BasePromptTemplate,
|
264 |
+
output_parser: Optional[BaseOutputParser] = None,
|
265 |
+
document_prompt: Optional[BasePromptTemplate] = None,
|
266 |
+
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
|
267 |
+
) -> Runnable[Dict[str, Any], Any]:
|
268 |
+
"""Create a chain for passing a list of Documents to a model."""
|
269 |
+
_validate_prompt(prompt)
|
270 |
+
_document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT
|
271 |
+
_output_parser = output_parser or StrOutputParser()
|
272 |
+
|
273 |
+
def format_docs(inputs: dict) -> str:
|
274 |
+
return document_separator.join(
|
275 |
+
format_document(doc, _document_prompt) for doc in inputs[DOCUMENTS_KEY]
|
276 |
+
)
|
277 |
+
|
278 |
+
return (
|
279 |
+
RunnablePassthrough.assign(**{DOCUMENTS_KEY: format_docs}).with_config(
|
280 |
+
run_name="format_inputs"
|
281 |
+
)
|
282 |
+
| prompt
|
283 |
+
| llm
|
284 |
+
| _output_parser
|
285 |
+
).with_config(run_name="stuff_documents_chain")
|
286 |
+
|
287 |
+
|
288 |
+
def create_retrieval_chain(
|
289 |
+
retriever: Union[BaseRetriever, Runnable[dict, RetrieverOutput]],
|
290 |
+
combine_docs_chain: Runnable[Dict[str, Any], str],
|
291 |
+
) -> Runnable:
|
292 |
+
"""Create retrieval chain that retrieves documents and then passes them on."""
|
293 |
+
if not isinstance(retriever, BaseRetriever):
|
294 |
+
retrieval_docs = retriever
|
295 |
+
else:
|
296 |
+
retrieval_docs = (lambda x: x["input"]) | retriever
|
297 |
+
|
298 |
+
retrieval_chain = (
|
299 |
+
RunnablePassthrough.assign(
|
300 |
+
context=retrieval_docs.with_config(run_name="retrieve_documents"),
|
301 |
+
).assign(answer=combine_docs_chain)
|
302 |
+
).with_config(run_name="retrieval_chain")
|
303 |
+
|
304 |
+
return retrieval_chain
|
305 |
+
|
306 |
+
|
307 |
+
def return_questions(query, response, chat_history_str, context):
|
308 |
+
|
309 |
+
system = (
|
310 |
+
"You are someone that suggests a question based on the student's input and chat history. "
|
311 |
+
"Generate a question that is relevant to the student's input and chat history. "
|
312 |
+
"Incorporate relevant details from the chat history to make the question clearer and more specific. "
|
313 |
+
"Chat history: \n{chat_history_str}\n"
|
314 |
+
"Use the context to generate a question that is relevant to the student's input and chat history: Context: {context}"
|
315 |
+
"Generate 3 short and concise questions from the students voice based on the following input and response: "
|
316 |
+
"The 3 short and concise questions should be sperated by dots. Example: 'What is the capital of France?...What is the population of France?...What is the currency of France?'"
|
317 |
+
"User Query: {query}"
|
318 |
+
"AI Response: {response}"
|
319 |
+
"The 3 short and concise questions seperated by dots (...) are:"
|
320 |
+
)
|
321 |
+
|
322 |
+
prompt = ChatPromptTemplate.from_messages(
|
323 |
+
[
|
324 |
+
("system", system),
|
325 |
+
("human", "{chat_history_str}, {context}, {query}, {response}"),
|
326 |
+
]
|
327 |
+
)
|
328 |
+
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
329 |
+
question_generator = prompt | llm | StrOutputParser()
|
330 |
+
new_questions = question_generator.invoke(
|
331 |
+
{
|
332 |
+
"chat_history_str": chat_history_str,
|
333 |
+
"context": context,
|
334 |
+
"query": query,
|
335 |
+
"response": response,
|
336 |
+
}
|
337 |
+
)
|
338 |
+
|
339 |
+
list_of_questions = new_questions.split("...")
|
340 |
+
return list_of_questions
|
code/modules/chat/llm_tutor.py
CHANGED
@@ -1,211 +1,163 @@
|
|
1 |
-
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
|
2 |
-
from langchain.memory import (
|
3 |
-
ConversationBufferWindowMemory,
|
4 |
-
ConversationSummaryBufferMemory,
|
5 |
-
)
|
6 |
-
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
|
7 |
-
import os
|
8 |
-
from modules.config.constants import *
|
9 |
from modules.chat.helpers import get_prompt
|
10 |
from modules.chat.chat_model_loader import ChatModelLoader
|
11 |
from modules.vectorstore.store_manager import VectorStoreManager
|
12 |
-
|
13 |
from modules.retriever.retriever import Retriever
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
from langchain_core.messages import BaseMessage
|
20 |
-
|
21 |
-
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
|
22 |
-
|
23 |
-
from langchain_core.output_parsers import StrOutputParser
|
24 |
-
from langchain_core.prompts import ChatPromptTemplate
|
25 |
-
from langchain_community.chat_models import ChatOpenAI
|
26 |
-
|
27 |
-
|
28 |
-
class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
29 |
-
|
30 |
-
def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
|
31 |
-
_ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
|
32 |
-
buffer = ""
|
33 |
-
for dialogue_turn in chat_history:
|
34 |
-
if isinstance(dialogue_turn, BaseMessage):
|
35 |
-
role_prefix = _ROLE_MAP.get(
|
36 |
-
dialogue_turn.type, f"{dialogue_turn.type}: "
|
37 |
-
)
|
38 |
-
buffer += f"\n{role_prefix}{dialogue_turn.content}"
|
39 |
-
elif isinstance(dialogue_turn, tuple):
|
40 |
-
human = "Student: " + dialogue_turn[0]
|
41 |
-
ai = "AI Tutor: " + dialogue_turn[1]
|
42 |
-
buffer += "\n" + "\n".join([human, ai])
|
43 |
-
else:
|
44 |
-
raise ValueError(
|
45 |
-
f"Unsupported chat history format: {type(dialogue_turn)}."
|
46 |
-
f" Full chat history: {chat_history} "
|
47 |
-
)
|
48 |
-
return buffer
|
49 |
-
|
50 |
-
async def _acall(
|
51 |
-
self,
|
52 |
-
inputs: Dict[str, Any],
|
53 |
-
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
54 |
-
) -> Dict[str, Any]:
|
55 |
-
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
56 |
-
question = inputs["question"]
|
57 |
-
get_chat_history = self._get_chat_history
|
58 |
-
chat_history_str = get_chat_history(inputs["chat_history"])
|
59 |
-
if chat_history_str:
|
60 |
-
# callbacks = _run_manager.get_child()
|
61 |
-
# new_question = await self.question_generator.arun(
|
62 |
-
# question=question, chat_history=chat_history_str, callbacks=callbacks
|
63 |
-
# )
|
64 |
-
system = (
|
65 |
-
"You are someone that rephrases statements. Rephrase the student's question to add context from their chat history if relevant, ensuring it remains from the student's point of view. "
|
66 |
-
"Incorporate relevant details from the chat history to make the question clearer and more specific."
|
67 |
-
"Do not change the meaning of the original statement, and maintain the student's tone and perspective. "
|
68 |
-
"If the question is conversational and doesn't require context, do not rephrase it. "
|
69 |
-
"Example: If the student previously asked about backpropagation in the context of deep learning and now asks 'what is it', rephrase to 'What is backprogatation.'. "
|
70 |
-
"Example: Do not rephrase if the user is asking something specific like 'cool, suggest a project with transformers to use as my final project'"
|
71 |
-
"Chat history: \n{chat_history_str}\n"
|
72 |
-
"Rephrase the following question only if necessary: '{question}'"
|
73 |
-
)
|
74 |
-
|
75 |
-
prompt = ChatPromptTemplate.from_messages(
|
76 |
-
[
|
77 |
-
("system", system),
|
78 |
-
("human", "{question}, {chat_history_str}"),
|
79 |
-
]
|
80 |
-
)
|
81 |
-
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
82 |
-
step_back = prompt | llm | StrOutputParser()
|
83 |
-
new_question = step_back.invoke(
|
84 |
-
{"question": question, "chat_history_str": chat_history_str}
|
85 |
-
)
|
86 |
-
else:
|
87 |
-
new_question = question
|
88 |
-
accepts_run_manager = (
|
89 |
-
"run_manager" in inspect.signature(self._aget_docs).parameters
|
90 |
-
)
|
91 |
-
if accepts_run_manager:
|
92 |
-
docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager)
|
93 |
-
else:
|
94 |
-
docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
|
95 |
-
|
96 |
-
output: Dict[str, Any] = {}
|
97 |
-
output["original_question"] = question
|
98 |
-
if self.response_if_no_docs_found is not None and len(docs) == 0:
|
99 |
-
output[self.output_key] = self.response_if_no_docs_found
|
100 |
-
else:
|
101 |
-
new_inputs = inputs.copy()
|
102 |
-
if self.rephrase_question:
|
103 |
-
new_inputs["question"] = new_question
|
104 |
-
new_inputs["chat_history"] = chat_history_str
|
105 |
-
|
106 |
-
# Prepare the final prompt with metadata
|
107 |
-
context = "\n\n".join(
|
108 |
-
[
|
109 |
-
f"Context {idx+1}: \n(Document content: {doc.page_content}\nMetadata: (source_file: {doc.metadata['source'] if 'source' in doc.metadata else 'unknown'}))"
|
110 |
-
for idx, doc in enumerate(docs)
|
111 |
-
]
|
112 |
-
)
|
113 |
-
final_prompt = (
|
114 |
-
"You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance."
|
115 |
-
"If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
|
116 |
-
"Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevent."
|
117 |
-
"Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
|
118 |
-
f"Chat History:\n{chat_history_str}\n\n"
|
119 |
-
f"Context:\n{context}\n\n"
|
120 |
-
"Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
|
121 |
-
f"Student: {question}\n"
|
122 |
-
"AI Tutor:"
|
123 |
-
)
|
124 |
-
|
125 |
-
# new_inputs["input"] = final_prompt
|
126 |
-
new_inputs["question"] = final_prompt
|
127 |
-
# output["final_prompt"] = final_prompt
|
128 |
-
|
129 |
-
answer = await self.combine_docs_chain.arun(
|
130 |
-
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
|
131 |
-
)
|
132 |
-
output[self.output_key] = answer
|
133 |
-
|
134 |
-
if self.return_source_documents:
|
135 |
-
output["source_documents"] = docs
|
136 |
-
output["rephrased_question"] = new_question
|
137 |
-
return output
|
138 |
|
139 |
|
140 |
class LLMTutor:
|
141 |
-
def __init__(self, config, logger=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
self.config = config
|
143 |
self.llm = self.load_llm()
|
|
|
144 |
self.logger = logger
|
145 |
-
self.vector_db = VectorStoreManager(config, logger=self.logger)
|
|
|
|
|
|
|
|
|
146 |
if self.config["vectorstore"]["embedd_files"]:
|
147 |
self.vector_db.create_database()
|
148 |
self.vector_db.save_database()
|
149 |
|
150 |
-
def
|
151 |
"""
|
152 |
-
|
|
|
|
|
|
|
153 |
"""
|
154 |
-
|
155 |
-
|
|
|
|
|
156 |
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
-
|
160 |
-
|
|
|
|
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
retriever = Retriever(self.config)._return_retriever(db)
|
163 |
|
164 |
-
if self.config["llm_params"]["
|
165 |
-
|
166 |
-
k=self.config["llm_params"]["memory_window"],
|
167 |
-
memory_key="chat_history",
|
168 |
-
return_messages=True,
|
169 |
-
output_key="answer",
|
170 |
-
max_token_limit=128,
|
171 |
-
)
|
172 |
-
qa_chain = CustomConversationalRetrievalChain.from_llm(
|
173 |
llm=llm,
|
174 |
-
chain_type="stuff",
|
175 |
-
retriever=retriever,
|
176 |
-
return_source_documents=True,
|
177 |
memory=memory,
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
180 |
)
|
|
|
|
|
181 |
else:
|
182 |
-
|
183 |
-
|
184 |
-
chain_type="stuff",
|
185 |
-
retriever=retriever,
|
186 |
-
return_source_documents=True,
|
187 |
-
chain_type_kwargs={"prompt": prompt},
|
188 |
)
|
189 |
-
return qa_chain
|
190 |
|
191 |
-
# Loading the model
|
192 |
def load_llm(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
chat_model_loader = ChatModelLoader(self.config)
|
194 |
llm = chat_model_loader.load_chat_model()
|
195 |
return llm
|
196 |
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
qa = self.retrieval_qa_chain(
|
202 |
-
self.llm,
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
return qa
|
206 |
-
|
207 |
-
# output function
|
208 |
-
def final_result(query):
|
209 |
-
qa_result = qa_bot()
|
210 |
-
response = qa_result({"query": query})
|
211 |
-
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from modules.chat.helpers import get_prompt
|
2 |
from modules.chat.chat_model_loader import ChatModelLoader
|
3 |
from modules.vectorstore.store_manager import VectorStoreManager
|
|
|
4 |
from modules.retriever.retriever import Retriever
|
5 |
+
from modules.chat.langchain.langchain_rag import (
|
6 |
+
Langchain_RAG_V1,
|
7 |
+
Langchain_RAG_V2,
|
8 |
+
QuestionGenerator,
|
9 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
class LLMTutor:
|
13 |
+
def __init__(self, config, user, logger=None):
|
14 |
+
"""
|
15 |
+
Initialize the LLMTutor class.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
config (dict): Configuration dictionary.
|
19 |
+
user (str): User identifier.
|
20 |
+
logger (Logger, optional): Logger instance. Defaults to None.
|
21 |
+
"""
|
22 |
self.config = config
|
23 |
self.llm = self.load_llm()
|
24 |
+
self.user = user
|
25 |
self.logger = logger
|
26 |
+
self.vector_db = VectorStoreManager(config, logger=self.logger).load_database()
|
27 |
+
self.qa_prompt = get_prompt(config, "qa") # Initialize qa_prompt
|
28 |
+
self.rephrase_prompt = get_prompt(
|
29 |
+
config, "rephrase"
|
30 |
+
) # Initialize rephrase_prompt
|
31 |
if self.config["vectorstore"]["embedd_files"]:
|
32 |
self.vector_db.create_database()
|
33 |
self.vector_db.save_database()
|
34 |
|
35 |
+
def update_llm(self, old_config, new_config):
|
36 |
"""
|
37 |
+
Update the LLM and VectorStoreManager based on new configuration.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
new_config (dict): New configuration dictionary.
|
41 |
"""
|
42 |
+
changes = self.get_config_changes(old_config, new_config)
|
43 |
+
|
44 |
+
if "llm_params.llm_loader" in changes:
|
45 |
+
self.llm = self.load_llm() # Reinitialize LLM if chat_model changes
|
46 |
|
47 |
+
if "vectorstore.db_option" in changes:
|
48 |
+
self.vector_db = VectorStoreManager(
|
49 |
+
self.config, logger=self.logger
|
50 |
+
).load_database() # Reinitialize VectorStoreManager if vectorstore changes
|
51 |
+
if self.config["vectorstore"]["embedd_files"]:
|
52 |
+
self.vector_db.create_database()
|
53 |
+
self.vector_db.save_database()
|
54 |
|
55 |
+
if "llm_params.llm_style" in changes:
|
56 |
+
self.qa_prompt = get_prompt(
|
57 |
+
self.config, "qa"
|
58 |
+
) # Update qa_prompt if ELI5 changes
|
59 |
|
60 |
+
def get_config_changes(self, old_config, new_config):
|
61 |
+
"""
|
62 |
+
Get the changes between the old and new configuration.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
old_config (dict): Old configuration dictionary.
|
66 |
+
new_config (dict): New configuration dictionary.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
dict: Dictionary containing the changes.
|
70 |
+
"""
|
71 |
+
changes = {}
|
72 |
+
|
73 |
+
def compare_dicts(old, new, parent_key=""):
|
74 |
+
for key in new:
|
75 |
+
full_key = f"{parent_key}.{key}" if parent_key else key
|
76 |
+
if isinstance(new[key], dict) and isinstance(old.get(key), dict):
|
77 |
+
compare_dicts(old.get(key, {}), new[key], full_key)
|
78 |
+
elif old.get(key) != new[key]:
|
79 |
+
changes[full_key] = (old.get(key), new[key])
|
80 |
+
# Include keys that are in old but not in new
|
81 |
+
for key in old:
|
82 |
+
if key not in new:
|
83 |
+
full_key = f"{parent_key}.{key}" if parent_key else key
|
84 |
+
changes[full_key] = (old[key], None)
|
85 |
+
|
86 |
+
compare_dicts(old_config, new_config)
|
87 |
+
return changes
|
88 |
+
|
89 |
+
def retrieval_qa_chain(
|
90 |
+
self, llm, qa_prompt, rephrase_prompt, db, memory=None, callbacks=None
|
91 |
+
):
|
92 |
+
"""
|
93 |
+
Create a Retrieval QA Chain.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
llm (LLM): The language model instance.
|
97 |
+
qa_prompt (str): The QA prompt string.
|
98 |
+
rephrase_prompt (str): The rephrase prompt string.
|
99 |
+
db (VectorStore): The vector store instance.
|
100 |
+
memory (Memory, optional): Memory instance. Defaults to None.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
Chain: The retrieval QA chain instance.
|
104 |
+
"""
|
105 |
retriever = Retriever(self.config)._return_retriever(db)
|
106 |
|
107 |
+
if self.config["llm_params"]["llm_arch"] == "langchain":
|
108 |
+
self.qa_chain = Langchain_RAG_V2(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
llm=llm,
|
|
|
|
|
|
|
110 |
memory=memory,
|
111 |
+
retriever=retriever,
|
112 |
+
qa_prompt=qa_prompt,
|
113 |
+
rephrase_prompt=rephrase_prompt,
|
114 |
+
config=self.config,
|
115 |
+
callbacks=callbacks,
|
116 |
)
|
117 |
+
|
118 |
+
self.question_generator = QuestionGenerator()
|
119 |
else:
|
120 |
+
raise ValueError(
|
121 |
+
f"Invalid LLM Architecture: {self.config['llm_params']['llm_arch']}"
|
|
|
|
|
|
|
|
|
122 |
)
|
123 |
+
return self.qa_chain
|
124 |
|
|
|
125 |
def load_llm(self):
|
126 |
+
"""
|
127 |
+
Load the language model.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
LLM: The loaded language model instance.
|
131 |
+
"""
|
132 |
chat_model_loader = ChatModelLoader(self.config)
|
133 |
llm = chat_model_loader.load_chat_model()
|
134 |
return llm
|
135 |
|
136 |
+
def qa_bot(self, memory=None, callbacks=None):
|
137 |
+
"""
|
138 |
+
Create a QA bot instance.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
memory (Memory, optional): Memory instance. Defaults to None.
|
142 |
+
qa_prompt (str, optional): QA prompt string. Defaults to None.
|
143 |
+
rephrase_prompt (str, optional): Rephrase prompt string. Defaults to None.
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
Chain: The QA bot chain instance.
|
147 |
+
"""
|
148 |
+
# sanity check to see if there are any documents in the database
|
149 |
+
if len(self.vector_db) == 0:
|
150 |
+
raise ValueError(
|
151 |
+
"No documents in the database. Populate the database first."
|
152 |
+
)
|
153 |
+
|
154 |
qa = self.retrieval_qa_chain(
|
155 |
+
self.llm,
|
156 |
+
self.qa_prompt,
|
157 |
+
self.rephrase_prompt,
|
158 |
+
self.vector_db,
|
159 |
+
memory,
|
160 |
+
callbacks=callbacks,
|
161 |
+
)
|
162 |
|
163 |
return qa
|
|
|
|
|
|
|
|
|
|
|
|
code/modules/chat_processor/base.py
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
# Template for chat processor classes
|
2 |
-
|
3 |
-
|
4 |
-
class ChatProcessorBase:
|
5 |
-
def __init__(self, config):
|
6 |
-
self.config = config
|
7 |
-
|
8 |
-
def process(self, message):
|
9 |
-
"""
|
10 |
-
Processes and Logs the message
|
11 |
-
"""
|
12 |
-
raise NotImplementedError("process method not implemented")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/modules/chat_processor/chat_processor.py
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
from modules.chat_processor.literal_ai import LiteralaiChatProcessor
|
2 |
-
|
3 |
-
|
4 |
-
class ChatProcessor:
|
5 |
-
def __init__(self, config, tags=None):
|
6 |
-
self.chat_processor_type = config["chat_logging"]["platform"]
|
7 |
-
self.logging = config["chat_logging"]["log_chat"]
|
8 |
-
self.tags = tags
|
9 |
-
if self.logging:
|
10 |
-
self._init_processor()
|
11 |
-
|
12 |
-
def _init_processor(self):
|
13 |
-
if self.chat_processor_type == "literalai":
|
14 |
-
self.processor = LiteralaiChatProcessor(self.tags)
|
15 |
-
else:
|
16 |
-
raise ValueError(
|
17 |
-
f"Chat processor type {self.chat_processor_type} not supported"
|
18 |
-
)
|
19 |
-
|
20 |
-
def _process(self, user_message, assistant_message, source_dict):
|
21 |
-
if self.logging:
|
22 |
-
return self.processor.process(user_message, assistant_message, source_dict)
|
23 |
-
else:
|
24 |
-
pass
|
25 |
-
|
26 |
-
async def rag(self, user_query: str, chain, cb):
|
27 |
-
if self.logging:
|
28 |
-
return await self.processor.rag(user_query, chain, cb)
|
29 |
-
else:
|
30 |
-
return await chain.acall(user_query, callbacks=[cb])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/modules/chat_processor/literal_ai.py
CHANGED
@@ -1,37 +1,44 @@
|
|
1 |
-
from
|
2 |
-
import os
|
3 |
-
from .base import ChatProcessorBase
|
4 |
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
with self.literal_client.thread(name="TEST") as thread:
|
11 |
-
self.thread_id = thread.id
|
12 |
-
self.thread = thread
|
13 |
-
if tags is not None and type(tags) == list:
|
14 |
-
self.thread.tags = tags
|
15 |
-
print(f"Thread ID: {self.thread}")
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
name="AI_Tutor",
|
28 |
-
)
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from chainlit.data import ChainlitDataLayer, queue_until_user_message
|
|
|
|
|
2 |
|
3 |
|
4 |
+
# update custom methods here (Ref: https://github.com/Chainlit/chainlit/blob/4b533cd53173bcc24abe4341a7108f0070d60099/backend/chainlit/data/__init__.py)
|
5 |
+
class CustomLiteralDataLayer(ChainlitDataLayer):
|
6 |
+
def __init__(self, **kwargs):
|
7 |
+
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
@queue_until_user_message()
|
10 |
+
async def create_step(self, step_dict: "StepDict"):
|
11 |
+
metadata = dict(
|
12 |
+
step_dict.get("metadata", {}),
|
13 |
+
**{
|
14 |
+
"waitForAnswer": step_dict.get("waitForAnswer"),
|
15 |
+
"language": step_dict.get("language"),
|
16 |
+
"showInput": step_dict.get("showInput"),
|
17 |
+
},
|
18 |
+
)
|
|
|
|
|
19 |
|
20 |
+
step: LiteralStepDict = {
|
21 |
+
"createdAt": step_dict.get("createdAt"),
|
22 |
+
"startTime": step_dict.get("start"),
|
23 |
+
"endTime": step_dict.get("end"),
|
24 |
+
"generation": step_dict.get("generation"),
|
25 |
+
"id": step_dict.get("id"),
|
26 |
+
"parentId": step_dict.get("parentId"),
|
27 |
+
"name": step_dict.get("name"),
|
28 |
+
"threadId": step_dict.get("threadId"),
|
29 |
+
"type": step_dict.get("type"),
|
30 |
+
"tags": step_dict.get("tags"),
|
31 |
+
"metadata": metadata,
|
32 |
+
}
|
33 |
+
if step_dict.get("input"):
|
34 |
+
step["input"] = {"content": step_dict.get("input")}
|
35 |
+
if step_dict.get("output"):
|
36 |
+
step["output"] = {"content": step_dict.get("output")}
|
37 |
+
if step_dict.get("isError"):
|
38 |
+
step["error"] = step_dict.get("output")
|
39 |
+
|
40 |
+
# print("\n\n\n")
|
41 |
+
# print("Step: ", step)
|
42 |
+
# print("\n\n\n")
|
43 |
+
|
44 |
+
await self.client.api.send_steps([step])
|
code/modules/config/config.yml
CHANGED
@@ -3,18 +3,19 @@ log_chunk_dir: '../storage/logs/chunks' # str
|
|
3 |
device: 'cpu' # str [cuda, cpu]
|
4 |
|
5 |
vectorstore:
|
|
|
6 |
embedd_files: False # bool
|
7 |
data_path: '../storage/data' # str
|
8 |
url_file_path: '../storage/data/urls.txt' # str
|
9 |
expand_urls: True # bool
|
10 |
-
db_option : '
|
11 |
db_path : '../vectorstores' # str
|
12 |
model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
|
13 |
search_top_k : 3 # int
|
14 |
score_threshold : 0.2 # float
|
15 |
|
16 |
faiss_params: # Not used as of now
|
17 |
-
index_path: '
|
18 |
index_type: 'Flat' # str [Flat, HNSW, IVF]
|
19 |
index_dimension: 384 # int
|
20 |
index_nlist: 100 # int
|
@@ -24,27 +25,36 @@ vectorstore:
|
|
24 |
index_name: "new_idx" # str
|
25 |
|
26 |
llm_params:
|
|
|
27 |
use_history: True # bool
|
|
|
28 |
memory_window: 3 # int
|
29 |
-
|
|
|
30 |
openai_params:
|
31 |
-
|
32 |
local_llm_params:
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
35 |
|
36 |
chat_logging:
|
37 |
-
log_chat:
|
38 |
platform: 'literalai'
|
|
|
39 |
|
40 |
splitter_options:
|
41 |
use_splitter: True # bool
|
42 |
split_by_token : True # bool
|
43 |
remove_leftover_delimiters: True # bool
|
44 |
remove_chunks: False # bool
|
|
|
45 |
chunk_size : 300 # int
|
46 |
chunk_overlap : 30 # int
|
47 |
chunk_separators : ["\n\n", "\n", " ", ""] # list of strings
|
48 |
front_chunks_to_remove : null # int or None
|
49 |
last_chunks_to_remove : null # int or None
|
50 |
-
delimiters_to_remove : ['\t', '\n', ' ', ' '] # list of strings
|
|
|
3 |
device: 'cpu' # str [cuda, cpu]
|
4 |
|
5 |
vectorstore:
|
6 |
+
load_from_HF: True # bool
|
7 |
embedd_files: False # bool
|
8 |
data_path: '../storage/data' # str
|
9 |
url_file_path: '../storage/data/urls.txt' # str
|
10 |
expand_urls: True # bool
|
11 |
+
db_option : 'RAGatouille' # str [FAISS, Chroma, RAGatouille, RAPTOR]
|
12 |
db_path : '../vectorstores' # str
|
13 |
model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
|
14 |
search_top_k : 3 # int
|
15 |
score_threshold : 0.2 # float
|
16 |
|
17 |
faiss_params: # Not used as of now
|
18 |
+
index_path: 'vectorstores/faiss.index' # str
|
19 |
index_type: 'Flat' # str [Flat, HNSW, IVF]
|
20 |
index_dimension: 384 # int
|
21 |
index_nlist: 100 # int
|
|
|
25 |
index_name: "new_idx" # str
|
26 |
|
27 |
llm_params:
|
28 |
+
llm_arch: 'langchain' # [langchain]
|
29 |
use_history: True # bool
|
30 |
+
generate_follow_up: False # bool
|
31 |
memory_window: 3 # int
|
32 |
+
llm_style: 'Normal' # str [Normal, ELI5]
|
33 |
+
llm_loader: 'gpt-4o-mini' # str [local_llm, gpt-3.5-turbo-1106, gpt-4, gpt-4o-mini]
|
34 |
openai_params:
|
35 |
+
temperature: 0.7 # float
|
36 |
local_llm_params:
|
37 |
+
temperature: 0.7 # float
|
38 |
+
repo_id: 'TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF' # HuggingFace repo id
|
39 |
+
filename: 'tinyllama-1.1b-chat-v1.0.Q5_0.gguf' # Specific name of gguf file in the repo
|
40 |
+
pdf_reader: 'pymupdf' # str [llama, pymupdf, gpt]
|
41 |
+
stream: False # bool
|
42 |
+
pdf_reader: 'gpt' # str [llama, pymupdf, gpt]
|
43 |
|
44 |
chat_logging:
|
45 |
+
log_chat: True # bool
|
46 |
platform: 'literalai'
|
47 |
+
callbacks: False # bool
|
48 |
|
49 |
splitter_options:
|
50 |
use_splitter: True # bool
|
51 |
split_by_token : True # bool
|
52 |
remove_leftover_delimiters: True # bool
|
53 |
remove_chunks: False # bool
|
54 |
+
chunking_mode: 'semantic' # str [fixed, semantic]
|
55 |
chunk_size : 300 # int
|
56 |
chunk_overlap : 30 # int
|
57 |
chunk_separators : ["\n\n", "\n", " ", ""] # list of strings
|
58 |
front_chunks_to_remove : null # int or None
|
59 |
last_chunks_to_remove : null # int or None
|
60 |
+
delimiters_to_remove : ['\t', '\n', ' ', ' '] # list of strings
|
code/modules/config/constants.py
CHANGED
@@ -6,77 +6,18 @@ load_dotenv()
|
|
6 |
# API Keys - Loaded from the .env file
|
7 |
|
8 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
|
9 |
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
|
10 |
-
|
|
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
# Prompt Templates
|
15 |
-
|
16 |
-
openai_prompt_template = """Use the following pieces of information to answer the user's question.
|
17 |
-
If you don't know the answer, just say that you don't know.
|
18 |
-
|
19 |
-
Context: {context}
|
20 |
-
Question: {question}
|
21 |
-
|
22 |
-
Only return the helpful answer below and nothing else.
|
23 |
-
Helpful answer:
|
24 |
-
"""
|
25 |
-
|
26 |
-
openai_prompt_template_with_history = """Use the following pieces of information to answer the user's question.
|
27 |
-
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
28 |
-
Use the history to answer the question if you can.
|
29 |
-
Chat History:
|
30 |
-
{chat_history}
|
31 |
-
Context: {context}
|
32 |
-
Question: {question}
|
33 |
-
|
34 |
-
Only return the helpful answer below and nothing else.
|
35 |
-
Helpful answer:
|
36 |
-
"""
|
37 |
-
|
38 |
-
tinyllama_prompt_template = """
|
39 |
-
<|im_start|>system
|
40 |
-
Assistant is an intelligent chatbot designed to help students with questions regarding the course. Only answer questions using the context below and if you're not sure of an answer, you can say "I don't know". Always give a breif and concise answer to the question. Use the history to answer the question if you can.
|
41 |
-
|
42 |
-
Context:
|
43 |
-
{context}
|
44 |
-
<|im_end|>
|
45 |
-
<|im_start|>user
|
46 |
-
Question: Who is the instructor for this course?
|
47 |
-
<|im_end|>
|
48 |
-
<|im_start|>assistant
|
49 |
-
The instructor for this course is Prof. Thomas Gardos.
|
50 |
-
<|im_end|>
|
51 |
-
<|im_start|>user
|
52 |
-
Question: {question}
|
53 |
-
<|im_end|>
|
54 |
-
<|im_start|>assistant
|
55 |
-
"""
|
56 |
-
|
57 |
-
tinyllama_prompt_template_with_history = """
|
58 |
-
<|im_start|>system
|
59 |
-
Assistant is an intelligent chatbot designed to help students with questions regarding the course. Only answer questions using the context below and if you're not sure of an answer, you can say "I don't know". Always give a breif and concise answer to the question.
|
60 |
-
|
61 |
-
Chat History:
|
62 |
-
{chat_history}
|
63 |
-
Context:
|
64 |
-
{context}
|
65 |
-
<|im_end|>
|
66 |
-
<|im_start|>user
|
67 |
-
Question: Who is the instructor for this course?
|
68 |
-
<|im_end|>
|
69 |
-
<|im_start|>assistant
|
70 |
-
The instructor for this course is Prof. Thomas Gardos.
|
71 |
-
<|im_end|>
|
72 |
-
<|im_start|>user
|
73 |
-
Question: {question}
|
74 |
-
<|im_end|>
|
75 |
-
<|im_start|>assistant
|
76 |
-
"""
|
77 |
|
|
|
78 |
|
79 |
# Model Paths
|
80 |
|
81 |
-
LLAMA_PATH = "../storage/models/tinyllama
|
82 |
-
|
|
|
|
6 |
# API Keys - Loaded from the .env file
|
7 |
|
8 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
9 |
+
LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY")
|
10 |
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
|
11 |
+
LITERAL_API_KEY_LOGGING = os.getenv("LITERAL_API_KEY_LOGGING")
|
12 |
+
LITERAL_API_URL = os.getenv("LITERAL_API_URL")
|
13 |
|
14 |
+
OAUTH_GOOGLE_CLIENT_ID = os.getenv("OAUTH_GOOGLE_CLIENT_ID")
|
15 |
+
OAUTH_GOOGLE_CLIENT_SECRET = os.getenv("OAUTH_GOOGLE_CLIENT_SECRET")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
opening_message = f"Hey, What Can I Help You With?\n\nYou can me ask me questions about the course logistics, course content, about the final project, or anything else!"
|
18 |
|
19 |
# Model Paths
|
20 |
|
21 |
+
LLAMA_PATH = "../storage/models/tinyllama"
|
22 |
+
|
23 |
+
RETRIEVER_HF_PATHS = {"RAGatouille": "XThomasBU/Colbert_Index"}
|
code/modules/config/prompts.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
prompts = {
|
2 |
+
"openai": {
|
3 |
+
"rephrase_prompt": (
|
4 |
+
"You are someone that rephrases statements. Rephrase the student's question to add context from their chat history if relevant, ensuring it remains from the student's point of view. "
|
5 |
+
"Incorporate relevant details from the chat history to make the question clearer and more specific. "
|
6 |
+
"Do not change the meaning of the original statement, and maintain the student's tone and perspective. "
|
7 |
+
"If the question is conversational and doesn't require context, do not rephrase it. "
|
8 |
+
"Example: If the student previously asked about backpropagation in the context of deep learning and now asks 'what is it', rephrase to 'What is backpropagation.'. "
|
9 |
+
"Example: Do not rephrase if the user is asking something specific like 'cool, suggest a project with transformers to use as my final project' "
|
10 |
+
"Chat history: \n{chat_history}\n"
|
11 |
+
"Rephrase the following question only if necessary: '{input}'"
|
12 |
+
"Rephrased Question:'"
|
13 |
+
),
|
14 |
+
"prompt_with_history": {
|
15 |
+
"normal": (
|
16 |
+
"You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance. "
|
17 |
+
"If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
|
18 |
+
"Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevant. "
|
19 |
+
"Render math equations in LaTeX format between $ or $$ signs, stick to the parameter and variable icons found in your context. Be sure to explain the parameters and variables in the equations."
|
20 |
+
"Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
|
21 |
+
"Do not get influenced by the style of conversation in the chat history. Follow the instructions given here."
|
22 |
+
"Chat History:\n{chat_history}\n\n"
|
23 |
+
"Context:\n{context}\n\n"
|
24 |
+
"Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
|
25 |
+
"Student: {input}\n"
|
26 |
+
"AI Tutor:"
|
27 |
+
),
|
28 |
+
"eli5": (
|
29 |
+
"You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Your job is to explain things in the simplest and most engaging way possible, just like the 'Explain Like I'm 5' (ELI5) concept."
|
30 |
+
"If you don't know the answer, do your best without making things up. Keep your explanations straightforward and very easy to understand."
|
31 |
+
"Use the chat history and context to help you, but avoid repeating past responses. Provide links from the source_file metadata when they're helpful."
|
32 |
+
"Use very simple language and examples to explain any math equations, and put the equations in LaTeX format between $ or $$ signs."
|
33 |
+
"Be friendly and engaging, like you're chatting with a young child who's curious and eager to learn. Avoid complex terms and jargon."
|
34 |
+
"Include simple and clear examples wherever you can to make things easier to understand."
|
35 |
+
"Do not get influenced by the style of conversation in the chat history. Follow the instructions given here."
|
36 |
+
"Chat History:\n{chat_history}\n\n"
|
37 |
+
"Context:\n{context}\n\n"
|
38 |
+
"Answer the student's question below in a friendly, simple, and engaging way, just like the ELI5 concept. Use the context and history only if they're relevant, otherwise, just have a natural conversation."
|
39 |
+
"Give a clear and detailed explanation with simple examples to make it easier to understand. Remember, your goal is to break down complex topics into very simple terms, just like ELI5."
|
40 |
+
"Student: {input}\n"
|
41 |
+
"AI Tutor:"
|
42 |
+
),
|
43 |
+
"socratic": (
|
44 |
+
"You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Engage the student in a Socratic dialogue to help them discover answers on their own. Use the provided context to guide your questioning."
|
45 |
+
"If you don't know the answer, do your best without making things up. Keep the conversation engaging and inquisitive."
|
46 |
+
"Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata when relevant. Use the source context that is most relevant."
|
47 |
+
"Speak in a friendly and engaging manner, encouraging critical thinking and self-discovery."
|
48 |
+
"Use questions to lead the student to explore the topic and uncover answers."
|
49 |
+
"Chat History:\n{chat_history}\n\n"
|
50 |
+
"Context:\n{context}\n\n"
|
51 |
+
"Answer the student's question below by guiding them through a series of questions and insights that lead to deeper understanding. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation."
|
52 |
+
"Foster an inquisitive mindset and help the student discover answers through dialogue."
|
53 |
+
"Student: {input}\n"
|
54 |
+
"AI Tutor:"
|
55 |
+
),
|
56 |
+
},
|
57 |
+
"prompt_no_history": (
|
58 |
+
"You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance. "
|
59 |
+
"If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
|
60 |
+
"Provide links from the source_file metadata. Use the source context that is most relevant. "
|
61 |
+
"Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
|
62 |
+
"Context:\n{context}\n\n"
|
63 |
+
"Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
|
64 |
+
"Student: {input}\n"
|
65 |
+
"AI Tutor:"
|
66 |
+
),
|
67 |
+
},
|
68 |
+
"tiny_llama": {
|
69 |
+
"prompt_no_history": (
|
70 |
+
"system\n"
|
71 |
+
"Assistant is an intelligent chatbot designed to help students with questions regarding the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance.\n"
|
72 |
+
"If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally.\n"
|
73 |
+
"Provide links from the source_file metadata. Use the source context that is most relevant.\n"
|
74 |
+
"Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n"
|
75 |
+
"\n\n"
|
76 |
+
"user\n"
|
77 |
+
"Context:\n{context}\n\n"
|
78 |
+
"Question: {input}\n"
|
79 |
+
"\n\n"
|
80 |
+
"assistant"
|
81 |
+
),
|
82 |
+
"prompt_with_history": (
|
83 |
+
"system\n"
|
84 |
+
"You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance. "
|
85 |
+
"If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
|
86 |
+
"Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevant. "
|
87 |
+
"Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n"
|
88 |
+
"\n\n"
|
89 |
+
"user\n"
|
90 |
+
"Chat History:\n{chat_history}\n\n"
|
91 |
+
"Context:\n{context}\n\n"
|
92 |
+
"Question: {input}\n"
|
93 |
+
"\n\n"
|
94 |
+
"assistant"
|
95 |
+
),
|
96 |
+
},
|
97 |
+
}
|
code/modules/dataloader/data_loader.py
CHANGED
@@ -14,32 +14,89 @@ from llama_parse import LlamaParse
|
|
14 |
from langchain.schema import Document
|
15 |
import logging
|
16 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
|
|
17 |
from ragatouille import RAGPretrainedModel
|
18 |
from langchain.chains import LLMChain
|
19 |
from langchain_community.llms import OpenAI
|
20 |
from langchain import PromptTemplate
|
21 |
import json
|
22 |
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
|
|
|
|
|
|
|
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
return loader
|
34 |
|
35 |
-
|
36 |
-
|
|
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
class FileReader:
|
40 |
-
def __init__(self, logger):
|
41 |
-
self.pdf_reader = PDFReader()
|
42 |
self.logger = logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
def extract_text_from_pdf(self, pdf_path):
|
45 |
text = ""
|
@@ -51,20 +108,8 @@ class FileReader:
|
|
51 |
text += page.extract_text()
|
52 |
return text
|
53 |
|
54 |
-
def download_pdf_from_url(self, pdf_url):
|
55 |
-
response = requests.get(pdf_url)
|
56 |
-
if response.status_code == 200:
|
57 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
|
58 |
-
temp_file.write(response.content)
|
59 |
-
temp_file_path = temp_file.name
|
60 |
-
return temp_file_path
|
61 |
-
else:
|
62 |
-
self.logger.error(f"Failed to download PDF from URL: {pdf_url}")
|
63 |
-
return None
|
64 |
-
|
65 |
def read_pdf(self, temp_file_path: str):
|
66 |
-
|
67 |
-
documents = self.pdf_reader.get_documents(loader)
|
68 |
return documents
|
69 |
|
70 |
def read_txt(self, temp_file_path: str):
|
@@ -89,8 +134,7 @@ class FileReader:
|
|
89 |
return loader.load()
|
90 |
|
91 |
def read_html(self, url: str):
|
92 |
-
|
93 |
-
return loader.load()
|
94 |
|
95 |
def read_tex_from_url(self, tex_url):
|
96 |
response = requests.get(tex_url)
|
@@ -110,21 +154,31 @@ class ChunkProcessor:
|
|
110 |
self.document_metadata = {}
|
111 |
self.document_chunks_full = []
|
112 |
|
|
|
|
|
|
|
113 |
if config["splitter_options"]["use_splitter"]:
|
114 |
-
if config["splitter_options"]["
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
else:
|
122 |
-
self.splitter =
|
123 |
-
|
124 |
-
|
125 |
-
separators=config["splitter_options"]["chunk_separators"],
|
126 |
-
disallowed_special=(),
|
127 |
)
|
|
|
128 |
else:
|
129 |
self.splitter = None
|
130 |
self.logger.info("ChunkProcessor instance created")
|
@@ -147,16 +201,12 @@ class ChunkProcessor:
|
|
147 |
def process_chunks(
|
148 |
self, documents, file_type="txt", source="", page=0, metadata={}
|
149 |
):
|
|
|
150 |
documents = [Document(page_content=documents, source=source, page=page)]
|
151 |
-
if
|
152 |
-
|
153 |
-
|
154 |
-
or file_type == "srt"
|
155 |
-
or file_type == "tex"
|
156 |
-
):
|
157 |
document_chunks = self.splitter.split_documents(documents)
|
158 |
-
elif file_type == "pdf":
|
159 |
-
document_chunks = documents # Full page for now
|
160 |
|
161 |
# add the source and page number back to the metadata
|
162 |
for chunk in document_chunks:
|
@@ -179,7 +229,6 @@ class ChunkProcessor:
|
|
179 |
"https://dl4ds.github.io/sp2024/lectures/",
|
180 |
"https://dl4ds.github.io/sp2024/schedule/",
|
181 |
) # For any additional metadata
|
182 |
-
|
183 |
with ThreadPoolExecutor() as executor:
|
184 |
executor.map(
|
185 |
self.process_file,
|
@@ -228,11 +277,11 @@ class ChunkProcessor:
|
|
228 |
|
229 |
page_num = doc.metadata.get("page", 0)
|
230 |
file_data[page_num] = doc.page_content
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
file_metadata[page_num] = metadata
|
237 |
|
238 |
if self.config["vectorstore"]["db_option"] not in ["RAGatouille"]:
|
@@ -250,11 +299,8 @@ class ChunkProcessor:
|
|
250 |
|
251 |
def process_file(self, file_path, file_index, file_reader, addl_metadata):
|
252 |
file_name = os.path.basename(file_path)
|
253 |
-
if file_name in self.document_data:
|
254 |
-
return
|
255 |
|
256 |
-
file_type = file_name.split(".")[-1]
|
257 |
-
self.logger.info(f"Reading file {file_index + 1}: {file_path}")
|
258 |
|
259 |
read_methods = {
|
260 |
"pdf": file_reader.read_pdf,
|
@@ -268,7 +314,13 @@ class ChunkProcessor:
|
|
268 |
return
|
269 |
|
270 |
try:
|
271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
self.process_documents(
|
273 |
documents, file_path, file_type, "file", addl_metadata
|
274 |
)
|
@@ -326,11 +378,14 @@ class ChunkProcessor:
|
|
326 |
f"{self.config['log_chunk_dir']}/metadata/doc_metadata.json", "r"
|
327 |
) as json_file:
|
328 |
self.document_metadata = json.load(json_file)
|
|
|
|
|
|
|
329 |
|
330 |
|
331 |
class DataLoader:
|
332 |
def __init__(self, config, logger=None):
|
333 |
-
self.file_reader = FileReader(logger=logger)
|
334 |
self.chunk_processor = ChunkProcessor(config, logger=logger)
|
335 |
|
336 |
def get_chunks(self, uploaded_files, weblinks):
|
@@ -348,13 +403,19 @@ if __name__ == "__main__":
|
|
348 |
with open("../code/modules/config/config.yml", "r") as f:
|
349 |
config = yaml.safe_load(f)
|
350 |
|
|
|
|
|
|
|
|
|
|
|
351 |
data_loader = DataLoader(config, logger=logger)
|
352 |
document_chunks, document_names, documents, document_metadata = (
|
353 |
data_loader.get_chunks(
|
|
|
354 |
[],
|
355 |
-
["https://dl4ds.github.io/sp2024/"],
|
356 |
)
|
357 |
)
|
358 |
|
359 |
-
print(document_names)
|
360 |
print(len(document_chunks))
|
|
|
|
14 |
from langchain.schema import Document
|
15 |
import logging
|
16 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
17 |
+
from langchain_experimental.text_splitter import SemanticChunker
|
18 |
+
from langchain_openai.embeddings import OpenAIEmbeddings
|
19 |
from ragatouille import RAGPretrainedModel
|
20 |
from langchain.chains import LLMChain
|
21 |
from langchain_community.llms import OpenAI
|
22 |
from langchain import PromptTemplate
|
23 |
import json
|
24 |
from concurrent.futures import ThreadPoolExecutor
|
25 |
+
from urllib.parse import urljoin
|
26 |
+
import html2text
|
27 |
+
import bs4
|
28 |
+
import tempfile
|
29 |
+
import PyPDF2
|
30 |
+
from modules.dataloader.pdf_readers.base import PDFReader
|
31 |
+
from modules.dataloader.pdf_readers.llama import LlamaParser
|
32 |
+
from modules.dataloader.pdf_readers.gpt import GPTParser
|
33 |
+
|
34 |
+
try:
|
35 |
+
from modules.dataloader.helpers import get_metadata, download_pdf_from_url
|
36 |
+
from modules.config.constants import OPENAI_API_KEY, LLAMA_CLOUD_API_KEY
|
37 |
+
except:
|
38 |
+
from dataloader.helpers import get_metadata, download_pdf_from_url
|
39 |
+
from config.constants import OPENAI_API_KEY, LLAMA_CLOUD_API_KEY
|
40 |
+
|
41 |
+
logger = logging.getLogger(__name__)
|
42 |
+
BASE_DIR = os.getcwd()
|
43 |
+
|
44 |
+
|
45 |
+
class HTMLReader:
|
46 |
+
def __init__(self):
|
47 |
+
pass
|
48 |
|
49 |
+
def read_url(self, url):
|
50 |
+
response = requests.get(url)
|
51 |
+
if response.status_code == 200:
|
52 |
+
return response.text
|
53 |
+
else:
|
54 |
+
logger.warning(f"Failed to download HTML from URL: {url}")
|
55 |
+
return None
|
56 |
|
57 |
+
def check_links(self, base_url, html_content):
|
58 |
+
soup = bs4.BeautifulSoup(html_content, "html.parser")
|
59 |
+
for link in soup.find_all("a"):
|
60 |
+
href = link.get("href")
|
61 |
|
62 |
+
if not href or href.startswith("#"):
|
63 |
+
continue
|
64 |
+
elif not href.startswith("https"):
|
65 |
+
href = href.replace("http", "https")
|
66 |
|
67 |
+
absolute_url = urljoin(base_url, href)
|
68 |
+
link['href'] = absolute_url
|
|
|
69 |
|
70 |
+
resp = requests.head(absolute_url)
|
71 |
+
if resp.status_code != 200:
|
72 |
+
logger.warning(f"Link {absolute_url} is broken. Status code: {resp.status_code}")
|
73 |
|
74 |
+
return str(soup)
|
75 |
+
|
76 |
+
def html_to_md(self, url, html_content):
|
77 |
+
html_processed = self.check_links(url, html_content)
|
78 |
+
markdown_content = html2text.html2text(html_processed)
|
79 |
+
return markdown_content
|
80 |
+
|
81 |
+
def read_html(self, url):
|
82 |
+
html_content = self.read_url(url)
|
83 |
+
if html_content:
|
84 |
+
return self.html_to_md(url, html_content)
|
85 |
+
else:
|
86 |
+
return None
|
87 |
|
88 |
class FileReader:
|
89 |
+
def __init__(self, logger, kind):
|
|
|
90 |
self.logger = logger
|
91 |
+
self.kind = kind
|
92 |
+
if kind == "llama":
|
93 |
+
self.pdf_reader = LlamaParser()
|
94 |
+
elif kind == "gpt":
|
95 |
+
self.pdf_reader = GPTParser()
|
96 |
+
else:
|
97 |
+
self.pdf_reader = PDFReader()
|
98 |
+
self.web_reader = HTMLReader()
|
99 |
+
self.logger.info(f"Initialized FileReader with {kind} PDF reader and HTML reader")
|
100 |
|
101 |
def extract_text_from_pdf(self, pdf_path):
|
102 |
text = ""
|
|
|
108 |
text += page.extract_text()
|
109 |
return text
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
def read_pdf(self, temp_file_path: str):
|
112 |
+
documents = self.pdf_reader.parse(temp_file_path)
|
|
|
113 |
return documents
|
114 |
|
115 |
def read_txt(self, temp_file_path: str):
|
|
|
134 |
return loader.load()
|
135 |
|
136 |
def read_html(self, url: str):
|
137 |
+
return [Document(page_content=self.web_reader.read_html(url))]
|
|
|
138 |
|
139 |
def read_tex_from_url(self, tex_url):
|
140 |
response = requests.get(tex_url)
|
|
|
154 |
self.document_metadata = {}
|
155 |
self.document_chunks_full = []
|
156 |
|
157 |
+
if not config['vectorstore']['embedd_files']:
|
158 |
+
self.load_document_data()
|
159 |
+
|
160 |
if config["splitter_options"]["use_splitter"]:
|
161 |
+
if config["splitter_options"]["chunking_mode"] == "fixed":
|
162 |
+
if config["splitter_options"]["split_by_token"]:
|
163 |
+
self.splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
164 |
+
chunk_size=config["splitter_options"]["chunk_size"],
|
165 |
+
chunk_overlap=config["splitter_options"]["chunk_overlap"],
|
166 |
+
separators=config["splitter_options"]["chunk_separators"],
|
167 |
+
disallowed_special=(),
|
168 |
+
)
|
169 |
+
else:
|
170 |
+
self.splitter = RecursiveCharacterTextSplitter(
|
171 |
+
chunk_size=config["splitter_options"]["chunk_size"],
|
172 |
+
chunk_overlap=config["splitter_options"]["chunk_overlap"],
|
173 |
+
separators=config["splitter_options"]["chunk_separators"],
|
174 |
+
disallowed_special=(),
|
175 |
+
)
|
176 |
else:
|
177 |
+
self.splitter = SemanticChunker(
|
178 |
+
OpenAIEmbeddings(),
|
179 |
+
breakpoint_threshold_type="percentile"
|
|
|
|
|
180 |
)
|
181 |
+
|
182 |
else:
|
183 |
self.splitter = None
|
184 |
self.logger.info("ChunkProcessor instance created")
|
|
|
201 |
def process_chunks(
|
202 |
self, documents, file_type="txt", source="", page=0, metadata={}
|
203 |
):
|
204 |
+
# TODO: Clear up this pipeline of re-adding metadata
|
205 |
documents = [Document(page_content=documents, source=source, page=page)]
|
206 |
+
if file_type == "pdf" and self.config["splitter_options"]["chunking_mode"] == "fixed":
|
207 |
+
document_chunks = documents
|
208 |
+
else:
|
|
|
|
|
|
|
209 |
document_chunks = self.splitter.split_documents(documents)
|
|
|
|
|
210 |
|
211 |
# add the source and page number back to the metadata
|
212 |
for chunk in document_chunks:
|
|
|
229 |
"https://dl4ds.github.io/sp2024/lectures/",
|
230 |
"https://dl4ds.github.io/sp2024/schedule/",
|
231 |
) # For any additional metadata
|
|
|
232 |
with ThreadPoolExecutor() as executor:
|
233 |
executor.map(
|
234 |
self.process_file,
|
|
|
277 |
|
278 |
page_num = doc.metadata.get("page", 0)
|
279 |
file_data[page_num] = doc.page_content
|
280 |
+
|
281 |
+
# Create a new dictionary for metadata in each iteration
|
282 |
+
metadata = addl_metadata.get(file_path, {}).copy()
|
283 |
+
metadata["page"] = page_num
|
284 |
+
metadata["source"] = file_path
|
285 |
file_metadata[page_num] = metadata
|
286 |
|
287 |
if self.config["vectorstore"]["db_option"] not in ["RAGatouille"]:
|
|
|
299 |
|
300 |
def process_file(self, file_path, file_index, file_reader, addl_metadata):
|
301 |
file_name = os.path.basename(file_path)
|
|
|
|
|
302 |
|
303 |
+
file_type = file_name.split(".")[-1]
|
|
|
304 |
|
305 |
read_methods = {
|
306 |
"pdf": file_reader.read_pdf,
|
|
|
314 |
return
|
315 |
|
316 |
try:
|
317 |
+
|
318 |
+
if file_path in self.document_data:
|
319 |
+
self.logger.warning(f"File {file_name} already processed")
|
320 |
+
documents = [Document(page_content=content) for content in self.document_data[file_path].values()]
|
321 |
+
else:
|
322 |
+
documents = read_methods[file_type](file_path)
|
323 |
+
|
324 |
self.process_documents(
|
325 |
documents, file_path, file_type, "file", addl_metadata
|
326 |
)
|
|
|
378 |
f"{self.config['log_chunk_dir']}/metadata/doc_metadata.json", "r"
|
379 |
) as json_file:
|
380 |
self.document_metadata = json.load(json_file)
|
381 |
+
self.logger.info(
|
382 |
+
f"Loaded document content from {self.config['log_chunk_dir']}/docs/doc_content.json. Total documents: {len(self.document_data)}"
|
383 |
+
)
|
384 |
|
385 |
|
386 |
class DataLoader:
|
387 |
def __init__(self, config, logger=None):
|
388 |
+
self.file_reader = FileReader(logger=logger, kind=config["llm_params"]["pdf_reader"])
|
389 |
self.chunk_processor = ChunkProcessor(config, logger=logger)
|
390 |
|
391 |
def get_chunks(self, uploaded_files, weblinks):
|
|
|
403 |
with open("../code/modules/config/config.yml", "r") as f:
|
404 |
config = yaml.safe_load(f)
|
405 |
|
406 |
+
STORAGE_DIR = os.path.join(BASE_DIR, config['vectorstore']["data_path"])
|
407 |
+
uploaded_files = [
|
408 |
+
os.path.join(STORAGE_DIR, file) for file in os.listdir(STORAGE_DIR) if file != "urls.txt"
|
409 |
+
]
|
410 |
+
|
411 |
data_loader = DataLoader(config, logger=logger)
|
412 |
document_chunks, document_names, documents, document_metadata = (
|
413 |
data_loader.get_chunks(
|
414 |
+
["https://dl4ds.github.io/sp2024/static_files/lectures/05_loss_functions_v2.pdf"],
|
415 |
[],
|
|
|
416 |
)
|
417 |
)
|
418 |
|
419 |
+
print(document_names[:5])
|
420 |
print(len(document_chunks))
|
421 |
+
|
code/modules/dataloader/helpers.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import requests
|
2 |
from bs4 import BeautifulSoup
|
3 |
-
from
|
4 |
-
|
5 |
|
6 |
def get_urls_from_file(file_path: str):
|
7 |
"""
|
@@ -106,3 +106,23 @@ def get_metadata(lectures_url, schedule_url):
|
|
106 |
continue
|
107 |
|
108 |
return lecture_metadata
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import requests
|
2 |
from bs4 import BeautifulSoup
|
3 |
+
from urllib.parse import urlparse
|
4 |
+
import tempfile
|
5 |
|
6 |
def get_urls_from_file(file_path: str):
|
7 |
"""
|
|
|
106 |
continue
|
107 |
|
108 |
return lecture_metadata
|
109 |
+
|
110 |
+
|
111 |
+
def download_pdf_from_url(pdf_url):
|
112 |
+
"""
|
113 |
+
Function to temporarily download a PDF file from a URL and return the local file path.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
pdf_url (str): The URL of the PDF file to download.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
str: The local file path of the downloaded PDF file.
|
120 |
+
"""
|
121 |
+
response = requests.get(pdf_url)
|
122 |
+
if response.status_code == 200:
|
123 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
|
124 |
+
temp_file.write(response.content)
|
125 |
+
temp_file_path = temp_file.name
|
126 |
+
return temp_file_path
|
127 |
+
else:
|
128 |
+
return None
|
code/modules/dataloader/pdf_readers/base.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.document_loaders import PyMuPDFLoader
|
2 |
+
|
3 |
+
|
4 |
+
class PDFReader:
|
5 |
+
def __init__(self):
|
6 |
+
pass
|
7 |
+
|
8 |
+
def get_loader(self, pdf_path):
|
9 |
+
loader = PyMuPDFLoader(pdf_path)
|
10 |
+
return loader
|
11 |
+
|
12 |
+
def parse(self, pdf_path):
|
13 |
+
loader = self.get_loader(pdf_path)
|
14 |
+
return loader.load()
|
code/modules/dataloader/pdf_readers/gpt.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import os
|
3 |
+
import requests
|
4 |
+
|
5 |
+
from io import BytesIO
|
6 |
+
from openai import OpenAI
|
7 |
+
from pdf2image import convert_from_path
|
8 |
+
from langchain.schema import Document
|
9 |
+
|
10 |
+
|
11 |
+
class GPTParser:
|
12 |
+
"""
|
13 |
+
This class uses OpenAI's GPT-4o mini model to parse PDFs and extract text, images and equations.
|
14 |
+
It is the most advanced parser in the system and is able to handle complex formats and layouts
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self):
|
18 |
+
self.client = OpenAI()
|
19 |
+
self.api_key = os.getenv("OPENAI_API_KEY")
|
20 |
+
self.prompt = """
|
21 |
+
The provided documents are images of PDFs of lecture slides of deep learning material.
|
22 |
+
They contain LaTeX equations, images, and text.
|
23 |
+
The goal is to extract the text, images and equations from the slides and convert everything to markdown format. Some of the equations may be complicated.
|
24 |
+
The markdown should be clean and easy to read, and any math equation should be converted to LaTeX, between $$.
|
25 |
+
For images, give a description and if you can, a source. Separate each page with '---'.
|
26 |
+
Just respond with the markdown. Do not include page numbers or any other metadata. Do not try to provide titles. Strictly the content.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def parse(self, pdf_path):
|
30 |
+
images = convert_from_path(pdf_path)
|
31 |
+
|
32 |
+
encoded_images = [self.encode_image(image) for image in images]
|
33 |
+
|
34 |
+
chunks = [encoded_images[i:i + 5] for i in range(0, len(encoded_images), 5)]
|
35 |
+
|
36 |
+
headers = {
|
37 |
+
"Content-Type": "application/json",
|
38 |
+
"Authorization": f"Bearer {self.api_key}"
|
39 |
+
}
|
40 |
+
|
41 |
+
output = ""
|
42 |
+
for chunk_num, chunk in enumerate(chunks):
|
43 |
+
content = [{"type": "image_url", "image_url": {
|
44 |
+
"url": f"data:image/jpeg;base64,{image}"}} for image in chunk]
|
45 |
+
|
46 |
+
content.insert(0, {"type": "text", "text": self.prompt})
|
47 |
+
|
48 |
+
payload = {
|
49 |
+
"model": "gpt-4o-mini",
|
50 |
+
"messages": [
|
51 |
+
{
|
52 |
+
"role": "user",
|
53 |
+
"content": content
|
54 |
+
}
|
55 |
+
],
|
56 |
+
}
|
57 |
+
|
58 |
+
response = requests.post(
|
59 |
+
"https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
|
60 |
+
|
61 |
+
resp = response.json()
|
62 |
+
|
63 |
+
chunk_output = resp['choices'][0]['message']['content'].replace("```", "").replace("markdown", "").replace("````", "")
|
64 |
+
|
65 |
+
output += chunk_output + "\n---\n"
|
66 |
+
|
67 |
+
output = output.split("\n---\n")
|
68 |
+
output = [doc for doc in output if doc.strip() != ""]
|
69 |
+
|
70 |
+
documents = [
|
71 |
+
Document(
|
72 |
+
page_content=page,
|
73 |
+
metadata={"source": pdf_path, "page": i}
|
74 |
+
) for i, page in enumerate(output)
|
75 |
+
]
|
76 |
+
return documents
|
77 |
+
|
78 |
+
def encode_image(self, image):
|
79 |
+
buffered = BytesIO()
|
80 |
+
image.save(buffered, format="JPEG")
|
81 |
+
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
code/modules/dataloader/pdf_readers/llama.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
from llama_parse import LlamaParse
|
4 |
+
from langchain.schema import Document
|
5 |
+
from modules.config.constants import OPENAI_API_KEY, LLAMA_CLOUD_API_KEY
|
6 |
+
from modules.dataloader.helpers import download_pdf_from_url
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
class LlamaParser:
|
11 |
+
def __init__(self):
|
12 |
+
self.GPT_API_KEY = OPENAI_API_KEY
|
13 |
+
self.LLAMA_CLOUD_API_KEY = LLAMA_CLOUD_API_KEY
|
14 |
+
self.parse_url = "https://api.cloud.llamaindex.ai/api/parsing/upload"
|
15 |
+
self.headers = {
|
16 |
+
'Accept': 'application/json',
|
17 |
+
'Authorization': f'Bearer {LLAMA_CLOUD_API_KEY}'
|
18 |
+
}
|
19 |
+
self.parser = LlamaParse(
|
20 |
+
api_key=LLAMA_CLOUD_API_KEY,
|
21 |
+
result_type="markdown",
|
22 |
+
verbose=True,
|
23 |
+
language="en",
|
24 |
+
gpt4o_mode=False,
|
25 |
+
# gpt4o_api_key=OPENAI_API_KEY,
|
26 |
+
parsing_instruction="The provided documents are PDFs of lecture slides of deep learning material. They contain LaTeX equations, images, and text. The goal is to extract the text, images and equations from the slides. The markdown should be clean and easy to read, and any math equation should be converted to LaTeX format, between $ signs. For images, if you can, give a description and a source."
|
27 |
+
)
|
28 |
+
|
29 |
+
def parse(self, pdf_path):
|
30 |
+
if not os.path.exists(pdf_path):
|
31 |
+
pdf_path = download_pdf_from_url(pdf_path)
|
32 |
+
|
33 |
+
documents = self.parser.load_data(pdf_path)
|
34 |
+
document = [document.to_langchain_format() for document in documents][0]
|
35 |
+
|
36 |
+
content = document.page_content
|
37 |
+
pages = content.split("\n---\n")
|
38 |
+
pages = [page.strip() for page in pages]
|
39 |
+
|
40 |
+
documents = [
|
41 |
+
Document(
|
42 |
+
page_content=page,
|
43 |
+
metadata={"source": pdf_path, "page": i}
|
44 |
+
) for i, page in enumerate(pages)
|
45 |
+
]
|
46 |
+
|
47 |
+
return documents
|
48 |
+
|
49 |
+
def make_request(self, pdf_url):
|
50 |
+
payload = {
|
51 |
+
"gpt4o_mode": "false",
|
52 |
+
"parsing_instruction": "The provided document is a PDF of lecture slides of deep learning material. They contain LaTeX equations, images, and text. The goal is to extract the text, images and equations from the slides and convert them to markdown format. The markdown should be clean and easy to read, and any math equation should be converted to LaTeX, between $$. For images, give a description and if you can, a source.",
|
53 |
+
}
|
54 |
+
|
55 |
+
files = [
|
56 |
+
('file', ('file', requests.get(pdf_url).content, 'application/octet-stream'))
|
57 |
+
]
|
58 |
+
|
59 |
+
response = requests.request(
|
60 |
+
"POST", self.parse_url, headers=self.headers, data=payload, files=files)
|
61 |
+
|
62 |
+
return response.json()['id'], response.json()['status']
|
63 |
+
|
64 |
+
async def get_result(self, job_id):
|
65 |
+
url = f"https://api.cloud.llamaindex.ai/api/parsing/job/{job_id}/result/markdown"
|
66 |
+
|
67 |
+
response = requests.request("GET", url, headers=self.headers, data={})
|
68 |
+
|
69 |
+
return response.json()['markdown']
|
70 |
+
|
71 |
+
async def _parse(self, pdf_path):
|
72 |
+
job_id, status = self.make_request(pdf_path)
|
73 |
+
|
74 |
+
while status != "SUCCESS":
|
75 |
+
url = f"https://api.cloud.llamaindex.ai/api/parsing/job/{job_id}"
|
76 |
+
response = requests.request("GET", url, headers=self.headers, data={})
|
77 |
+
status = response.json()["status"]
|
78 |
+
|
79 |
+
result = await self.get_result(job_id)
|
80 |
+
|
81 |
+
documents = [
|
82 |
+
Document(
|
83 |
+
page_content=result,
|
84 |
+
metadata={"source": pdf_path}
|
85 |
+
)
|
86 |
+
]
|
87 |
+
|
88 |
+
return documents
|
89 |
+
|
90 |
+
async def _parse(self, pdf_path):
|
91 |
+
return await self._parse(pdf_path)
|
92 |
+
|
code/modules/dataloader/webpage_crawler.py
CHANGED
@@ -66,7 +66,6 @@ class WebpageCrawler:
|
|
66 |
)
|
67 |
for link in unchecked_links:
|
68 |
dict_links[link] = "Checked"
|
69 |
-
print(f"Checked: {link}")
|
70 |
dict_links.update(
|
71 |
{
|
72 |
link: "Not-checked"
|
|
|
66 |
)
|
67 |
for link in unchecked_links:
|
68 |
dict_links[link] = "Checked"
|
|
|
69 |
dict_links.update(
|
70 |
{
|
71 |
link: "Not-checked"
|
code/modules/vectorstore/base.py
CHANGED
@@ -29,5 +29,8 @@ class VectorStoreBase:
|
|
29 |
"""
|
30 |
raise NotImplementedError
|
31 |
|
|
|
|
|
|
|
32 |
def __str__(self):
|
33 |
return self.__class__.__name__
|
|
|
29 |
"""
|
30 |
raise NotImplementedError
|
31 |
|
32 |
+
def __len__(self):
|
33 |
+
raise NotImplementedError
|
34 |
+
|
35 |
def __str__(self):
|
36 |
return self.__class__.__name__
|
code/modules/vectorstore/chroma.py
CHANGED
@@ -39,3 +39,6 @@ class ChromaVectorStore(VectorStoreBase):
|
|
39 |
|
40 |
def as_retriever(self):
|
41 |
return self.vectorstore.as_retriever()
|
|
|
|
|
|
|
|
39 |
|
40 |
def as_retriever(self):
|
41 |
return self.vectorstore.as_retriever()
|
42 |
+
|
43 |
+
def __len__(self):
|
44 |
+
return len(self.vectorstore)
|
code/modules/vectorstore/colbert.py
CHANGED
@@ -1,6 +1,67 @@
|
|
1 |
from ragatouille import RAGPretrainedModel
|
2 |
from modules.vectorstore.base import VectorStoreBase
|
|
|
|
|
|
|
|
|
3 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
|
6 |
class ColbertVectorStore(VectorStoreBase):
|
@@ -24,16 +85,28 @@ class ColbertVectorStore(VectorStoreBase):
|
|
24 |
document_ids=document_names,
|
25 |
document_metadatas=document_metadata,
|
26 |
)
|
|
|
27 |
|
28 |
def load_database(self):
|
29 |
path = os.path.join(
|
|
|
30 |
self.config["vectorstore"]["db_path"],
|
31 |
"db_" + self.config["vectorstore"]["db_option"],
|
32 |
)
|
33 |
self.vectorstore = RAGPretrainedModel.from_index(
|
34 |
f"{path}/colbert/indexes/new_idx"
|
35 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
return self.vectorstore
|
37 |
|
38 |
def as_retriever(self):
|
39 |
return self.vectorstore.as_retriever()
|
|
|
|
|
|
|
|
1 |
from ragatouille import RAGPretrainedModel
|
2 |
from modules.vectorstore.base import VectorStoreBase
|
3 |
+
from langchain_core.retrievers import BaseRetriever
|
4 |
+
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun, Callbacks
|
5 |
+
from langchain_core.documents import Document
|
6 |
+
from typing import Any, List, Optional, Sequence
|
7 |
import os
|
8 |
+
import json
|
9 |
+
|
10 |
+
|
11 |
+
class RAGatouilleLangChainRetrieverWithScore(BaseRetriever):
|
12 |
+
model: Any
|
13 |
+
kwargs: dict = {}
|
14 |
+
|
15 |
+
def _get_relevant_documents(
|
16 |
+
self,
|
17 |
+
query: str,
|
18 |
+
*,
|
19 |
+
run_manager: CallbackManagerForRetrieverRun, # noqa
|
20 |
+
) -> List[Document]:
|
21 |
+
"""Get documents relevant to a query."""
|
22 |
+
docs = self.model.search(query, **self.kwargs)
|
23 |
+
return [
|
24 |
+
Document(
|
25 |
+
page_content=doc["content"],
|
26 |
+
metadata={**doc.get("document_metadata", {}), "score": doc["score"]},
|
27 |
+
)
|
28 |
+
for doc in docs
|
29 |
+
]
|
30 |
+
|
31 |
+
async def _aget_relevant_documents(
|
32 |
+
self,
|
33 |
+
query: str,
|
34 |
+
*,
|
35 |
+
run_manager: CallbackManagerForRetrieverRun, # noqa
|
36 |
+
) -> List[Document]:
|
37 |
+
"""Get documents relevant to a query."""
|
38 |
+
docs = self.model.search(query, **self.kwargs)
|
39 |
+
return [
|
40 |
+
Document(
|
41 |
+
page_content=doc["content"],
|
42 |
+
metadata={**doc.get("document_metadata", {}), "score": doc["score"]},
|
43 |
+
)
|
44 |
+
for doc in docs
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
+
class RAGPretrainedModel(RAGPretrainedModel):
|
49 |
+
"""
|
50 |
+
Adding len property to RAGPretrainedModel
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(self, *args, **kwargs):
|
54 |
+
super().__init__(*args, **kwargs)
|
55 |
+
self._document_count = 0
|
56 |
+
|
57 |
+
def set_document_count(self, count):
|
58 |
+
self._document_count = count
|
59 |
+
|
60 |
+
def __len__(self):
|
61 |
+
return self._document_count
|
62 |
+
|
63 |
+
def as_langchain_retriever(self, **kwargs: Any) -> BaseRetriever:
|
64 |
+
return RAGatouilleLangChainRetrieverWithScore(model=self, kwargs=kwargs)
|
65 |
|
66 |
|
67 |
class ColbertVectorStore(VectorStoreBase):
|
|
|
85 |
document_ids=document_names,
|
86 |
document_metadatas=document_metadata,
|
87 |
)
|
88 |
+
self.colbert.set_document_count(len(document_names))
|
89 |
|
90 |
def load_database(self):
|
91 |
path = os.path.join(
|
92 |
+
os.getcwd(),
|
93 |
self.config["vectorstore"]["db_path"],
|
94 |
"db_" + self.config["vectorstore"]["db_option"],
|
95 |
)
|
96 |
self.vectorstore = RAGPretrainedModel.from_index(
|
97 |
f"{path}/colbert/indexes/new_idx"
|
98 |
)
|
99 |
+
|
100 |
+
index_metadata = json.load(
|
101 |
+
open(f"{path}/colbert/indexes/new_idx/0.metadata.json")
|
102 |
+
)
|
103 |
+
num_documents = index_metadata["num_passages"]
|
104 |
+
self.vectorstore.set_document_count(num_documents)
|
105 |
+
|
106 |
return self.vectorstore
|
107 |
|
108 |
def as_retriever(self):
|
109 |
return self.vectorstore.as_retriever()
|
110 |
+
|
111 |
+
def __len__(self):
|
112 |
+
return len(self.vectorstore)
|
code/modules/vectorstore/faiss.py
CHANGED
@@ -3,10 +3,21 @@ from modules.vectorstore.base import VectorStoreBase
|
|
3 |
import os
|
4 |
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
class FaissVectorStore(VectorStoreBase):
|
7 |
def __init__(self, config):
|
8 |
self.config = config
|
9 |
self._init_vector_db()
|
|
|
|
|
|
|
|
|
10 |
|
11 |
def _init_vector_db(self):
|
12 |
self.faiss = FAISS(
|
@@ -18,24 +29,12 @@ class FaissVectorStore(VectorStoreBase):
|
|
18 |
documents=document_chunks, embedding=embedding_model
|
19 |
)
|
20 |
self.vectorstore.save_local(
|
21 |
-
|
22 |
-
self.config["vectorstore"]["db_path"],
|
23 |
-
"db_"
|
24 |
-
+ self.config["vectorstore"]["db_option"]
|
25 |
-
+ "_"
|
26 |
-
+ self.config["vectorstore"]["model"],
|
27 |
-
)
|
28 |
)
|
29 |
|
30 |
def load_database(self, embedding_model):
|
31 |
self.vectorstore = self.faiss.load_local(
|
32 |
-
|
33 |
-
self.config["vectorstore"]["db_path"],
|
34 |
-
"db_"
|
35 |
-
+ self.config["vectorstore"]["db_option"]
|
36 |
-
+ "_"
|
37 |
-
+ self.config["vectorstore"]["model"],
|
38 |
-
),
|
39 |
embedding_model,
|
40 |
allow_dangerous_deserialization=True,
|
41 |
)
|
@@ -43,3 +42,6 @@ class FaissVectorStore(VectorStoreBase):
|
|
43 |
|
44 |
def as_retriever(self):
|
45 |
return self.vectorstore.as_retriever()
|
|
|
|
|
|
|
|
3 |
import os
|
4 |
|
5 |
|
6 |
+
class FAISS(FAISS):
|
7 |
+
"""To add length property to FAISS class"""
|
8 |
+
|
9 |
+
def __len__(self):
|
10 |
+
return self.index.ntotal
|
11 |
+
|
12 |
+
|
13 |
class FaissVectorStore(VectorStoreBase):
|
14 |
def __init__(self, config):
|
15 |
self.config = config
|
16 |
self._init_vector_db()
|
17 |
+
self.local_path = os.path.join(self.config["vectorstore"]["db_path"],
|
18 |
+
"db_" + self.config["vectorstore"]["db_option"]
|
19 |
+
+ "_" + self.config["vectorstore"]["model"]
|
20 |
+
+ "_" + config["splitter_options"]["chunking_mode"])
|
21 |
|
22 |
def _init_vector_db(self):
|
23 |
self.faiss = FAISS(
|
|
|
29 |
documents=document_chunks, embedding=embedding_model
|
30 |
)
|
31 |
self.vectorstore.save_local(
|
32 |
+
self.local_path
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
)
|
34 |
|
35 |
def load_database(self, embedding_model):
|
36 |
self.vectorstore = self.faiss.load_local(
|
37 |
+
self.local_path,
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
embedding_model,
|
39 |
allow_dangerous_deserialization=True,
|
40 |
)
|
|
|
42 |
|
43 |
def as_retriever(self):
|
44 |
return self.vectorstore.as_retriever()
|
45 |
+
|
46 |
+
def __len__(self):
|
47 |
+
return len(self.vectorstore)
|
code/modules/vectorstore/raptor.py
CHANGED
@@ -5,7 +5,7 @@ import os
|
|
5 |
import numpy as np
|
6 |
import pandas as pd
|
7 |
import umap
|
8 |
-
from langchain_core.prompts import ChatPromptTemplate
|
9 |
from langchain_core.output_parsers import StrOutputParser
|
10 |
from sklearn.mixture import GaussianMixture
|
11 |
from langchain_community.chat_models import ChatOpenAI
|
@@ -16,6 +16,13 @@ from modules.vectorstore.base import VectorStoreBase
|
|
16 |
RANDOM_SEED = 42
|
17 |
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
class RAPTORVectoreStore(VectorStoreBase):
|
20 |
def __init__(self, config, documents=[], text_splitter=None, embedding_model=None):
|
21 |
self.documents = documents
|
|
|
5 |
import numpy as np
|
6 |
import pandas as pd
|
7 |
import umap
|
8 |
+
from langchain_core.prompts.chat import ChatPromptTemplate
|
9 |
from langchain_core.output_parsers import StrOutputParser
|
10 |
from sklearn.mixture import GaussianMixture
|
11 |
from langchain_community.chat_models import ChatOpenAI
|
|
|
16 |
RANDOM_SEED = 42
|
17 |
|
18 |
|
19 |
+
class FAISS(FAISS):
|
20 |
+
"""To add length property to FAISS class"""
|
21 |
+
|
22 |
+
def __len__(self):
|
23 |
+
return self.index.ntotal
|
24 |
+
|
25 |
+
|
26 |
class RAPTORVectoreStore(VectorStoreBase):
|
27 |
def __init__(self, config, documents=[], text_splitter=None, embedding_model=None):
|
28 |
self.documents = documents
|
code/modules/vectorstore/store_manager.py
CHANGED
@@ -3,6 +3,7 @@ from modules.vectorstore.helpers import *
|
|
3 |
from modules.dataloader.webpage_crawler import WebpageCrawler
|
4 |
from modules.dataloader.data_loader import DataLoader
|
5 |
from modules.dataloader.helpers import *
|
|
|
6 |
from modules.vectorstore.embedding_model_loader import EmbeddingModelLoader
|
7 |
import logging
|
8 |
import os
|
@@ -135,14 +136,34 @@ class VectorStoreManager:
|
|
135 |
self.embedding_model = self.create_embedding_model()
|
136 |
else:
|
137 |
self.embedding_model = None
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
end_time = time.time() # End time for loading database
|
140 |
self.logger.info(
|
141 |
-
f"Time taken to load database: {end_time - start_time} seconds"
|
142 |
)
|
143 |
self.logger.info("Loaded database")
|
144 |
return self.loaded_vector_db
|
145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
if __name__ == "__main__":
|
148 |
import yaml
|
@@ -152,7 +173,20 @@ if __name__ == "__main__":
|
|
152 |
print(config)
|
153 |
print(f"Trying to create database with config: {config}")
|
154 |
vector_db = VectorStoreManager(config)
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
print("Created database")
|
157 |
|
158 |
print(f"Trying to load the database")
|
|
|
3 |
from modules.dataloader.webpage_crawler import WebpageCrawler
|
4 |
from modules.dataloader.data_loader import DataLoader
|
5 |
from modules.dataloader.helpers import *
|
6 |
+
from modules.config.constants import RETRIEVER_HF_PATHS
|
7 |
from modules.vectorstore.embedding_model_loader import EmbeddingModelLoader
|
8 |
import logging
|
9 |
import os
|
|
|
136 |
self.embedding_model = self.create_embedding_model()
|
137 |
else:
|
138 |
self.embedding_model = None
|
139 |
+
try:
|
140 |
+
self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
|
141 |
+
except Exception as e:
|
142 |
+
raise ValueError(
|
143 |
+
f"Error loading database, check if it exists. if not run python -m modules.vectorstore.store_manager / Resteart the HF Space: {e}"
|
144 |
+
)
|
145 |
+
# print(f"Creating database")
|
146 |
+
# self.create_database()
|
147 |
+
# self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
|
148 |
end_time = time.time() # End time for loading database
|
149 |
self.logger.info(
|
150 |
+
f"Time taken to load database {self.config['vectorstore']['db_option']}: {end_time - start_time} seconds"
|
151 |
)
|
152 |
self.logger.info("Loaded database")
|
153 |
return self.loaded_vector_db
|
154 |
|
155 |
+
def load_from_HF(self, HF_PATH):
|
156 |
+
start_time = time.time() # Start time for loading database
|
157 |
+
self.vector_db._load_from_HF(HF_PATH)
|
158 |
+
end_time = time.time()
|
159 |
+
self.logger.info(
|
160 |
+
f"Time taken to Download database {self.config['vectorstore']['db_option']} from Hugging Face: {end_time - start_time} seconds"
|
161 |
+
)
|
162 |
+
self.logger.info("Downloaded database")
|
163 |
+
|
164 |
+
def __len__(self):
|
165 |
+
return len(self.vector_db)
|
166 |
+
|
167 |
|
168 |
if __name__ == "__main__":
|
169 |
import yaml
|
|
|
173 |
print(config)
|
174 |
print(f"Trying to create database with config: {config}")
|
175 |
vector_db = VectorStoreManager(config)
|
176 |
+
if config["vectorstore"]["load_from_HF"]:
|
177 |
+
if config["vectorstore"]["db_option"] in RETRIEVER_HF_PATHS:
|
178 |
+
vector_db.load_from_HF(
|
179 |
+
HF_PATH=RETRIEVER_HF_PATHS[config["vectorstore"]["db_option"]]
|
180 |
+
)
|
181 |
+
else:
|
182 |
+
# print(f"HF_PATH not available for {config['vectorstore']['db_option']}")
|
183 |
+
# print("Creating database")
|
184 |
+
# vector_db.create_database()
|
185 |
+
raise ValueError(
|
186 |
+
f"HF_PATH not available for {config['vectorstore']['db_option']}"
|
187 |
+
)
|
188 |
+
else:
|
189 |
+
vector_db.create_database()
|
190 |
print("Created database")
|
191 |
|
192 |
print(f"Trying to load the database")
|
code/modules/vectorstore/vectorstore.py
CHANGED
@@ -2,6 +2,9 @@ from modules.vectorstore.faiss import FaissVectorStore
|
|
2 |
from modules.vectorstore.chroma import ChromaVectorStore
|
3 |
from modules.vectorstore.colbert import ColbertVectorStore
|
4 |
from modules.vectorstore.raptor import RAPTORVectoreStore
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
class VectorStore:
|
@@ -50,8 +53,39 @@ class VectorStore:
|
|
50 |
else:
|
51 |
return self.vectorstore.load_database(embedding_model)
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
def _as_retriever(self):
|
54 |
return self.vectorstore.as_retriever()
|
55 |
|
56 |
def _get_vectorstore(self):
|
57 |
return self.vectorstore
|
|
|
|
|
|
|
|
2 |
from modules.vectorstore.chroma import ChromaVectorStore
|
3 |
from modules.vectorstore.colbert import ColbertVectorStore
|
4 |
from modules.vectorstore.raptor import RAPTORVectoreStore
|
5 |
+
from huggingface_hub import snapshot_download
|
6 |
+
import os
|
7 |
+
import shutil
|
8 |
|
9 |
|
10 |
class VectorStore:
|
|
|
53 |
else:
|
54 |
return self.vectorstore.load_database(embedding_model)
|
55 |
|
56 |
+
def _load_from_HF(self, HF_PATH):
|
57 |
+
# Download the snapshot from Hugging Face Hub
|
58 |
+
# Note: Download goes to the cache directory
|
59 |
+
snapshot_path = snapshot_download(
|
60 |
+
repo_id=HF_PATH,
|
61 |
+
repo_type="dataset",
|
62 |
+
force_download=True,
|
63 |
+
)
|
64 |
+
|
65 |
+
# Move the downloaded files to the desired directory
|
66 |
+
target_path = os.path.join(
|
67 |
+
self.config["vectorstore"]["db_path"],
|
68 |
+
"db_" + self.config["vectorstore"]["db_option"],
|
69 |
+
)
|
70 |
+
|
71 |
+
# Create target path if it doesn't exist
|
72 |
+
os.makedirs(target_path, exist_ok=True)
|
73 |
+
|
74 |
+
# move all files and directories from snapshot_path to target_path
|
75 |
+
# target path is used while loading the database
|
76 |
+
for item in os.listdir(snapshot_path):
|
77 |
+
s = os.path.join(snapshot_path, item)
|
78 |
+
d = os.path.join(target_path, item)
|
79 |
+
if os.path.isdir(s):
|
80 |
+
shutil.copytree(s, d, dirs_exist_ok=True)
|
81 |
+
else:
|
82 |
+
shutil.copy2(s, d)
|
83 |
+
|
84 |
def _as_retriever(self):
|
85 |
return self.vectorstore.as_retriever()
|
86 |
|
87 |
def _get_vectorstore(self):
|
88 |
return self.vectorstore
|
89 |
+
|
90 |
+
def __len__(self):
|
91 |
+
return self.vectorstore.__len__()
|
code/public/test.css
CHANGED
@@ -31,3 +31,13 @@ a[href*='https://github.com/Chainlit/chainlit'] {
|
|
31 |
.MuiAvatar-root.MuiAvatar-circular.css-v72an7 .MuiAvatar-img.css-1hy9t21 {
|
32 |
display: none;
|
33 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
.MuiAvatar-root.MuiAvatar-circular.css-v72an7 .MuiAvatar-img.css-1hy9t21 {
|
32 |
display: none;
|
33 |
}
|
34 |
+
|
35 |
+
/* Hide the new chat button
|
36 |
+
#new-chat-button {
|
37 |
+
display: none;
|
38 |
+
} */
|
39 |
+
|
40 |
+
/* Hide the open sidebar button
|
41 |
+
#open-sidebar-button {
|
42 |
+
display: none;
|
43 |
+
} */
|
requirements.txt
CHANGED
@@ -1,27 +1,25 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
langchain
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
trulens-eval==0.31.0
|
25 |
-
llama-cpp-python==0.2.77
|
26 |
-
pymupdf==1.24.5
|
27 |
websockets
|
|
|
|
1 |
+
aiohttp
|
2 |
+
beautifulsoup4
|
3 |
+
chainlit
|
4 |
+
langchain
|
5 |
+
langchain-community
|
6 |
+
langchain-core
|
7 |
+
literalai
|
8 |
+
llama-parse
|
9 |
+
numpy
|
10 |
+
pandas
|
11 |
+
pysrt
|
12 |
+
python-dotenv
|
13 |
+
PyYAML
|
14 |
+
RAGatouille
|
15 |
+
requests
|
16 |
+
scikit-learn
|
17 |
+
torch
|
18 |
+
tqdm
|
19 |
+
transformers
|
20 |
+
trulens_eval
|
21 |
+
umap-learn
|
22 |
+
llama-cpp-python
|
23 |
+
pymupdf
|
|
|
|
|
|
|
24 |
websockets
|
25 |
+
langchain-openai
|