Thomas (Tom) Gardos commited on
Commit
652745e
·
2 Parent(s): 0f566b9 166d2a9

Merge pull request #51 from DL4DS/dev_branch

Browse files
.github/workflows/push_to_hf_space_prototype.yml CHANGED
@@ -1,20 +1,21 @@
1
  name: Push Prototype to HuggingFace
2
 
3
  on:
4
- pull_request:
5
- branches:
6
- - dev_branch
7
-
 
8
 
9
  jobs:
10
- build:
11
  runs-on: ubuntu-latest
12
  steps:
13
- - name: Deploy Prototype to HuggingFace
14
- uses: nateraw/[email protected]
15
- with:
16
- github_repo_id: DL4DS/dl4ds_tutor
17
- huggingface_repo_id: dl4ds/tutor_dev
18
- repo_type: space
19
- space_sdk: static
20
- hf_token: ${{ secrets.HF_TOKEN }}
 
1
  name: Push Prototype to HuggingFace
2
 
3
  on:
4
+ push:
5
+ branches: [dev_branch]
6
+
7
+ # run this workflow manuall from the Actions tab
8
+ workflow_dispatch:
9
 
10
  jobs:
11
+ sync-to-hub:
12
  runs-on: ubuntu-latest
13
  steps:
14
+ - uses: actions/checkout@v4
15
+ with:
16
+ fetch-depth: 0
17
+ lfs: true
18
+ - name: Deploy Prototype to HuggingFace
19
+ env:
20
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
21
+ run: git push https://trgardos:$HF_TOKEN@huggingface.co/spaces/dl4ds/tutor_dev dev_branch:main
.vscode/launch.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // Use IntelliSense to learn about possible attributes.
3
+ // Hover to view descriptions of existing attributes.
4
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5
+ "version": "0.2.0",
6
+ "configurations": [
7
+ {
8
+ "name": "Python Debugger: Chainlit run main.py",
9
+ "type": "debugpy",
10
+ "request": "launch",
11
+ "program": "${workspaceFolder}/.venv/bin/chainlit",
12
+ "console": "integratedTerminal",
13
+ "args": ["run", "main.py"],
14
+ "cwd": "${workspaceFolder}/code",
15
+ "justMyCode": true
16
+ },
17
+ { "name":"Python Debugger: Module store_manager",
18
+ "type":"debugpy",
19
+ "request":"launch",
20
+ "module":"modules.vectorstore.store_manager",
21
+ "env": {"PYTHONPATH": "${workspaceFolder}/code"},
22
+ "cwd": "${workspaceFolder}/code",
23
+ "justMyCode": true
24
+ },
25
+ {
26
+ "name": "Python Debugger: Module data_loader",
27
+ "type": "debugpy",
28
+ "request": "launch",
29
+ "module": "modules.dataloader.data_loader",
30
+ "env": {"PYTHONPATH": "${workspaceFolder}/code"},
31
+ "cwd": "${workspaceFolder}/code",
32
+ "justMyCode": true
33
+ }
34
+ ]
35
+ }
.vscode/tasks.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // See https://go.microsoft.com/fwlink/?LinkId=733558
3
+ // for the documentation about the tasks.json format
4
+ "version": "2.0.0",
5
+ "tasks": [
6
+ {
7
+ "label": "echo",
8
+ "type": "shell",
9
+ "command": "echo ${workspaceFolder}; ls ${workspaceFolder}/code",
10
+ "problemMatcher": []
11
+ }
12
+ ]
13
+ }
Dockerfile.dev CHANGED
@@ -25,7 +25,7 @@ RUN mkdir /.cache && chmod -R 777 /.cache
25
  WORKDIR /code/code
26
 
27
  # Expose the port the app runs on
28
- EXPOSE 8051
29
 
30
  # Default command to run the application
31
- CMD ["sh", "-c", "python -m modules.vectorstore.store_manager && chainlit run main.py --host 0.0.0.0 --port 8051"]
 
25
  WORKDIR /code/code
26
 
27
  # Expose the port the app runs on
28
+ EXPOSE 8000
29
 
30
  # Default command to run the application
31
+ CMD ["sh", "-c", "python -m modules.vectorstore.store_manager && chainlit run main.py --host 0.0.0.0 --port 8000"]
README.md CHANGED
@@ -76,7 +76,7 @@ The HuggingFace Space is built using the `Dockerfile` in the repository. To run
76
 
77
  ```bash
78
  docker build --tag dev -f Dockerfile.dev .
79
- docker run -it --rm -p 8051:8051 dev
80
  ```
81
 
82
  ## Contributing
 
76
 
77
  ```bash
78
  docker build --tag dev -f Dockerfile.dev .
79
+ docker run -it --rm -p 8000:8000 dev
80
  ```
81
 
82
  ## Contributing
code/.chainlit/config.toml CHANGED
@@ -23,7 +23,7 @@ allow_origins = ["*"]
23
  unsafe_allow_html = false
24
 
25
  # Process and display mathematical expressions. This can clash with "$" characters in messages.
26
- latex = false
27
 
28
  # Automatically tag threads with the current chat profile (if a chat profile is used)
29
  auto_tag_thread = true
@@ -85,31 +85,34 @@ custom_meta_image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/f/
85
  # custom_build = "./public/build"
86
 
87
  [UI.theme]
88
- default = "light"
89
  #layout = "wide"
90
  #font_family = "Inter, sans-serif"
91
  # Override default MUI light theme. (Check theme.ts)
92
  [UI.theme.light]
93
- background = "#FAFAFA"
94
- paper = "#FFFFFF"
95
 
96
  [UI.theme.light.primary]
97
- main = "#b22222" # Brighter shade of red
98
- dark = "#8b0000" # Darker shade of the brighter red
99
- light = "#ff6347" # Lighter shade of the brighter red
100
  [UI.theme.light.text]
101
- primary = "#212121"
102
- secondary = "#616161"
 
103
  # Override default MUI dark theme. (Check theme.ts)
104
  [UI.theme.dark]
105
- background = "#1C1C1C" # Slightly lighter dark background color
106
- paper = "#2A2A2A" # Slightly lighter dark paper color
107
 
108
  [UI.theme.dark.primary]
109
- main = "#89CFF0" # Primary color
110
- dark = "#3700B3" # Dark variant of primary color
111
- light = "#CFBCFF" # Lighter variant of primary color
112
-
 
 
113
 
114
  [meta]
115
- generated_by = "1.1.302"
 
23
  unsafe_allow_html = false
24
 
25
  # Process and display mathematical expressions. This can clash with "$" characters in messages.
26
+ latex = true
27
 
28
  # Automatically tag threads with the current chat profile (if a chat profile is used)
29
  auto_tag_thread = true
 
85
  # custom_build = "./public/build"
86
 
87
  [UI.theme]
88
+ default = "dark"
89
  #layout = "wide"
90
  #font_family = "Inter, sans-serif"
91
  # Override default MUI light theme. (Check theme.ts)
92
  [UI.theme.light]
93
+ #background = "#FAFAFA"
94
+ #paper = "#FFFFFF"
95
 
96
  [UI.theme.light.primary]
97
+ #main = "#F80061"
98
+ #dark = "#980039"
99
+ #light = "#FFE7EB"
100
  [UI.theme.light.text]
101
+ #primary = "#212121"
102
+ #secondary = "#616161"
103
+
104
  # Override default MUI dark theme. (Check theme.ts)
105
  [UI.theme.dark]
106
+ #background = "#FAFAFA"
107
+ #paper = "#FFFFFF"
108
 
109
  [UI.theme.dark.primary]
110
+ #main = "#F80061"
111
+ #dark = "#980039"
112
+ #light = "#FFE7EB"
113
+ [UI.theme.dark.text]
114
+ #primary = "#EEEEEE"
115
+ #secondary = "#BDBDBD"
116
 
117
  [meta]
118
+ generated_by = "1.1.304"
code/main.py CHANGED
@@ -1,176 +1,465 @@
1
- from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
2
- from langchain_core.prompts import PromptTemplate
3
- from langchain_community.embeddings import HuggingFaceEmbeddings
4
- from langchain_community.vectorstores import FAISS
5
- from langchain.chains import RetrievalQA
6
- import chainlit as cl
7
- from langchain_community.chat_models import ChatOpenAI
8
- from langchain_community.embeddings import OpenAIEmbeddings
9
- import yaml
10
- import logging
11
- from dotenv import load_dotenv
12
 
 
 
 
 
 
13
  from modules.chat.llm_tutor import LLMTutor
14
- from modules.config.constants import *
15
- from modules.chat.helpers import get_sources
16
- from modules.chat_processor.chat_processor import ChatProcessor
17
-
18
- global logger
19
- # Initialize logger
20
- logger = logging.getLogger(__name__)
21
- logger.setLevel(logging.INFO)
22
- formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
23
-
24
- # Console Handler
25
- console_handler = logging.StreamHandler()
26
- console_handler.setLevel(logging.INFO)
27
- console_handler.setFormatter(formatter)
28
- logger.addHandler(console_handler)
29
-
30
-
31
- @cl.set_starters
32
- async def set_starters():
33
- return [
34
- cl.Starter(
35
- label="recording on CNNs?",
36
- message="Where can I find the recording for the lecture on Transfromers?",
37
- icon="/public/adv-screen-recorder-svgrepo-com.svg",
38
- ),
39
- cl.Starter(
40
- label="where's the slides?",
41
- message="When are the lectures? I can't find the schedule.",
42
- icon="/public/alarmy-svgrepo-com.svg",
43
- ),
44
- cl.Starter(
45
- label="Due Date?",
46
- message="When is the final project due?",
47
- icon="/public/calendar-samsung-17-svgrepo-com.svg",
48
- ),
49
- cl.Starter(
50
- label="Explain backprop.",
51
- message="I didnt understand the math behind backprop, could you explain it?",
52
- icon="/public/acastusphoton-svgrepo-com.svg",
53
- ),
54
- ]
55
-
56
-
57
- # Adding option to select the chat profile
58
- @cl.set_chat_profiles
59
- async def chat_profile():
60
- return [
61
- # cl.ChatProfile(
62
- # name="Mistral",
63
- # markdown_description="Use the local LLM: **Mistral**.",
64
- # ),
65
- cl.ChatProfile(
66
- name="gpt-3.5-turbo-1106",
67
- markdown_description="Use OpenAI API for **gpt-3.5-turbo-1106**.",
68
- ),
69
- cl.ChatProfile(
70
- name="gpt-4",
71
- markdown_description="Use OpenAI API for **gpt-4**.",
72
- ),
73
- cl.ChatProfile(
74
- name="Llama",
75
- markdown_description="Use the local LLM: **Tiny Llama**.",
76
- ),
77
- ]
78
-
79
-
80
- @cl.author_rename
81
- def rename(orig_author: str):
82
- rename_dict = {"Chatbot": "AI Tutor"}
83
- return rename_dict.get(orig_author, orig_author)
84
-
85
-
86
- # chainlit code
87
- @cl.on_chat_start
88
- async def start():
89
- with open("modules/config/config.yml", "r") as f:
90
- config = yaml.safe_load(f)
91
-
92
- # Ensure log directory exists
93
- log_directory = config["log_dir"]
94
- if not os.path.exists(log_directory):
95
- os.makedirs(log_directory)
96
-
97
- # File Handler
98
- log_file_path = (
99
- f"{log_directory}/tutor.log" # Change this to your desired log file path
100
- )
101
- file_handler = logging.FileHandler(log_file_path, mode="w")
102
- file_handler.setLevel(logging.INFO)
103
- file_handler.setFormatter(formatter)
104
- logger.addHandler(file_handler)
105
-
106
- logger.info("Config file loaded")
107
- logger.info(f"Config: {config}")
108
- logger.info("Creating llm_tutor instance")
109
-
110
- chat_profile = cl.user_session.get("chat_profile")
111
- if chat_profile is not None:
112
- if chat_profile.lower() in ["gpt-3.5-turbo-1106", "gpt-4"]:
113
- config["llm_params"]["llm_loader"] = "openai"
114
- config["llm_params"]["openai_params"]["model"] = chat_profile.lower()
115
- elif chat_profile.lower() == "llama":
116
- config["llm_params"]["llm_loader"] = "local_llm"
117
- config["llm_params"]["local_llm_params"]["model"] = LLAMA_PATH
118
- config["llm_params"]["local_llm_params"]["model_type"] = "llama"
119
- elif chat_profile.lower() == "mistral":
120
- config["llm_params"]["llm_loader"] = "local_llm"
121
- config["llm_params"]["local_llm_params"]["model"] = MISTRAL_PATH
122
- config["llm_params"]["local_llm_params"]["model_type"] = "mistral"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  else:
125
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- llm_tutor = LLMTutor(config, logger=logger)
128
 
129
- chain = llm_tutor.qa_bot()
130
- # msg = cl.Message(content=f"Starting the bot {chat_profile}...")
131
- # await msg.send()
132
- # msg.content = opening_message
133
- # await msg.update()
 
 
 
134
 
135
- tags = [chat_profile, config["vectorstore"]["db_option"]]
136
- chat_processor = ChatProcessor(config, tags=tags)
137
- cl.user_session.set("chain", chain)
138
- cl.user_session.set("counter", 0)
139
- cl.user_session.set("chat_processor", chat_processor)
140
 
 
 
 
 
 
 
 
141
 
142
- @cl.on_chat_end
143
- async def on_chat_end():
144
- await cl.Message(content="Sorry, I have to go now. Goodbye!").send()
 
 
 
 
 
 
 
 
 
145
 
 
 
 
 
 
 
 
 
146
 
147
- @cl.on_message
148
- async def main(message):
149
- global logger
150
- user = cl.user_session.get("user")
151
- chain = cl.user_session.get("chain")
 
 
152
 
153
- counter = cl.user_session.get("counter")
154
- counter += 1
155
- cl.user_session.set("counter", counter)
156
 
157
- # if counter >= 3: # Ensure the counter condition is checked
158
- # await cl.Message(content="Your credits are up!").send()
159
- # await on_chat_end() # Call the on_chat_end function to handle the end of the chat
160
- # return # Exit the function to stop further processing
161
- # else:
162
 
163
- cb = cl.AsyncLangchainCallbackHandler() # TODO: fix streaming here
164
- cb.answer_reached = True
165
 
166
- processor = cl.user_session.get("chat_processor")
167
- res = await processor.rag(message.content, chain, cb)
168
- try:
169
- answer = res["answer"]
170
- except:
171
- answer = res["result"]
 
 
 
 
172
 
173
- answer_with_sources, source_elements, sources_dict = get_sources(res, answer)
174
- processor._process(message.content, answer, sources_dict)
175
 
176
- await cl.Message(content=answer_with_sources, elements=source_elements).send()
 
1
+ import chainlit.data as cl_data
2
+ import asyncio
3
+ from modules.config.constants import (
4
+ LLAMA_PATH,
5
+ LITERAL_API_KEY_LOGGING,
6
+ LITERAL_API_URL,
7
+ )
8
+ from modules.chat_processor.literal_ai import CustomLiteralDataLayer
 
 
 
9
 
10
+ import json
11
+ import yaml
12
+ import os
13
+ from typing import Any, Dict, no_type_check
14
+ import chainlit as cl
15
  from modules.chat.llm_tutor import LLMTutor
16
+ from modules.chat.helpers import (
17
+ get_sources,
18
+ get_history_chat_resume,
19
+ get_history_setup_llm,
20
+ get_last_config,
21
+ )
22
+ import copy
23
+ from typing import Optional
24
+ from chainlit.types import ThreadDict
25
+ import time
26
+
27
+ USER_TIMEOUT = 60_000
28
+ SYSTEM = "System 🖥️"
29
+ LLM = "LLM 🧠"
30
+ AGENT = "Agent <>"
31
+ YOU = "You 😃"
32
+ ERROR = "Error 🚫"
33
+
34
+ with open("modules/config/config.yml", "r") as f:
35
+ config = yaml.safe_load(f)
36
+
37
+
38
+ async def setup_data_layer():
39
+ """
40
+ Set up the data layer for chat logging.
41
+ """
42
+ if config["chat_logging"]["log_chat"]:
43
+ data_layer = CustomLiteralDataLayer(
44
+ api_key=LITERAL_API_KEY_LOGGING, server=LITERAL_API_URL
45
+ )
46
+ else:
47
+ data_layer = None
48
+
49
+ return data_layer
50
+
51
+
52
+ class Chatbot:
53
+ def __init__(self, config):
54
+ """
55
+ Initialize the Chatbot class.
56
+ """
57
+ self.config = config
58
+
59
+ async def _load_config(self):
60
+ """
61
+ Load the configuration from a YAML file.
62
+ """
63
+ with open("modules/config/config.yml", "r") as f:
64
+ return yaml.safe_load(f)
65
+
66
+ @no_type_check
67
+ async def setup_llm(self):
68
+ """
69
+ Set up the LLM with the provided settings. Update the configuration and initialize the LLM tutor.
70
+
71
+ #TODO: Clean this up.
72
+ """
73
+ start_time = time.time()
74
+
75
+ llm_settings = cl.user_session.get("llm_settings", {})
76
+ chat_profile, retriever_method, memory_window, llm_style, generate_follow_up, chunking_mode = (
77
+ llm_settings.get("chat_model"),
78
+ llm_settings.get("retriever_method"),
79
+ llm_settings.get("memory_window"),
80
+ llm_settings.get("llm_style"),
81
+ llm_settings.get("follow_up_questions"),
82
+ llm_settings.get("chunking_mode"),
83
+ )
84
+
85
+ chain = cl.user_session.get("chain")
86
+ memory_list = cl.user_session.get(
87
+ "memory",
88
+ (
89
+ list(chain.store.values())[0].messages
90
+ if len(chain.store.values()) > 0
91
+ else []
92
+ ),
93
+ )
94
+ conversation_list = get_history_setup_llm(memory_list)
95
+
96
+ old_config = copy.deepcopy(self.config)
97
+ self.config["vectorstore"]["db_option"] = retriever_method
98
+ self.config["llm_params"]["memory_window"] = memory_window
99
+ self.config["llm_params"]["llm_style"] = llm_style
100
+ self.config["llm_params"]["llm_loader"] = chat_profile
101
+ self.config["llm_params"]["generate_follow_up"] = generate_follow_up
102
+ self.config["splitter_options"]["chunking_mode"] = chunking_mode
103
+
104
+ self.llm_tutor.update_llm(
105
+ old_config, self.config
106
+ ) # update only llm attributes that are changed
107
+ self.chain = self.llm_tutor.qa_bot(
108
+ memory=conversation_list,
109
+ callbacks=(
110
+ [cl.LangchainCallbackHandler()]
111
+ if cl_data._data_layer and self.config["chat_logging"]["callbacks"]
112
+ else None
113
+ ),
114
+ )
115
+
116
+ tags = [chat_profile, self.config["vectorstore"]["db_option"]]
117
+
118
+ cl.user_session.set("chain", self.chain)
119
+ cl.user_session.set("llm_tutor", self.llm_tutor)
120
+
121
+ print("Time taken to setup LLM: ", time.time() - start_time)
122
+
123
+ @no_type_check
124
+ async def update_llm(self, new_settings: Dict[str, Any]):
125
+ """
126
+ Update the LLM settings and reinitialize the LLM with the new settings.
127
+
128
+ Args:
129
+ new_settings (Dict[str, Any]): The new settings to update.
130
+ """
131
+ cl.user_session.set("llm_settings", new_settings)
132
+ await self.inform_llm_settings()
133
+ await self.setup_llm()
134
+
135
+ async def make_llm_settings_widgets(self, config=None):
136
+ """
137
+ Create and send the widgets for LLM settings configuration.
138
+
139
+ Args:
140
+ config: The configuration to use for setting up the widgets.
141
+ """
142
+ config = config or self.config
143
+ await cl.ChatSettings(
144
+ [
145
+ cl.input_widget.Select(
146
+ id="chat_model",
147
+ label="Model Name (Default GPT-3)",
148
+ values=["local_llm", "gpt-3.5-turbo-1106", "gpt-4", "gpt-4o-mini"],
149
+ initial_index=[
150
+ "local_llm",
151
+ "gpt-3.5-turbo-1106",
152
+ "gpt-4",
153
+ "gpt-4o-mini",
154
+ ].index(config["llm_params"]["llm_loader"]),
155
+ ),
156
+ cl.input_widget.Select(
157
+ id="retriever_method",
158
+ label="Retriever (Default FAISS)",
159
+ values=["FAISS", "Chroma", "RAGatouille", "RAPTOR"],
160
+ initial_index=["FAISS", "Chroma", "RAGatouille", "RAPTOR"].index(
161
+ config["vectorstore"]["db_option"]
162
+ ),
163
+ ),
164
+ cl.input_widget.Slider(
165
+ id="memory_window",
166
+ label="Memory Window (Default 3)",
167
+ initial=3,
168
+ min=0,
169
+ max=10,
170
+ step=1,
171
+ ),
172
+ cl.input_widget.Switch(
173
+ id="view_sources", label="View Sources", initial=False
174
+ ),
175
+ cl.input_widget.Switch(
176
+ id="stream_response",
177
+ label="Stream response",
178
+ initial=config["llm_params"]["stream"],
179
+ ),
180
+ cl.input_widget.Select(
181
+ id="chunking_mode",
182
+ label="Chunking mode",
183
+ values=['fixed', 'semantic'],
184
+ initial_index=1,
185
+ ),
186
+ cl.input_widget.Switch(
187
+ id="follow_up_questions",
188
+ label="Generate follow up questions",
189
+ initial=False,
190
+ ),
191
+ cl.input_widget.Select(
192
+ id="llm_style",
193
+ label="Type of Conversation (Default Normal)",
194
+ values=["Normal", "ELI5"],
195
+ initial_index=0,
196
+ ),
197
+ ]
198
+ ).send()
199
+
200
+ @no_type_check
201
+ async def inform_llm_settings(self):
202
+ """
203
+ Inform the user about the updated LLM settings and display them as a message.
204
+ """
205
+ llm_settings: Dict[str, Any] = cl.user_session.get("llm_settings", {})
206
+ llm_tutor = cl.user_session.get("llm_tutor")
207
+ settings_dict = {
208
+ "model": llm_settings.get("chat_model"),
209
+ "retriever": llm_settings.get("retriever_method"),
210
+ "memory_window": llm_settings.get("memory_window"),
211
+ "num_docs_in_db": (
212
+ len(llm_tutor.vector_db)
213
+ if llm_tutor and hasattr(llm_tutor, "vector_db")
214
+ else 0
215
+ ),
216
+ "view_sources": llm_settings.get("view_sources"),
217
+ "follow_up_questions": llm_settings.get("follow_up_questions"),
218
+ }
219
+ await cl.Message(
220
+ author=SYSTEM,
221
+ content="LLM settings have been updated. You can continue with your Query!",
222
+ elements=[
223
+ cl.Text(
224
+ name="settings",
225
+ display="side",
226
+ content=json.dumps(settings_dict, indent=4),
227
+ language="json",
228
+ ),
229
+ ],
230
+ ).send()
231
+
232
+ async def set_starters(self):
233
+ """
234
+ Set starter messages for the chatbot.
235
+ """
236
+ # Return Starters only if the chat is new
237
 
238
+ try:
239
+ thread = cl_data._data_layer.get_thread(
240
+ cl.context.session.thread_id
241
+ ) # see if the thread has any steps
242
+ if thread.steps or len(thread.steps) > 0:
243
+ return None
244
+ except:
245
+ return [
246
+ cl.Starter(
247
+ label="recording on CNNs?",
248
+ message="Where can I find the recording for the lecture on Transformers?",
249
+ icon="/public/adv-screen-recorder-svgrepo-com.svg",
250
+ ),
251
+ cl.Starter(
252
+ label="where's the slides?",
253
+ message="When are the lectures? I can't find the schedule.",
254
+ icon="/public/alarmy-svgrepo-com.svg",
255
+ ),
256
+ cl.Starter(
257
+ label="Due Date?",
258
+ message="When is the final project due?",
259
+ icon="/public/calendar-samsung-17-svgrepo-com.svg",
260
+ ),
261
+ cl.Starter(
262
+ label="Explain backprop.",
263
+ message="I didn't understand the math behind backprop, could you explain it?",
264
+ icon="/public/acastusphoton-svgrepo-com.svg",
265
+ ),
266
+ ]
267
+
268
+ def rename(self, orig_author: str):
269
+ """
270
+ Rename the original author to a more user-friendly name.
271
+
272
+ Args:
273
+ orig_author (str): The original author's name.
274
+
275
+ Returns:
276
+ str: The renamed author.
277
+ """
278
+ rename_dict = {"Chatbot": "AI Tutor"}
279
+ return rename_dict.get(orig_author, orig_author)
280
+
281
+ async def start(self, config=None):
282
+ """
283
+ Start the chatbot, initialize settings widgets,
284
+ and display and load previous conversation if chat logging is enabled.
285
+ """
286
+
287
+ start_time = time.time()
288
+
289
+ self.config = (
290
+ await self._load_config() if config is None else config
291
+ ) # Reload the configuration on chat resume
292
+
293
+ await self.make_llm_settings_widgets(self.config) # Reload the settings widgets
294
+
295
+ await self.make_llm_settings_widgets(self.config)
296
+ user = cl.user_session.get("user")
297
+ self.user = {
298
+ "user_id": user.identifier,
299
+ "session_id": cl.context.session.thread_id,
300
+ }
301
+
302
+ memory = cl.user_session.get("memory", [])
303
+
304
+ cl.user_session.set("user", self.user)
305
+ self.llm_tutor = LLMTutor(self.config, user=self.user)
306
+
307
+ self.chain = self.llm_tutor.qa_bot(
308
+ memory=memory,
309
+ callbacks=(
310
+ [cl.LangchainCallbackHandler()]
311
+ if cl_data._data_layer and self.config["chat_logging"]["callbacks"]
312
+ else None
313
+ ),
314
+ )
315
+ self.question_generator = self.llm_tutor.question_generator
316
+ cl.user_session.set("llm_tutor", self.llm_tutor)
317
+ cl.user_session.set("chain", self.chain)
318
+
319
+ print("Time taken to start LLM: ", time.time() - start_time)
320
+
321
+ async def stream_response(self, response):
322
+ """
323
+ Stream the response from the LLM.
324
+
325
+ Args:
326
+ response: The response from the LLM.
327
+ """
328
+ msg = cl.Message(content="")
329
+ await msg.send()
330
+
331
+ output = {}
332
+ for chunk in response:
333
+ if "answer" in chunk:
334
+ await msg.stream_token(chunk["answer"])
335
+
336
+ for key in chunk:
337
+ if key not in output:
338
+ output[key] = chunk[key]
339
+ else:
340
+ output[key] += chunk[key]
341
+ return output
342
+
343
+ async def main(self, message):
344
+ """
345
+ Process and Display the Conversation.
346
+
347
+ Args:
348
+ message: The incoming chat message.
349
+ """
350
+
351
+ start_time = time.time()
352
+
353
+ chain = cl.user_session.get("chain")
354
+
355
+ llm_settings = cl.user_session.get("llm_settings", {})
356
+ view_sources = llm_settings.get("view_sources", False)
357
+ stream = llm_settings.get("stream_response", False)
358
+ steam = False # Fix streaming
359
+ user_query_dict = {"input": message.content}
360
+ # Define the base configuration
361
+ chain_config = {
362
+ "configurable": {
363
+ "user_id": self.user["user_id"],
364
+ "conversation_id": self.user["session_id"],
365
+ "memory_window": self.config["llm_params"]["memory_window"],
366
+ }
367
+ }
368
+
369
+ if stream:
370
+ res = chain.stream(user_query=user_query_dict, config=chain_config)
371
+ res = await self.stream_response(res)
372
  else:
373
+ res = await chain.invoke(
374
+ user_query=user_query_dict,
375
+ config=chain_config,
376
+ )
377
+
378
+ answer = res.get("answer", res.get("result"))
379
+
380
+ answer_with_sources, source_elements, sources_dict = get_sources(
381
+ res, answer, stream=stream, view_sources=view_sources
382
+ )
383
+ answer_with_sources = answer_with_sources.replace("$$", "$")
384
+
385
+ print("Time taken to process the message: ", time.time() - start_time)
386
+
387
+ actions = []
388
+
389
+ if self.config["llm_params"]["generate_follow_up"]:
390
+ start_time = time.time()
391
+ list_of_questions = self.question_generator.generate_questions(
392
+ query=user_query_dict["input"],
393
+ response=answer,
394
+ chat_history=res.get("chat_history"),
395
+ context=res.get("context"),
396
+ )
397
 
398
+ for question in list_of_questions:
399
 
400
+ actions.append(
401
+ cl.Action(
402
+ name="follow up question",
403
+ value="example_value",
404
+ description=question,
405
+ label=question,
406
+ )
407
+ )
408
 
409
+ print("Time taken to generate questions: ", time.time() - start_time)
 
 
 
 
410
 
411
+ await cl.Message(
412
+ content=answer_with_sources,
413
+ elements=source_elements,
414
+ author=LLM,
415
+ actions=actions,
416
+ metadata=self.config,
417
+ ).send()
418
 
419
+ async def on_chat_resume(self, thread: ThreadDict):
420
+ thread_config = None
421
+ steps = thread["steps"]
422
+ k = self.config["llm_params"][
423
+ "memory_window"
424
+ ] # on resume, alwyas use the default memory window
425
+ conversation_list = get_history_chat_resume(steps, k, SYSTEM, LLM)
426
+ thread_config = get_last_config(
427
+ steps
428
+ ) # TODO: Returns None for now - which causes config to be reloaded with default values
429
+ cl.user_session.set("memory", conversation_list)
430
+ await self.start(config=thread_config)
431
 
432
+ @cl.oauth_callback
433
+ def auth_callback(
434
+ provider_id: str,
435
+ token: str,
436
+ raw_user_data: Dict[str, str],
437
+ default_user: cl.User,
438
+ ) -> Optional[cl.User]:
439
+ return default_user
440
 
441
+ async def on_follow_up(self, action: cl.Action):
442
+ message = await cl.Message(
443
+ content=action.description,
444
+ type="user_message",
445
+ author=self.user["user_id"],
446
+ ).send()
447
+ await self.main(message)
448
 
 
 
 
449
 
450
+ chatbot = Chatbot(config=config)
 
 
 
 
451
 
 
 
452
 
453
+ async def start_app():
454
+ cl_data._data_layer = await setup_data_layer()
455
+ chatbot.literal_client = cl_data._data_layer.client if cl_data._data_layer else None
456
+ cl.set_starters(chatbot.set_starters)
457
+ cl.author_rename(chatbot.rename)
458
+ cl.on_chat_start(chatbot.start)
459
+ cl.on_chat_resume(chatbot.on_chat_resume)
460
+ cl.on_message(chatbot.main)
461
+ cl.on_settings_update(chatbot.update_llm)
462
+ cl.action_callback("follow up question")(chatbot.on_follow_up)
463
 
 
 
464
 
465
+ asyncio.run(start_app())
code/modules/chat/base.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class BaseRAG:
2
+ """
3
+ Base class for RAG chatbot.
4
+ """
5
+
6
+ def __init__():
7
+ pass
8
+
9
+ def invoke():
10
+ """
11
+ Invoke the RAG chatbot.
12
+ """
13
+ pass
code/modules/chat/chat_model_loader.py CHANGED
@@ -1,12 +1,15 @@
1
- from langchain_community.chat_models import ChatOpenAI
2
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
3
  from transformers import AutoTokenizer, TextStreamer
4
  from langchain_community.llms import LlamaCpp
5
  import torch
6
  import transformers
7
  import os
 
 
8
  from langchain.callbacks.manager import CallbackManager
9
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 
10
 
11
 
12
  class ChatModelLoader:
@@ -14,16 +17,28 @@ class ChatModelLoader:
14
  self.config = config
15
  self.huggingface_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
16
 
 
 
 
 
 
 
 
 
17
  def load_chat_model(self):
18
- if self.config["llm_params"]["llm_loader"] == "openai":
19
- llm = ChatOpenAI(
20
- model_name=self.config["llm_params"]["openai_params"]["model"]
21
- )
 
 
22
  elif self.config["llm_params"]["llm_loader"] == "local_llm":
23
  n_batch = 512 # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
24
- model_path = self.config["llm_params"]["local_llm_params"]["model"]
 
 
25
  llm = LlamaCpp(
26
- model_path=model_path,
27
  n_batch=n_batch,
28
  n_ctx=2048,
29
  f16_kv=True,
@@ -34,5 +49,7 @@ class ChatModelLoader:
34
  ],
35
  )
36
  else:
37
- raise ValueError("Invalid LLM Loader")
 
 
38
  return llm
 
1
+ from langchain_openai import ChatOpenAI
2
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
3
  from transformers import AutoTokenizer, TextStreamer
4
  from langchain_community.llms import LlamaCpp
5
  import torch
6
  import transformers
7
  import os
8
+ from pathlib import Path
9
+ from huggingface_hub import hf_hub_download
10
  from langchain.callbacks.manager import CallbackManager
11
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
12
+ from modules.config.constants import LLAMA_PATH
13
 
14
 
15
  class ChatModelLoader:
 
17
  self.config = config
18
  self.huggingface_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
19
 
20
+ def _verify_model_cache(self, model_cache_path):
21
+ hf_hub_download(
22
+ repo_id=self.config["llm_params"]["local_llm_params"]["repo_id"],
23
+ filename=self.config["llm_params"]["local_llm_params"]["filename"],
24
+ cache_dir=model_cache_path,
25
+ )
26
+ return str(list(Path(model_cache_path).glob("*/snapshots/*/*.gguf"))[0])
27
+
28
  def load_chat_model(self):
29
+ if self.config["llm_params"]["llm_loader"] in [
30
+ "gpt-3.5-turbo-1106",
31
+ "gpt-4",
32
+ "gpt-4o-mini",
33
+ ]:
34
+ llm = ChatOpenAI(model_name=self.config["llm_params"]["llm_loader"])
35
  elif self.config["llm_params"]["llm_loader"] == "local_llm":
36
  n_batch = 512 # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
37
+ model_path = self._verify_model_cache(
38
+ self.config["llm_params"]["local_llm_params"]["model"]
39
+ )
40
  llm = LlamaCpp(
41
+ model_path=LLAMA_PATH,
42
  n_batch=n_batch,
43
  n_ctx=2048,
44
  f16_kv=True,
 
49
  ],
50
  )
51
  else:
52
+ raise ValueError(
53
+ f"Invalid LLM Loader: {self.config['llm_params']['llm_loader']}"
54
+ )
55
  return llm
code/modules/chat/helpers.py CHANGED
@@ -1,13 +1,12 @@
1
- from modules.config.constants import *
2
  import chainlit as cl
3
- from langchain_core.prompts import PromptTemplate
4
 
5
 
6
- def get_sources(res, answer):
7
  source_elements = []
8
  source_dict = {} # Dictionary to store URL elements
9
 
10
- for idx, source in enumerate(res["source_documents"]):
11
  source_metadata = source.metadata
12
  url = source_metadata.get("source", "N/A")
13
  score = source_metadata.get("score", "N/A")
@@ -36,69 +35,135 @@ def get_sources(res, answer):
36
  else:
37
  source_dict[url_name]["text"] += f"\n\n{source.page_content}"
38
 
39
- # First, display the answer
40
- full_answer = "**Answer:**\n"
41
- full_answer += answer
42
 
43
- # Then, display the sources
44
- full_answer += "\n\n**Sources:**\n"
45
- for idx, (url_name, source_data) in enumerate(source_dict.items()):
46
- full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
47
 
48
- name = f"Source {idx + 1} Text\n"
49
- full_answer += name
50
- source_elements.append(
51
- cl.Text(name=name, content=source_data["text"], display="side")
52
- )
53
 
54
- # Add a PDF element if the source is a PDF file
55
- if source_data["url"].lower().endswith(".pdf"):
56
- name = f"Source {idx + 1} PDF\n"
57
- full_answer += name
58
- pdf_url = f"{source_data['url']}#page={source_data['page']+1}"
59
- source_elements.append(cl.Pdf(name=name, url=pdf_url, display="side"))
60
-
61
- full_answer += "\n**Metadata:**\n"
62
- for idx, (url_name, source_data) in enumerate(source_dict.items()):
63
- full_answer += f"\nSource {idx + 1} Metadata:\n"
64
- source_elements.append(
65
- cl.Text(
66
- name=f"Source {idx + 1} Metadata",
67
- content=f"Source: {source_data['url']}\n"
68
- f"Page: {source_data['page']}\n"
69
- f"Type: {source_data['source_type']}\n"
70
- f"Date: {source_data['date']}\n"
71
- f"TL;DR: {source_data['lecture_tldr']}\n"
72
- f"Lecture Recording: {source_data['lecture_recording']}\n"
73
- f"Suggested Readings: {source_data['suggested_readings']}\n",
74
- display="side",
75
- )
76
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  return full_answer, source_elements, source_dict
79
 
80
 
81
- def get_prompt(config):
82
- if config["llm_params"]["use_history"]:
83
- if config["llm_params"]["llm_loader"] == "local_llm":
84
- custom_prompt_template = tinyllama_prompt_template_with_history
85
- elif config["llm_params"]["llm_loader"] == "openai":
86
- custom_prompt_template = openai_prompt_template_with_history
87
- # else:
88
- # custom_prompt_template = tinyllama_prompt_template_with_history # default
89
- prompt = PromptTemplate(
90
- template=custom_prompt_template,
91
- input_variables=["context", "chat_history", "question"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  )
93
- else:
94
- if config["llm_params"]["llm_loader"] == "local_llm":
95
- custom_prompt_template = tinyllama_prompt_template
96
- elif config["llm_params"]["llm_loader"] == "openai":
97
- custom_prompt_template = openai_prompt_template
98
- # else:
99
- # custom_prompt_template = tinyllama_prompt_template
100
- prompt = PromptTemplate(
101
- template=custom_prompt_template,
102
- input_variables=["context", "question"],
103
  )
104
- return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.config.prompts import prompts
2
  import chainlit as cl
 
3
 
4
 
5
+ def get_sources(res, answer, stream=True, view_sources=False):
6
  source_elements = []
7
  source_dict = {} # Dictionary to store URL elements
8
 
9
+ for idx, source in enumerate(res["context"]):
10
  source_metadata = source.metadata
11
  url = source_metadata.get("source", "N/A")
12
  score = source_metadata.get("score", "N/A")
 
35
  else:
36
  source_dict[url_name]["text"] += f"\n\n{source.page_content}"
37
 
38
+ full_answer = "" # Not to include the answer again if streaming
 
 
39
 
40
+ if not stream: # First, display the answer if not streaming
41
+ full_answer = "**Answer:**\n"
42
+ full_answer += answer
 
43
 
44
+ if view_sources:
 
 
 
 
45
 
46
+ # Then, display the sources
47
+ # check if the answer has sources
48
+ if len(source_dict) == 0:
49
+ full_answer += "\n\n**No sources found.**"
50
+ return full_answer, source_elements, source_dict
51
+ else:
52
+ full_answer += "\n\n**Sources:**\n"
53
+ for idx, (url_name, source_data) in enumerate(source_dict.items()):
54
+
55
+ full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
56
+
57
+ name = f"Source {idx + 1} Text\n"
58
+ full_answer += name
59
+ source_elements.append(
60
+ cl.Text(name=name, content=source_data["text"], display="side")
61
+ )
62
+
63
+ # Add a PDF element if the source is a PDF file
64
+ if source_data["url"].lower().endswith(".pdf"):
65
+ name = f"Source {idx + 1} PDF\n"
66
+ full_answer += name
67
+ pdf_url = f"{source_data['url']}#page={source_data['page']+1}"
68
+ source_elements.append(
69
+ cl.Pdf(name=name, url=pdf_url, display="side")
70
+ )
71
+
72
+ full_answer += "\n**Metadata:**\n"
73
+ for idx, (url_name, source_data) in enumerate(source_dict.items()):
74
+ full_answer += f"\nSource {idx + 1} Metadata:\n"
75
+ source_elements.append(
76
+ cl.Text(
77
+ name=f"Source {idx + 1} Metadata",
78
+ content=f"Source: {source_data['url']}\n"
79
+ f"Page: {source_data['page']}\n"
80
+ f"Type: {source_data['source_type']}\n"
81
+ f"Date: {source_data['date']}\n"
82
+ f"TL;DR: {source_data['lecture_tldr']}\n"
83
+ f"Lecture Recording: {source_data['lecture_recording']}\n"
84
+ f"Suggested Readings: {source_data['suggested_readings']}\n",
85
+ display="side",
86
+ )
87
+ )
88
 
89
  return full_answer, source_elements, source_dict
90
 
91
 
92
+ def get_prompt(config, prompt_type):
93
+ llm_params = config["llm_params"]
94
+ llm_loader = llm_params["llm_loader"]
95
+ use_history = llm_params["use_history"]
96
+ llm_style = llm_params["llm_style"].lower()
97
+
98
+ if prompt_type == "qa":
99
+ if llm_loader == "local_llm":
100
+ if use_history:
101
+ return prompts["tiny_llama"]["prompt_with_history"]
102
+ else:
103
+ return prompts["tiny_llama"]["prompt_no_history"]
104
+ else:
105
+ if use_history:
106
+ return prompts["openai"]["prompt_with_history"][llm_style]
107
+ else:
108
+ return prompts["openai"]["prompt_no_history"]
109
+ elif prompt_type == "rephrase":
110
+ return prompts["openai"]["rephrase_prompt"]
111
+
112
+
113
+ def get_history_chat_resume(steps, k, SYSTEM, LLM):
114
+ conversation_list = []
115
+ count = 0
116
+ for step in reversed(steps):
117
+ if step["name"] not in [SYSTEM]:
118
+ if step["type"] == "user_message":
119
+ conversation_list.append(
120
+ {"type": "user_message", "content": step["output"]}
121
+ )
122
+ elif step["type"] == "assistant_message":
123
+ if step["name"] == LLM:
124
+ conversation_list.append(
125
+ {"type": "ai_message", "content": step["output"]}
126
+ )
127
+ else:
128
+ raise ValueError("Invalid message type")
129
+ count += 1
130
+ if count >= 2 * k: # 2 * k to account for both user and assistant messages
131
+ break
132
+ conversation_list = conversation_list[::-1]
133
+ return conversation_list
134
+
135
+
136
+ def get_history_setup_llm(memory_list):
137
+ conversation_list = []
138
+ for message in memory_list:
139
+ message_dict = message.to_dict() if hasattr(message, "to_dict") else message
140
+
141
+ # Check if the type attribute is present as a key or attribute
142
+ message_type = (
143
+ message_dict.get("type", None)
144
+ if isinstance(message_dict, dict)
145
+ else getattr(message, "type", None)
146
  )
147
+
148
+ # Check if content is present as a key or attribute
149
+ message_content = (
150
+ message_dict.get("content", None)
151
+ if isinstance(message_dict, dict)
152
+ else getattr(message, "content", None)
 
 
 
 
153
  )
154
+
155
+ if message_type in ["ai", "ai_message"]:
156
+ conversation_list.append({"type": "ai_message", "content": message_content})
157
+ elif message_type in ["human", "user_message"]:
158
+ conversation_list.append(
159
+ {"type": "user_message", "content": message_content}
160
+ )
161
+ else:
162
+ raise ValueError("Invalid message type")
163
+
164
+ return conversation_list
165
+
166
+
167
+ def get_last_config(steps):
168
+ # TODO: Implement this function
169
+ return None
code/modules/chat/langchain/langchain_rag.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import ChatPromptTemplate
2
+
3
+ from modules.chat.langchain.utils import *
4
+ from langchain.memory import ChatMessageHistory
5
+ from modules.chat.base import BaseRAG
6
+ from langchain_core.prompts import PromptTemplate
7
+ from langchain.memory import (
8
+ ConversationBufferWindowMemory,
9
+ ConversationSummaryBufferMemory,
10
+ )
11
+
12
+ import chainlit as cl
13
+ from langchain_community.chat_models import ChatOpenAI
14
+
15
+
16
+ class Langchain_RAG_V1(BaseRAG):
17
+
18
+ def __init__(
19
+ self,
20
+ llm,
21
+ memory,
22
+ retriever,
23
+ qa_prompt: str,
24
+ rephrase_prompt: str,
25
+ config: dict,
26
+ callbacks=None,
27
+ ):
28
+ """
29
+ Initialize the Langchain_RAG class.
30
+
31
+ Args:
32
+ llm (LanguageModelLike): The language model instance.
33
+ memory (BaseChatMessageHistory): The chat message history instance.
34
+ retriever (BaseRetriever): The retriever instance.
35
+ qa_prompt (str): The QA prompt string.
36
+ rephrase_prompt (str): The rephrase prompt string.
37
+ """
38
+ self.llm = llm
39
+ self.config = config
40
+ # self.memory = self.add_history_from_list(memory)
41
+ self.memory = ConversationBufferWindowMemory(
42
+ k=self.config["llm_params"]["memory_window"],
43
+ memory_key="chat_history",
44
+ return_messages=True,
45
+ output_key="answer",
46
+ max_token_limit=128,
47
+ )
48
+ self.retriever = retriever
49
+ self.qa_prompt = qa_prompt
50
+ self.rephrase_prompt = rephrase_prompt
51
+ self.store = {}
52
+
53
+ self.qa_prompt = PromptTemplate(
54
+ template=self.qa_prompt,
55
+ input_variables=["context", "chat_history", "input"],
56
+ )
57
+
58
+ self.rag_chain = CustomConversationalRetrievalChain.from_llm(
59
+ llm=llm,
60
+ chain_type="stuff",
61
+ retriever=retriever,
62
+ return_source_documents=True,
63
+ memory=self.memory,
64
+ combine_docs_chain_kwargs={"prompt": self.qa_prompt},
65
+ response_if_no_docs_found="No context found",
66
+ )
67
+
68
+ def add_history_from_list(self, history_list):
69
+ """
70
+ TODO: Add messages from a list to the chat history.
71
+ """
72
+ history = []
73
+
74
+ return history
75
+
76
+ async def invoke(self, user_query, config):
77
+ """
78
+ Invoke the chain.
79
+
80
+ Args:
81
+ kwargs: The input variables.
82
+
83
+ Returns:
84
+ dict: The output variables.
85
+ """
86
+ res = await self.rag_chain.acall(user_query["input"])
87
+ return res
88
+
89
+
90
+ class QuestionGenerator:
91
+ """
92
+ Generate a question from the LLMs response and users input and past conversations.
93
+ """
94
+
95
+ def __init__(self):
96
+ pass
97
+
98
+ def generate_questions(self, query, response, chat_history, context):
99
+ questions = return_questions(query, response, chat_history, context)
100
+ return questions
101
+
102
+
103
+ class Langchain_RAG_V2(BaseRAG):
104
+ def __init__(
105
+ self,
106
+ llm,
107
+ memory,
108
+ retriever,
109
+ qa_prompt: str,
110
+ rephrase_prompt: str,
111
+ config: dict,
112
+ callbacks=None,
113
+ ):
114
+ """
115
+ Initialize the Langchain_RAG class.
116
+
117
+ Args:
118
+ llm (LanguageModelLike): The language model instance.
119
+ memory (BaseChatMessageHistory): The chat message history instance.
120
+ retriever (BaseRetriever): The retriever instance.
121
+ qa_prompt (str): The QA prompt string.
122
+ rephrase_prompt (str): The rephrase prompt string.
123
+ """
124
+ self.llm = llm
125
+ self.memory = self.add_history_from_list(memory)
126
+ self.retriever = retriever
127
+ self.qa_prompt = qa_prompt
128
+ self.rephrase_prompt = rephrase_prompt
129
+ self.store = {}
130
+
131
+ # Contextualize question prompt
132
+ contextualize_q_system_prompt = rephrase_prompt or (
133
+ "Given a chat history and the latest user question "
134
+ "which might reference context in the chat history, "
135
+ "formulate a standalone question which can be understood "
136
+ "without the chat history. Do NOT answer the question, just "
137
+ "reformulate it if needed and otherwise return it as is."
138
+ )
139
+ self.contextualize_q_prompt = ChatPromptTemplate.from_template(
140
+ contextualize_q_system_prompt
141
+ )
142
+
143
+ # History-aware retriever
144
+ self.history_aware_retriever = create_history_aware_retriever(
145
+ self.llm, self.retriever, self.contextualize_q_prompt
146
+ )
147
+
148
+ # Answer question prompt
149
+ qa_system_prompt = qa_prompt or (
150
+ "You are an assistant for question-answering tasks. Use "
151
+ "the following pieces of retrieved context to answer the "
152
+ "question. If you don't know the answer, just say that you "
153
+ "don't know. Use three sentences maximum and keep the answer "
154
+ "concise."
155
+ "\n\n"
156
+ "{context}"
157
+ )
158
+ self.qa_prompt_template = ChatPromptTemplate.from_template(qa_system_prompt)
159
+
160
+ # Question-answer chain
161
+ self.question_answer_chain = create_stuff_documents_chain(
162
+ self.llm, self.qa_prompt_template
163
+ )
164
+
165
+ # Final retrieval chain
166
+ self.rag_chain = create_retrieval_chain(
167
+ self.history_aware_retriever, self.question_answer_chain
168
+ )
169
+
170
+ self.rag_chain = CustomRunnableWithHistory(
171
+ self.rag_chain,
172
+ get_session_history=self.get_session_history,
173
+ input_messages_key="input",
174
+ history_messages_key="chat_history",
175
+ output_messages_key="answer",
176
+ history_factory_config=[
177
+ ConfigurableFieldSpec(
178
+ id="user_id",
179
+ annotation=str,
180
+ name="User ID",
181
+ description="Unique identifier for the user.",
182
+ default="",
183
+ is_shared=True,
184
+ ),
185
+ ConfigurableFieldSpec(
186
+ id="conversation_id",
187
+ annotation=str,
188
+ name="Conversation ID",
189
+ description="Unique identifier for the conversation.",
190
+ default="",
191
+ is_shared=True,
192
+ ),
193
+ ConfigurableFieldSpec(
194
+ id="memory_window",
195
+ annotation=int,
196
+ name="Number of Conversations",
197
+ description="Number of conversations to consider for context.",
198
+ default=1,
199
+ is_shared=True,
200
+ ),
201
+ ],
202
+ )
203
+
204
+ if callbacks is not None:
205
+ self.rag_chain = self.rag_chain.with_config(callbacks=callbacks)
206
+
207
+ def get_session_history(
208
+ self, user_id: str, conversation_id: str, memory_window: int
209
+ ) -> BaseChatMessageHistory:
210
+ """
211
+ Get the session history for a user and conversation.
212
+
213
+ Args:
214
+ user_id (str): The user identifier.
215
+ conversation_id (str): The conversation identifier.
216
+ memory_window (int): The number of conversations to consider for context.
217
+
218
+ Returns:
219
+ BaseChatMessageHistory: The chat message history.
220
+ """
221
+ if (user_id, conversation_id) not in self.store:
222
+ self.store[(user_id, conversation_id)] = InMemoryHistory()
223
+ self.store[(user_id, conversation_id)].add_messages(
224
+ self.memory.messages
225
+ ) # add previous messages to the store. Note: the store is in-memory.
226
+ return self.store[(user_id, conversation_id)]
227
+
228
+ async def invoke(self, user_query, config, **kwargs):
229
+ """
230
+ Invoke the chain.
231
+
232
+ Args:
233
+ kwargs: The input variables.
234
+
235
+ Returns:
236
+ dict: The output variables.
237
+ """
238
+ res = await self.rag_chain.ainvoke(user_query, config, **kwargs)
239
+ res["rephrase_prompt"] = self.rephrase_prompt
240
+ res["qa_prompt"] = self.qa_prompt
241
+ return res
242
+
243
+ def stream(self, user_query, config):
244
+ res = self.rag_chain.stream(user_query, config)
245
+ return res
246
+
247
+ def add_history_from_list(self, conversation_list):
248
+ """
249
+ Add messages from a list to the chat history.
250
+
251
+ Args:
252
+ messages (list): The list of messages to add.
253
+ """
254
+ history = ChatMessageHistory()
255
+
256
+ for idx, message in enumerate(conversation_list):
257
+ message_type = (
258
+ message.get("type", None)
259
+ if isinstance(message, dict)
260
+ else getattr(message, "type", None)
261
+ )
262
+
263
+ message_content = (
264
+ message.get("content", None)
265
+ if isinstance(message, dict)
266
+ else getattr(message, "content", None)
267
+ )
268
+
269
+ if message_type in ["human", "user_message"]:
270
+ history.add_user_message(message_content)
271
+ elif message_type in ["ai", "ai_message"]:
272
+ history.add_ai_message(message_content)
273
+
274
+ return history
code/modules/chat/langchain/utils.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Union, Tuple, Optional
2
+ from langchain_core.messages import (
3
+ BaseMessage,
4
+ AIMessage,
5
+ FunctionMessage,
6
+ HumanMessage,
7
+ )
8
+
9
+ from langchain_core.prompts.base import BasePromptTemplate, format_document
10
+ from langchain_core.prompts.chat import MessagesPlaceholder
11
+ from langchain_core.output_parsers import StrOutputParser
12
+ from langchain_core.output_parsers.base import BaseOutputParser
13
+ from langchain_core.retrievers import BaseRetriever, RetrieverOutput
14
+ from langchain_core.language_models import LanguageModelLike
15
+ from langchain_core.runnables import Runnable, RunnableBranch, RunnablePassthrough
16
+ from langchain_core.runnables.history import RunnableWithMessageHistory
17
+ from langchain_core.runnables.utils import ConfigurableFieldSpec
18
+ from langchain_core.chat_history import BaseChatMessageHistory
19
+ from langchain_core.pydantic_v1 import BaseModel, Field
20
+ from langchain.chains.combine_documents.base import (
21
+ DEFAULT_DOCUMENT_PROMPT,
22
+ DEFAULT_DOCUMENT_SEPARATOR,
23
+ DOCUMENTS_KEY,
24
+ BaseCombineDocumentsChain,
25
+ _validate_prompt,
26
+ )
27
+ from langchain.chains.llm import LLMChain
28
+ from langchain_core.callbacks import Callbacks
29
+ from langchain_core.documents import Document
30
+
31
+
32
+ CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
33
+
34
+ from langchain_core.runnables.config import RunnableConfig
35
+ from langchain_core.messages import BaseMessage
36
+
37
+
38
+ from langchain_core.output_parsers import StrOutputParser
39
+ from langchain_core.prompts import ChatPromptTemplate
40
+ from langchain_community.chat_models import ChatOpenAI
41
+
42
+ from langchain.chains import RetrievalQA, ConversationalRetrievalChain
43
+ from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
44
+
45
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
46
+ from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
47
+ import inspect
48
+ from langchain.chains.conversational_retrieval.base import _get_chat_history
49
+ from langchain_core.messages import BaseMessage
50
+
51
+
52
+ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
53
+
54
+ def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
55
+ _ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
56
+ buffer = ""
57
+ for dialogue_turn in chat_history:
58
+ if isinstance(dialogue_turn, BaseMessage):
59
+ role_prefix = _ROLE_MAP.get(
60
+ dialogue_turn.type, f"{dialogue_turn.type}: "
61
+ )
62
+ buffer += f"\n{role_prefix}{dialogue_turn.content}"
63
+ elif isinstance(dialogue_turn, tuple):
64
+ human = "Student: " + dialogue_turn[0]
65
+ ai = "AI Tutor: " + dialogue_turn[1]
66
+ buffer += "\n" + "\n".join([human, ai])
67
+ else:
68
+ raise ValueError(
69
+ f"Unsupported chat history format: {type(dialogue_turn)}."
70
+ f" Full chat history: {chat_history} "
71
+ )
72
+ return buffer
73
+
74
+ async def _acall(
75
+ self,
76
+ inputs: Dict[str, Any],
77
+ run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
78
+ ) -> Dict[str, Any]:
79
+ _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
80
+ question = inputs["question"]
81
+ get_chat_history = self._get_chat_history
82
+ chat_history_str = get_chat_history(inputs["chat_history"])
83
+ if chat_history_str:
84
+ # callbacks = _run_manager.get_child()
85
+ # new_question = await self.question_generator.arun(
86
+ # question=question, chat_history=chat_history_str, callbacks=callbacks
87
+ # )
88
+ system = (
89
+ "You are someone that rephrases statements. Rephrase the student's question to add context from their chat history if relevant, ensuring it remains from the student's point of view. "
90
+ "Incorporate relevant details from the chat history to make the question clearer and more specific."
91
+ "Do not change the meaning of the original statement, and maintain the student's tone and perspective. "
92
+ "If the question is conversational and doesn't require context, do not rephrase it. "
93
+ "Example: If the student previously asked about backpropagation in the context of deep learning and now asks 'what is it', rephrase to 'What is backprogatation.'. "
94
+ "Example: Do not rephrase if the user is asking something specific like 'cool, suggest a project with transformers to use as my final project'"
95
+ "Chat history: \n{chat_history_str}\n"
96
+ "Rephrase the following question only if necessary: '{input}'"
97
+ )
98
+
99
+ prompt = ChatPromptTemplate.from_messages(
100
+ [
101
+ ("system", system),
102
+ ("human", "{input}, {chat_history_str}"),
103
+ ]
104
+ )
105
+ llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
106
+ step_back = prompt | llm | StrOutputParser()
107
+ new_question = step_back.invoke(
108
+ {"input": question, "chat_history_str": chat_history_str}
109
+ )
110
+ else:
111
+ new_question = question
112
+ accepts_run_manager = (
113
+ "run_manager" in inspect.signature(self._aget_docs).parameters
114
+ )
115
+ if accepts_run_manager:
116
+ docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager)
117
+ else:
118
+ docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
119
+
120
+ output: Dict[str, Any] = {}
121
+ output["original_question"] = question
122
+ if self.response_if_no_docs_found is not None and len(docs) == 0:
123
+ output[self.output_key] = self.response_if_no_docs_found
124
+ else:
125
+ new_inputs = inputs.copy()
126
+ if self.rephrase_question:
127
+ new_inputs["question"] = new_question
128
+ new_inputs["chat_history"] = chat_history_str
129
+
130
+ # Prepare the final prompt with metadata
131
+ context = "\n\n".join(
132
+ [
133
+ f"Context {idx+1}: \n(Document content: {doc.page_content}\nMetadata: (source_file: {doc.metadata['source'] if 'source' in doc.metadata else 'unknown'}))"
134
+ for idx, doc in enumerate(docs)
135
+ ]
136
+ )
137
+ final_prompt = (
138
+ "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance."
139
+ "If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
140
+ "Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevent."
141
+ "Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
142
+ f"Chat History:\n{chat_history_str}\n\n"
143
+ f"Context:\n{context}\n\n"
144
+ "Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
145
+ f"Student: {input}\n"
146
+ "AI Tutor:"
147
+ )
148
+
149
+ new_inputs["input"] = final_prompt
150
+ # new_inputs["question"] = final_prompt
151
+ # output["final_prompt"] = final_prompt
152
+
153
+ answer = await self.combine_docs_chain.arun(
154
+ input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
155
+ )
156
+ output[self.output_key] = answer
157
+
158
+ if self.return_source_documents:
159
+ output["source_documents"] = docs
160
+ output["rephrased_question"] = new_question
161
+ output["context"] = output["source_documents"]
162
+ return output
163
+
164
+
165
+ class CustomRunnableWithHistory(RunnableWithMessageHistory):
166
+
167
+ def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
168
+ _ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
169
+ buffer = ""
170
+ for dialogue_turn in chat_history:
171
+ if isinstance(dialogue_turn, BaseMessage):
172
+ role_prefix = _ROLE_MAP.get(
173
+ dialogue_turn.type, f"{dialogue_turn.type}: "
174
+ )
175
+ buffer += f"\n{role_prefix}{dialogue_turn.content}"
176
+ elif isinstance(dialogue_turn, tuple):
177
+ human = "Student: " + dialogue_turn[0]
178
+ ai = "AI Tutor: " + dialogue_turn[1]
179
+ buffer += "\n" + "\n".join([human, ai])
180
+ else:
181
+ raise ValueError(
182
+ f"Unsupported chat history format: {type(dialogue_turn)}."
183
+ f" Full chat history: {chat_history} "
184
+ )
185
+ return buffer
186
+
187
+ async def _aenter_history(
188
+ self, input: Any, config: RunnableConfig
189
+ ) -> List[BaseMessage]:
190
+ """
191
+ Get the last k conversations from the message history.
192
+
193
+ Args:
194
+ input (Any): The input data.
195
+ config (RunnableConfig): The runnable configuration.
196
+
197
+ Returns:
198
+ List[BaseMessage]: The last k conversations.
199
+ """
200
+ hist: BaseChatMessageHistory = config["configurable"]["message_history"]
201
+ messages = (await hist.aget_messages()).copy()
202
+ if not self.history_messages_key:
203
+ # return all messages
204
+ input_val = (
205
+ input if not self.input_messages_key else input[self.input_messages_key]
206
+ )
207
+ messages += self._get_input_messages(input_val)
208
+
209
+ # return last k conversations
210
+ if config["configurable"]["memory_window"] == 0: # if k is 0, return empty list
211
+ messages = []
212
+ else:
213
+ messages = messages[-2 * config["configurable"]["memory_window"] :]
214
+
215
+ messages = self._get_chat_history(messages)
216
+
217
+ return messages
218
+
219
+
220
+ class InMemoryHistory(BaseChatMessageHistory, BaseModel):
221
+ """In-memory implementation of chat message history."""
222
+
223
+ messages: List[BaseMessage] = Field(default_factory=list)
224
+
225
+ def add_messages(self, messages: List[BaseMessage]) -> None:
226
+ """Add a list of messages to the store."""
227
+ self.messages.extend(messages)
228
+
229
+ def clear(self) -> None:
230
+ """Clear the message history."""
231
+ self.messages = []
232
+
233
+ def __len__(self) -> int:
234
+ """Return the number of messages."""
235
+ return len(self.messages)
236
+
237
+
238
+ def create_history_aware_retriever(
239
+ llm: LanguageModelLike,
240
+ retriever: BaseRetriever,
241
+ prompt: BasePromptTemplate,
242
+ ) -> Runnable[Dict[str, Any], RetrieverOutput]:
243
+ """Create a chain that takes conversation history and returns documents."""
244
+ if "input" not in prompt.input_variables:
245
+ raise ValueError(
246
+ "Expected `input` to be a prompt variable, "
247
+ f"but got {prompt.input_variables}"
248
+ )
249
+
250
+ retrieve_documents = RunnableBranch(
251
+ (
252
+ lambda x: not x["chat_history"],
253
+ (lambda x: x["input"]) | retriever,
254
+ ),
255
+ prompt | llm | StrOutputParser() | retriever,
256
+ ).with_config(run_name="chat_retriever_chain")
257
+
258
+ return retrieve_documents
259
+
260
+
261
+ def create_stuff_documents_chain(
262
+ llm: LanguageModelLike,
263
+ prompt: BasePromptTemplate,
264
+ output_parser: Optional[BaseOutputParser] = None,
265
+ document_prompt: Optional[BasePromptTemplate] = None,
266
+ document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
267
+ ) -> Runnable[Dict[str, Any], Any]:
268
+ """Create a chain for passing a list of Documents to a model."""
269
+ _validate_prompt(prompt)
270
+ _document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT
271
+ _output_parser = output_parser or StrOutputParser()
272
+
273
+ def format_docs(inputs: dict) -> str:
274
+ return document_separator.join(
275
+ format_document(doc, _document_prompt) for doc in inputs[DOCUMENTS_KEY]
276
+ )
277
+
278
+ return (
279
+ RunnablePassthrough.assign(**{DOCUMENTS_KEY: format_docs}).with_config(
280
+ run_name="format_inputs"
281
+ )
282
+ | prompt
283
+ | llm
284
+ | _output_parser
285
+ ).with_config(run_name="stuff_documents_chain")
286
+
287
+
288
+ def create_retrieval_chain(
289
+ retriever: Union[BaseRetriever, Runnable[dict, RetrieverOutput]],
290
+ combine_docs_chain: Runnable[Dict[str, Any], str],
291
+ ) -> Runnable:
292
+ """Create retrieval chain that retrieves documents and then passes them on."""
293
+ if not isinstance(retriever, BaseRetriever):
294
+ retrieval_docs = retriever
295
+ else:
296
+ retrieval_docs = (lambda x: x["input"]) | retriever
297
+
298
+ retrieval_chain = (
299
+ RunnablePassthrough.assign(
300
+ context=retrieval_docs.with_config(run_name="retrieve_documents"),
301
+ ).assign(answer=combine_docs_chain)
302
+ ).with_config(run_name="retrieval_chain")
303
+
304
+ return retrieval_chain
305
+
306
+
307
+ def return_questions(query, response, chat_history_str, context):
308
+
309
+ system = (
310
+ "You are someone that suggests a question based on the student's input and chat history. "
311
+ "Generate a question that is relevant to the student's input and chat history. "
312
+ "Incorporate relevant details from the chat history to make the question clearer and more specific. "
313
+ "Chat history: \n{chat_history_str}\n"
314
+ "Use the context to generate a question that is relevant to the student's input and chat history: Context: {context}"
315
+ "Generate 3 short and concise questions from the students voice based on the following input and response: "
316
+ "The 3 short and concise questions should be sperated by dots. Example: 'What is the capital of France?...What is the population of France?...What is the currency of France?'"
317
+ "User Query: {query}"
318
+ "AI Response: {response}"
319
+ "The 3 short and concise questions seperated by dots (...) are:"
320
+ )
321
+
322
+ prompt = ChatPromptTemplate.from_messages(
323
+ [
324
+ ("system", system),
325
+ ("human", "{chat_history_str}, {context}, {query}, {response}"),
326
+ ]
327
+ )
328
+ llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
329
+ question_generator = prompt | llm | StrOutputParser()
330
+ new_questions = question_generator.invoke(
331
+ {
332
+ "chat_history_str": chat_history_str,
333
+ "context": context,
334
+ "query": query,
335
+ "response": response,
336
+ }
337
+ )
338
+
339
+ list_of_questions = new_questions.split("...")
340
+ return list_of_questions
code/modules/chat/llm_tutor.py CHANGED
@@ -1,211 +1,163 @@
1
- from langchain.chains import RetrievalQA, ConversationalRetrievalChain
2
- from langchain.memory import (
3
- ConversationBufferWindowMemory,
4
- ConversationSummaryBufferMemory,
5
- )
6
- from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
7
- import os
8
- from modules.config.constants import *
9
  from modules.chat.helpers import get_prompt
10
  from modules.chat.chat_model_loader import ChatModelLoader
11
  from modules.vectorstore.store_manager import VectorStoreManager
12
-
13
  from modules.retriever.retriever import Retriever
14
-
15
- from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
16
- from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
17
- import inspect
18
- from langchain.chains.conversational_retrieval.base import _get_chat_history
19
- from langchain_core.messages import BaseMessage
20
-
21
- CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
22
-
23
- from langchain_core.output_parsers import StrOutputParser
24
- from langchain_core.prompts import ChatPromptTemplate
25
- from langchain_community.chat_models import ChatOpenAI
26
-
27
-
28
- class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
29
-
30
- def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
31
- _ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
32
- buffer = ""
33
- for dialogue_turn in chat_history:
34
- if isinstance(dialogue_turn, BaseMessage):
35
- role_prefix = _ROLE_MAP.get(
36
- dialogue_turn.type, f"{dialogue_turn.type}: "
37
- )
38
- buffer += f"\n{role_prefix}{dialogue_turn.content}"
39
- elif isinstance(dialogue_turn, tuple):
40
- human = "Student: " + dialogue_turn[0]
41
- ai = "AI Tutor: " + dialogue_turn[1]
42
- buffer += "\n" + "\n".join([human, ai])
43
- else:
44
- raise ValueError(
45
- f"Unsupported chat history format: {type(dialogue_turn)}."
46
- f" Full chat history: {chat_history} "
47
- )
48
- return buffer
49
-
50
- async def _acall(
51
- self,
52
- inputs: Dict[str, Any],
53
- run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
54
- ) -> Dict[str, Any]:
55
- _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
56
- question = inputs["question"]
57
- get_chat_history = self._get_chat_history
58
- chat_history_str = get_chat_history(inputs["chat_history"])
59
- if chat_history_str:
60
- # callbacks = _run_manager.get_child()
61
- # new_question = await self.question_generator.arun(
62
- # question=question, chat_history=chat_history_str, callbacks=callbacks
63
- # )
64
- system = (
65
- "You are someone that rephrases statements. Rephrase the student's question to add context from their chat history if relevant, ensuring it remains from the student's point of view. "
66
- "Incorporate relevant details from the chat history to make the question clearer and more specific."
67
- "Do not change the meaning of the original statement, and maintain the student's tone and perspective. "
68
- "If the question is conversational and doesn't require context, do not rephrase it. "
69
- "Example: If the student previously asked about backpropagation in the context of deep learning and now asks 'what is it', rephrase to 'What is backprogatation.'. "
70
- "Example: Do not rephrase if the user is asking something specific like 'cool, suggest a project with transformers to use as my final project'"
71
- "Chat history: \n{chat_history_str}\n"
72
- "Rephrase the following question only if necessary: '{question}'"
73
- )
74
-
75
- prompt = ChatPromptTemplate.from_messages(
76
- [
77
- ("system", system),
78
- ("human", "{question}, {chat_history_str}"),
79
- ]
80
- )
81
- llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
82
- step_back = prompt | llm | StrOutputParser()
83
- new_question = step_back.invoke(
84
- {"question": question, "chat_history_str": chat_history_str}
85
- )
86
- else:
87
- new_question = question
88
- accepts_run_manager = (
89
- "run_manager" in inspect.signature(self._aget_docs).parameters
90
- )
91
- if accepts_run_manager:
92
- docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager)
93
- else:
94
- docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
95
-
96
- output: Dict[str, Any] = {}
97
- output["original_question"] = question
98
- if self.response_if_no_docs_found is not None and len(docs) == 0:
99
- output[self.output_key] = self.response_if_no_docs_found
100
- else:
101
- new_inputs = inputs.copy()
102
- if self.rephrase_question:
103
- new_inputs["question"] = new_question
104
- new_inputs["chat_history"] = chat_history_str
105
-
106
- # Prepare the final prompt with metadata
107
- context = "\n\n".join(
108
- [
109
- f"Context {idx+1}: \n(Document content: {doc.page_content}\nMetadata: (source_file: {doc.metadata['source'] if 'source' in doc.metadata else 'unknown'}))"
110
- for idx, doc in enumerate(docs)
111
- ]
112
- )
113
- final_prompt = (
114
- "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance."
115
- "If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
116
- "Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevent."
117
- "Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
118
- f"Chat History:\n{chat_history_str}\n\n"
119
- f"Context:\n{context}\n\n"
120
- "Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
121
- f"Student: {question}\n"
122
- "AI Tutor:"
123
- )
124
-
125
- # new_inputs["input"] = final_prompt
126
- new_inputs["question"] = final_prompt
127
- # output["final_prompt"] = final_prompt
128
-
129
- answer = await self.combine_docs_chain.arun(
130
- input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
131
- )
132
- output[self.output_key] = answer
133
-
134
- if self.return_source_documents:
135
- output["source_documents"] = docs
136
- output["rephrased_question"] = new_question
137
- return output
138
 
139
 
140
  class LLMTutor:
141
- def __init__(self, config, logger=None):
 
 
 
 
 
 
 
 
142
  self.config = config
143
  self.llm = self.load_llm()
 
144
  self.logger = logger
145
- self.vector_db = VectorStoreManager(config, logger=self.logger)
 
 
 
 
146
  if self.config["vectorstore"]["embedd_files"]:
147
  self.vector_db.create_database()
148
  self.vector_db.save_database()
149
 
150
- def set_custom_prompt(self):
151
  """
152
- Prompt template for QA retrieval for each vectorstore
 
 
 
153
  """
154
- prompt = get_prompt(self.config)
155
- # prompt = QA_PROMPT
 
 
156
 
157
- return prompt
 
 
 
 
 
 
158
 
159
- # Retrieval QA Chain
160
- def retrieval_qa_chain(self, llm, prompt, db):
 
 
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  retriever = Retriever(self.config)._return_retriever(db)
163
 
164
- if self.config["llm_params"]["use_history"]:
165
- memory = ConversationBufferWindowMemory(
166
- k=self.config["llm_params"]["memory_window"],
167
- memory_key="chat_history",
168
- return_messages=True,
169
- output_key="answer",
170
- max_token_limit=128,
171
- )
172
- qa_chain = CustomConversationalRetrievalChain.from_llm(
173
  llm=llm,
174
- chain_type="stuff",
175
- retriever=retriever,
176
- return_source_documents=True,
177
  memory=memory,
178
- combine_docs_chain_kwargs={"prompt": prompt},
179
- response_if_no_docs_found="No context found",
 
 
 
180
  )
 
 
181
  else:
182
- qa_chain = RetrievalQA.from_chain_type(
183
- llm=llm,
184
- chain_type="stuff",
185
- retriever=retriever,
186
- return_source_documents=True,
187
- chain_type_kwargs={"prompt": prompt},
188
  )
189
- return qa_chain
190
 
191
- # Loading the model
192
  def load_llm(self):
 
 
 
 
 
 
193
  chat_model_loader = ChatModelLoader(self.config)
194
  llm = chat_model_loader.load_chat_model()
195
  return llm
196
 
197
- # QA Model Function
198
- def qa_bot(self):
199
- db = self.vector_db.load_database()
200
- qa_prompt = self.set_custom_prompt()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  qa = self.retrieval_qa_chain(
202
- self.llm, qa_prompt, db
203
- ) # TODO: PROMPT is overwritten in CustomConversationalRetrievalChain
 
 
 
 
 
204
 
205
  return qa
206
-
207
- # output function
208
- def final_result(query):
209
- qa_result = qa_bot()
210
- response = qa_result({"query": query})
211
- return response
 
 
 
 
 
 
 
 
 
1
  from modules.chat.helpers import get_prompt
2
  from modules.chat.chat_model_loader import ChatModelLoader
3
  from modules.vectorstore.store_manager import VectorStoreManager
 
4
  from modules.retriever.retriever import Retriever
5
+ from modules.chat.langchain.langchain_rag import (
6
+ Langchain_RAG_V1,
7
+ Langchain_RAG_V2,
8
+ QuestionGenerator,
9
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  class LLMTutor:
13
+ def __init__(self, config, user, logger=None):
14
+ """
15
+ Initialize the LLMTutor class.
16
+
17
+ Args:
18
+ config (dict): Configuration dictionary.
19
+ user (str): User identifier.
20
+ logger (Logger, optional): Logger instance. Defaults to None.
21
+ """
22
  self.config = config
23
  self.llm = self.load_llm()
24
+ self.user = user
25
  self.logger = logger
26
+ self.vector_db = VectorStoreManager(config, logger=self.logger).load_database()
27
+ self.qa_prompt = get_prompt(config, "qa") # Initialize qa_prompt
28
+ self.rephrase_prompt = get_prompt(
29
+ config, "rephrase"
30
+ ) # Initialize rephrase_prompt
31
  if self.config["vectorstore"]["embedd_files"]:
32
  self.vector_db.create_database()
33
  self.vector_db.save_database()
34
 
35
+ def update_llm(self, old_config, new_config):
36
  """
37
+ Update the LLM and VectorStoreManager based on new configuration.
38
+
39
+ Args:
40
+ new_config (dict): New configuration dictionary.
41
  """
42
+ changes = self.get_config_changes(old_config, new_config)
43
+
44
+ if "llm_params.llm_loader" in changes:
45
+ self.llm = self.load_llm() # Reinitialize LLM if chat_model changes
46
 
47
+ if "vectorstore.db_option" in changes:
48
+ self.vector_db = VectorStoreManager(
49
+ self.config, logger=self.logger
50
+ ).load_database() # Reinitialize VectorStoreManager if vectorstore changes
51
+ if self.config["vectorstore"]["embedd_files"]:
52
+ self.vector_db.create_database()
53
+ self.vector_db.save_database()
54
 
55
+ if "llm_params.llm_style" in changes:
56
+ self.qa_prompt = get_prompt(
57
+ self.config, "qa"
58
+ ) # Update qa_prompt if ELI5 changes
59
 
60
+ def get_config_changes(self, old_config, new_config):
61
+ """
62
+ Get the changes between the old and new configuration.
63
+
64
+ Args:
65
+ old_config (dict): Old configuration dictionary.
66
+ new_config (dict): New configuration dictionary.
67
+
68
+ Returns:
69
+ dict: Dictionary containing the changes.
70
+ """
71
+ changes = {}
72
+
73
+ def compare_dicts(old, new, parent_key=""):
74
+ for key in new:
75
+ full_key = f"{parent_key}.{key}" if parent_key else key
76
+ if isinstance(new[key], dict) and isinstance(old.get(key), dict):
77
+ compare_dicts(old.get(key, {}), new[key], full_key)
78
+ elif old.get(key) != new[key]:
79
+ changes[full_key] = (old.get(key), new[key])
80
+ # Include keys that are in old but not in new
81
+ for key in old:
82
+ if key not in new:
83
+ full_key = f"{parent_key}.{key}" if parent_key else key
84
+ changes[full_key] = (old[key], None)
85
+
86
+ compare_dicts(old_config, new_config)
87
+ return changes
88
+
89
+ def retrieval_qa_chain(
90
+ self, llm, qa_prompt, rephrase_prompt, db, memory=None, callbacks=None
91
+ ):
92
+ """
93
+ Create a Retrieval QA Chain.
94
+
95
+ Args:
96
+ llm (LLM): The language model instance.
97
+ qa_prompt (str): The QA prompt string.
98
+ rephrase_prompt (str): The rephrase prompt string.
99
+ db (VectorStore): The vector store instance.
100
+ memory (Memory, optional): Memory instance. Defaults to None.
101
+
102
+ Returns:
103
+ Chain: The retrieval QA chain instance.
104
+ """
105
  retriever = Retriever(self.config)._return_retriever(db)
106
 
107
+ if self.config["llm_params"]["llm_arch"] == "langchain":
108
+ self.qa_chain = Langchain_RAG_V2(
 
 
 
 
 
 
 
109
  llm=llm,
 
 
 
110
  memory=memory,
111
+ retriever=retriever,
112
+ qa_prompt=qa_prompt,
113
+ rephrase_prompt=rephrase_prompt,
114
+ config=self.config,
115
+ callbacks=callbacks,
116
  )
117
+
118
+ self.question_generator = QuestionGenerator()
119
  else:
120
+ raise ValueError(
121
+ f"Invalid LLM Architecture: {self.config['llm_params']['llm_arch']}"
 
 
 
 
122
  )
123
+ return self.qa_chain
124
 
 
125
  def load_llm(self):
126
+ """
127
+ Load the language model.
128
+
129
+ Returns:
130
+ LLM: The loaded language model instance.
131
+ """
132
  chat_model_loader = ChatModelLoader(self.config)
133
  llm = chat_model_loader.load_chat_model()
134
  return llm
135
 
136
+ def qa_bot(self, memory=None, callbacks=None):
137
+ """
138
+ Create a QA bot instance.
139
+
140
+ Args:
141
+ memory (Memory, optional): Memory instance. Defaults to None.
142
+ qa_prompt (str, optional): QA prompt string. Defaults to None.
143
+ rephrase_prompt (str, optional): Rephrase prompt string. Defaults to None.
144
+
145
+ Returns:
146
+ Chain: The QA bot chain instance.
147
+ """
148
+ # sanity check to see if there are any documents in the database
149
+ if len(self.vector_db) == 0:
150
+ raise ValueError(
151
+ "No documents in the database. Populate the database first."
152
+ )
153
+
154
  qa = self.retrieval_qa_chain(
155
+ self.llm,
156
+ self.qa_prompt,
157
+ self.rephrase_prompt,
158
+ self.vector_db,
159
+ memory,
160
+ callbacks=callbacks,
161
+ )
162
 
163
  return qa
 
 
 
 
 
 
code/modules/chat_processor/base.py DELETED
@@ -1,12 +0,0 @@
1
- # Template for chat processor classes
2
-
3
-
4
- class ChatProcessorBase:
5
- def __init__(self, config):
6
- self.config = config
7
-
8
- def process(self, message):
9
- """
10
- Processes and Logs the message
11
- """
12
- raise NotImplementedError("process method not implemented")
 
 
 
 
 
 
 
 
 
 
 
 
 
code/modules/chat_processor/chat_processor.py DELETED
@@ -1,30 +0,0 @@
1
- from modules.chat_processor.literal_ai import LiteralaiChatProcessor
2
-
3
-
4
- class ChatProcessor:
5
- def __init__(self, config, tags=None):
6
- self.chat_processor_type = config["chat_logging"]["platform"]
7
- self.logging = config["chat_logging"]["log_chat"]
8
- self.tags = tags
9
- if self.logging:
10
- self._init_processor()
11
-
12
- def _init_processor(self):
13
- if self.chat_processor_type == "literalai":
14
- self.processor = LiteralaiChatProcessor(self.tags)
15
- else:
16
- raise ValueError(
17
- f"Chat processor type {self.chat_processor_type} not supported"
18
- )
19
-
20
- def _process(self, user_message, assistant_message, source_dict):
21
- if self.logging:
22
- return self.processor.process(user_message, assistant_message, source_dict)
23
- else:
24
- pass
25
-
26
- async def rag(self, user_query: str, chain, cb):
27
- if self.logging:
28
- return await self.processor.rag(user_query, chain, cb)
29
- else:
30
- return await chain.acall(user_query, callbacks=[cb])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/modules/chat_processor/literal_ai.py CHANGED
@@ -1,37 +1,44 @@
1
- from literalai import LiteralClient
2
- import os
3
- from .base import ChatProcessorBase
4
 
5
 
6
- class LiteralaiChatProcessor(ChatProcessorBase):
7
- def __init__(self, tags=None):
8
- self.literal_client = LiteralClient(api_key=os.getenv("LITERAL_API_KEY"))
9
- self.literal_client.reset_context()
10
- with self.literal_client.thread(name="TEST") as thread:
11
- self.thread_id = thread.id
12
- self.thread = thread
13
- if tags is not None and type(tags) == list:
14
- self.thread.tags = tags
15
- print(f"Thread ID: {self.thread}")
16
 
17
- def process(self, user_message, assistant_message, source_dict):
18
- with self.literal_client.thread(thread_id=self.thread_id) as thread:
19
- self.literal_client.message(
20
- content=user_message,
21
- type="user_message",
22
- name="User",
23
- )
24
- self.literal_client.message(
25
- content=assistant_message,
26
- type="assistant_message",
27
- name="AI_Tutor",
28
- )
29
 
30
- async def rag(self, user_query: str, chain, cb):
31
- with self.literal_client.step(
32
- type="retrieval", name="RAG", thread_id=self.thread_id
33
- ) as step:
34
- step.input = {"question": user_query}
35
- res = await chain.acall(user_query, callbacks=[cb])
36
- step.output = res
37
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from chainlit.data import ChainlitDataLayer, queue_until_user_message
 
 
2
 
3
 
4
+ # update custom methods here (Ref: https://github.com/Chainlit/chainlit/blob/4b533cd53173bcc24abe4341a7108f0070d60099/backend/chainlit/data/__init__.py)
5
+ class CustomLiteralDataLayer(ChainlitDataLayer):
6
+ def __init__(self, **kwargs):
7
+ super().__init__(**kwargs)
 
 
 
 
 
 
8
 
9
+ @queue_until_user_message()
10
+ async def create_step(self, step_dict: "StepDict"):
11
+ metadata = dict(
12
+ step_dict.get("metadata", {}),
13
+ **{
14
+ "waitForAnswer": step_dict.get("waitForAnswer"),
15
+ "language": step_dict.get("language"),
16
+ "showInput": step_dict.get("showInput"),
17
+ },
18
+ )
 
 
19
 
20
+ step: LiteralStepDict = {
21
+ "createdAt": step_dict.get("createdAt"),
22
+ "startTime": step_dict.get("start"),
23
+ "endTime": step_dict.get("end"),
24
+ "generation": step_dict.get("generation"),
25
+ "id": step_dict.get("id"),
26
+ "parentId": step_dict.get("parentId"),
27
+ "name": step_dict.get("name"),
28
+ "threadId": step_dict.get("threadId"),
29
+ "type": step_dict.get("type"),
30
+ "tags": step_dict.get("tags"),
31
+ "metadata": metadata,
32
+ }
33
+ if step_dict.get("input"):
34
+ step["input"] = {"content": step_dict.get("input")}
35
+ if step_dict.get("output"):
36
+ step["output"] = {"content": step_dict.get("output")}
37
+ if step_dict.get("isError"):
38
+ step["error"] = step_dict.get("output")
39
+
40
+ # print("\n\n\n")
41
+ # print("Step: ", step)
42
+ # print("\n\n\n")
43
+
44
+ await self.client.api.send_steps([step])
code/modules/config/config.yml CHANGED
@@ -3,18 +3,19 @@ log_chunk_dir: '../storage/logs/chunks' # str
3
  device: 'cpu' # str [cuda, cpu]
4
 
5
  vectorstore:
 
6
  embedd_files: False # bool
7
  data_path: '../storage/data' # str
8
  url_file_path: '../storage/data/urls.txt' # str
9
  expand_urls: True # bool
10
- db_option : 'FAISS' # str [FAISS, Chroma, RAGatouille, RAPTOR]
11
  db_path : '../vectorstores' # str
12
  model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
13
  search_top_k : 3 # int
14
  score_threshold : 0.2 # float
15
 
16
  faiss_params: # Not used as of now
17
- index_path: '../vectorstores/faiss.index' # str
18
  index_type: 'Flat' # str [Flat, HNSW, IVF]
19
  index_dimension: 384 # int
20
  index_nlist: 100 # int
@@ -24,27 +25,36 @@ vectorstore:
24
  index_name: "new_idx" # str
25
 
26
  llm_params:
 
27
  use_history: True # bool
 
28
  memory_window: 3 # int
29
- llm_loader: 'openai' # str [local_llm, openai]
 
30
  openai_params:
31
- model: 'gpt-3.5-turbo-1106' # str [gpt-3.5-turbo-1106, gpt-4]
32
  local_llm_params:
33
- model: 'tiny-llama'
34
- temperature: 0.7
 
 
 
 
35
 
36
  chat_logging:
37
- log_chat: False # bool
38
  platform: 'literalai'
 
39
 
40
  splitter_options:
41
  use_splitter: True # bool
42
  split_by_token : True # bool
43
  remove_leftover_delimiters: True # bool
44
  remove_chunks: False # bool
 
45
  chunk_size : 300 # int
46
  chunk_overlap : 30 # int
47
  chunk_separators : ["\n\n", "\n", " ", ""] # list of strings
48
  front_chunks_to_remove : null # int or None
49
  last_chunks_to_remove : null # int or None
50
- delimiters_to_remove : ['\t', '\n', ' ', ' '] # list of strings
 
3
  device: 'cpu' # str [cuda, cpu]
4
 
5
  vectorstore:
6
+ load_from_HF: True # bool
7
  embedd_files: False # bool
8
  data_path: '../storage/data' # str
9
  url_file_path: '../storage/data/urls.txt' # str
10
  expand_urls: True # bool
11
+ db_option : 'RAGatouille' # str [FAISS, Chroma, RAGatouille, RAPTOR]
12
  db_path : '../vectorstores' # str
13
  model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
14
  search_top_k : 3 # int
15
  score_threshold : 0.2 # float
16
 
17
  faiss_params: # Not used as of now
18
+ index_path: 'vectorstores/faiss.index' # str
19
  index_type: 'Flat' # str [Flat, HNSW, IVF]
20
  index_dimension: 384 # int
21
  index_nlist: 100 # int
 
25
  index_name: "new_idx" # str
26
 
27
  llm_params:
28
+ llm_arch: 'langchain' # [langchain]
29
  use_history: True # bool
30
+ generate_follow_up: False # bool
31
  memory_window: 3 # int
32
+ llm_style: 'Normal' # str [Normal, ELI5]
33
+ llm_loader: 'gpt-4o-mini' # str [local_llm, gpt-3.5-turbo-1106, gpt-4, gpt-4o-mini]
34
  openai_params:
35
+ temperature: 0.7 # float
36
  local_llm_params:
37
+ temperature: 0.7 # float
38
+ repo_id: 'TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF' # HuggingFace repo id
39
+ filename: 'tinyllama-1.1b-chat-v1.0.Q5_0.gguf' # Specific name of gguf file in the repo
40
+ pdf_reader: 'pymupdf' # str [llama, pymupdf, gpt]
41
+ stream: False # bool
42
+ pdf_reader: 'gpt' # str [llama, pymupdf, gpt]
43
 
44
  chat_logging:
45
+ log_chat: True # bool
46
  platform: 'literalai'
47
+ callbacks: False # bool
48
 
49
  splitter_options:
50
  use_splitter: True # bool
51
  split_by_token : True # bool
52
  remove_leftover_delimiters: True # bool
53
  remove_chunks: False # bool
54
+ chunking_mode: 'semantic' # str [fixed, semantic]
55
  chunk_size : 300 # int
56
  chunk_overlap : 30 # int
57
  chunk_separators : ["\n\n", "\n", " ", ""] # list of strings
58
  front_chunks_to_remove : null # int or None
59
  last_chunks_to_remove : null # int or None
60
+ delimiters_to_remove : ['\t', '\n', ' ', ' '] # list of strings
code/modules/config/constants.py CHANGED
@@ -6,77 +6,18 @@ load_dotenv()
6
  # API Keys - Loaded from the .env file
7
 
8
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
 
9
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
10
- LITERAL_API_KEY = os.getenv("LITERAL_API_KEY")
 
11
 
12
- opening_message = f"Hey, What Can I Help You With?\n\nYou can me ask me questions about the course logistics, course content, about the final project, or anything else!"
13
-
14
- # Prompt Templates
15
-
16
- openai_prompt_template = """Use the following pieces of information to answer the user's question.
17
- If you don't know the answer, just say that you don't know.
18
-
19
- Context: {context}
20
- Question: {question}
21
-
22
- Only return the helpful answer below and nothing else.
23
- Helpful answer:
24
- """
25
-
26
- openai_prompt_template_with_history = """Use the following pieces of information to answer the user's question.
27
- If you don't know the answer, just say that you don't know, don't try to make up an answer.
28
- Use the history to answer the question if you can.
29
- Chat History:
30
- {chat_history}
31
- Context: {context}
32
- Question: {question}
33
-
34
- Only return the helpful answer below and nothing else.
35
- Helpful answer:
36
- """
37
-
38
- tinyllama_prompt_template = """
39
- <|im_start|>system
40
- Assistant is an intelligent chatbot designed to help students with questions regarding the course. Only answer questions using the context below and if you're not sure of an answer, you can say "I don't know". Always give a breif and concise answer to the question. Use the history to answer the question if you can.
41
-
42
- Context:
43
- {context}
44
- <|im_end|>
45
- <|im_start|>user
46
- Question: Who is the instructor for this course?
47
- <|im_end|>
48
- <|im_start|>assistant
49
- The instructor for this course is Prof. Thomas Gardos.
50
- <|im_end|>
51
- <|im_start|>user
52
- Question: {question}
53
- <|im_end|>
54
- <|im_start|>assistant
55
- """
56
-
57
- tinyllama_prompt_template_with_history = """
58
- <|im_start|>system
59
- Assistant is an intelligent chatbot designed to help students with questions regarding the course. Only answer questions using the context below and if you're not sure of an answer, you can say "I don't know". Always give a breif and concise answer to the question.
60
-
61
- Chat History:
62
- {chat_history}
63
- Context:
64
- {context}
65
- <|im_end|>
66
- <|im_start|>user
67
- Question: Who is the instructor for this course?
68
- <|im_end|>
69
- <|im_start|>assistant
70
- The instructor for this course is Prof. Thomas Gardos.
71
- <|im_end|>
72
- <|im_start|>user
73
- Question: {question}
74
- <|im_end|>
75
- <|im_start|>assistant
76
- """
77
 
 
78
 
79
  # Model Paths
80
 
81
- LLAMA_PATH = "../storage/models/tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
82
- MISTRAL_PATH = "storage/models/mistral-7b-v0.1.Q4_K_M.gguf"
 
 
6
  # API Keys - Loaded from the .env file
7
 
8
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
9
+ LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY")
10
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
11
+ LITERAL_API_KEY_LOGGING = os.getenv("LITERAL_API_KEY_LOGGING")
12
+ LITERAL_API_URL = os.getenv("LITERAL_API_URL")
13
 
14
+ OAUTH_GOOGLE_CLIENT_ID = os.getenv("OAUTH_GOOGLE_CLIENT_ID")
15
+ OAUTH_GOOGLE_CLIENT_SECRET = os.getenv("OAUTH_GOOGLE_CLIENT_SECRET")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ opening_message = f"Hey, What Can I Help You With?\n\nYou can me ask me questions about the course logistics, course content, about the final project, or anything else!"
18
 
19
  # Model Paths
20
 
21
+ LLAMA_PATH = "../storage/models/tinyllama"
22
+
23
+ RETRIEVER_HF_PATHS = {"RAGatouille": "XThomasBU/Colbert_Index"}
code/modules/config/prompts.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompts = {
2
+ "openai": {
3
+ "rephrase_prompt": (
4
+ "You are someone that rephrases statements. Rephrase the student's question to add context from their chat history if relevant, ensuring it remains from the student's point of view. "
5
+ "Incorporate relevant details from the chat history to make the question clearer and more specific. "
6
+ "Do not change the meaning of the original statement, and maintain the student's tone and perspective. "
7
+ "If the question is conversational and doesn't require context, do not rephrase it. "
8
+ "Example: If the student previously asked about backpropagation in the context of deep learning and now asks 'what is it', rephrase to 'What is backpropagation.'. "
9
+ "Example: Do not rephrase if the user is asking something specific like 'cool, suggest a project with transformers to use as my final project' "
10
+ "Chat history: \n{chat_history}\n"
11
+ "Rephrase the following question only if necessary: '{input}'"
12
+ "Rephrased Question:'"
13
+ ),
14
+ "prompt_with_history": {
15
+ "normal": (
16
+ "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance. "
17
+ "If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
18
+ "Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevant. "
19
+ "Render math equations in LaTeX format between $ or $$ signs, stick to the parameter and variable icons found in your context. Be sure to explain the parameters and variables in the equations."
20
+ "Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
21
+ "Do not get influenced by the style of conversation in the chat history. Follow the instructions given here."
22
+ "Chat History:\n{chat_history}\n\n"
23
+ "Context:\n{context}\n\n"
24
+ "Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
25
+ "Student: {input}\n"
26
+ "AI Tutor:"
27
+ ),
28
+ "eli5": (
29
+ "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Your job is to explain things in the simplest and most engaging way possible, just like the 'Explain Like I'm 5' (ELI5) concept."
30
+ "If you don't know the answer, do your best without making things up. Keep your explanations straightforward and very easy to understand."
31
+ "Use the chat history and context to help you, but avoid repeating past responses. Provide links from the source_file metadata when they're helpful."
32
+ "Use very simple language and examples to explain any math equations, and put the equations in LaTeX format between $ or $$ signs."
33
+ "Be friendly and engaging, like you're chatting with a young child who's curious and eager to learn. Avoid complex terms and jargon."
34
+ "Include simple and clear examples wherever you can to make things easier to understand."
35
+ "Do not get influenced by the style of conversation in the chat history. Follow the instructions given here."
36
+ "Chat History:\n{chat_history}\n\n"
37
+ "Context:\n{context}\n\n"
38
+ "Answer the student's question below in a friendly, simple, and engaging way, just like the ELI5 concept. Use the context and history only if they're relevant, otherwise, just have a natural conversation."
39
+ "Give a clear and detailed explanation with simple examples to make it easier to understand. Remember, your goal is to break down complex topics into very simple terms, just like ELI5."
40
+ "Student: {input}\n"
41
+ "AI Tutor:"
42
+ ),
43
+ "socratic": (
44
+ "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Engage the student in a Socratic dialogue to help them discover answers on their own. Use the provided context to guide your questioning."
45
+ "If you don't know the answer, do your best without making things up. Keep the conversation engaging and inquisitive."
46
+ "Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata when relevant. Use the source context that is most relevant."
47
+ "Speak in a friendly and engaging manner, encouraging critical thinking and self-discovery."
48
+ "Use questions to lead the student to explore the topic and uncover answers."
49
+ "Chat History:\n{chat_history}\n\n"
50
+ "Context:\n{context}\n\n"
51
+ "Answer the student's question below by guiding them through a series of questions and insights that lead to deeper understanding. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation."
52
+ "Foster an inquisitive mindset and help the student discover answers through dialogue."
53
+ "Student: {input}\n"
54
+ "AI Tutor:"
55
+ ),
56
+ },
57
+ "prompt_no_history": (
58
+ "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance. "
59
+ "If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
60
+ "Provide links from the source_file metadata. Use the source context that is most relevant. "
61
+ "Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
62
+ "Context:\n{context}\n\n"
63
+ "Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
64
+ "Student: {input}\n"
65
+ "AI Tutor:"
66
+ ),
67
+ },
68
+ "tiny_llama": {
69
+ "prompt_no_history": (
70
+ "system\n"
71
+ "Assistant is an intelligent chatbot designed to help students with questions regarding the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance.\n"
72
+ "If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally.\n"
73
+ "Provide links from the source_file metadata. Use the source context that is most relevant.\n"
74
+ "Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n"
75
+ "\n\n"
76
+ "user\n"
77
+ "Context:\n{context}\n\n"
78
+ "Question: {input}\n"
79
+ "\n\n"
80
+ "assistant"
81
+ ),
82
+ "prompt_with_history": (
83
+ "system\n"
84
+ "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance. "
85
+ "If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
86
+ "Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevant. "
87
+ "Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n"
88
+ "\n\n"
89
+ "user\n"
90
+ "Chat History:\n{chat_history}\n\n"
91
+ "Context:\n{context}\n\n"
92
+ "Question: {input}\n"
93
+ "\n\n"
94
+ "assistant"
95
+ ),
96
+ },
97
+ }
code/modules/dataloader/data_loader.py CHANGED
@@ -14,32 +14,89 @@ from llama_parse import LlamaParse
14
  from langchain.schema import Document
15
  import logging
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 
17
  from ragatouille import RAGPretrainedModel
18
  from langchain.chains import LLMChain
19
  from langchain_community.llms import OpenAI
20
  from langchain import PromptTemplate
21
  import json
22
  from concurrent.futures import ThreadPoolExecutor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- from modules.dataloader.helpers import get_metadata
 
 
 
 
 
 
25
 
 
 
 
 
26
 
27
- class PDFReader:
28
- def __init__(self):
29
- pass
 
30
 
31
- def get_loader(self, pdf_path):
32
- loader = PyMuPDFLoader(pdf_path)
33
- return loader
34
 
35
- def get_documents(self, loader):
36
- return loader.load()
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  class FileReader:
40
- def __init__(self, logger):
41
- self.pdf_reader = PDFReader()
42
  self.logger = logger
 
 
 
 
 
 
 
 
 
43
 
44
  def extract_text_from_pdf(self, pdf_path):
45
  text = ""
@@ -51,20 +108,8 @@ class FileReader:
51
  text += page.extract_text()
52
  return text
53
 
54
- def download_pdf_from_url(self, pdf_url):
55
- response = requests.get(pdf_url)
56
- if response.status_code == 200:
57
- with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
58
- temp_file.write(response.content)
59
- temp_file_path = temp_file.name
60
- return temp_file_path
61
- else:
62
- self.logger.error(f"Failed to download PDF from URL: {pdf_url}")
63
- return None
64
-
65
  def read_pdf(self, temp_file_path: str):
66
- loader = self.pdf_reader.get_loader(temp_file_path)
67
- documents = self.pdf_reader.get_documents(loader)
68
  return documents
69
 
70
  def read_txt(self, temp_file_path: str):
@@ -89,8 +134,7 @@ class FileReader:
89
  return loader.load()
90
 
91
  def read_html(self, url: str):
92
- loader = WebBaseLoader(url)
93
- return loader.load()
94
 
95
  def read_tex_from_url(self, tex_url):
96
  response = requests.get(tex_url)
@@ -110,21 +154,31 @@ class ChunkProcessor:
110
  self.document_metadata = {}
111
  self.document_chunks_full = []
112
 
 
 
 
113
  if config["splitter_options"]["use_splitter"]:
114
- if config["splitter_options"]["split_by_token"]:
115
- self.splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
116
- chunk_size=config["splitter_options"]["chunk_size"],
117
- chunk_overlap=config["splitter_options"]["chunk_overlap"],
118
- separators=config["splitter_options"]["chunk_separators"],
119
- disallowed_special=(),
120
- )
 
 
 
 
 
 
 
 
121
  else:
122
- self.splitter = RecursiveCharacterTextSplitter(
123
- chunk_size=config["splitter_options"]["chunk_size"],
124
- chunk_overlap=config["splitter_options"]["chunk_overlap"],
125
- separators=config["splitter_options"]["chunk_separators"],
126
- disallowed_special=(),
127
  )
 
128
  else:
129
  self.splitter = None
130
  self.logger.info("ChunkProcessor instance created")
@@ -147,16 +201,12 @@ class ChunkProcessor:
147
  def process_chunks(
148
  self, documents, file_type="txt", source="", page=0, metadata={}
149
  ):
 
150
  documents = [Document(page_content=documents, source=source, page=page)]
151
- if (
152
- file_type == "txt"
153
- or file_type == "docx"
154
- or file_type == "srt"
155
- or file_type == "tex"
156
- ):
157
  document_chunks = self.splitter.split_documents(documents)
158
- elif file_type == "pdf":
159
- document_chunks = documents # Full page for now
160
 
161
  # add the source and page number back to the metadata
162
  for chunk in document_chunks:
@@ -179,7 +229,6 @@ class ChunkProcessor:
179
  "https://dl4ds.github.io/sp2024/lectures/",
180
  "https://dl4ds.github.io/sp2024/schedule/",
181
  ) # For any additional metadata
182
-
183
  with ThreadPoolExecutor() as executor:
184
  executor.map(
185
  self.process_file,
@@ -228,11 +277,11 @@ class ChunkProcessor:
228
 
229
  page_num = doc.metadata.get("page", 0)
230
  file_data[page_num] = doc.page_content
231
- metadata = (
232
- addl_metadata.get(file_path, {})
233
- if metadata_source == "file"
234
- else {"source": file_path, "page": page_num}
235
- )
236
  file_metadata[page_num] = metadata
237
 
238
  if self.config["vectorstore"]["db_option"] not in ["RAGatouille"]:
@@ -250,11 +299,8 @@ class ChunkProcessor:
250
 
251
  def process_file(self, file_path, file_index, file_reader, addl_metadata):
252
  file_name = os.path.basename(file_path)
253
- if file_name in self.document_data:
254
- return
255
 
256
- file_type = file_name.split(".")[-1].lower()
257
- self.logger.info(f"Reading file {file_index + 1}: {file_path}")
258
 
259
  read_methods = {
260
  "pdf": file_reader.read_pdf,
@@ -268,7 +314,13 @@ class ChunkProcessor:
268
  return
269
 
270
  try:
271
- documents = read_methods[file_type](file_path)
 
 
 
 
 
 
272
  self.process_documents(
273
  documents, file_path, file_type, "file", addl_metadata
274
  )
@@ -326,11 +378,14 @@ class ChunkProcessor:
326
  f"{self.config['log_chunk_dir']}/metadata/doc_metadata.json", "r"
327
  ) as json_file:
328
  self.document_metadata = json.load(json_file)
 
 
 
329
 
330
 
331
  class DataLoader:
332
  def __init__(self, config, logger=None):
333
- self.file_reader = FileReader(logger=logger)
334
  self.chunk_processor = ChunkProcessor(config, logger=logger)
335
 
336
  def get_chunks(self, uploaded_files, weblinks):
@@ -348,13 +403,19 @@ if __name__ == "__main__":
348
  with open("../code/modules/config/config.yml", "r") as f:
349
  config = yaml.safe_load(f)
350
 
 
 
 
 
 
351
  data_loader = DataLoader(config, logger=logger)
352
  document_chunks, document_names, documents, document_metadata = (
353
  data_loader.get_chunks(
 
354
  [],
355
- ["https://dl4ds.github.io/sp2024/"],
356
  )
357
  )
358
 
359
- print(document_names)
360
  print(len(document_chunks))
 
 
14
  from langchain.schema import Document
15
  import logging
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
+ from langchain_experimental.text_splitter import SemanticChunker
18
+ from langchain_openai.embeddings import OpenAIEmbeddings
19
  from ragatouille import RAGPretrainedModel
20
  from langchain.chains import LLMChain
21
  from langchain_community.llms import OpenAI
22
  from langchain import PromptTemplate
23
  import json
24
  from concurrent.futures import ThreadPoolExecutor
25
+ from urllib.parse import urljoin
26
+ import html2text
27
+ import bs4
28
+ import tempfile
29
+ import PyPDF2
30
+ from modules.dataloader.pdf_readers.base import PDFReader
31
+ from modules.dataloader.pdf_readers.llama import LlamaParser
32
+ from modules.dataloader.pdf_readers.gpt import GPTParser
33
+
34
+ try:
35
+ from modules.dataloader.helpers import get_metadata, download_pdf_from_url
36
+ from modules.config.constants import OPENAI_API_KEY, LLAMA_CLOUD_API_KEY
37
+ except:
38
+ from dataloader.helpers import get_metadata, download_pdf_from_url
39
+ from config.constants import OPENAI_API_KEY, LLAMA_CLOUD_API_KEY
40
+
41
+ logger = logging.getLogger(__name__)
42
+ BASE_DIR = os.getcwd()
43
+
44
+
45
+ class HTMLReader:
46
+ def __init__(self):
47
+ pass
48
 
49
+ def read_url(self, url):
50
+ response = requests.get(url)
51
+ if response.status_code == 200:
52
+ return response.text
53
+ else:
54
+ logger.warning(f"Failed to download HTML from URL: {url}")
55
+ return None
56
 
57
+ def check_links(self, base_url, html_content):
58
+ soup = bs4.BeautifulSoup(html_content, "html.parser")
59
+ for link in soup.find_all("a"):
60
+ href = link.get("href")
61
 
62
+ if not href or href.startswith("#"):
63
+ continue
64
+ elif not href.startswith("https"):
65
+ href = href.replace("http", "https")
66
 
67
+ absolute_url = urljoin(base_url, href)
68
+ link['href'] = absolute_url
 
69
 
70
+ resp = requests.head(absolute_url)
71
+ if resp.status_code != 200:
72
+ logger.warning(f"Link {absolute_url} is broken. Status code: {resp.status_code}")
73
 
74
+ return str(soup)
75
+
76
+ def html_to_md(self, url, html_content):
77
+ html_processed = self.check_links(url, html_content)
78
+ markdown_content = html2text.html2text(html_processed)
79
+ return markdown_content
80
+
81
+ def read_html(self, url):
82
+ html_content = self.read_url(url)
83
+ if html_content:
84
+ return self.html_to_md(url, html_content)
85
+ else:
86
+ return None
87
 
88
  class FileReader:
89
+ def __init__(self, logger, kind):
 
90
  self.logger = logger
91
+ self.kind = kind
92
+ if kind == "llama":
93
+ self.pdf_reader = LlamaParser()
94
+ elif kind == "gpt":
95
+ self.pdf_reader = GPTParser()
96
+ else:
97
+ self.pdf_reader = PDFReader()
98
+ self.web_reader = HTMLReader()
99
+ self.logger.info(f"Initialized FileReader with {kind} PDF reader and HTML reader")
100
 
101
  def extract_text_from_pdf(self, pdf_path):
102
  text = ""
 
108
  text += page.extract_text()
109
  return text
110
 
 
 
 
 
 
 
 
 
 
 
 
111
  def read_pdf(self, temp_file_path: str):
112
+ documents = self.pdf_reader.parse(temp_file_path)
 
113
  return documents
114
 
115
  def read_txt(self, temp_file_path: str):
 
134
  return loader.load()
135
 
136
  def read_html(self, url: str):
137
+ return [Document(page_content=self.web_reader.read_html(url))]
 
138
 
139
  def read_tex_from_url(self, tex_url):
140
  response = requests.get(tex_url)
 
154
  self.document_metadata = {}
155
  self.document_chunks_full = []
156
 
157
+ if not config['vectorstore']['embedd_files']:
158
+ self.load_document_data()
159
+
160
  if config["splitter_options"]["use_splitter"]:
161
+ if config["splitter_options"]["chunking_mode"] == "fixed":
162
+ if config["splitter_options"]["split_by_token"]:
163
+ self.splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
164
+ chunk_size=config["splitter_options"]["chunk_size"],
165
+ chunk_overlap=config["splitter_options"]["chunk_overlap"],
166
+ separators=config["splitter_options"]["chunk_separators"],
167
+ disallowed_special=(),
168
+ )
169
+ else:
170
+ self.splitter = RecursiveCharacterTextSplitter(
171
+ chunk_size=config["splitter_options"]["chunk_size"],
172
+ chunk_overlap=config["splitter_options"]["chunk_overlap"],
173
+ separators=config["splitter_options"]["chunk_separators"],
174
+ disallowed_special=(),
175
+ )
176
  else:
177
+ self.splitter = SemanticChunker(
178
+ OpenAIEmbeddings(),
179
+ breakpoint_threshold_type="percentile"
 
 
180
  )
181
+
182
  else:
183
  self.splitter = None
184
  self.logger.info("ChunkProcessor instance created")
 
201
  def process_chunks(
202
  self, documents, file_type="txt", source="", page=0, metadata={}
203
  ):
204
+ # TODO: Clear up this pipeline of re-adding metadata
205
  documents = [Document(page_content=documents, source=source, page=page)]
206
+ if file_type == "pdf" and self.config["splitter_options"]["chunking_mode"] == "fixed":
207
+ document_chunks = documents
208
+ else:
 
 
 
209
  document_chunks = self.splitter.split_documents(documents)
 
 
210
 
211
  # add the source and page number back to the metadata
212
  for chunk in document_chunks:
 
229
  "https://dl4ds.github.io/sp2024/lectures/",
230
  "https://dl4ds.github.io/sp2024/schedule/",
231
  ) # For any additional metadata
 
232
  with ThreadPoolExecutor() as executor:
233
  executor.map(
234
  self.process_file,
 
277
 
278
  page_num = doc.metadata.get("page", 0)
279
  file_data[page_num] = doc.page_content
280
+
281
+ # Create a new dictionary for metadata in each iteration
282
+ metadata = addl_metadata.get(file_path, {}).copy()
283
+ metadata["page"] = page_num
284
+ metadata["source"] = file_path
285
  file_metadata[page_num] = metadata
286
 
287
  if self.config["vectorstore"]["db_option"] not in ["RAGatouille"]:
 
299
 
300
  def process_file(self, file_path, file_index, file_reader, addl_metadata):
301
  file_name = os.path.basename(file_path)
 
 
302
 
303
+ file_type = file_name.split(".")[-1]
 
304
 
305
  read_methods = {
306
  "pdf": file_reader.read_pdf,
 
314
  return
315
 
316
  try:
317
+
318
+ if file_path in self.document_data:
319
+ self.logger.warning(f"File {file_name} already processed")
320
+ documents = [Document(page_content=content) for content in self.document_data[file_path].values()]
321
+ else:
322
+ documents = read_methods[file_type](file_path)
323
+
324
  self.process_documents(
325
  documents, file_path, file_type, "file", addl_metadata
326
  )
 
378
  f"{self.config['log_chunk_dir']}/metadata/doc_metadata.json", "r"
379
  ) as json_file:
380
  self.document_metadata = json.load(json_file)
381
+ self.logger.info(
382
+ f"Loaded document content from {self.config['log_chunk_dir']}/docs/doc_content.json. Total documents: {len(self.document_data)}"
383
+ )
384
 
385
 
386
  class DataLoader:
387
  def __init__(self, config, logger=None):
388
+ self.file_reader = FileReader(logger=logger, kind=config["llm_params"]["pdf_reader"])
389
  self.chunk_processor = ChunkProcessor(config, logger=logger)
390
 
391
  def get_chunks(self, uploaded_files, weblinks):
 
403
  with open("../code/modules/config/config.yml", "r") as f:
404
  config = yaml.safe_load(f)
405
 
406
+ STORAGE_DIR = os.path.join(BASE_DIR, config['vectorstore']["data_path"])
407
+ uploaded_files = [
408
+ os.path.join(STORAGE_DIR, file) for file in os.listdir(STORAGE_DIR) if file != "urls.txt"
409
+ ]
410
+
411
  data_loader = DataLoader(config, logger=logger)
412
  document_chunks, document_names, documents, document_metadata = (
413
  data_loader.get_chunks(
414
+ ["https://dl4ds.github.io/sp2024/static_files/lectures/05_loss_functions_v2.pdf"],
415
  [],
 
416
  )
417
  )
418
 
419
+ print(document_names[:5])
420
  print(len(document_chunks))
421
+
code/modules/dataloader/helpers.py CHANGED
@@ -1,7 +1,7 @@
1
  import requests
2
  from bs4 import BeautifulSoup
3
- from tqdm import tqdm
4
-
5
 
6
  def get_urls_from_file(file_path: str):
7
  """
@@ -106,3 +106,23 @@ def get_metadata(lectures_url, schedule_url):
106
  continue
107
 
108
  return lecture_metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import requests
2
  from bs4 import BeautifulSoup
3
+ from urllib.parse import urlparse
4
+ import tempfile
5
 
6
  def get_urls_from_file(file_path: str):
7
  """
 
106
  continue
107
 
108
  return lecture_metadata
109
+
110
+
111
+ def download_pdf_from_url(pdf_url):
112
+ """
113
+ Function to temporarily download a PDF file from a URL and return the local file path.
114
+
115
+ Args:
116
+ pdf_url (str): The URL of the PDF file to download.
117
+
118
+ Returns:
119
+ str: The local file path of the downloaded PDF file.
120
+ """
121
+ response = requests.get(pdf_url)
122
+ if response.status_code == 200:
123
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
124
+ temp_file.write(response.content)
125
+ temp_file_path = temp_file.name
126
+ return temp_file_path
127
+ else:
128
+ return None
code/modules/dataloader/pdf_readers/base.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.document_loaders import PyMuPDFLoader
2
+
3
+
4
+ class PDFReader:
5
+ def __init__(self):
6
+ pass
7
+
8
+ def get_loader(self, pdf_path):
9
+ loader = PyMuPDFLoader(pdf_path)
10
+ return loader
11
+
12
+ def parse(self, pdf_path):
13
+ loader = self.get_loader(pdf_path)
14
+ return loader.load()
code/modules/dataloader/pdf_readers/gpt.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ import requests
4
+
5
+ from io import BytesIO
6
+ from openai import OpenAI
7
+ from pdf2image import convert_from_path
8
+ from langchain.schema import Document
9
+
10
+
11
+ class GPTParser:
12
+ """
13
+ This class uses OpenAI's GPT-4o mini model to parse PDFs and extract text, images and equations.
14
+ It is the most advanced parser in the system and is able to handle complex formats and layouts
15
+ """
16
+
17
+ def __init__(self):
18
+ self.client = OpenAI()
19
+ self.api_key = os.getenv("OPENAI_API_KEY")
20
+ self.prompt = """
21
+ The provided documents are images of PDFs of lecture slides of deep learning material.
22
+ They contain LaTeX equations, images, and text.
23
+ The goal is to extract the text, images and equations from the slides and convert everything to markdown format. Some of the equations may be complicated.
24
+ The markdown should be clean and easy to read, and any math equation should be converted to LaTeX, between $$.
25
+ For images, give a description and if you can, a source. Separate each page with '---'.
26
+ Just respond with the markdown. Do not include page numbers or any other metadata. Do not try to provide titles. Strictly the content.
27
+ """
28
+
29
+ def parse(self, pdf_path):
30
+ images = convert_from_path(pdf_path)
31
+
32
+ encoded_images = [self.encode_image(image) for image in images]
33
+
34
+ chunks = [encoded_images[i:i + 5] for i in range(0, len(encoded_images), 5)]
35
+
36
+ headers = {
37
+ "Content-Type": "application/json",
38
+ "Authorization": f"Bearer {self.api_key}"
39
+ }
40
+
41
+ output = ""
42
+ for chunk_num, chunk in enumerate(chunks):
43
+ content = [{"type": "image_url", "image_url": {
44
+ "url": f"data:image/jpeg;base64,{image}"}} for image in chunk]
45
+
46
+ content.insert(0, {"type": "text", "text": self.prompt})
47
+
48
+ payload = {
49
+ "model": "gpt-4o-mini",
50
+ "messages": [
51
+ {
52
+ "role": "user",
53
+ "content": content
54
+ }
55
+ ],
56
+ }
57
+
58
+ response = requests.post(
59
+ "https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
60
+
61
+ resp = response.json()
62
+
63
+ chunk_output = resp['choices'][0]['message']['content'].replace("```", "").replace("markdown", "").replace("````", "")
64
+
65
+ output += chunk_output + "\n---\n"
66
+
67
+ output = output.split("\n---\n")
68
+ output = [doc for doc in output if doc.strip() != ""]
69
+
70
+ documents = [
71
+ Document(
72
+ page_content=page,
73
+ metadata={"source": pdf_path, "page": i}
74
+ ) for i, page in enumerate(output)
75
+ ]
76
+ return documents
77
+
78
+ def encode_image(self, image):
79
+ buffered = BytesIO()
80
+ image.save(buffered, format="JPEG")
81
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
code/modules/dataloader/pdf_readers/llama.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from llama_parse import LlamaParse
4
+ from langchain.schema import Document
5
+ from modules.config.constants import OPENAI_API_KEY, LLAMA_CLOUD_API_KEY
6
+ from modules.dataloader.helpers import download_pdf_from_url
7
+
8
+
9
+
10
+ class LlamaParser:
11
+ def __init__(self):
12
+ self.GPT_API_KEY = OPENAI_API_KEY
13
+ self.LLAMA_CLOUD_API_KEY = LLAMA_CLOUD_API_KEY
14
+ self.parse_url = "https://api.cloud.llamaindex.ai/api/parsing/upload"
15
+ self.headers = {
16
+ 'Accept': 'application/json',
17
+ 'Authorization': f'Bearer {LLAMA_CLOUD_API_KEY}'
18
+ }
19
+ self.parser = LlamaParse(
20
+ api_key=LLAMA_CLOUD_API_KEY,
21
+ result_type="markdown",
22
+ verbose=True,
23
+ language="en",
24
+ gpt4o_mode=False,
25
+ # gpt4o_api_key=OPENAI_API_KEY,
26
+ parsing_instruction="The provided documents are PDFs of lecture slides of deep learning material. They contain LaTeX equations, images, and text. The goal is to extract the text, images and equations from the slides. The markdown should be clean and easy to read, and any math equation should be converted to LaTeX format, between $ signs. For images, if you can, give a description and a source."
27
+ )
28
+
29
+ def parse(self, pdf_path):
30
+ if not os.path.exists(pdf_path):
31
+ pdf_path = download_pdf_from_url(pdf_path)
32
+
33
+ documents = self.parser.load_data(pdf_path)
34
+ document = [document.to_langchain_format() for document in documents][0]
35
+
36
+ content = document.page_content
37
+ pages = content.split("\n---\n")
38
+ pages = [page.strip() for page in pages]
39
+
40
+ documents = [
41
+ Document(
42
+ page_content=page,
43
+ metadata={"source": pdf_path, "page": i}
44
+ ) for i, page in enumerate(pages)
45
+ ]
46
+
47
+ return documents
48
+
49
+ def make_request(self, pdf_url):
50
+ payload = {
51
+ "gpt4o_mode": "false",
52
+ "parsing_instruction": "The provided document is a PDF of lecture slides of deep learning material. They contain LaTeX equations, images, and text. The goal is to extract the text, images and equations from the slides and convert them to markdown format. The markdown should be clean and easy to read, and any math equation should be converted to LaTeX, between $$. For images, give a description and if you can, a source.",
53
+ }
54
+
55
+ files = [
56
+ ('file', ('file', requests.get(pdf_url).content, 'application/octet-stream'))
57
+ ]
58
+
59
+ response = requests.request(
60
+ "POST", self.parse_url, headers=self.headers, data=payload, files=files)
61
+
62
+ return response.json()['id'], response.json()['status']
63
+
64
+ async def get_result(self, job_id):
65
+ url = f"https://api.cloud.llamaindex.ai/api/parsing/job/{job_id}/result/markdown"
66
+
67
+ response = requests.request("GET", url, headers=self.headers, data={})
68
+
69
+ return response.json()['markdown']
70
+
71
+ async def _parse(self, pdf_path):
72
+ job_id, status = self.make_request(pdf_path)
73
+
74
+ while status != "SUCCESS":
75
+ url = f"https://api.cloud.llamaindex.ai/api/parsing/job/{job_id}"
76
+ response = requests.request("GET", url, headers=self.headers, data={})
77
+ status = response.json()["status"]
78
+
79
+ result = await self.get_result(job_id)
80
+
81
+ documents = [
82
+ Document(
83
+ page_content=result,
84
+ metadata={"source": pdf_path}
85
+ )
86
+ ]
87
+
88
+ return documents
89
+
90
+ async def _parse(self, pdf_path):
91
+ return await self._parse(pdf_path)
92
+
code/modules/dataloader/webpage_crawler.py CHANGED
@@ -66,7 +66,6 @@ class WebpageCrawler:
66
  )
67
  for link in unchecked_links:
68
  dict_links[link] = "Checked"
69
- print(f"Checked: {link}")
70
  dict_links.update(
71
  {
72
  link: "Not-checked"
 
66
  )
67
  for link in unchecked_links:
68
  dict_links[link] = "Checked"
 
69
  dict_links.update(
70
  {
71
  link: "Not-checked"
code/modules/vectorstore/base.py CHANGED
@@ -29,5 +29,8 @@ class VectorStoreBase:
29
  """
30
  raise NotImplementedError
31
 
 
 
 
32
  def __str__(self):
33
  return self.__class__.__name__
 
29
  """
30
  raise NotImplementedError
31
 
32
+ def __len__(self):
33
+ raise NotImplementedError
34
+
35
  def __str__(self):
36
  return self.__class__.__name__
code/modules/vectorstore/chroma.py CHANGED
@@ -39,3 +39,6 @@ class ChromaVectorStore(VectorStoreBase):
39
 
40
  def as_retriever(self):
41
  return self.vectorstore.as_retriever()
 
 
 
 
39
 
40
  def as_retriever(self):
41
  return self.vectorstore.as_retriever()
42
+
43
+ def __len__(self):
44
+ return len(self.vectorstore)
code/modules/vectorstore/colbert.py CHANGED
@@ -1,6 +1,67 @@
1
  from ragatouille import RAGPretrainedModel
2
  from modules.vectorstore.base import VectorStoreBase
 
 
 
 
3
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  class ColbertVectorStore(VectorStoreBase):
@@ -24,16 +85,28 @@ class ColbertVectorStore(VectorStoreBase):
24
  document_ids=document_names,
25
  document_metadatas=document_metadata,
26
  )
 
27
 
28
  def load_database(self):
29
  path = os.path.join(
 
30
  self.config["vectorstore"]["db_path"],
31
  "db_" + self.config["vectorstore"]["db_option"],
32
  )
33
  self.vectorstore = RAGPretrainedModel.from_index(
34
  f"{path}/colbert/indexes/new_idx"
35
  )
 
 
 
 
 
 
 
36
  return self.vectorstore
37
 
38
  def as_retriever(self):
39
  return self.vectorstore.as_retriever()
 
 
 
 
1
  from ragatouille import RAGPretrainedModel
2
  from modules.vectorstore.base import VectorStoreBase
3
+ from langchain_core.retrievers import BaseRetriever
4
+ from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun, Callbacks
5
+ from langchain_core.documents import Document
6
+ from typing import Any, List, Optional, Sequence
7
  import os
8
+ import json
9
+
10
+
11
+ class RAGatouilleLangChainRetrieverWithScore(BaseRetriever):
12
+ model: Any
13
+ kwargs: dict = {}
14
+
15
+ def _get_relevant_documents(
16
+ self,
17
+ query: str,
18
+ *,
19
+ run_manager: CallbackManagerForRetrieverRun, # noqa
20
+ ) -> List[Document]:
21
+ """Get documents relevant to a query."""
22
+ docs = self.model.search(query, **self.kwargs)
23
+ return [
24
+ Document(
25
+ page_content=doc["content"],
26
+ metadata={**doc.get("document_metadata", {}), "score": doc["score"]},
27
+ )
28
+ for doc in docs
29
+ ]
30
+
31
+ async def _aget_relevant_documents(
32
+ self,
33
+ query: str,
34
+ *,
35
+ run_manager: CallbackManagerForRetrieverRun, # noqa
36
+ ) -> List[Document]:
37
+ """Get documents relevant to a query."""
38
+ docs = self.model.search(query, **self.kwargs)
39
+ return [
40
+ Document(
41
+ page_content=doc["content"],
42
+ metadata={**doc.get("document_metadata", {}), "score": doc["score"]},
43
+ )
44
+ for doc in docs
45
+ ]
46
+
47
+
48
+ class RAGPretrainedModel(RAGPretrainedModel):
49
+ """
50
+ Adding len property to RAGPretrainedModel
51
+ """
52
+
53
+ def __init__(self, *args, **kwargs):
54
+ super().__init__(*args, **kwargs)
55
+ self._document_count = 0
56
+
57
+ def set_document_count(self, count):
58
+ self._document_count = count
59
+
60
+ def __len__(self):
61
+ return self._document_count
62
+
63
+ def as_langchain_retriever(self, **kwargs: Any) -> BaseRetriever:
64
+ return RAGatouilleLangChainRetrieverWithScore(model=self, kwargs=kwargs)
65
 
66
 
67
  class ColbertVectorStore(VectorStoreBase):
 
85
  document_ids=document_names,
86
  document_metadatas=document_metadata,
87
  )
88
+ self.colbert.set_document_count(len(document_names))
89
 
90
  def load_database(self):
91
  path = os.path.join(
92
+ os.getcwd(),
93
  self.config["vectorstore"]["db_path"],
94
  "db_" + self.config["vectorstore"]["db_option"],
95
  )
96
  self.vectorstore = RAGPretrainedModel.from_index(
97
  f"{path}/colbert/indexes/new_idx"
98
  )
99
+
100
+ index_metadata = json.load(
101
+ open(f"{path}/colbert/indexes/new_idx/0.metadata.json")
102
+ )
103
+ num_documents = index_metadata["num_passages"]
104
+ self.vectorstore.set_document_count(num_documents)
105
+
106
  return self.vectorstore
107
 
108
  def as_retriever(self):
109
  return self.vectorstore.as_retriever()
110
+
111
+ def __len__(self):
112
+ return len(self.vectorstore)
code/modules/vectorstore/faiss.py CHANGED
@@ -3,10 +3,21 @@ from modules.vectorstore.base import VectorStoreBase
3
  import os
4
 
5
 
 
 
 
 
 
 
 
6
  class FaissVectorStore(VectorStoreBase):
7
  def __init__(self, config):
8
  self.config = config
9
  self._init_vector_db()
 
 
 
 
10
 
11
  def _init_vector_db(self):
12
  self.faiss = FAISS(
@@ -18,24 +29,12 @@ class FaissVectorStore(VectorStoreBase):
18
  documents=document_chunks, embedding=embedding_model
19
  )
20
  self.vectorstore.save_local(
21
- os.path.join(
22
- self.config["vectorstore"]["db_path"],
23
- "db_"
24
- + self.config["vectorstore"]["db_option"]
25
- + "_"
26
- + self.config["vectorstore"]["model"],
27
- )
28
  )
29
 
30
  def load_database(self, embedding_model):
31
  self.vectorstore = self.faiss.load_local(
32
- os.path.join(
33
- self.config["vectorstore"]["db_path"],
34
- "db_"
35
- + self.config["vectorstore"]["db_option"]
36
- + "_"
37
- + self.config["vectorstore"]["model"],
38
- ),
39
  embedding_model,
40
  allow_dangerous_deserialization=True,
41
  )
@@ -43,3 +42,6 @@ class FaissVectorStore(VectorStoreBase):
43
 
44
  def as_retriever(self):
45
  return self.vectorstore.as_retriever()
 
 
 
 
3
  import os
4
 
5
 
6
+ class FAISS(FAISS):
7
+ """To add length property to FAISS class"""
8
+
9
+ def __len__(self):
10
+ return self.index.ntotal
11
+
12
+
13
  class FaissVectorStore(VectorStoreBase):
14
  def __init__(self, config):
15
  self.config = config
16
  self._init_vector_db()
17
+ self.local_path = os.path.join(self.config["vectorstore"]["db_path"],
18
+ "db_" + self.config["vectorstore"]["db_option"]
19
+ + "_" + self.config["vectorstore"]["model"]
20
+ + "_" + config["splitter_options"]["chunking_mode"])
21
 
22
  def _init_vector_db(self):
23
  self.faiss = FAISS(
 
29
  documents=document_chunks, embedding=embedding_model
30
  )
31
  self.vectorstore.save_local(
32
+ self.local_path
 
 
 
 
 
 
33
  )
34
 
35
  def load_database(self, embedding_model):
36
  self.vectorstore = self.faiss.load_local(
37
+ self.local_path,
 
 
 
 
 
 
38
  embedding_model,
39
  allow_dangerous_deserialization=True,
40
  )
 
42
 
43
  def as_retriever(self):
44
  return self.vectorstore.as_retriever()
45
+
46
+ def __len__(self):
47
+ return len(self.vectorstore)
code/modules/vectorstore/raptor.py CHANGED
@@ -5,7 +5,7 @@ import os
5
  import numpy as np
6
  import pandas as pd
7
  import umap
8
- from langchain_core.prompts import ChatPromptTemplate
9
  from langchain_core.output_parsers import StrOutputParser
10
  from sklearn.mixture import GaussianMixture
11
  from langchain_community.chat_models import ChatOpenAI
@@ -16,6 +16,13 @@ from modules.vectorstore.base import VectorStoreBase
16
  RANDOM_SEED = 42
17
 
18
 
 
 
 
 
 
 
 
19
  class RAPTORVectoreStore(VectorStoreBase):
20
  def __init__(self, config, documents=[], text_splitter=None, embedding_model=None):
21
  self.documents = documents
 
5
  import numpy as np
6
  import pandas as pd
7
  import umap
8
+ from langchain_core.prompts.chat import ChatPromptTemplate
9
  from langchain_core.output_parsers import StrOutputParser
10
  from sklearn.mixture import GaussianMixture
11
  from langchain_community.chat_models import ChatOpenAI
 
16
  RANDOM_SEED = 42
17
 
18
 
19
+ class FAISS(FAISS):
20
+ """To add length property to FAISS class"""
21
+
22
+ def __len__(self):
23
+ return self.index.ntotal
24
+
25
+
26
  class RAPTORVectoreStore(VectorStoreBase):
27
  def __init__(self, config, documents=[], text_splitter=None, embedding_model=None):
28
  self.documents = documents
code/modules/vectorstore/store_manager.py CHANGED
@@ -3,6 +3,7 @@ from modules.vectorstore.helpers import *
3
  from modules.dataloader.webpage_crawler import WebpageCrawler
4
  from modules.dataloader.data_loader import DataLoader
5
  from modules.dataloader.helpers import *
 
6
  from modules.vectorstore.embedding_model_loader import EmbeddingModelLoader
7
  import logging
8
  import os
@@ -135,14 +136,34 @@ class VectorStoreManager:
135
  self.embedding_model = self.create_embedding_model()
136
  else:
137
  self.embedding_model = None
138
- self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
 
 
 
 
 
 
 
 
139
  end_time = time.time() # End time for loading database
140
  self.logger.info(
141
- f"Time taken to load database: {end_time - start_time} seconds"
142
  )
143
  self.logger.info("Loaded database")
144
  return self.loaded_vector_db
145
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  if __name__ == "__main__":
148
  import yaml
@@ -152,7 +173,20 @@ if __name__ == "__main__":
152
  print(config)
153
  print(f"Trying to create database with config: {config}")
154
  vector_db = VectorStoreManager(config)
155
- vector_db.create_database()
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  print("Created database")
157
 
158
  print(f"Trying to load the database")
 
3
  from modules.dataloader.webpage_crawler import WebpageCrawler
4
  from modules.dataloader.data_loader import DataLoader
5
  from modules.dataloader.helpers import *
6
+ from modules.config.constants import RETRIEVER_HF_PATHS
7
  from modules.vectorstore.embedding_model_loader import EmbeddingModelLoader
8
  import logging
9
  import os
 
136
  self.embedding_model = self.create_embedding_model()
137
  else:
138
  self.embedding_model = None
139
+ try:
140
+ self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
141
+ except Exception as e:
142
+ raise ValueError(
143
+ f"Error loading database, check if it exists. if not run python -m modules.vectorstore.store_manager / Resteart the HF Space: {e}"
144
+ )
145
+ # print(f"Creating database")
146
+ # self.create_database()
147
+ # self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
148
  end_time = time.time() # End time for loading database
149
  self.logger.info(
150
+ f"Time taken to load database {self.config['vectorstore']['db_option']}: {end_time - start_time} seconds"
151
  )
152
  self.logger.info("Loaded database")
153
  return self.loaded_vector_db
154
 
155
+ def load_from_HF(self, HF_PATH):
156
+ start_time = time.time() # Start time for loading database
157
+ self.vector_db._load_from_HF(HF_PATH)
158
+ end_time = time.time()
159
+ self.logger.info(
160
+ f"Time taken to Download database {self.config['vectorstore']['db_option']} from Hugging Face: {end_time - start_time} seconds"
161
+ )
162
+ self.logger.info("Downloaded database")
163
+
164
+ def __len__(self):
165
+ return len(self.vector_db)
166
+
167
 
168
  if __name__ == "__main__":
169
  import yaml
 
173
  print(config)
174
  print(f"Trying to create database with config: {config}")
175
  vector_db = VectorStoreManager(config)
176
+ if config["vectorstore"]["load_from_HF"]:
177
+ if config["vectorstore"]["db_option"] in RETRIEVER_HF_PATHS:
178
+ vector_db.load_from_HF(
179
+ HF_PATH=RETRIEVER_HF_PATHS[config["vectorstore"]["db_option"]]
180
+ )
181
+ else:
182
+ # print(f"HF_PATH not available for {config['vectorstore']['db_option']}")
183
+ # print("Creating database")
184
+ # vector_db.create_database()
185
+ raise ValueError(
186
+ f"HF_PATH not available for {config['vectorstore']['db_option']}"
187
+ )
188
+ else:
189
+ vector_db.create_database()
190
  print("Created database")
191
 
192
  print(f"Trying to load the database")
code/modules/vectorstore/vectorstore.py CHANGED
@@ -2,6 +2,9 @@ from modules.vectorstore.faiss import FaissVectorStore
2
  from modules.vectorstore.chroma import ChromaVectorStore
3
  from modules.vectorstore.colbert import ColbertVectorStore
4
  from modules.vectorstore.raptor import RAPTORVectoreStore
 
 
 
5
 
6
 
7
  class VectorStore:
@@ -50,8 +53,39 @@ class VectorStore:
50
  else:
51
  return self.vectorstore.load_database(embedding_model)
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def _as_retriever(self):
54
  return self.vectorstore.as_retriever()
55
 
56
  def _get_vectorstore(self):
57
  return self.vectorstore
 
 
 
 
2
  from modules.vectorstore.chroma import ChromaVectorStore
3
  from modules.vectorstore.colbert import ColbertVectorStore
4
  from modules.vectorstore.raptor import RAPTORVectoreStore
5
+ from huggingface_hub import snapshot_download
6
+ import os
7
+ import shutil
8
 
9
 
10
  class VectorStore:
 
53
  else:
54
  return self.vectorstore.load_database(embedding_model)
55
 
56
+ def _load_from_HF(self, HF_PATH):
57
+ # Download the snapshot from Hugging Face Hub
58
+ # Note: Download goes to the cache directory
59
+ snapshot_path = snapshot_download(
60
+ repo_id=HF_PATH,
61
+ repo_type="dataset",
62
+ force_download=True,
63
+ )
64
+
65
+ # Move the downloaded files to the desired directory
66
+ target_path = os.path.join(
67
+ self.config["vectorstore"]["db_path"],
68
+ "db_" + self.config["vectorstore"]["db_option"],
69
+ )
70
+
71
+ # Create target path if it doesn't exist
72
+ os.makedirs(target_path, exist_ok=True)
73
+
74
+ # move all files and directories from snapshot_path to target_path
75
+ # target path is used while loading the database
76
+ for item in os.listdir(snapshot_path):
77
+ s = os.path.join(snapshot_path, item)
78
+ d = os.path.join(target_path, item)
79
+ if os.path.isdir(s):
80
+ shutil.copytree(s, d, dirs_exist_ok=True)
81
+ else:
82
+ shutil.copy2(s, d)
83
+
84
  def _as_retriever(self):
85
  return self.vectorstore.as_retriever()
86
 
87
  def _get_vectorstore(self):
88
  return self.vectorstore
89
+
90
+ def __len__(self):
91
+ return self.vectorstore.__len__()
code/public/test.css CHANGED
@@ -31,3 +31,13 @@ a[href*='https://github.com/Chainlit/chainlit'] {
31
  .MuiAvatar-root.MuiAvatar-circular.css-v72an7 .MuiAvatar-img.css-1hy9t21 {
32
  display: none;
33
  }
 
 
 
 
 
 
 
 
 
 
 
31
  .MuiAvatar-root.MuiAvatar-circular.css-v72an7 .MuiAvatar-img.css-1hy9t21 {
32
  display: none;
33
  }
34
+
35
+ /* Hide the new chat button
36
+ #new-chat-button {
37
+ display: none;
38
+ } */
39
+
40
+ /* Hide the open sidebar button
41
+ #open-sidebar-button {
42
+ display: none;
43
+ } */
requirements.txt CHANGED
@@ -1,27 +1,25 @@
1
- # Automatically generated by https://github.com/damnever/pigar.
2
-
3
- aiohttp==3.9.5
4
- beautifulsoup4==4.12.3
5
- chainlit==1.1.304
6
- langchain==0.1.20
7
- langchain-community==0.0.38
8
- langchain-core==0.1.52
9
- literalai==0.0.604
10
- llama-parse==0.4.4
11
- numpy==1.26.4
12
- pandas==2.2.2
13
- pysrt==1.1.2
14
- python-dotenv==1.0.1
15
- PyYAML==6.0.1
16
- RAGatouille==0.0.8.post2
17
- requests==2.32.3
18
- scikit-learn==1.5.0
19
- torch==2.3.1
20
- tqdm==4.66.4
21
- transformers==4.41.2
22
- trulens_eval==0.31.0
23
- umap-learn==0.5.6
24
- trulens-eval==0.31.0
25
- llama-cpp-python==0.2.77
26
- pymupdf==1.24.5
27
  websockets
 
 
1
+ aiohttp
2
+ beautifulsoup4
3
+ chainlit
4
+ langchain
5
+ langchain-community
6
+ langchain-core
7
+ literalai
8
+ llama-parse
9
+ numpy
10
+ pandas
11
+ pysrt
12
+ python-dotenv
13
+ PyYAML
14
+ RAGatouille
15
+ requests
16
+ scikit-learn
17
+ torch
18
+ tqdm
19
+ transformers
20
+ trulens_eval
21
+ umap-learn
22
+ llama-cpp-python
23
+ pymupdf
 
 
 
24
  websockets
25
+ langchain-openai