Farid Karimli commited on
Commit
49a1201
·
2 Parent(s): 679cb58 ae2ff9e

Merged chainlit enhancements into text_extraction

Browse files
Dockerfile.dev CHANGED
@@ -25,7 +25,7 @@ RUN mkdir /.cache && chmod -R 777 /.cache
25
  WORKDIR /code/code
26
 
27
  # Expose the port the app runs on
28
- EXPOSE 8051
29
 
30
  # Default command to run the application
31
- CMD ["sh", "-c", "python -m modules.vectorstore.store_manager && chainlit run main.py --host 0.0.0.0 --port 8051"]
 
25
  WORKDIR /code/code
26
 
27
  # Expose the port the app runs on
28
+ EXPOSE 8000
29
 
30
  # Default command to run the application
31
+ CMD ["sh", "-c", "python -m modules.vectorstore.store_manager && chainlit run main.py --host 0.0.0.0 --port 8000"]
README.md CHANGED
@@ -76,7 +76,7 @@ The HuggingFace Space is built using the `Dockerfile` in the repository. To run
76
 
77
  ```bash
78
  docker build --tag dev -f Dockerfile.dev .
79
- docker run -it --rm -p 8051:8051 dev
80
  ```
81
 
82
  ## Contributing
 
76
 
77
  ```bash
78
  docker build --tag dev -f Dockerfile.dev .
79
+ docker run -it --rm -p 8000:8000 dev
80
  ```
81
 
82
  ## Contributing
code/.chainlit/translations/en-US.json ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "components": {
3
+ "atoms": {
4
+ "buttons": {
5
+ "userButton": {
6
+ "menu": {
7
+ "settings": "Settings",
8
+ "settingsKey": "S",
9
+ "APIKeys": "API Keys",
10
+ "logout": "Logout"
11
+ }
12
+ }
13
+ }
14
+ },
15
+ "molecules": {
16
+ "newChatButton": {
17
+ "newChat": "New Chat"
18
+ },
19
+ "tasklist": {
20
+ "TaskList": {
21
+ "title": "\ud83d\uddd2\ufe0f Task List",
22
+ "loading": "Loading...",
23
+ "error": "An error occurred"
24
+ }
25
+ },
26
+ "attachments": {
27
+ "cancelUpload": "Cancel upload",
28
+ "removeAttachment": "Remove attachment"
29
+ },
30
+ "newChatDialog": {
31
+ "createNewChat": "Create new chat?",
32
+ "clearChat": "This will clear the current messages and start a new chat.",
33
+ "cancel": "Cancel",
34
+ "confirm": "Confirm"
35
+ },
36
+ "settingsModal": {
37
+ "settings": "Settings",
38
+ "expandMessages": "Expand Messages",
39
+ "hideChainOfThought": "Hide Chain of Thought",
40
+ "darkMode": "Dark Mode"
41
+ },
42
+ "detailsButton": {
43
+ "using": "Using",
44
+ "used": "Used"
45
+ },
46
+ "auth": {
47
+ "authLogin": {
48
+ "title": "Login to access the app.",
49
+ "form": {
50
+ "email": "Email address",
51
+ "password": "Password",
52
+ "noAccount": "Don't have an account?",
53
+ "alreadyHaveAccount": "Already have an account?",
54
+ "signup": "Sign Up",
55
+ "signin": "Sign In",
56
+ "or": "OR",
57
+ "continue": "Continue",
58
+ "forgotPassword": "Forgot password?",
59
+ "passwordMustContain": "Your password must contain:",
60
+ "emailRequired": "email is a required field",
61
+ "passwordRequired": "password is a required field"
62
+ },
63
+ "error": {
64
+ "default": "Unable to sign in.",
65
+ "signin": "Try signing in with a different account.",
66
+ "oauthsignin": "Try signing in with a different account.",
67
+ "redirect_uri_mismatch": "The redirect URI is not matching the oauth app configuration.",
68
+ "oauthcallbackerror": "Try signing in with a different account.",
69
+ "oauthcreateaccount": "Try signing in with a different account.",
70
+ "emailcreateaccount": "Try signing in with a different account.",
71
+ "callback": "Try signing in with a different account.",
72
+ "oauthaccountnotlinked": "To confirm your identity, sign in with the same account you used originally.",
73
+ "emailsignin": "The e-mail could not be sent.",
74
+ "emailverify": "Please verify your email, a new email has been sent.",
75
+ "credentialssignin": "Sign in failed. Check the details you provided are correct.",
76
+ "sessionrequired": "Please sign in to access this page."
77
+ }
78
+ },
79
+ "authVerifyEmail": {
80
+ "almostThere": "You're almost there! We've sent an email to ",
81
+ "verifyEmailLink": "Please click on the link in that email to complete your signup.",
82
+ "didNotReceive": "Can't find the email?",
83
+ "resendEmail": "Resend email",
84
+ "goBack": "Go Back",
85
+ "emailSent": "Email sent successfully.",
86
+ "verifyEmail": "Verify your email address"
87
+ },
88
+ "providerButton": {
89
+ "continue": "Continue with {{provider}}",
90
+ "signup": "Sign up with {{provider}}"
91
+ },
92
+ "authResetPassword": {
93
+ "newPasswordRequired": "New password is a required field",
94
+ "passwordsMustMatch": "Passwords must match",
95
+ "confirmPasswordRequired": "Confirm password is a required field",
96
+ "newPassword": "New password",
97
+ "confirmPassword": "Confirm password",
98
+ "resetPassword": "Reset Password"
99
+ },
100
+ "authForgotPassword": {
101
+ "email": "Email address",
102
+ "emailRequired": "email is a required field",
103
+ "emailSent": "Please check the email address {{email}} for instructions to reset your password.",
104
+ "enterEmail": "Enter your email address and we will send you instructions to reset your password.",
105
+ "resendEmail": "Resend email",
106
+ "continue": "Continue",
107
+ "goBack": "Go Back"
108
+ }
109
+ }
110
+ },
111
+ "organisms": {
112
+ "chat": {
113
+ "history": {
114
+ "index": {
115
+ "showHistory": "Show history",
116
+ "lastInputs": "Last Inputs",
117
+ "noInputs": "Such empty...",
118
+ "loading": "Loading..."
119
+ }
120
+ },
121
+ "inputBox": {
122
+ "input": {
123
+ "placeholder": "Type your message here..."
124
+ },
125
+ "speechButton": {
126
+ "start": "Start recording",
127
+ "stop": "Stop recording"
128
+ },
129
+ "SubmitButton": {
130
+ "sendMessage": "Send message",
131
+ "stopTask": "Stop Task"
132
+ },
133
+ "UploadButton": {
134
+ "attachFiles": "Attach files"
135
+ },
136
+ "waterMark": {
137
+ "text": "Built with"
138
+ }
139
+ },
140
+ "Messages": {
141
+ "index": {
142
+ "running": "Running",
143
+ "executedSuccessfully": "executed successfully",
144
+ "failed": "failed",
145
+ "feedbackUpdated": "Feedback updated",
146
+ "updating": "Updating"
147
+ }
148
+ },
149
+ "dropScreen": {
150
+ "dropYourFilesHere": "Drop your files here"
151
+ },
152
+ "index": {
153
+ "failedToUpload": "Failed to upload",
154
+ "cancelledUploadOf": "Cancelled upload of",
155
+ "couldNotReachServer": "Could not reach the server",
156
+ "continuingChat": "Continuing previous chat"
157
+ },
158
+ "settings": {
159
+ "settingsPanel": "Settings panel",
160
+ "reset": "Reset",
161
+ "cancel": "Cancel",
162
+ "confirm": "Confirm"
163
+ }
164
+ },
165
+ "threadHistory": {
166
+ "sidebar": {
167
+ "filters": {
168
+ "FeedbackSelect": {
169
+ "feedbackAll": "Feedback: All",
170
+ "feedbackPositive": "Feedback: Positive",
171
+ "feedbackNegative": "Feedback: Negative"
172
+ },
173
+ "SearchBar": {
174
+ "search": "Search"
175
+ }
176
+ },
177
+ "DeleteThreadButton": {
178
+ "confirmMessage": "This will delete the thread as well as it's messages and elements.",
179
+ "cancel": "Cancel",
180
+ "confirm": "Confirm",
181
+ "deletingChat": "Deleting chat",
182
+ "chatDeleted": "Chat deleted"
183
+ },
184
+ "index": {
185
+ "pastChats": "Past Chats"
186
+ },
187
+ "ThreadList": {
188
+ "empty": "Empty...",
189
+ "today": "Today",
190
+ "yesterday": "Yesterday",
191
+ "previous7days": "Previous 7 days",
192
+ "previous30days": "Previous 30 days"
193
+ },
194
+ "TriggerButton": {
195
+ "closeSidebar": "Close sidebar",
196
+ "openSidebar": "Open sidebar"
197
+ }
198
+ },
199
+ "Thread": {
200
+ "backToChat": "Go back to chat",
201
+ "chatCreatedOn": "This chat was created on"
202
+ }
203
+ },
204
+ "header": {
205
+ "chat": "Chat",
206
+ "readme": "Readme"
207
+ }
208
+ }
209
+ },
210
+ "hooks": {
211
+ "useLLMProviders": {
212
+ "failedToFetchProviders": "Failed to fetch providers:"
213
+ }
214
+ },
215
+ "pages": {
216
+ "Design": {},
217
+ "Env": {
218
+ "savedSuccessfully": "Saved successfully",
219
+ "requiredApiKeys": "Required API Keys",
220
+ "requiredApiKeysInfo": "To use this app, the following API keys are required. The keys are stored on your device's local storage."
221
+ },
222
+ "Page": {
223
+ "notPartOfProject": "You are not part of this project."
224
+ },
225
+ "ResumeButton": {
226
+ "resumeChat": "Resume Chat"
227
+ }
228
+ }
229
+ }
code/main.py CHANGED
@@ -1,178 +1,460 @@
1
- from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
2
- from langchain_core.prompts import PromptTemplate
3
- from langchain_community.embeddings import HuggingFaceEmbeddings
4
- from langchain_community.vectorstores import FAISS
5
- from langchain.chains import RetrievalQA
6
- import chainlit as cl
7
- from langchain_community.chat_models import ChatOpenAI
8
- from langchain_community.embeddings import OpenAIEmbeddings
9
- import yaml
10
- import logging
11
- from dotenv import load_dotenv
12
 
 
 
 
 
 
13
  from modules.chat.llm_tutor import LLMTutor
14
- from modules.config.constants import *
15
- from modules.chat.helpers import get_sources
16
- from modules.chat_processor.chat_processor import ChatProcessor
17
-
18
- global logger
19
- # Initialize logger
20
- logger = logging.getLogger(__name__)
21
- logger.setLevel(logging.INFO)
22
- formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
23
-
24
- # Console Handler
25
- console_handler = logging.StreamHandler()
26
- console_handler.setLevel(logging.INFO)
27
- console_handler.setFormatter(formatter)
28
- logger.addHandler(console_handler)
29
-
30
-
31
- @cl.set_starters
32
- async def set_starters():
33
- return [
34
- cl.Starter(
35
- label="recording on CNNs?",
36
- message="Where can I find the recording for the lecture on Transfromers?",
37
- icon="/public/adv-screen-recorder-svgrepo-com.svg",
38
- ),
39
- cl.Starter(
40
- label="where's the slides?",
41
- message="When are the lectures? I can't find the schedule.",
42
- icon="/public/alarmy-svgrepo-com.svg",
43
- ),
44
- cl.Starter(
45
- label="Due Date?",
46
- message="When is the final project due?",
47
- icon="/public/calendar-samsung-17-svgrepo-com.svg",
48
- ),
49
- cl.Starter(
50
- label="Explain backprop.",
51
- message="I didnt understand the math behind backprop, could you explain it?",
52
- icon="/public/acastusphoton-svgrepo-com.svg",
53
- ),
54
- ]
55
-
56
-
57
- # Adding option to select the chat profile
58
- @cl.set_chat_profiles
59
- async def chat_profile():
60
- return [
61
- # cl.ChatProfile(
62
- # name="Mistral",
63
- # markdown_description="Use the local LLM: **Mistral**.",
64
- # ),
65
- cl.ChatProfile(
66
- name="gpt-3.5-turbo-1106",
67
- markdown_description="Use OpenAI API for **gpt-3.5-turbo-1106**.",
68
- ),
69
- cl.ChatProfile(
70
- name="gpt-4",
71
- markdown_description="Use OpenAI API for **gpt-4**.",
72
- ),
73
- cl.ChatProfile(
74
- name="Llama",
75
- markdown_description="Use the local LLM: **Tiny Llama**.",
76
- ),
77
- ]
78
-
79
-
80
- @cl.author_rename
81
- def rename(orig_author: str):
82
- rename_dict = {"Chatbot": "AI Tutor"}
83
- return rename_dict.get(orig_author, orig_author)
84
-
85
-
86
- # chainlit code
87
- @cl.on_chat_start
88
- async def start():
89
- with open("modules/config/config.yml", "r") as f:
90
- config = yaml.safe_load(f)
91
-
92
- # Ensure log directory exists
93
- log_directory = config["log_dir"]
94
- if not os.path.exists(log_directory):
95
- os.makedirs(log_directory)
96
-
97
- # File Handler
98
- log_file_path = (
99
- f"{log_directory}/tutor.log" # Change this to your desired log file path
100
- )
101
- file_handler = logging.FileHandler(log_file_path, mode="w")
102
- file_handler.setLevel(logging.INFO)
103
- file_handler.setFormatter(formatter)
104
- logger.addHandler(file_handler)
105
-
106
- logger.info("Config file loaded")
107
- logger.info(f"Config: {config}")
108
- logger.info("Creating llm_tutor instance")
109
-
110
- chat_profile = cl.user_session.get("chat_profile")
111
- if chat_profile is not None:
112
- if chat_profile.lower() in ["gpt-3.5-turbo-1106", "gpt-4"]:
113
- config["llm_params"]["llm_loader"] = "openai"
114
- config["llm_params"]["openai_params"]["model"] = chat_profile.lower()
115
- elif chat_profile.lower() == "llama":
116
- config["llm_params"]["llm_loader"] = "local_llm"
117
- config["llm_params"]["local_llm_params"]["model"] = LLAMA_PATH
118
- config["llm_params"]["local_llm_params"]["model_type"] = "llama"
119
- elif chat_profile.lower() == "mistral":
120
- config["llm_params"]["llm_loader"] = "local_llm"
121
- config["llm_params"]["local_llm_params"]["model"] = MISTRAL_PATH
122
- config["llm_params"]["local_llm_params"]["model_type"] = "mistral"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  else:
125
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- llm_tutor = LLMTutor(config, logger=logger)
 
 
 
 
 
 
 
128
 
129
- chain = llm_tutor.qa_bot()
130
- # msg = cl.Message(content=f"Starting the bot {chat_profile}...")
131
- # await msg.send()
132
- # msg.content = opening_message
133
- # await msg.update()
134
 
135
- tags = [chat_profile, config["vectorstore"]["db_option"]]
136
- chat_processor = ChatProcessor(config, tags=tags)
137
- cl.user_session.set("chain", chain)
138
- cl.user_session.set("counter", 0)
139
- cl.user_session.set("chat_processor", chat_processor)
 
 
 
140
 
 
141
 
142
- @cl.on_chat_end
143
- async def on_chat_end():
144
- await cl.Message(content="Sorry, I have to go now. Goodbye!").send()
 
 
 
145
 
 
 
 
 
 
 
146
 
147
- @cl.on_message
148
- async def main(message):
149
- global logger
150
- user = cl.user_session.get("user")
151
- chain = cl.user_session.get("chain")
 
 
 
152
 
153
- counter = cl.user_session.get("counter")
154
- counter += 1
155
- cl.user_session.set("counter", counter)
 
 
 
 
156
 
157
- # if counter >= 3: # Ensure the counter condition is checked
158
- # await cl.Message(content="Your credits are up!").send()
159
- # await on_chat_end() # Call the on_chat_end function to handle the end of the chat
160
- # return # Exit the function to stop further processing
161
- # else:
162
 
163
- cb = cl.AsyncLangchainCallbackHandler() # TODO: fix streaming here
164
- cb.answer_reached = True
165
 
166
- processor = cl.user_session.get("chat_processor")
167
- res = await processor.rag(message.content, chain, cb)
168
- try:
169
- answer = res["answer"]
170
- except:
171
- answer = res["result"]
172
 
173
- answer_with_sources, source_elements, sources_dict = get_sources(res, answer)
174
- processor._process(message.content, answer, sources_dict)
 
 
 
 
 
 
 
 
175
 
176
- answer_with_sources = answer_with_sources.replace("$$", "$")
177
 
178
- await cl.Message(content=answer_with_sources, elements=source_elements).send()
 
1
+ import chainlit.data as cl_data
2
+ import asyncio
3
+ from modules.config.constants import (
4
+ LLAMA_PATH,
5
+ LITERAL_API_KEY_LOGGING,
6
+ LITERAL_API_URL,
7
+ )
8
+ from modules.chat_processor.literal_ai import CustomLiteralDataLayer
 
 
 
9
 
10
+ import json
11
+ import yaml
12
+ import os
13
+ from typing import Any, Dict, no_type_check
14
+ import chainlit as cl
15
  from modules.chat.llm_tutor import LLMTutor
16
+ from modules.chat.helpers import (
17
+ get_sources,
18
+ get_history_chat_resume,
19
+ get_history_setup_llm,
20
+ )
21
+ import copy
22
+ from typing import Optional
23
+ from chainlit.types import ThreadDict
24
+ import time
25
+
26
+ USER_TIMEOUT = 60_000
27
+ SYSTEM = "System 🖥️"
28
+ LLM = "LLM 🧠"
29
+ AGENT = "Agent <>"
30
+ YOU = "You 😃"
31
+ ERROR = "Error 🚫"
32
+
33
+ with open("modules/config/config.yml", "r") as f:
34
+ config = yaml.safe_load(f)
35
+
36
+
37
+ async def setup_data_layer():
38
+ """
39
+ Set up the data layer for chat logging.
40
+ """
41
+ if config["chat_logging"]["log_chat"]:
42
+ data_layer = CustomLiteralDataLayer(
43
+ api_key=LITERAL_API_KEY_LOGGING, server=LITERAL_API_URL
44
+ )
45
+ else:
46
+ data_layer = None
47
+
48
+ return data_layer
49
+
50
+
51
+ class Chatbot:
52
+ def __init__(self, config):
53
+ """
54
+ Initialize the Chatbot class.
55
+ """
56
+ self.config = config
57
+
58
+ def _load_config(self):
59
+ """
60
+ Load the configuration from a YAML file.
61
+ """
62
+ with open("modules/config/config.yml", "r") as f:
63
+ return yaml.safe_load(f)
64
+
65
+ @no_type_check
66
+ async def setup_llm(self):
67
+ """
68
+ Set up the LLM with the provided settings. Update the configuration and initialize the LLM tutor.
69
+ """
70
+ start_time = time.time()
71
+
72
+ llm_settings = cl.user_session.get("llm_settings", {})
73
+ chat_profile, retriever_method, memory_window, llm_style, generate_follow_up = (
74
+ llm_settings.get("chat_model"),
75
+ llm_settings.get("retriever_method"),
76
+ llm_settings.get("memory_window"),
77
+ llm_settings.get("llm_style"),
78
+ llm_settings.get("follow_up_questions"),
79
+ )
80
+
81
+ chain = cl.user_session.get("chain")
82
+ memory_list = cl.user_session.get(
83
+ "memory",
84
+ (
85
+ list(chain.store.values())[0].messages
86
+ if len(chain.store.values()) > 0
87
+ else []
88
+ ),
89
+ )
90
+ conversation_list = get_history_setup_llm(memory_list)
91
+
92
+ old_config = copy.deepcopy(self.config)
93
+ self.config["vectorstore"]["db_option"] = retriever_method
94
+ self.config["llm_params"]["memory_window"] = memory_window
95
+ self.config["llm_params"]["llm_style"] = llm_style
96
+ self.config["llm_params"]["llm_loader"] = chat_profile
97
+ self.config["llm_params"]["generate_follow_up"] = generate_follow_up
98
+
99
+ self.llm_tutor.update_llm(
100
+ old_config, self.config
101
+ ) # update only llm attributes that are changed
102
+ self.chain = self.llm_tutor.qa_bot(
103
+ memory=conversation_list,
104
+ callbacks=(
105
+ [cl.LangchainCallbackHandler()]
106
+ if cl_data._data_layer and self.config["chat_logging"]["callbacks"]
107
+ else None
108
+ ),
109
+ )
110
+
111
+ tags = [chat_profile, self.config["vectorstore"]["db_option"]]
112
+
113
+ cl.user_session.set("chain", self.chain)
114
+ cl.user_session.set("llm_tutor", self.llm_tutor)
115
+
116
+ print("Time taken to setup LLM: ", time.time() - start_time)
117
+
118
+ @no_type_check
119
+ async def update_llm(self, new_settings: Dict[str, Any]):
120
+ """
121
+ Update the LLM settings and reinitialize the LLM with the new settings.
122
+
123
+ Args:
124
+ new_settings (Dict[str, Any]): The new settings to update.
125
+ """
126
+ cl.user_session.set("llm_settings", new_settings)
127
+ await self.inform_llm_settings()
128
+ await self.setup_llm()
129
+
130
+ async def make_llm_settings_widgets(self, config=None):
131
+ """
132
+ Create and send the widgets for LLM settings configuration.
133
+
134
+ Args:
135
+ config: The configuration to use for setting up the widgets.
136
+ """
137
+ config = config or self.config
138
+ await cl.ChatSettings(
139
+ [
140
+ cl.input_widget.Select(
141
+ id="chat_model",
142
+ label="Model Name (Default GPT-3)",
143
+ values=["local_llm", "gpt-3.5-turbo-1106", "gpt-4", "gpt-4o-mini"],
144
+ initial_index=[
145
+ "local_llm",
146
+ "gpt-3.5-turbo-1106",
147
+ "gpt-4",
148
+ "gpt-4o-mini",
149
+ ].index(config["llm_params"]["llm_loader"]),
150
+ ),
151
+ cl.input_widget.Select(
152
+ id="retriever_method",
153
+ label="Retriever (Default FAISS)",
154
+ values=["FAISS", "Chroma", "RAGatouille", "RAPTOR"],
155
+ initial_index=["FAISS", "Chroma", "RAGatouille", "RAPTOR"].index(
156
+ config["vectorstore"]["db_option"]
157
+ ),
158
+ ),
159
+ cl.input_widget.Slider(
160
+ id="memory_window",
161
+ label="Memory Window (Default 3)",
162
+ initial=3,
163
+ min=0,
164
+ max=10,
165
+ step=1,
166
+ ),
167
+ cl.input_widget.Switch(
168
+ id="view_sources", label="View Sources", initial=False
169
+ ),
170
+ cl.input_widget.Switch(
171
+ id="stream_response",
172
+ label="Stream response",
173
+ initial=config["llm_params"]["stream"],
174
+ ),
175
+ cl.input_widget.Switch(
176
+ id="follow_up_questions",
177
+ label="Generate follow up questions",
178
+ initial=False,
179
+ ),
180
+ cl.input_widget.Select(
181
+ id="llm_style",
182
+ label="Type of Conversation (Default Normal)",
183
+ values=["Normal", "ELI5"],
184
+ initial_index=0,
185
+ ),
186
+ ]
187
+ ).send()
188
+
189
+ @no_type_check
190
+ async def inform_llm_settings(self):
191
+ """
192
+ Inform the user about the updated LLM settings and display them as a message.
193
+ """
194
+ llm_settings: Dict[str, Any] = cl.user_session.get("llm_settings", {})
195
+ llm_tutor = cl.user_session.get("llm_tutor")
196
+ settings_dict = {
197
+ "model": llm_settings.get("chat_model"),
198
+ "retriever": llm_settings.get("retriever_method"),
199
+ "memory_window": llm_settings.get("memory_window"),
200
+ "num_docs_in_db": (
201
+ len(llm_tutor.vector_db)
202
+ if llm_tutor and hasattr(llm_tutor, "vector_db")
203
+ else 0
204
+ ),
205
+ "view_sources": llm_settings.get("view_sources"),
206
+ "follow_up_questions": llm_settings.get("follow_up_questions"),
207
+ }
208
+ await cl.Message(
209
+ author=SYSTEM,
210
+ content="LLM settings have been updated. You can continue with your Query!",
211
+ elements=[
212
+ cl.Text(
213
+ name="settings",
214
+ display="side",
215
+ content=json.dumps(settings_dict, indent=4),
216
+ language="json",
217
+ ),
218
+ ],
219
+ ).send()
220
+
221
+ async def set_starters(self):
222
+ """
223
+ Set starter messages for the chatbot.
224
+ """
225
+ # Return Starters only if the chat is new
226
+
227
+ try:
228
+ thread = cl_data._data_layer.get_thread(
229
+ cl.context.session.thread_id
230
+ ) # see if the thread has any steps
231
+ if thread.steps or len(thread.steps) > 0:
232
+ return None
233
+ except:
234
+ return [
235
+ cl.Starter(
236
+ label="recording on CNNs?",
237
+ message="Where can I find the recording for the lecture on Transformers?",
238
+ icon="/public/adv-screen-recorder-svgrepo-com.svg",
239
+ ),
240
+ cl.Starter(
241
+ label="where's the slides?",
242
+ message="When are the lectures? I can't find the schedule.",
243
+ icon="/public/alarmy-svgrepo-com.svg",
244
+ ),
245
+ cl.Starter(
246
+ label="Due Date?",
247
+ message="When is the final project due?",
248
+ icon="/public/calendar-samsung-17-svgrepo-com.svg",
249
+ ),
250
+ cl.Starter(
251
+ label="Explain backprop.",
252
+ message="I didn't understand the math behind backprop, could you explain it?",
253
+ icon="/public/acastusphoton-svgrepo-com.svg",
254
+ ),
255
+ ]
256
 
257
+ def rename(self, orig_author: str):
258
+ """
259
+ Rename the original author to a more user-friendly name.
260
+
261
+ Args:
262
+ orig_author (str): The original author's name.
263
+
264
+ Returns:
265
+ str: The renamed author.
266
+ """
267
+ rename_dict = {"Chatbot": "AI Tutor"}
268
+ return rename_dict.get(orig_author, orig_author)
269
+
270
+ async def start(self):
271
+ """
272
+ Start the chatbot, initialize settings widgets,
273
+ and display and load previous conversation if chat logging is enabled.
274
+ """
275
+
276
+ start_time = time.time()
277
+
278
+ await self.make_llm_settings_widgets(self.config)
279
+ user = cl.user_session.get("user")
280
+ self.user = {
281
+ "user_id": user.identifier,
282
+ "session_id": cl.context.session.thread_id,
283
+ }
284
+
285
+ memory = cl.user_session.get("memory", [])
286
+
287
+ cl.user_session.set("user", self.user)
288
+ self.llm_tutor = LLMTutor(self.config, user=self.user)
289
+
290
+ self.chain = self.llm_tutor.qa_bot(
291
+ memory=memory,
292
+ callbacks=(
293
+ [cl.LangchainCallbackHandler()]
294
+ if cl_data._data_layer and self.config["chat_logging"]["callbacks"]
295
+ else None
296
+ ),
297
+ )
298
+ self.question_generator = self.llm_tutor.question_generator
299
+ cl.user_session.set("llm_tutor", self.llm_tutor)
300
+ cl.user_session.set("chain", self.chain)
301
+
302
+ print("Time taken to start LLM: ", time.time() - start_time)
303
+
304
+ async def stream_response(self, response):
305
+ """
306
+ Stream the response from the LLM.
307
+
308
+ Args:
309
+ response: The response from the LLM.
310
+ """
311
+ msg = cl.Message(content="")
312
+ await msg.send()
313
+
314
+ output = {}
315
+ for chunk in response:
316
+ if "answer" in chunk:
317
+ await msg.stream_token(chunk["answer"])
318
+
319
+ for key in chunk:
320
+ if key not in output:
321
+ output[key] = chunk[key]
322
+ else:
323
+ output[key] += chunk[key]
324
+ return output
325
+
326
+ async def main(self, message):
327
+ """
328
+ Process and Display the Conversation.
329
+
330
+ Args:
331
+ message: The incoming chat message.
332
+ """
333
+
334
+ start_time = time.time()
335
+
336
+ chain = cl.user_session.get("chain")
337
+
338
+ llm_settings = cl.user_session.get("llm_settings", {})
339
+ view_sources = llm_settings.get("view_sources", False)
340
+ stream = llm_settings.get("stream_response", False)
341
+ steam = False # Fix streaming
342
+ user_query_dict = {"input": message.content}
343
+ # Define the base configuration
344
+ chain_config = {
345
+ "configurable": {
346
+ "user_id": self.user["user_id"],
347
+ "conversation_id": self.user["session_id"],
348
+ "memory_window": self.config["llm_params"]["memory_window"],
349
+ }
350
+ }
351
+
352
+ if stream:
353
+ res = chain.stream(user_query=user_query_dict, config=chain_config)
354
+ res = await self.stream_response(res)
355
  else:
356
+ res = await chain.invoke(
357
+ user_query=user_query_dict,
358
+ config=chain_config,
359
+ )
360
+
361
+ answer = res.get("answer", res.get("result"))
362
+
363
+ if cl_data._data_layer is not None:
364
+ with cl_data._data_layer.client.step(
365
+ type="run",
366
+ name="step_info",
367
+ thread_id=cl.context.session.thread_id,
368
+ # tags=self.tags,
369
+ ) as step:
370
+
371
+ step.input = {"question": user_query_dict["input"]}
372
+
373
+ step.output = {
374
+ "chat_history": res.get("chat_history"),
375
+ "context": res.get("context"),
376
+ "answer": answer,
377
+ "rephrase_prompt": res.get("rephrase_prompt"),
378
+ "qa_prompt": res.get("qa_prompt"),
379
+ }
380
+ step.metadata = self.config
381
+
382
+ answer_with_sources, source_elements, sources_dict = get_sources(
383
+ res, answer, stream=stream, view_sources=view_sources
384
+ )
385
+ answer_with_sources = answer_with_sources.replace("$$", "$")
386
+
387
+ print("Time taken to process the message: ", time.time() - start_time)
388
+
389
+ actions = []
390
 
391
+ if self.config["llm_params"]["generate_follow_up"]:
392
+ start_time = time.time()
393
+ list_of_questions = self.question_generator.generate_questions(
394
+ query=user_query_dict["input"],
395
+ response=answer,
396
+ chat_history=res.get("chat_history"),
397
+ context=res.get("context"),
398
+ )
399
 
400
+ for question in list_of_questions:
 
 
 
 
401
 
402
+ actions.append(
403
+ cl.Action(
404
+ name="follow up question",
405
+ value="example_value",
406
+ description=question,
407
+ label=question,
408
+ )
409
+ )
410
 
411
+ print("Time taken to generate questions: ", time.time() - start_time)
412
 
413
+ await cl.Message(
414
+ content=answer_with_sources,
415
+ elements=source_elements,
416
+ author=LLM,
417
+ actions=actions,
418
+ ).send()
419
 
420
+ async def on_chat_resume(self, thread: ThreadDict):
421
+ steps = thread["steps"]
422
+ k = self.config["llm_params"]["memory_window"]
423
+ conversation_list = get_history_chat_resume(steps, k, SYSTEM, LLM)
424
+ cl.user_session.set("memory", conversation_list)
425
+ await self.start()
426
 
427
+ @cl.oauth_callback
428
+ def auth_callback(
429
+ provider_id: str,
430
+ token: str,
431
+ raw_user_data: Dict[str, str],
432
+ default_user: cl.User,
433
+ ) -> Optional[cl.User]:
434
+ return default_user
435
 
436
+ async def on_follow_up(self, action: cl.Action):
437
+ message = await cl.Message(
438
+ content=action.description,
439
+ type="user_message",
440
+ author=self.user["user_id"],
441
+ ).send()
442
+ await self.main(message)
443
 
 
 
 
 
 
444
 
445
+ chatbot = Chatbot(config=config)
 
446
 
 
 
 
 
 
 
447
 
448
+ async def start_app():
449
+ cl_data._data_layer = await setup_data_layer()
450
+ chatbot.literal_client = cl_data._data_layer.client if cl_data._data_layer else None
451
+ cl.set_starters(chatbot.set_starters)
452
+ cl.author_rename(chatbot.rename)
453
+ cl.on_chat_start(chatbot.start)
454
+ cl.on_chat_resume(chatbot.on_chat_resume)
455
+ cl.on_message(chatbot.main)
456
+ cl.on_settings_update(chatbot.update_llm)
457
+ cl.action_callback("follow up question")(chatbot.on_follow_up)
458
 
 
459
 
460
+ asyncio.run(start_app())
code/modules/chat/base.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class BaseRAG:
2
+ """
3
+ Base class for RAG chatbot.
4
+ """
5
+
6
+ def __init__():
7
+ pass
8
+
9
+ def invoke():
10
+ """
11
+ Invoke the RAG chatbot.
12
+ """
13
+ pass
code/modules/chat/chat_model_loader.py CHANGED
@@ -1,4 +1,4 @@
1
- from langchain_community.chat_models import ChatOpenAI
2
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
3
  from transformers import AutoTokenizer, TextStreamer
4
  from langchain_community.llms import LlamaCpp
@@ -7,6 +7,7 @@ import transformers
7
  import os
8
  from langchain.callbacks.manager import CallbackManager
9
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 
10
 
11
 
12
  class ChatModelLoader:
@@ -15,15 +16,16 @@ class ChatModelLoader:
15
  self.huggingface_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
16
 
17
  def load_chat_model(self):
18
- if self.config["llm_params"]["llm_loader"] == "openai":
19
- llm = ChatOpenAI(
20
- model_name=self.config["llm_params"]["openai_params"]["model"]
21
- )
 
 
22
  elif self.config["llm_params"]["llm_loader"] == "local_llm":
23
  n_batch = 512 # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
24
- model_path = self.config["llm_params"]["local_llm_params"]["model"]
25
  llm = LlamaCpp(
26
- model_path=model_path,
27
  n_batch=n_batch,
28
  n_ctx=2048,
29
  f16_kv=True,
@@ -34,5 +36,7 @@ class ChatModelLoader:
34
  ],
35
  )
36
  else:
37
- raise ValueError("Invalid LLM Loader")
 
 
38
  return llm
 
1
+ from langchain_openai import ChatOpenAI
2
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
3
  from transformers import AutoTokenizer, TextStreamer
4
  from langchain_community.llms import LlamaCpp
 
7
  import os
8
  from langchain.callbacks.manager import CallbackManager
9
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
10
+ from modules.config.constants import LLAMA_PATH
11
 
12
 
13
  class ChatModelLoader:
 
16
  self.huggingface_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
17
 
18
  def load_chat_model(self):
19
+ if self.config["llm_params"]["llm_loader"] in [
20
+ "gpt-3.5-turbo-1106",
21
+ "gpt-4",
22
+ "gpt-4o-mini",
23
+ ]:
24
+ llm = ChatOpenAI(model_name=self.config["llm_params"]["llm_loader"])
25
  elif self.config["llm_params"]["llm_loader"] == "local_llm":
26
  n_batch = 512 # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
 
27
  llm = LlamaCpp(
28
+ model_path=LLAMA_PATH,
29
  n_batch=n_batch,
30
  n_ctx=2048,
31
  f16_kv=True,
 
36
  ],
37
  )
38
  else:
39
+ raise ValueError(
40
+ f"Invalid LLM Loader: {self.config['llm_params']['llm_loader']}"
41
+ )
42
  return llm
code/modules/chat/helpers.py CHANGED
@@ -1,13 +1,12 @@
1
- from modules.config.constants import *
2
  import chainlit as cl
3
- from langchain_core.prompts import PromptTemplate
4
 
5
 
6
- def get_sources(res, answer):
7
  source_elements = []
8
  source_dict = {} # Dictionary to store URL elements
9
 
10
- for idx, source in enumerate(res["source_documents"]):
11
  source_metadata = source.metadata
12
  url = source_metadata.get("source", "N/A")
13
  score = source_metadata.get("score", "N/A")
@@ -36,69 +35,130 @@ def get_sources(res, answer):
36
  else:
37
  source_dict[url_name]["text"] += f"\n\n{source.page_content}"
38
 
39
- # First, display the answer
40
- full_answer = "**Answer:**\n"
41
- full_answer += answer
42
 
43
- # Then, display the sources
44
- full_answer += "\n\n**Sources:**\n"
45
- for idx, (url_name, source_data) in enumerate(source_dict.items()):
46
- full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
47
 
48
- name = f"Source {idx + 1} Text\n"
49
- full_answer += name
50
- source_elements.append(
51
- cl.Text(name=name, content=source_data["text"], display="side")
52
- )
53
 
54
- # Add a PDF element if the source is a PDF file
55
- if source_data["url"].lower().endswith(".pdf"):
56
- name = f"Source {idx + 1} PDF\n"
57
- full_answer += name
58
- pdf_url = f"{source_data['url']}#page={source_data['page']+1}"
59
- source_elements.append(cl.Pdf(name=name, url=pdf_url, display="side"))
60
-
61
- full_answer += "\n**Metadata:**\n"
62
- for idx, (url_name, source_data) in enumerate(source_dict.items()):
63
- full_answer += f"\nSource {idx + 1} Metadata:\n"
64
- source_elements.append(
65
- cl.Text(
66
- name=f"Source {idx + 1} Metadata",
67
- content=f"Source: {source_data['url']}\n"
68
- f"Page: {source_data['page']}\n"
69
- f"Type: {source_data['source_type']}\n"
70
- f"Date: {source_data['date']}\n"
71
- f"TL;DR: {source_data['lecture_tldr']}\n"
72
- f"Lecture Recording: {source_data['lecture_recording']}\n"
73
- f"Suggested Readings: {source_data['suggested_readings']}\n",
74
- display="side",
75
- )
76
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  return full_answer, source_elements, source_dict
79
 
80
 
81
- def get_prompt(config):
82
- if config["llm_params"]["use_history"]:
83
- if config["llm_params"]["llm_loader"] == "local_llm":
84
- custom_prompt_template = tinyllama_prompt_template_with_history
85
- elif config["llm_params"]["llm_loader"] == "openai":
86
- custom_prompt_template = openai_prompt_template_with_history
87
- # else:
88
- # custom_prompt_template = tinyllama_prompt_template_with_history # default
89
- prompt = PromptTemplate(
90
- template=custom_prompt_template,
91
- input_variables=["context", "chat_history", "question"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  )
93
- else:
94
- if config["llm_params"]["llm_loader"] == "local_llm":
95
- custom_prompt_template = tinyllama_prompt_template
96
- elif config["llm_params"]["llm_loader"] == "openai":
97
- custom_prompt_template = openai_prompt_template
98
- # else:
99
- # custom_prompt_template = tinyllama_prompt_template
100
- prompt = PromptTemplate(
101
- template=custom_prompt_template,
102
- input_variables=["context", "question"],
103
  )
104
- return prompt
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.config.prompts import prompts
2
  import chainlit as cl
 
3
 
4
 
5
+ def get_sources(res, answer, stream=True, view_sources=False):
6
  source_elements = []
7
  source_dict = {} # Dictionary to store URL elements
8
 
9
+ for idx, source in enumerate(res["context"]):
10
  source_metadata = source.metadata
11
  url = source_metadata.get("source", "N/A")
12
  score = source_metadata.get("score", "N/A")
 
35
  else:
36
  source_dict[url_name]["text"] += f"\n\n{source.page_content}"
37
 
38
+ full_answer = "" # Not to include the answer again if streaming
 
 
39
 
40
+ if not stream: # First, display the answer if not streaming
41
+ full_answer = "**Answer:**\n"
42
+ full_answer += answer
 
43
 
44
+ if view_sources:
 
 
 
 
45
 
46
+ # Then, display the sources
47
+ # check if the answer has sources
48
+ if len(source_dict) == 0:
49
+ full_answer += "\n\n**No sources found.**"
50
+ return full_answer, source_elements, source_dict
51
+ else:
52
+ full_answer += "\n\n**Sources:**\n"
53
+ for idx, (url_name, source_data) in enumerate(source_dict.items()):
54
+
55
+ full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
56
+
57
+ name = f"Source {idx + 1} Text\n"
58
+ full_answer += name
59
+ source_elements.append(
60
+ cl.Text(name=name, content=source_data["text"], display="side")
61
+ )
62
+
63
+ # Add a PDF element if the source is a PDF file
64
+ if source_data["url"].lower().endswith(".pdf"):
65
+ name = f"Source {idx + 1} PDF\n"
66
+ full_answer += name
67
+ pdf_url = f"{source_data['url']}#page={source_data['page']+1}"
68
+ source_elements.append(
69
+ cl.Pdf(name=name, url=pdf_url, display="side")
70
+ )
71
+
72
+ full_answer += "\n**Metadata:**\n"
73
+ for idx, (url_name, source_data) in enumerate(source_dict.items()):
74
+ full_answer += f"\nSource {idx + 1} Metadata:\n"
75
+ source_elements.append(
76
+ cl.Text(
77
+ name=f"Source {idx + 1} Metadata",
78
+ content=f"Source: {source_data['url']}\n"
79
+ f"Page: {source_data['page']}\n"
80
+ f"Type: {source_data['source_type']}\n"
81
+ f"Date: {source_data['date']}\n"
82
+ f"TL;DR: {source_data['lecture_tldr']}\n"
83
+ f"Lecture Recording: {source_data['lecture_recording']}\n"
84
+ f"Suggested Readings: {source_data['suggested_readings']}\n",
85
+ display="side",
86
+ )
87
+ )
88
 
89
  return full_answer, source_elements, source_dict
90
 
91
 
92
+ def get_prompt(config, prompt_type):
93
+ llm_params = config["llm_params"]
94
+ llm_loader = llm_params["llm_loader"]
95
+ use_history = llm_params["use_history"]
96
+ llm_style = llm_params["llm_style"].lower()
97
+
98
+ if prompt_type == "qa":
99
+ if llm_loader == "local_llm":
100
+ if use_history:
101
+ return prompts["tiny_llama"]["prompt_with_history"]
102
+ else:
103
+ return prompts["tiny_llama"]["prompt_no_history"]
104
+ else:
105
+ if use_history:
106
+ return prompts["openai"]["prompt_with_history"][llm_style]
107
+ else:
108
+ return prompts["openai"]["prompt_no_history"]
109
+ elif prompt_type == "rephrase":
110
+ return prompts["openai"]["rephrase_prompt"]
111
+
112
+
113
+ def get_history_chat_resume(steps, k, SYSTEM, LLM):
114
+ conversation_list = []
115
+ count = 0
116
+ for step in reversed(steps):
117
+ if step["name"] not in [SYSTEM]:
118
+ if step["type"] == "user_message":
119
+ conversation_list.append(
120
+ {"type": "user_message", "content": step["output"]}
121
+ )
122
+ elif step["type"] == "assistant_message":
123
+ if step["name"] == LLM:
124
+ conversation_list.append(
125
+ {"type": "ai_message", "content": step["output"]}
126
+ )
127
+ else:
128
+ raise ValueError("Invalid message type")
129
+ count += 1
130
+ if count >= 2 * k: # 2 * k to account for both user and assistant messages
131
+ break
132
+ conversation_list = conversation_list[::-1]
133
+ return conversation_list
134
+
135
+
136
+ def get_history_setup_llm(memory_list):
137
+ conversation_list = []
138
+ for message in memory_list:
139
+ message_dict = message.to_dict() if hasattr(message, "to_dict") else message
140
+
141
+ # Check if the type attribute is present as a key or attribute
142
+ message_type = (
143
+ message_dict.get("type", None)
144
+ if isinstance(message_dict, dict)
145
+ else getattr(message, "type", None)
146
  )
147
+
148
+ # Check if content is present as a key or attribute
149
+ message_content = (
150
+ message_dict.get("content", None)
151
+ if isinstance(message_dict, dict)
152
+ else getattr(message, "content", None)
 
 
 
 
153
  )
154
+
155
+ if message_type in ["ai", "ai_message"]:
156
+ conversation_list.append({"type": "ai_message", "content": message_content})
157
+ elif message_type in ["human", "user_message"]:
158
+ conversation_list.append(
159
+ {"type": "user_message", "content": message_content}
160
+ )
161
+ else:
162
+ raise ValueError("Invalid message type")
163
+
164
+ return conversation_list
code/modules/chat/langchain/langchain_rag.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import ChatPromptTemplate
2
+
3
+ from modules.chat.langchain.utils import *
4
+ from langchain.memory import ChatMessageHistory
5
+ from modules.chat.base import BaseRAG
6
+ from langchain_core.prompts import PromptTemplate
7
+ from langchain.memory import (
8
+ ConversationBufferWindowMemory,
9
+ ConversationSummaryBufferMemory,
10
+ )
11
+
12
+ import chainlit as cl
13
+ from langchain_community.chat_models import ChatOpenAI
14
+
15
+
16
+ class Langchain_RAG_V1(BaseRAG):
17
+
18
+ def __init__(
19
+ self,
20
+ llm,
21
+ memory,
22
+ retriever,
23
+ qa_prompt: str,
24
+ rephrase_prompt: str,
25
+ config: dict,
26
+ callbacks=None,
27
+ ):
28
+ """
29
+ Initialize the Langchain_RAG class.
30
+
31
+ Args:
32
+ llm (LanguageModelLike): The language model instance.
33
+ memory (BaseChatMessageHistory): The chat message history instance.
34
+ retriever (BaseRetriever): The retriever instance.
35
+ qa_prompt (str): The QA prompt string.
36
+ rephrase_prompt (str): The rephrase prompt string.
37
+ """
38
+ self.llm = llm
39
+ self.config = config
40
+ # self.memory = self.add_history_from_list(memory)
41
+ self.memory = ConversationBufferWindowMemory(
42
+ k=self.config["llm_params"]["memory_window"],
43
+ memory_key="chat_history",
44
+ return_messages=True,
45
+ output_key="answer",
46
+ max_token_limit=128,
47
+ )
48
+ self.retriever = retriever
49
+ self.qa_prompt = qa_prompt
50
+ self.rephrase_prompt = rephrase_prompt
51
+ self.store = {}
52
+
53
+ self.qa_prompt = PromptTemplate(
54
+ template=self.qa_prompt,
55
+ input_variables=["context", "chat_history", "input"],
56
+ )
57
+
58
+ self.rag_chain = CustomConversationalRetrievalChain.from_llm(
59
+ llm=llm,
60
+ chain_type="stuff",
61
+ retriever=retriever,
62
+ return_source_documents=True,
63
+ memory=self.memory,
64
+ combine_docs_chain_kwargs={"prompt": self.qa_prompt},
65
+ response_if_no_docs_found="No context found",
66
+ )
67
+
68
+ def add_history_from_list(self, history_list):
69
+ """
70
+ TODO: Add messages from a list to the chat history.
71
+ """
72
+ history = []
73
+
74
+ return history
75
+
76
+ async def invoke(self, user_query, config):
77
+ """
78
+ Invoke the chain.
79
+
80
+ Args:
81
+ kwargs: The input variables.
82
+
83
+ Returns:
84
+ dict: The output variables.
85
+ """
86
+ res = await self.rag_chain.acall(user_query["input"])
87
+ return res
88
+
89
+
90
+ class QuestionGenerator:
91
+ """
92
+ Generate a question from the LLMs response and users input and past conversations.
93
+ """
94
+
95
+ def __init__(self):
96
+ pass
97
+
98
+ def generate_questions(self, query, response, chat_history, context):
99
+ questions = return_questions(query, response, chat_history, context)
100
+ return questions
101
+
102
+
103
+ class Langchain_RAG_V2(BaseRAG):
104
+ def __init__(
105
+ self,
106
+ llm,
107
+ memory,
108
+ retriever,
109
+ qa_prompt: str,
110
+ rephrase_prompt: str,
111
+ config: dict,
112
+ callbacks=None,
113
+ ):
114
+ """
115
+ Initialize the Langchain_RAG class.
116
+
117
+ Args:
118
+ llm (LanguageModelLike): The language model instance.
119
+ memory (BaseChatMessageHistory): The chat message history instance.
120
+ retriever (BaseRetriever): The retriever instance.
121
+ qa_prompt (str): The QA prompt string.
122
+ rephrase_prompt (str): The rephrase prompt string.
123
+ """
124
+ self.llm = llm
125
+ self.memory = self.add_history_from_list(memory)
126
+ self.retriever = retriever
127
+ self.qa_prompt = qa_prompt
128
+ self.rephrase_prompt = rephrase_prompt
129
+ self.store = {}
130
+
131
+ # Contextualize question prompt
132
+ contextualize_q_system_prompt = rephrase_prompt or (
133
+ "Given a chat history and the latest user question "
134
+ "which might reference context in the chat history, "
135
+ "formulate a standalone question which can be understood "
136
+ "without the chat history. Do NOT answer the question, just "
137
+ "reformulate it if needed and otherwise return it as is."
138
+ )
139
+ self.contextualize_q_prompt = ChatPromptTemplate.from_template(
140
+ contextualize_q_system_prompt
141
+ )
142
+
143
+ # History-aware retriever
144
+ self.history_aware_retriever = create_history_aware_retriever(
145
+ self.llm, self.retriever, self.contextualize_q_prompt
146
+ )
147
+
148
+ # Answer question prompt
149
+ qa_system_prompt = qa_prompt or (
150
+ "You are an assistant for question-answering tasks. Use "
151
+ "the following pieces of retrieved context to answer the "
152
+ "question. If you don't know the answer, just say that you "
153
+ "don't know. Use three sentences maximum and keep the answer "
154
+ "concise."
155
+ "\n\n"
156
+ "{context}"
157
+ )
158
+ self.qa_prompt_template = ChatPromptTemplate.from_template(qa_system_prompt)
159
+
160
+ # Question-answer chain
161
+ self.question_answer_chain = create_stuff_documents_chain(
162
+ self.llm, self.qa_prompt_template
163
+ )
164
+
165
+ # Final retrieval chain
166
+ self.rag_chain = create_retrieval_chain(
167
+ self.history_aware_retriever, self.question_answer_chain
168
+ )
169
+
170
+ self.rag_chain = CustomRunnableWithHistory(
171
+ self.rag_chain,
172
+ get_session_history=self.get_session_history,
173
+ input_messages_key="input",
174
+ history_messages_key="chat_history",
175
+ output_messages_key="answer",
176
+ history_factory_config=[
177
+ ConfigurableFieldSpec(
178
+ id="user_id",
179
+ annotation=str,
180
+ name="User ID",
181
+ description="Unique identifier for the user.",
182
+ default="",
183
+ is_shared=True,
184
+ ),
185
+ ConfigurableFieldSpec(
186
+ id="conversation_id",
187
+ annotation=str,
188
+ name="Conversation ID",
189
+ description="Unique identifier for the conversation.",
190
+ default="",
191
+ is_shared=True,
192
+ ),
193
+ ConfigurableFieldSpec(
194
+ id="memory_window",
195
+ annotation=int,
196
+ name="Number of Conversations",
197
+ description="Number of conversations to consider for context.",
198
+ default=1,
199
+ is_shared=True,
200
+ ),
201
+ ],
202
+ )
203
+
204
+ if callbacks is not None:
205
+ self.rag_chain = self.rag_chain.with_config(callbacks=callbacks)
206
+
207
+ def get_session_history(
208
+ self, user_id: str, conversation_id: str, memory_window: int
209
+ ) -> BaseChatMessageHistory:
210
+ """
211
+ Get the session history for a user and conversation.
212
+
213
+ Args:
214
+ user_id (str): The user identifier.
215
+ conversation_id (str): The conversation identifier.
216
+ memory_window (int): The number of conversations to consider for context.
217
+
218
+ Returns:
219
+ BaseChatMessageHistory: The chat message history.
220
+ """
221
+ if (user_id, conversation_id) not in self.store:
222
+ self.store[(user_id, conversation_id)] = InMemoryHistory()
223
+ self.store[(user_id, conversation_id)].add_messages(
224
+ self.memory.messages
225
+ ) # add previous messages to the store. Note: the store is in-memory.
226
+ return self.store[(user_id, conversation_id)]
227
+
228
+ async def invoke(self, user_query, config, **kwargs):
229
+ """
230
+ Invoke the chain.
231
+
232
+ Args:
233
+ kwargs: The input variables.
234
+
235
+ Returns:
236
+ dict: The output variables.
237
+ """
238
+ res = await self.rag_chain.ainvoke(user_query, config, **kwargs)
239
+ res["rephrase_prompt"] = self.rephrase_prompt
240
+ res["qa_prompt"] = self.qa_prompt
241
+ return res
242
+
243
+ def stream(self, user_query, config):
244
+ res = self.rag_chain.stream(user_query, config)
245
+ return res
246
+
247
+ def add_history_from_list(self, conversation_list):
248
+ """
249
+ Add messages from a list to the chat history.
250
+
251
+ Args:
252
+ messages (list): The list of messages to add.
253
+ """
254
+ history = ChatMessageHistory()
255
+
256
+ for idx, message in enumerate(conversation_list):
257
+ message_type = (
258
+ message.get("type", None)
259
+ if isinstance(message, dict)
260
+ else getattr(message, "type", None)
261
+ )
262
+
263
+ message_content = (
264
+ message.get("content", None)
265
+ if isinstance(message, dict)
266
+ else getattr(message, "content", None)
267
+ )
268
+
269
+ if message_type in ["human", "user_message"]:
270
+ history.add_user_message(message_content)
271
+ elif message_type in ["ai", "ai_message"]:
272
+ history.add_ai_message(message_content)
273
+
274
+ return history
code/modules/chat/langchain/utils.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Union, Tuple, Optional
2
+ from langchain_core.messages import (
3
+ BaseMessage,
4
+ AIMessage,
5
+ FunctionMessage,
6
+ HumanMessage,
7
+ )
8
+
9
+ from langchain_core.prompts.base import BasePromptTemplate, format_document
10
+ from langchain_core.prompts.chat import MessagesPlaceholder
11
+ from langchain_core.output_parsers import StrOutputParser
12
+ from langchain_core.output_parsers.base import BaseOutputParser
13
+ from langchain_core.retrievers import BaseRetriever, RetrieverOutput
14
+ from langchain_core.language_models import LanguageModelLike
15
+ from langchain_core.runnables import Runnable, RunnableBranch, RunnablePassthrough
16
+ from langchain_core.runnables.history import RunnableWithMessageHistory
17
+ from langchain_core.runnables.utils import ConfigurableFieldSpec
18
+ from langchain_core.chat_history import BaseChatMessageHistory
19
+ from langchain_core.pydantic_v1 import BaseModel, Field
20
+ from langchain.chains.combine_documents.base import (
21
+ DEFAULT_DOCUMENT_PROMPT,
22
+ DEFAULT_DOCUMENT_SEPARATOR,
23
+ DOCUMENTS_KEY,
24
+ BaseCombineDocumentsChain,
25
+ _validate_prompt,
26
+ )
27
+ from langchain.chains.llm import LLMChain
28
+ from langchain_core.callbacks import Callbacks
29
+ from langchain_core.documents import Document
30
+
31
+
32
+ CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
33
+
34
+ from langchain_core.runnables.config import RunnableConfig
35
+ from langchain_core.messages import BaseMessage
36
+
37
+
38
+ from langchain_core.output_parsers import StrOutputParser
39
+ from langchain_core.prompts import ChatPromptTemplate
40
+ from langchain_community.chat_models import ChatOpenAI
41
+
42
+ from langchain.chains import RetrievalQA, ConversationalRetrievalChain
43
+ from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
44
+
45
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
46
+ from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
47
+ import inspect
48
+ from langchain.chains.conversational_retrieval.base import _get_chat_history
49
+ from langchain_core.messages import BaseMessage
50
+
51
+
52
+ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
53
+
54
+ def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
55
+ _ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
56
+ buffer = ""
57
+ for dialogue_turn in chat_history:
58
+ if isinstance(dialogue_turn, BaseMessage):
59
+ role_prefix = _ROLE_MAP.get(
60
+ dialogue_turn.type, f"{dialogue_turn.type}: "
61
+ )
62
+ buffer += f"\n{role_prefix}{dialogue_turn.content}"
63
+ elif isinstance(dialogue_turn, tuple):
64
+ human = "Student: " + dialogue_turn[0]
65
+ ai = "AI Tutor: " + dialogue_turn[1]
66
+ buffer += "\n" + "\n".join([human, ai])
67
+ else:
68
+ raise ValueError(
69
+ f"Unsupported chat history format: {type(dialogue_turn)}."
70
+ f" Full chat history: {chat_history} "
71
+ )
72
+ return buffer
73
+
74
+ async def _acall(
75
+ self,
76
+ inputs: Dict[str, Any],
77
+ run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
78
+ ) -> Dict[str, Any]:
79
+ _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
80
+ question = inputs["question"]
81
+ get_chat_history = self._get_chat_history
82
+ chat_history_str = get_chat_history(inputs["chat_history"])
83
+ if chat_history_str:
84
+ # callbacks = _run_manager.get_child()
85
+ # new_question = await self.question_generator.arun(
86
+ # question=question, chat_history=chat_history_str, callbacks=callbacks
87
+ # )
88
+ system = (
89
+ "You are someone that rephrases statements. Rephrase the student's question to add context from their chat history if relevant, ensuring it remains from the student's point of view. "
90
+ "Incorporate relevant details from the chat history to make the question clearer and more specific."
91
+ "Do not change the meaning of the original statement, and maintain the student's tone and perspective. "
92
+ "If the question is conversational and doesn't require context, do not rephrase it. "
93
+ "Example: If the student previously asked about backpropagation in the context of deep learning and now asks 'what is it', rephrase to 'What is backprogatation.'. "
94
+ "Example: Do not rephrase if the user is asking something specific like 'cool, suggest a project with transformers to use as my final project'"
95
+ "Chat history: \n{chat_history_str}\n"
96
+ "Rephrase the following question only if necessary: '{input}'"
97
+ )
98
+
99
+ prompt = ChatPromptTemplate.from_messages(
100
+ [
101
+ ("system", system),
102
+ ("human", "{input}, {chat_history_str}"),
103
+ ]
104
+ )
105
+ llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
106
+ step_back = prompt | llm | StrOutputParser()
107
+ new_question = step_back.invoke(
108
+ {"input": question, "chat_history_str": chat_history_str}
109
+ )
110
+ else:
111
+ new_question = question
112
+ accepts_run_manager = (
113
+ "run_manager" in inspect.signature(self._aget_docs).parameters
114
+ )
115
+ if accepts_run_manager:
116
+ docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager)
117
+ else:
118
+ docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
119
+
120
+ output: Dict[str, Any] = {}
121
+ output["original_question"] = question
122
+ if self.response_if_no_docs_found is not None and len(docs) == 0:
123
+ output[self.output_key] = self.response_if_no_docs_found
124
+ else:
125
+ new_inputs = inputs.copy()
126
+ if self.rephrase_question:
127
+ new_inputs["question"] = new_question
128
+ new_inputs["chat_history"] = chat_history_str
129
+
130
+ # Prepare the final prompt with metadata
131
+ context = "\n\n".join(
132
+ [
133
+ f"Context {idx+1}: \n(Document content: {doc.page_content}\nMetadata: (source_file: {doc.metadata['source'] if 'source' in doc.metadata else 'unknown'}))"
134
+ for idx, doc in enumerate(docs)
135
+ ]
136
+ )
137
+ final_prompt = (
138
+ "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance."
139
+ "If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
140
+ "Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevent."
141
+ "Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
142
+ f"Chat History:\n{chat_history_str}\n\n"
143
+ f"Context:\n{context}\n\n"
144
+ "Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
145
+ f"Student: {input}\n"
146
+ "AI Tutor:"
147
+ )
148
+
149
+ new_inputs["input"] = final_prompt
150
+ # new_inputs["question"] = final_prompt
151
+ # output["final_prompt"] = final_prompt
152
+
153
+ answer = await self.combine_docs_chain.arun(
154
+ input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
155
+ )
156
+ output[self.output_key] = answer
157
+
158
+ if self.return_source_documents:
159
+ output["source_documents"] = docs
160
+ output["rephrased_question"] = new_question
161
+ output["context"] = output["source_documents"]
162
+ return output
163
+
164
+
165
+ class CustomRunnableWithHistory(RunnableWithMessageHistory):
166
+
167
+ def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
168
+ _ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
169
+ buffer = ""
170
+ for dialogue_turn in chat_history:
171
+ if isinstance(dialogue_turn, BaseMessage):
172
+ role_prefix = _ROLE_MAP.get(
173
+ dialogue_turn.type, f"{dialogue_turn.type}: "
174
+ )
175
+ buffer += f"\n{role_prefix}{dialogue_turn.content}"
176
+ elif isinstance(dialogue_turn, tuple):
177
+ human = "Student: " + dialogue_turn[0]
178
+ ai = "AI Tutor: " + dialogue_turn[1]
179
+ buffer += "\n" + "\n".join([human, ai])
180
+ else:
181
+ raise ValueError(
182
+ f"Unsupported chat history format: {type(dialogue_turn)}."
183
+ f" Full chat history: {chat_history} "
184
+ )
185
+ return buffer
186
+
187
+ async def _aenter_history(
188
+ self, input: Any, config: RunnableConfig
189
+ ) -> List[BaseMessage]:
190
+ """
191
+ Get the last k conversations from the message history.
192
+
193
+ Args:
194
+ input (Any): The input data.
195
+ config (RunnableConfig): The runnable configuration.
196
+
197
+ Returns:
198
+ List[BaseMessage]: The last k conversations.
199
+ """
200
+ hist: BaseChatMessageHistory = config["configurable"]["message_history"]
201
+ messages = (await hist.aget_messages()).copy()
202
+ if not self.history_messages_key:
203
+ # return all messages
204
+ input_val = (
205
+ input if not self.input_messages_key else input[self.input_messages_key]
206
+ )
207
+ messages += self._get_input_messages(input_val)
208
+
209
+ # return last k conversations
210
+ if config["configurable"]["memory_window"] == 0: # if k is 0, return empty list
211
+ messages = []
212
+ else:
213
+ messages = messages[-2 * config["configurable"]["memory_window"] :]
214
+
215
+ messages = self._get_chat_history(messages)
216
+
217
+ return messages
218
+
219
+
220
+ class InMemoryHistory(BaseChatMessageHistory, BaseModel):
221
+ """In-memory implementation of chat message history."""
222
+
223
+ messages: List[BaseMessage] = Field(default_factory=list)
224
+
225
+ def add_messages(self, messages: List[BaseMessage]) -> None:
226
+ """Add a list of messages to the store."""
227
+ self.messages.extend(messages)
228
+
229
+ def clear(self) -> None:
230
+ """Clear the message history."""
231
+ self.messages = []
232
+
233
+ def __len__(self) -> int:
234
+ """Return the number of messages."""
235
+ return len(self.messages)
236
+
237
+
238
+ def create_history_aware_retriever(
239
+ llm: LanguageModelLike,
240
+ retriever: BaseRetriever,
241
+ prompt: BasePromptTemplate,
242
+ ) -> Runnable[Dict[str, Any], RetrieverOutput]:
243
+ """Create a chain that takes conversation history and returns documents."""
244
+ if "input" not in prompt.input_variables:
245
+ raise ValueError(
246
+ "Expected `input` to be a prompt variable, "
247
+ f"but got {prompt.input_variables}"
248
+ )
249
+
250
+ retrieve_documents = RunnableBranch(
251
+ (
252
+ lambda x: not x["chat_history"],
253
+ (lambda x: x["input"]) | retriever,
254
+ ),
255
+ prompt | llm | StrOutputParser() | retriever,
256
+ ).with_config(run_name="chat_retriever_chain")
257
+
258
+ return retrieve_documents
259
+
260
+
261
+ def create_stuff_documents_chain(
262
+ llm: LanguageModelLike,
263
+ prompt: BasePromptTemplate,
264
+ output_parser: Optional[BaseOutputParser] = None,
265
+ document_prompt: Optional[BasePromptTemplate] = None,
266
+ document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
267
+ ) -> Runnable[Dict[str, Any], Any]:
268
+ """Create a chain for passing a list of Documents to a model."""
269
+ _validate_prompt(prompt)
270
+ _document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT
271
+ _output_parser = output_parser or StrOutputParser()
272
+
273
+ def format_docs(inputs: dict) -> str:
274
+ return document_separator.join(
275
+ format_document(doc, _document_prompt) for doc in inputs[DOCUMENTS_KEY]
276
+ )
277
+
278
+ return (
279
+ RunnablePassthrough.assign(**{DOCUMENTS_KEY: format_docs}).with_config(
280
+ run_name="format_inputs"
281
+ )
282
+ | prompt
283
+ | llm
284
+ | _output_parser
285
+ ).with_config(run_name="stuff_documents_chain")
286
+
287
+
288
+ def create_retrieval_chain(
289
+ retriever: Union[BaseRetriever, Runnable[dict, RetrieverOutput]],
290
+ combine_docs_chain: Runnable[Dict[str, Any], str],
291
+ ) -> Runnable:
292
+ """Create retrieval chain that retrieves documents and then passes them on."""
293
+ if not isinstance(retriever, BaseRetriever):
294
+ retrieval_docs = retriever
295
+ else:
296
+ retrieval_docs = (lambda x: x["input"]) | retriever
297
+
298
+ retrieval_chain = (
299
+ RunnablePassthrough.assign(
300
+ context=retrieval_docs.with_config(run_name="retrieve_documents"),
301
+ ).assign(answer=combine_docs_chain)
302
+ ).with_config(run_name="retrieval_chain")
303
+
304
+ return retrieval_chain
305
+
306
+
307
+ def return_questions(query, response, chat_history_str, context):
308
+
309
+ system = (
310
+ "You are someone that suggests a question based on the student's input and chat history. "
311
+ "Generate a question that is relevant to the student's input and chat history. "
312
+ "Incorporate relevant details from the chat history to make the question clearer and more specific. "
313
+ "Chat history: \n{chat_history_str}\n"
314
+ "Use the context to generate a question that is relevant to the student's input and chat history: Context: {context}"
315
+ "Generate 3 short and concise questions from the students voice based on the following input and response: "
316
+ "The 3 short and concise questions should be sperated by dots. Example: 'What is the capital of France?...What is the population of France?...What is the currency of France?'"
317
+ "User Query: {query}"
318
+ "AI Response: {response}"
319
+ "The 3 short and concise questions seperated by dots (...) are:"
320
+ )
321
+
322
+ prompt = ChatPromptTemplate.from_messages(
323
+ [
324
+ ("system", system),
325
+ ("human", "{chat_history_str}, {context}, {query}, {response}"),
326
+ ]
327
+ )
328
+ llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
329
+ question_generator = prompt | llm | StrOutputParser()
330
+ new_questions = question_generator.invoke(
331
+ {
332
+ "chat_history_str": chat_history_str,
333
+ "context": context,
334
+ "query": query,
335
+ "response": response,
336
+ }
337
+ )
338
+
339
+ list_of_questions = new_questions.split("...")
340
+ return list_of_questions
code/modules/chat/llm_tutor.py CHANGED
@@ -1,211 +1,163 @@
1
- from langchain.chains import RetrievalQA, ConversationalRetrievalChain
2
- from langchain.memory import (
3
- ConversationBufferWindowMemory,
4
- ConversationSummaryBufferMemory,
5
- )
6
- from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
7
- import os
8
- from modules.config.constants import *
9
  from modules.chat.helpers import get_prompt
10
  from modules.chat.chat_model_loader import ChatModelLoader
11
  from modules.vectorstore.store_manager import VectorStoreManager
12
-
13
  from modules.retriever.retriever import Retriever
14
-
15
- from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
16
- from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
17
- import inspect
18
- from langchain.chains.conversational_retrieval.base import _get_chat_history
19
- from langchain_core.messages import BaseMessage
20
-
21
- CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
22
-
23
- from langchain_core.output_parsers import StrOutputParser
24
- from langchain_core.prompts import ChatPromptTemplate
25
- from langchain_community.chat_models import ChatOpenAI
26
-
27
-
28
- class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
29
-
30
- def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
31
- _ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
32
- buffer = ""
33
- for dialogue_turn in chat_history:
34
- if isinstance(dialogue_turn, BaseMessage):
35
- role_prefix = _ROLE_MAP.get(
36
- dialogue_turn.type, f"{dialogue_turn.type}: "
37
- )
38
- buffer += f"\n{role_prefix}{dialogue_turn.content}"
39
- elif isinstance(dialogue_turn, tuple):
40
- human = "Student: " + dialogue_turn[0]
41
- ai = "AI Tutor: " + dialogue_turn[1]
42
- buffer += "\n" + "\n".join([human, ai])
43
- else:
44
- raise ValueError(
45
- f"Unsupported chat history format: {type(dialogue_turn)}."
46
- f" Full chat history: {chat_history} "
47
- )
48
- return buffer
49
-
50
- async def _acall(
51
- self,
52
- inputs: Dict[str, Any],
53
- run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
54
- ) -> Dict[str, Any]:
55
- _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
56
- question = inputs["question"]
57
- get_chat_history = self._get_chat_history
58
- chat_history_str = get_chat_history(inputs["chat_history"])
59
- if chat_history_str:
60
- # callbacks = _run_manager.get_child()
61
- # new_question = await self.question_generator.arun(
62
- # question=question, chat_history=chat_history_str, callbacks=callbacks
63
- # )
64
- system = (
65
- "You are someone that rephrases statements. Rephrase the student's question to add context from their chat history if relevant, ensuring it remains from the student's point of view. "
66
- "Incorporate relevant details from the chat history to make the question clearer and more specific."
67
- "Do not change the meaning of the original statement, and maintain the student's tone and perspective. "
68
- "If the question is conversational and doesn't require context, do not rephrase it. "
69
- "Example: If the student previously asked about backpropagation in the context of deep learning and now asks 'what is it', rephrase to 'What is backprogatation.'. "
70
- "Example: Do not rephrase if the user is asking something specific like 'cool, suggest a project with transformers to use as my final project'"
71
- "Chat history: \n{chat_history_str}\n"
72
- "Rephrase the following question only if necessary: '{question}'"
73
- )
74
-
75
- prompt = ChatPromptTemplate.from_messages(
76
- [
77
- ("system", system),
78
- ("human", "{question}, {chat_history_str}"),
79
- ]
80
- )
81
- llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
82
- step_back = prompt | llm | StrOutputParser()
83
- new_question = step_back.invoke(
84
- {"question": question, "chat_history_str": chat_history_str}
85
- )
86
- else:
87
- new_question = question
88
- accepts_run_manager = (
89
- "run_manager" in inspect.signature(self._aget_docs).parameters
90
- )
91
- if accepts_run_manager:
92
- docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager)
93
- else:
94
- docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
95
-
96
- output: Dict[str, Any] = {}
97
- output["original_question"] = question
98
- if self.response_if_no_docs_found is not None and len(docs) == 0:
99
- output[self.output_key] = self.response_if_no_docs_found
100
- else:
101
- new_inputs = inputs.copy()
102
- if self.rephrase_question:
103
- new_inputs["question"] = new_question
104
- new_inputs["chat_history"] = chat_history_str
105
-
106
- # Prepare the final prompt with metadata
107
- context = "\n\n".join(
108
- [
109
- f"Context {idx+1}: \n(Document content: {doc.page_content}\nMetadata: (source_file: {doc.metadata['source'] if 'source' in doc.metadata else 'unknown'}))"
110
- for idx, doc in enumerate(docs)
111
- ]
112
- )
113
- final_prompt = (
114
- "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance."
115
- "If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
116
- "Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevent."
117
- "Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
118
- f"Chat History:\n{chat_history_str}\n\n"
119
- f"Context:\n{context}\n\n"
120
- "Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
121
- f"Student: {question}\n"
122
- "AI Tutor:"
123
- )
124
-
125
- # new_inputs["input"] = final_prompt
126
- new_inputs["question"] = final_prompt
127
- # output["final_prompt"] = final_prompt
128
-
129
- answer = await self.combine_docs_chain.arun(
130
- input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
131
- )
132
- output[self.output_key] = answer
133
-
134
- if self.return_source_documents:
135
- output["source_documents"] = docs
136
- output["rephrased_question"] = new_question
137
- return output
138
 
139
 
140
  class LLMTutor:
141
- def __init__(self, config, logger=None):
 
 
 
 
 
 
 
 
142
  self.config = config
143
  self.llm = self.load_llm()
 
144
  self.logger = logger
145
- self.vector_db = VectorStoreManager(config, logger=self.logger)
 
 
 
 
146
  if self.config["vectorstore"]["embedd_files"]:
147
  self.vector_db.create_database()
148
  self.vector_db.save_database()
149
 
150
- def set_custom_prompt(self):
151
  """
152
- Prompt template for QA retrieval for each vectorstore
 
 
 
153
  """
154
- prompt = get_prompt(self.config)
155
- # prompt = QA_PROMPT
 
 
156
 
157
- return prompt
 
 
 
 
 
 
158
 
159
- # Retrieval QA Chain
160
- def retrieval_qa_chain(self, llm, prompt, db):
 
 
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  retriever = Retriever(self.config)._return_retriever(db)
163
 
164
- if self.config["llm_params"]["use_history"]:
165
- memory = ConversationBufferWindowMemory(
166
- k=self.config["llm_params"]["memory_window"],
167
- memory_key="chat_history",
168
- return_messages=True,
169
- output_key="answer",
170
- max_token_limit=128,
171
- )
172
- qa_chain = CustomConversationalRetrievalChain.from_llm(
173
  llm=llm,
174
- chain_type="stuff",
175
- retriever=retriever,
176
- return_source_documents=True,
177
  memory=memory,
178
- combine_docs_chain_kwargs={"prompt": prompt},
179
- response_if_no_docs_found="No context found",
 
 
 
180
  )
 
 
181
  else:
182
- qa_chain = RetrievalQA.from_chain_type(
183
- llm=llm,
184
- chain_type="stuff",
185
- retriever=retriever,
186
- return_source_documents=True,
187
- chain_type_kwargs={"prompt": prompt},
188
  )
189
- return qa_chain
190
 
191
- # Loading the model
192
  def load_llm(self):
 
 
 
 
 
 
193
  chat_model_loader = ChatModelLoader(self.config)
194
  llm = chat_model_loader.load_chat_model()
195
  return llm
196
 
197
- # QA Model Function
198
- def qa_bot(self):
199
- db = self.vector_db.load_database()
200
- qa_prompt = self.set_custom_prompt()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  qa = self.retrieval_qa_chain(
202
- self.llm, qa_prompt, db
203
- ) # TODO: PROMPT is overwritten in CustomConversationalRetrievalChain
 
 
 
 
 
204
 
205
  return qa
206
-
207
- # output function
208
- def final_result(query):
209
- qa_result = qa_bot()
210
- response = qa_result({"query": query})
211
- return response
 
 
 
 
 
 
 
 
 
1
  from modules.chat.helpers import get_prompt
2
  from modules.chat.chat_model_loader import ChatModelLoader
3
  from modules.vectorstore.store_manager import VectorStoreManager
 
4
  from modules.retriever.retriever import Retriever
5
+ from modules.chat.langchain.langchain_rag import (
6
+ Langchain_RAG_V1,
7
+ Langchain_RAG_V2,
8
+ QuestionGenerator,
9
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  class LLMTutor:
13
+ def __init__(self, config, user, logger=None):
14
+ """
15
+ Initialize the LLMTutor class.
16
+
17
+ Args:
18
+ config (dict): Configuration dictionary.
19
+ user (str): User identifier.
20
+ logger (Logger, optional): Logger instance. Defaults to None.
21
+ """
22
  self.config = config
23
  self.llm = self.load_llm()
24
+ self.user = user
25
  self.logger = logger
26
+ self.vector_db = VectorStoreManager(config, logger=self.logger).load_database()
27
+ self.qa_prompt = get_prompt(config, "qa") # Initialize qa_prompt
28
+ self.rephrase_prompt = get_prompt(
29
+ config, "rephrase"
30
+ ) # Initialize rephrase_prompt
31
  if self.config["vectorstore"]["embedd_files"]:
32
  self.vector_db.create_database()
33
  self.vector_db.save_database()
34
 
35
+ def update_llm(self, old_config, new_config):
36
  """
37
+ Update the LLM and VectorStoreManager based on new configuration.
38
+
39
+ Args:
40
+ new_config (dict): New configuration dictionary.
41
  """
42
+ changes = self.get_config_changes(old_config, new_config)
43
+
44
+ if "llm_params.llm_loader" in changes:
45
+ self.llm = self.load_llm() # Reinitialize LLM if chat_model changes
46
 
47
+ if "vectorstore.db_option" in changes:
48
+ self.vector_db = VectorStoreManager(
49
+ self.config, logger=self.logger
50
+ ).load_database() # Reinitialize VectorStoreManager if vectorstore changes
51
+ if self.config["vectorstore"]["embedd_files"]:
52
+ self.vector_db.create_database()
53
+ self.vector_db.save_database()
54
 
55
+ if "llm_params.llm_style" in changes:
56
+ self.qa_prompt = get_prompt(
57
+ self.config, "qa"
58
+ ) # Update qa_prompt if ELI5 changes
59
 
60
+ def get_config_changes(self, old_config, new_config):
61
+ """
62
+ Get the changes between the old and new configuration.
63
+
64
+ Args:
65
+ old_config (dict): Old configuration dictionary.
66
+ new_config (dict): New configuration dictionary.
67
+
68
+ Returns:
69
+ dict: Dictionary containing the changes.
70
+ """
71
+ changes = {}
72
+
73
+ def compare_dicts(old, new, parent_key=""):
74
+ for key in new:
75
+ full_key = f"{parent_key}.{key}" if parent_key else key
76
+ if isinstance(new[key], dict) and isinstance(old.get(key), dict):
77
+ compare_dicts(old.get(key, {}), new[key], full_key)
78
+ elif old.get(key) != new[key]:
79
+ changes[full_key] = (old.get(key), new[key])
80
+ # Include keys that are in old but not in new
81
+ for key in old:
82
+ if key not in new:
83
+ full_key = f"{parent_key}.{key}" if parent_key else key
84
+ changes[full_key] = (old[key], None)
85
+
86
+ compare_dicts(old_config, new_config)
87
+ return changes
88
+
89
+ def retrieval_qa_chain(
90
+ self, llm, qa_prompt, rephrase_prompt, db, memory=None, callbacks=None
91
+ ):
92
+ """
93
+ Create a Retrieval QA Chain.
94
+
95
+ Args:
96
+ llm (LLM): The language model instance.
97
+ qa_prompt (str): The QA prompt string.
98
+ rephrase_prompt (str): The rephrase prompt string.
99
+ db (VectorStore): The vector store instance.
100
+ memory (Memory, optional): Memory instance. Defaults to None.
101
+
102
+ Returns:
103
+ Chain: The retrieval QA chain instance.
104
+ """
105
  retriever = Retriever(self.config)._return_retriever(db)
106
 
107
+ if self.config["llm_params"]["llm_arch"] == "langchain":
108
+ self.qa_chain = Langchain_RAG_V2(
 
 
 
 
 
 
 
109
  llm=llm,
 
 
 
110
  memory=memory,
111
+ retriever=retriever,
112
+ qa_prompt=qa_prompt,
113
+ rephrase_prompt=rephrase_prompt,
114
+ config=self.config,
115
+ callbacks=callbacks,
116
  )
117
+
118
+ self.question_generator = QuestionGenerator()
119
  else:
120
+ raise ValueError(
121
+ f"Invalid LLM Architecture: {self.config['llm_params']['llm_arch']}"
 
 
 
 
122
  )
123
+ return self.qa_chain
124
 
 
125
  def load_llm(self):
126
+ """
127
+ Load the language model.
128
+
129
+ Returns:
130
+ LLM: The loaded language model instance.
131
+ """
132
  chat_model_loader = ChatModelLoader(self.config)
133
  llm = chat_model_loader.load_chat_model()
134
  return llm
135
 
136
+ def qa_bot(self, memory=None, callbacks=None):
137
+ """
138
+ Create a QA bot instance.
139
+
140
+ Args:
141
+ memory (Memory, optional): Memory instance. Defaults to None.
142
+ qa_prompt (str, optional): QA prompt string. Defaults to None.
143
+ rephrase_prompt (str, optional): Rephrase prompt string. Defaults to None.
144
+
145
+ Returns:
146
+ Chain: The QA bot chain instance.
147
+ """
148
+ # sanity check to see if there are any documents in the database
149
+ if len(self.vector_db) == 0:
150
+ raise ValueError(
151
+ "No documents in the database. Populate the database first."
152
+ )
153
+
154
  qa = self.retrieval_qa_chain(
155
+ self.llm,
156
+ self.qa_prompt,
157
+ self.rephrase_prompt,
158
+ self.vector_db,
159
+ memory,
160
+ callbacks=callbacks,
161
+ )
162
 
163
  return qa
 
 
 
 
 
 
code/modules/chat_processor/base.py DELETED
@@ -1,12 +0,0 @@
1
- # Template for chat processor classes
2
-
3
-
4
- class ChatProcessorBase:
5
- def __init__(self, config):
6
- self.config = config
7
-
8
- def process(self, message):
9
- """
10
- Processes and Logs the message
11
- """
12
- raise NotImplementedError("process method not implemented")
 
 
 
 
 
 
 
 
 
 
 
 
 
code/modules/chat_processor/chat_processor.py DELETED
@@ -1,30 +0,0 @@
1
- from modules.chat_processor.literal_ai import LiteralaiChatProcessor
2
-
3
-
4
- class ChatProcessor:
5
- def __init__(self, config, tags=None):
6
- self.chat_processor_type = config["chat_logging"]["platform"]
7
- self.logging = config["chat_logging"]["log_chat"]
8
- self.tags = tags
9
- if self.logging:
10
- self._init_processor()
11
-
12
- def _init_processor(self):
13
- if self.chat_processor_type == "literalai":
14
- self.processor = LiteralaiChatProcessor(self.tags)
15
- else:
16
- raise ValueError(
17
- f"Chat processor type {self.chat_processor_type} not supported"
18
- )
19
-
20
- def _process(self, user_message, assistant_message, source_dict):
21
- if self.logging:
22
- return self.processor.process(user_message, assistant_message, source_dict)
23
- else:
24
- pass
25
-
26
- async def rag(self, user_query: str, chain, cb):
27
- if self.logging:
28
- return await self.processor.rag(user_query, chain, cb)
29
- else:
30
- return await chain.acall(user_query, callbacks=[cb])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/modules/chat_processor/literal_ai.py CHANGED
@@ -1,37 +1,44 @@
1
- from literalai import LiteralClient
2
- import os
3
- from .base import ChatProcessorBase
4
 
5
 
6
- class LiteralaiChatProcessor(ChatProcessorBase):
7
- def __init__(self, tags=None):
8
- self.literal_client = LiteralClient(api_key=os.getenv("LITERAL_API_KEY"))
9
- self.literal_client.reset_context()
10
- with self.literal_client.thread(name="TEST") as thread:
11
- self.thread_id = thread.id
12
- self.thread = thread
13
- if tags is not None and type(tags) == list:
14
- self.thread.tags = tags
15
- print(f"Thread ID: {self.thread}")
16
 
17
- def process(self, user_message, assistant_message, source_dict):
18
- with self.literal_client.thread(thread_id=self.thread_id) as thread:
19
- self.literal_client.message(
20
- content=user_message,
21
- type="user_message",
22
- name="User",
23
- )
24
- self.literal_client.message(
25
- content=assistant_message,
26
- type="assistant_message",
27
- name="AI_Tutor",
28
- )
29
 
30
- async def rag(self, user_query: str, chain, cb):
31
- with self.literal_client.step(
32
- type="retrieval", name="RAG", thread_id=self.thread_id
33
- ) as step:
34
- step.input = {"question": user_query}
35
- res = await chain.acall(user_query, callbacks=[cb])
36
- step.output = res
37
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from chainlit.data import ChainlitDataLayer, queue_until_user_message
 
 
2
 
3
 
4
+ # update custom methods here (Ref: https://github.com/Chainlit/chainlit/blob/4b533cd53173bcc24abe4341a7108f0070d60099/backend/chainlit/data/__init__.py)
5
+ class CustomLiteralDataLayer(ChainlitDataLayer):
6
+ def __init__(self, **kwargs):
7
+ super().__init__(**kwargs)
 
 
 
 
 
 
8
 
9
+ @queue_until_user_message()
10
+ async def create_step(self, step_dict: "StepDict"):
11
+ metadata = dict(
12
+ step_dict.get("metadata", {}),
13
+ **{
14
+ "waitForAnswer": step_dict.get("waitForAnswer"),
15
+ "language": step_dict.get("language"),
16
+ "showInput": step_dict.get("showInput"),
17
+ },
18
+ )
 
 
19
 
20
+ step: LiteralStepDict = {
21
+ "createdAt": step_dict.get("createdAt"),
22
+ "startTime": step_dict.get("start"),
23
+ "endTime": step_dict.get("end"),
24
+ "generation": step_dict.get("generation"),
25
+ "id": step_dict.get("id"),
26
+ "parentId": step_dict.get("parentId"),
27
+ "name": step_dict.get("name"),
28
+ "threadId": step_dict.get("threadId"),
29
+ "type": step_dict.get("type"),
30
+ "tags": step_dict.get("tags"),
31
+ "metadata": metadata,
32
+ }
33
+ if step_dict.get("input"):
34
+ step["input"] = {"content": step_dict.get("input")}
35
+ if step_dict.get("output"):
36
+ step["output"] = {"content": step_dict.get("output")}
37
+ if step_dict.get("isError"):
38
+ step["error"] = step_dict.get("output")
39
+
40
+ # print("\n\n\n")
41
+ # print("Step: ", step)
42
+ # print("\n\n\n")
43
+
44
+ await self.client.api.send_steps([step])
code/modules/config/config.yml CHANGED
@@ -3,6 +3,7 @@ log_chunk_dir: '../storage/logs/chunks' # str
3
  device: 'cpu' # str [cuda, cpu]
4
 
5
  vectorstore:
 
6
  embedd_files: False # bool
7
  data_path: '../storage/data' # str
8
  url_file_path: '../storage/data/urls.txt' # str
@@ -14,7 +15,7 @@ vectorstore:
14
  score_threshold : 0.2 # float
15
 
16
  faiss_params: # Not used as of now
17
- index_path: '../vectorstores/faiss.index' # str
18
  index_type: 'Flat' # str [Flat, HNSW, IVF]
19
  index_dimension: 384 # int
20
  index_nlist: 100 # int
@@ -24,19 +25,23 @@ vectorstore:
24
  index_name: "new_idx" # str
25
 
26
  llm_params:
 
27
  use_history: True # bool
 
28
  memory_window: 3 # int
29
- llm_loader: 'openai' # str [local_llm, openai]
 
30
  openai_params:
31
- model: 'gpt-3.5-turbo-1106' # str [gpt-3.5-turbo-1106, gpt-4]
32
  local_llm_params:
33
- model: 'tiny-llama'
34
- temperature: 0.7
35
  pdf_reader: 'gpt' # str [llama, pymupdf, gpt]
36
 
37
  chat_logging:
38
- log_chat: False # bool
39
  platform: 'literalai'
 
40
 
41
  splitter_options:
42
  use_splitter: True # bool
 
3
  device: 'cpu' # str [cuda, cpu]
4
 
5
  vectorstore:
6
+ load_from_HF: True # bool
7
  embedd_files: False # bool
8
  data_path: '../storage/data' # str
9
  url_file_path: '../storage/data/urls.txt' # str
 
15
  score_threshold : 0.2 # float
16
 
17
  faiss_params: # Not used as of now
18
+ index_path: 'vectorstores/faiss.index' # str
19
  index_type: 'Flat' # str [Flat, HNSW, IVF]
20
  index_dimension: 384 # int
21
  index_nlist: 100 # int
 
25
  index_name: "new_idx" # str
26
 
27
  llm_params:
28
+ llm_arch: 'langchain' # [langchain]
29
  use_history: True # bool
30
+ generate_follow_up: False # bool
31
  memory_window: 3 # int
32
+ llm_style: 'Normal' # str [Normal, ELI5]
33
+ llm_loader: 'gpt-4o-mini' # str [local_llm, gpt-3.5-turbo-1106, gpt-4, gpt-4o-mini]
34
  openai_params:
35
+ temperature: 0.7 # float
36
  local_llm_params:
37
+ temperature: 0.7 # float
38
+ stream: False # bool
39
  pdf_reader: 'gpt' # str [llama, pymupdf, gpt]
40
 
41
  chat_logging:
42
+ log_chat: True # bool
43
  platform: 'literalai'
44
+ callbacks: False # bool
45
 
46
  splitter_options:
47
  use_splitter: True # bool
code/modules/config/constants.py CHANGED
@@ -8,83 +8,16 @@ load_dotenv()
8
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
9
  LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY")
10
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
11
- LITERAL_API_KEY = os.getenv("LITERAL_API_KEY")
 
12
 
13
- opening_message = f"Hey, What Can I Help You With?\n\nYou can me ask me questions about the course logistics, course content, about the final project, or anything else!"
14
-
15
- # Prompt Templates
16
-
17
- openai_prompt_template = """Use the following pieces of information to answer the user's question.
18
- You are an intelligent chatbot designed to help students with questions regarding the course.
19
- Render math equations in LaTeX format between $ or $$ signs, stick to the parameter and variable icons found in your context.
20
- Be sure to explain the parameters and variables in the equations.
21
- If you don't know the answer, just say that you don't know.
22
-
23
- Context: {context}
24
- Question: {question}
25
-
26
- Only return the helpful answer below and nothing else.
27
- Helpful answer:
28
- """
29
-
30
- openai_prompt_template_with_history = """Use the following pieces of information to answer the user's question.
31
- You are an intelligent chatbot designed to help students with questions regarding the course.
32
- Render math equations in LaTeX format between $ or $$ signs, stick to the parameter and variable icons found in your context.
33
- Be sure to explain the parameters and variables in the equations.
34
- If you don't know the answer, just say that you don't know, don't try to make up an answer.
35
-
36
- Use the history to answer the question if you can.
37
- Chat History:
38
- {chat_history}
39
- Context: {context}
40
- Question: {question}
41
-
42
- Only return the helpful answer below and nothing else.
43
- Helpful answer:
44
- """
45
-
46
- tinyllama_prompt_template = """
47
- <|im_start|>system
48
- Assistant is an intelligent chatbot designed to help students with questions regarding the course. Only answer questions using the context below and if you're not sure of an answer, you can say "I don't know". Always give a brief and concise answer to the question. When asked for formulas, give a brief description of the formula and output math equations in LaTeX format between $ signs.
49
-
50
- Context:
51
- {context}
52
- <|im_end|>
53
- <|im_start|>user
54
- Question: Who is the instructor for this course?
55
- <|im_end|>
56
- <|im_start|>assistant
57
- The instructor for this course is Prof. Thomas Gardos.
58
- <|im_end|>
59
- <|im_start|>user
60
- Question: {question}
61
- <|im_end|>
62
- <|im_start|>assistant
63
- """
64
-
65
- tinyllama_prompt_template_with_history = """
66
- <|im_start|>system
67
- Assistant is an intelligent chatbot designed to help students with questions regarding the course. Only answer questions using the context below and if you're not sure of an answer, you can say "I don't know". Always give a brief and concise answer to the question. Output math equations in LaTeX format between $ signs. Use the history to answer the question if you can.
68
-
69
- Chat History:
70
- {chat_history}
71
- Context:
72
- {context}
73
- <|im_end|>
74
- <|im_start|>user
75
- Question: Who is the instructor for this course?
76
- <|im_end|>
77
- <|im_start|>assistant
78
- The instructor for this course is Prof. Thomas Gardos.
79
- <|im_end|>
80
- <|im_start|>user
81
- Question: {question}
82
- <|im_end|>
83
- <|im_start|>assistant
84
- """
85
 
 
86
 
87
  # Model Paths
88
 
89
  LLAMA_PATH = "../storage/models/tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
90
- MISTRAL_PATH = "storage/models/mistral-7b-v0.1.Q4_K_M.gguf"
 
 
8
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
9
  LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY")
10
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
11
+ LITERAL_API_KEY_LOGGING = os.getenv("LITERAL_API_KEY_LOGGING")
12
+ LITERAL_API_URL = os.getenv("LITERAL_API_URL")
13
 
14
+ OAUTH_GOOGLE_CLIENT_ID = os.getenv("OAUTH_GOOGLE_CLIENT_ID")
15
+ OAUTH_GOOGLE_CLIENT_SECRET = os.getenv("OAUTH_GOOGLE_CLIENT_SECRET")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ opening_message = f"Hey, What Can I Help You With?\n\nYou can me ask me questions about the course logistics, course content, about the final project, or anything else!"
18
 
19
  # Model Paths
20
 
21
  LLAMA_PATH = "../storage/models/tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
22
+
23
+ RETRIEVER_HF_PATHS = {"RAGatouille": "XThomasBU/Colbert_Index"}
code/modules/config/prompts.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompts = {
2
+ "openai": {
3
+ "rephrase_prompt": (
4
+ "You are someone that rephrases statements. Rephrase the student's question to add context from their chat history if relevant, ensuring it remains from the student's point of view. "
5
+ "Incorporate relevant details from the chat history to make the question clearer and more specific. "
6
+ "Do not change the meaning of the original statement, and maintain the student's tone and perspective. "
7
+ "If the question is conversational and doesn't require context, do not rephrase it. "
8
+ "Example: If the student previously asked about backpropagation in the context of deep learning and now asks 'what is it', rephrase to 'What is backpropagation.'. "
9
+ "Example: Do not rephrase if the user is asking something specific like 'cool, suggest a project with transformers to use as my final project' "
10
+ "Chat history: \n{chat_history}\n"
11
+ "Rephrase the following question only if necessary: '{input}'"
12
+ "Rephrased Question:'"
13
+ ),
14
+ "prompt_with_history": {
15
+ "normal": (
16
+ "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance. "
17
+ "If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
18
+ "Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevant. "
19
+ "Render math equations in LaTeX format between $ or $$ signs, stick to the parameter and variable icons found in your context. Be sure to explain the parameters and variables in the equations."
20
+ "Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
21
+ "Do not get influenced by the style of conversation in the chat history. Follow the instructions given here."
22
+ "Chat History:\n{chat_history}\n\n"
23
+ "Context:\n{context}\n\n"
24
+ "Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
25
+ "Student: {input}\n"
26
+ "AI Tutor:"
27
+ ),
28
+ "eli5": (
29
+ "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Your job is to explain things in the simplest and most engaging way possible, just like the 'Explain Like I'm 5' (ELI5) concept."
30
+ "If you don't know the answer, do your best without making things up. Keep your explanations straightforward and very easy to understand."
31
+ "Use the chat history and context to help you, but avoid repeating past responses. Provide links from the source_file metadata when they're helpful."
32
+ "Use very simple language and examples to explain any math equations, and put the equations in LaTeX format between $ or $$ signs."
33
+ "Be friendly and engaging, like you're chatting with a young child who's curious and eager to learn. Avoid complex terms and jargon."
34
+ "Include simple and clear examples wherever you can to make things easier to understand."
35
+ "Do not get influenced by the style of conversation in the chat history. Follow the instructions given here."
36
+ "Chat History:\n{chat_history}\n\n"
37
+ "Context:\n{context}\n\n"
38
+ "Answer the student's question below in a friendly, simple, and engaging way, just like the ELI5 concept. Use the context and history only if they're relevant, otherwise, just have a natural conversation."
39
+ "Give a clear and detailed explanation with simple examples to make it easier to understand. Remember, your goal is to break down complex topics into very simple terms, just like ELI5."
40
+ "Student: {input}\n"
41
+ "AI Tutor:"
42
+ ),
43
+ "socratic": (
44
+ "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Engage the student in a Socratic dialogue to help them discover answers on their own. Use the provided context to guide your questioning."
45
+ "If you don't know the answer, do your best without making things up. Keep the conversation engaging and inquisitive."
46
+ "Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata when relevant. Use the source context that is most relevant."
47
+ "Speak in a friendly and engaging manner, encouraging critical thinking and self-discovery."
48
+ "Use questions to lead the student to explore the topic and uncover answers."
49
+ "Chat History:\n{chat_history}\n\n"
50
+ "Context:\n{context}\n\n"
51
+ "Answer the student's question below by guiding them through a series of questions and insights that lead to deeper understanding. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation."
52
+ "Foster an inquisitive mindset and help the student discover answers through dialogue."
53
+ "Student: {input}\n"
54
+ "AI Tutor:"
55
+ ),
56
+ },
57
+ "prompt_no_history": (
58
+ "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance. "
59
+ "If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
60
+ "Provide links from the source_file metadata. Use the source context that is most relevant. "
61
+ "Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
62
+ "Context:\n{context}\n\n"
63
+ "Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
64
+ "Student: {input}\n"
65
+ "AI Tutor:"
66
+ ),
67
+ },
68
+ "tiny_llama": {
69
+ "prompt_no_history": (
70
+ "system\n"
71
+ "Assistant is an intelligent chatbot designed to help students with questions regarding the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance.\n"
72
+ "If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally.\n"
73
+ "Provide links from the source_file metadata. Use the source context that is most relevant.\n"
74
+ "Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n"
75
+ "\n\n"
76
+ "user\n"
77
+ "Context:\n{context}\n\n"
78
+ "Question: {input}\n"
79
+ "\n\n"
80
+ "assistant"
81
+ ),
82
+ "prompt_with_history": (
83
+ "system\n"
84
+ "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance. "
85
+ "If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
86
+ "Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevant. "
87
+ "Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n"
88
+ "\n\n"
89
+ "user\n"
90
+ "Chat History:\n{chat_history}\n\n"
91
+ "Context:\n{context}\n\n"
92
+ "Question: {input}\n"
93
+ "\n\n"
94
+ "assistant"
95
+ ),
96
+ },
97
+ }
code/modules/dataloader/data_loader.py CHANGED
@@ -277,11 +277,11 @@ class ChunkProcessor:
277
 
278
  page_num = doc.metadata.get("page", 0)
279
  file_data[page_num] = doc.page_content
280
- metadata = (
281
- addl_metadata.get(file_path, {})
282
- if metadata_source == "file"
283
- else {"source": file_path, "page": page_num}
284
- )
285
  file_metadata[page_num] = metadata
286
 
287
  if self.config["vectorstore"]["db_option"] not in ["RAGatouille"]:
 
277
 
278
  page_num = doc.metadata.get("page", 0)
279
  file_data[page_num] = doc.page_content
280
+
281
+ # Create a new dictionary for metadata in each iteration
282
+ metadata = addl_metadata.get(file_path, {}).copy()
283
+ metadata["page"] = page_num
284
+ metadata["source"] = file_path
285
  file_metadata[page_num] = metadata
286
 
287
  if self.config["vectorstore"]["db_option"] not in ["RAGatouille"]:
code/modules/vectorstore/base.py CHANGED
@@ -29,5 +29,8 @@ class VectorStoreBase:
29
  """
30
  raise NotImplementedError
31
 
 
 
 
32
  def __str__(self):
33
  return self.__class__.__name__
 
29
  """
30
  raise NotImplementedError
31
 
32
+ def __len__(self):
33
+ raise NotImplementedError
34
+
35
  def __str__(self):
36
  return self.__class__.__name__
code/modules/vectorstore/chroma.py CHANGED
@@ -39,3 +39,6 @@ class ChromaVectorStore(VectorStoreBase):
39
 
40
  def as_retriever(self):
41
  return self.vectorstore.as_retriever()
 
 
 
 
39
 
40
  def as_retriever(self):
41
  return self.vectorstore.as_retriever()
42
+
43
+ def __len__(self):
44
+ return len(self.vectorstore)
code/modules/vectorstore/colbert.py CHANGED
@@ -1,6 +1,67 @@
1
  from ragatouille import RAGPretrainedModel
2
  from modules.vectorstore.base import VectorStoreBase
 
 
 
 
3
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  class ColbertVectorStore(VectorStoreBase):
@@ -24,16 +85,28 @@ class ColbertVectorStore(VectorStoreBase):
24
  document_ids=document_names,
25
  document_metadatas=document_metadata,
26
  )
 
27
 
28
  def load_database(self):
29
  path = os.path.join(
 
30
  self.config["vectorstore"]["db_path"],
31
  "db_" + self.config["vectorstore"]["db_option"],
32
  )
33
  self.vectorstore = RAGPretrainedModel.from_index(
34
  f"{path}/colbert/indexes/new_idx"
35
  )
 
 
 
 
 
 
 
36
  return self.vectorstore
37
 
38
  def as_retriever(self):
39
  return self.vectorstore.as_retriever()
 
 
 
 
1
  from ragatouille import RAGPretrainedModel
2
  from modules.vectorstore.base import VectorStoreBase
3
+ from langchain_core.retrievers import BaseRetriever
4
+ from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun, Callbacks
5
+ from langchain_core.documents import Document
6
+ from typing import Any, List, Optional, Sequence
7
  import os
8
+ import json
9
+
10
+
11
+ class RAGatouilleLangChainRetrieverWithScore(BaseRetriever):
12
+ model: Any
13
+ kwargs: dict = {}
14
+
15
+ def _get_relevant_documents(
16
+ self,
17
+ query: str,
18
+ *,
19
+ run_manager: CallbackManagerForRetrieverRun, # noqa
20
+ ) -> List[Document]:
21
+ """Get documents relevant to a query."""
22
+ docs = self.model.search(query, **self.kwargs)
23
+ return [
24
+ Document(
25
+ page_content=doc["content"],
26
+ metadata={**doc.get("document_metadata", {}), "score": doc["score"]},
27
+ )
28
+ for doc in docs
29
+ ]
30
+
31
+ async def _aget_relevant_documents(
32
+ self,
33
+ query: str,
34
+ *,
35
+ run_manager: CallbackManagerForRetrieverRun, # noqa
36
+ ) -> List[Document]:
37
+ """Get documents relevant to a query."""
38
+ docs = self.model.search(query, **self.kwargs)
39
+ return [
40
+ Document(
41
+ page_content=doc["content"],
42
+ metadata={**doc.get("document_metadata", {}), "score": doc["score"]},
43
+ )
44
+ for doc in docs
45
+ ]
46
+
47
+
48
+ class RAGPretrainedModel(RAGPretrainedModel):
49
+ """
50
+ Adding len property to RAGPretrainedModel
51
+ """
52
+
53
+ def __init__(self, *args, **kwargs):
54
+ super().__init__(*args, **kwargs)
55
+ self._document_count = 0
56
+
57
+ def set_document_count(self, count):
58
+ self._document_count = count
59
+
60
+ def __len__(self):
61
+ return self._document_count
62
+
63
+ def as_langchain_retriever(self, **kwargs: Any) -> BaseRetriever:
64
+ return RAGatouilleLangChainRetrieverWithScore(model=self, kwargs=kwargs)
65
 
66
 
67
  class ColbertVectorStore(VectorStoreBase):
 
85
  document_ids=document_names,
86
  document_metadatas=document_metadata,
87
  )
88
+ self.colbert.set_document_count(len(document_names))
89
 
90
  def load_database(self):
91
  path = os.path.join(
92
+ os.getcwd(),
93
  self.config["vectorstore"]["db_path"],
94
  "db_" + self.config["vectorstore"]["db_option"],
95
  )
96
  self.vectorstore = RAGPretrainedModel.from_index(
97
  f"{path}/colbert/indexes/new_idx"
98
  )
99
+
100
+ index_metadata = json.load(
101
+ open(f"{path}/colbert/indexes/new_idx/0.metadata.json")
102
+ )
103
+ num_documents = index_metadata["num_passages"]
104
+ self.vectorstore.set_document_count(num_documents)
105
+
106
  return self.vectorstore
107
 
108
  def as_retriever(self):
109
  return self.vectorstore.as_retriever()
110
+
111
+ def __len__(self):
112
+ return len(self.vectorstore)
code/modules/vectorstore/faiss.py CHANGED
@@ -3,6 +3,13 @@ from modules.vectorstore.base import VectorStoreBase
3
  import os
4
 
5
 
 
 
 
 
 
 
 
6
  class FaissVectorStore(VectorStoreBase):
7
  def __init__(self, config):
8
  self.config = config
@@ -35,3 +42,6 @@ class FaissVectorStore(VectorStoreBase):
35
 
36
  def as_retriever(self):
37
  return self.vectorstore.as_retriever()
 
 
 
 
3
  import os
4
 
5
 
6
+ class FAISS(FAISS):
7
+ """To add length property to FAISS class"""
8
+
9
+ def __len__(self):
10
+ return self.index.ntotal
11
+
12
+
13
  class FaissVectorStore(VectorStoreBase):
14
  def __init__(self, config):
15
  self.config = config
 
42
 
43
  def as_retriever(self):
44
  return self.vectorstore.as_retriever()
45
+
46
+ def __len__(self):
47
+ return len(self.vectorstore)
code/modules/vectorstore/raptor.py CHANGED
@@ -5,7 +5,7 @@ import os
5
  import numpy as np
6
  import pandas as pd
7
  import umap
8
- from langchain_core.prompts import ChatPromptTemplate
9
  from langchain_core.output_parsers import StrOutputParser
10
  from sklearn.mixture import GaussianMixture
11
  from langchain_community.chat_models import ChatOpenAI
@@ -16,6 +16,13 @@ from modules.vectorstore.base import VectorStoreBase
16
  RANDOM_SEED = 42
17
 
18
 
 
 
 
 
 
 
 
19
  class RAPTORVectoreStore(VectorStoreBase):
20
  def __init__(self, config, documents=[], text_splitter=None, embedding_model=None):
21
  self.documents = documents
 
5
  import numpy as np
6
  import pandas as pd
7
  import umap
8
+ from langchain_core.prompts.chat import ChatPromptTemplate
9
  from langchain_core.output_parsers import StrOutputParser
10
  from sklearn.mixture import GaussianMixture
11
  from langchain_community.chat_models import ChatOpenAI
 
16
  RANDOM_SEED = 42
17
 
18
 
19
+ class FAISS(FAISS):
20
+ """To add length property to FAISS class"""
21
+
22
+ def __len__(self):
23
+ return self.index.ntotal
24
+
25
+
26
  class RAPTORVectoreStore(VectorStoreBase):
27
  def __init__(self, config, documents=[], text_splitter=None, embedding_model=None):
28
  self.documents = documents
code/modules/vectorstore/store_manager.py CHANGED
@@ -3,6 +3,7 @@ from modules.vectorstore.helpers import *
3
  from modules.dataloader.webpage_crawler import WebpageCrawler
4
  from modules.dataloader.data_loader import DataLoader
5
  from modules.dataloader.helpers import *
 
6
  from modules.vectorstore.embedding_model_loader import EmbeddingModelLoader
7
  import logging
8
  import os
@@ -135,14 +136,34 @@ class VectorStoreManager:
135
  self.embedding_model = self.create_embedding_model()
136
  else:
137
  self.embedding_model = None
138
- self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
 
 
 
 
 
 
 
 
139
  end_time = time.time() # End time for loading database
140
  self.logger.info(
141
- f"Time taken to load database: {end_time - start_time} seconds"
142
  )
143
  self.logger.info("Loaded database")
144
  return self.loaded_vector_db
145
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  if __name__ == "__main__":
148
  import yaml
@@ -152,7 +173,20 @@ if __name__ == "__main__":
152
  print(config)
153
  print(f"Trying to create database with config: {config}")
154
  vector_db = VectorStoreManager(config)
155
- vector_db.create_database()
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  print("Created database")
157
 
158
  print(f"Trying to load the database")
 
3
  from modules.dataloader.webpage_crawler import WebpageCrawler
4
  from modules.dataloader.data_loader import DataLoader
5
  from modules.dataloader.helpers import *
6
+ from modules.config.constants import RETRIEVER_HF_PATHS
7
  from modules.vectorstore.embedding_model_loader import EmbeddingModelLoader
8
  import logging
9
  import os
 
136
  self.embedding_model = self.create_embedding_model()
137
  else:
138
  self.embedding_model = None
139
+ try:
140
+ self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
141
+ except Exception as e:
142
+ raise ValueError(
143
+ f"Error loading database, check if it exists. if not run python -m modules.vectorstore.store_manager / Resteart the HF Space: {e}"
144
+ )
145
+ # print(f"Creating database")
146
+ # self.create_database()
147
+ # self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
148
  end_time = time.time() # End time for loading database
149
  self.logger.info(
150
+ f"Time taken to load database {self.config['vectorstore']['db_option']}: {end_time - start_time} seconds"
151
  )
152
  self.logger.info("Loaded database")
153
  return self.loaded_vector_db
154
 
155
+ def load_from_HF(self, HF_PATH):
156
+ start_time = time.time() # Start time for loading database
157
+ self.vector_db._load_from_HF(HF_PATH)
158
+ end_time = time.time()
159
+ self.logger.info(
160
+ f"Time taken to Download database {self.config['vectorstore']['db_option']} from Hugging Face: {end_time - start_time} seconds"
161
+ )
162
+ self.logger.info("Downloaded database")
163
+
164
+ def __len__(self):
165
+ return len(self.vector_db)
166
+
167
 
168
  if __name__ == "__main__":
169
  import yaml
 
173
  print(config)
174
  print(f"Trying to create database with config: {config}")
175
  vector_db = VectorStoreManager(config)
176
+ if config["vectorstore"]["load_from_HF"]:
177
+ if config["vectorstore"]["db_option"] in RETRIEVER_HF_PATHS:
178
+ vector_db.load_from_HF(
179
+ HF_PATH=RETRIEVER_HF_PATHS[config["vectorstore"]["db_option"]]
180
+ )
181
+ else:
182
+ # print(f"HF_PATH not available for {config['vectorstore']['db_option']}")
183
+ # print("Creating database")
184
+ # vector_db.create_database()
185
+ raise ValueError(
186
+ f"HF_PATH not available for {config['vectorstore']['db_option']}"
187
+ )
188
+ else:
189
+ vector_db.create_database()
190
  print("Created database")
191
 
192
  print(f"Trying to load the database")
code/modules/vectorstore/vectorstore.py CHANGED
@@ -2,6 +2,9 @@ from modules.vectorstore.faiss import FaissVectorStore
2
  from modules.vectorstore.chroma import ChromaVectorStore
3
  from modules.vectorstore.colbert import ColbertVectorStore
4
  from modules.vectorstore.raptor import RAPTORVectoreStore
 
 
 
5
 
6
 
7
  class VectorStore:
@@ -50,8 +53,39 @@ class VectorStore:
50
  else:
51
  return self.vectorstore.load_database(embedding_model)
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def _as_retriever(self):
54
  return self.vectorstore.as_retriever()
55
 
56
  def _get_vectorstore(self):
57
  return self.vectorstore
 
 
 
 
2
  from modules.vectorstore.chroma import ChromaVectorStore
3
  from modules.vectorstore.colbert import ColbertVectorStore
4
  from modules.vectorstore.raptor import RAPTORVectoreStore
5
+ from huggingface_hub import snapshot_download
6
+ import os
7
+ import shutil
8
 
9
 
10
  class VectorStore:
 
53
  else:
54
  return self.vectorstore.load_database(embedding_model)
55
 
56
+ def _load_from_HF(self, HF_PATH):
57
+ # Download the snapshot from Hugging Face Hub
58
+ # Note: Download goes to the cache directory
59
+ snapshot_path = snapshot_download(
60
+ repo_id=HF_PATH,
61
+ repo_type="dataset",
62
+ force_download=True,
63
+ )
64
+
65
+ # Move the downloaded files to the desired directory
66
+ target_path = os.path.join(
67
+ self.config["vectorstore"]["db_path"],
68
+ "db_" + self.config["vectorstore"]["db_option"],
69
+ )
70
+
71
+ # Create target path if it doesn't exist
72
+ os.makedirs(target_path, exist_ok=True)
73
+
74
+ # move all files and directories from snapshot_path to target_path
75
+ # target path is used while loading the database
76
+ for item in os.listdir(snapshot_path):
77
+ s = os.path.join(snapshot_path, item)
78
+ d = os.path.join(target_path, item)
79
+ if os.path.isdir(s):
80
+ shutil.copytree(s, d, dirs_exist_ok=True)
81
+ else:
82
+ shutil.copy2(s, d)
83
+
84
  def _as_retriever(self):
85
  return self.vectorstore.as_retriever()
86
 
87
  def _get_vectorstore(self):
88
  return self.vectorstore
89
+
90
+ def __len__(self):
91
+ return self.vectorstore.__len__()
code/public/test.css CHANGED
@@ -31,3 +31,13 @@ a[href*='https://github.com/Chainlit/chainlit'] {
31
  .MuiAvatar-root.MuiAvatar-circular.css-v72an7 .MuiAvatar-img.css-1hy9t21 {
32
  display: none;
33
  }
 
 
 
 
 
 
 
 
 
 
 
31
  .MuiAvatar-root.MuiAvatar-circular.css-v72an7 .MuiAvatar-img.css-1hy9t21 {
32
  display: none;
33
  }
34
+
35
+ /* Hide the new chat button
36
+ #new-chat-button {
37
+ display: none;
38
+ } */
39
+
40
+ /* Hide the open sidebar button
41
+ #open-sidebar-button {
42
+ display: none;
43
+ } */
requirements.txt CHANGED
@@ -1,27 +1,25 @@
1
- # Automatically generated by https://github.com/damnever/pigar.
2
-
3
- aiohttp==3.9.5
4
- beautifulsoup4==4.12.3
5
- chainlit==1.1.304
6
- langchain==0.1.20
7
- langchain-community==0.0.38
8
- langchain-core==0.1.52
9
- literalai==0.0.604
10
- llama-parse==0.4.4
11
- numpy==1.26.4
12
- pandas==2.2.2
13
- pysrt==1.1.2
14
- python-dotenv==1.0.1
15
- PyYAML==6.0.1
16
- RAGatouille==0.0.8.post2
17
- requests==2.32.3
18
- scikit-learn==1.5.0
19
- torch==2.3.1
20
- tqdm==4.66.4
21
- transformers==4.41.2
22
- trulens_eval==0.31.0
23
- umap-learn==0.5.6
24
- trulens-eval==0.31.0
25
- llama-cpp-python==0.2.77
26
- pymupdf==1.24.5
27
  websockets
 
 
1
+ aiohttp
2
+ beautifulsoup4
3
+ chainlit
4
+ langchain
5
+ langchain-community
6
+ langchain-core
7
+ literalai
8
+ llama-parse
9
+ numpy
10
+ pandas
11
+ pysrt
12
+ python-dotenv
13
+ PyYAML
14
+ RAGatouille
15
+ requests
16
+ scikit-learn
17
+ torch
18
+ tqdm
19
+ transformers
20
+ trulens_eval
21
+ umap-learn
22
+ llama-cpp-python
23
+ pymupdf
 
 
 
24
  websockets
25
+ langchain-openai