XThomasBU commited on
Commit
8f6647c
·
1 Parent(s): 33e5fa6

init commit for chainlit improvements

Browse files
code/main.py CHANGED
@@ -1,176 +1,244 @@
1
- from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
2
- from langchain_core.prompts import PromptTemplate
3
- from langchain_community.embeddings import HuggingFaceEmbeddings
4
- from langchain_community.vectorstores import FAISS
5
- from langchain.chains import RetrievalQA
6
  import chainlit as cl
7
- from langchain_community.chat_models import ChatOpenAI
8
- from langchain_community.embeddings import OpenAIEmbeddings
9
  import yaml
10
- import logging
11
- from dotenv import load_dotenv
12
 
13
  from modules.chat.llm_tutor import LLMTutor
14
- from modules.config.constants import *
15
- from modules.chat.helpers import get_sources
16
  from modules.chat_processor.chat_processor import ChatProcessor
 
 
17
 
18
- global logger
19
- # Initialize logger
20
- logger = logging.getLogger(__name__)
21
- logger.setLevel(logging.INFO)
22
- formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
23
-
24
- # Console Handler
25
- console_handler = logging.StreamHandler()
26
- console_handler.setLevel(logging.INFO)
27
- console_handler.setFormatter(formatter)
28
- logger.addHandler(console_handler)
29
-
30
-
31
- @cl.set_starters
32
- async def set_starters():
33
- return [
34
- cl.Starter(
35
- label="recording on CNNs?",
36
- message="Where can I find the recording for the lecture on Transfromers?",
37
- icon="/public/adv-screen-recorder-svgrepo-com.svg",
38
- ),
39
- cl.Starter(
40
- label="where's the slides?",
41
- message="When are the lectures? I can't find the schedule.",
42
- icon="/public/alarmy-svgrepo-com.svg",
43
- ),
44
- cl.Starter(
45
- label="Due Date?",
46
- message="When is the final project due?",
47
- icon="/public/calendar-samsung-17-svgrepo-com.svg",
48
- ),
49
- cl.Starter(
50
- label="Explain backprop.",
51
- message="I didnt understand the math behind backprop, could you explain it?",
52
- icon="/public/acastusphoton-svgrepo-com.svg",
53
- ),
54
- ]
55
-
56
-
57
- # Adding option to select the chat profile
58
- @cl.set_chat_profiles
59
- async def chat_profile():
60
- return [
61
- # cl.ChatProfile(
62
- # name="Mistral",
63
- # markdown_description="Use the local LLM: **Mistral**.",
64
- # ),
65
- cl.ChatProfile(
66
- name="gpt-3.5-turbo-1106",
67
- markdown_description="Use OpenAI API for **gpt-3.5-turbo-1106**.",
68
- ),
69
- cl.ChatProfile(
70
- name="gpt-4",
71
- markdown_description="Use OpenAI API for **gpt-4**.",
72
- ),
73
- cl.ChatProfile(
74
- name="Llama",
75
- markdown_description="Use the local LLM: **Tiny Llama**.",
76
- ),
77
- ]
78
-
79
-
80
- @cl.author_rename
81
- def rename(orig_author: str):
82
- rename_dict = {"Chatbot": "AI Tutor"}
83
- return rename_dict.get(orig_author, orig_author)
84
-
85
-
86
- # chainlit code
87
- @cl.on_chat_start
88
- async def start():
89
- with open("modules/config/config.yml", "r") as f:
90
- config = yaml.safe_load(f)
91
-
92
- # Ensure log directory exists
93
- log_directory = config["log_dir"]
94
- if not os.path.exists(log_directory):
95
- os.makedirs(log_directory)
96
-
97
- # File Handler
98
- log_file_path = (
99
- f"{log_directory}/tutor.log" # Change this to your desired log file path
100
- )
101
- file_handler = logging.FileHandler(log_file_path, mode="w")
102
- file_handler.setLevel(logging.INFO)
103
- file_handler.setFormatter(formatter)
104
- logger.addHandler(file_handler)
105
-
106
- logger.info("Config file loaded")
107
- logger.info(f"Config: {config}")
108
- logger.info("Creating llm_tutor instance")
109
-
110
- chat_profile = cl.user_session.get("chat_profile")
111
- if chat_profile is not None:
112
- if chat_profile.lower() in ["gpt-3.5-turbo-1106", "gpt-4"]:
113
- config["llm_params"]["llm_loader"] = "openai"
114
- config["llm_params"]["openai_params"]["model"] = chat_profile.lower()
115
- elif chat_profile.lower() == "llama":
116
- config["llm_params"]["llm_loader"] = "local_llm"
117
- config["llm_params"]["local_llm_params"]["model"] = LLAMA_PATH
118
- config["llm_params"]["local_llm_params"]["model_type"] = "llama"
119
- elif chat_profile.lower() == "mistral":
120
- config["llm_params"]["llm_loader"] = "local_llm"
121
- config["llm_params"]["local_llm_params"]["model"] = MISTRAL_PATH
122
- config["llm_params"]["local_llm_params"]["model_type"] = "mistral"
123
-
124
- else:
125
- pass
126
-
127
- llm_tutor = LLMTutor(config, logger=logger)
128
-
129
- chain = llm_tutor.qa_bot()
130
- # msg = cl.Message(content=f"Starting the bot {chat_profile}...")
131
- # await msg.send()
132
- # msg.content = opening_message
133
- # await msg.update()
134
-
135
- tags = [chat_profile, config["vectorstore"]["db_option"]]
136
- chat_processor = ChatProcessor(config, tags=tags)
137
- cl.user_session.set("chain", chain)
138
- cl.user_session.set("counter", 0)
139
- cl.user_session.set("chat_processor", chat_processor)
140
-
141
-
142
- @cl.on_chat_end
143
- async def on_chat_end():
144
- await cl.Message(content="Sorry, I have to go now. Goodbye!").send()
145
-
146
-
147
- @cl.on_message
148
- async def main(message):
149
- global logger
150
- user = cl.user_session.get("user")
151
- chain = cl.user_session.get("chain")
152
-
153
- counter = cl.user_session.get("counter")
154
- counter += 1
155
- cl.user_session.set("counter", counter)
156
-
157
- # if counter >= 3: # Ensure the counter condition is checked
158
- # await cl.Message(content="Your credits are up!").send()
159
- # await on_chat_end() # Call the on_chat_end function to handle the end of the chat
160
- # return # Exit the function to stop further processing
161
- # else:
162
-
163
- cb = cl.AsyncLangchainCallbackHandler() # TODO: fix streaming here
164
- cb.answer_reached = True
165
-
166
- processor = cl.user_session.get("chat_processor")
167
- res = await processor.rag(message.content, chain, cb)
168
- try:
169
- answer = res["answer"]
170
- except:
171
- answer = res["result"]
172
-
173
- answer_with_sources, source_elements, sources_dict = get_sources(res, answer)
174
- processor._process(message.content, answer, sources_dict)
175
-
176
- await cl.Message(content=answer_with_sources, elements=source_elements).send()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import textwrap
3
+ from typing import Any, Callable, Dict, List, Literal, Optional, no_type_check
 
 
4
  import chainlit as cl
5
+ from chainlit import run_sync
6
+ from chainlit.config import config
7
  import yaml
8
+ import os
 
9
 
10
  from modules.chat.llm_tutor import LLMTutor
 
 
11
  from modules.chat_processor.chat_processor import ChatProcessor
12
+ from modules.config.constants import LLAMA_PATH
13
+ from modules.chat.helpers import get_sources
14
 
15
+ from chainlit.input_widget import Select, Switch, Slider
16
+
17
+ USER_TIMEOUT = 60_000
18
+ SYSTEM = "System 🖥️"
19
+ LLM = "LLM 🧠"
20
+ AGENT = "Agent <>"
21
+ YOU = "You 😃"
22
+ ERROR = "Error 🚫"
23
+
24
+
25
+ class Chatbot:
26
+ def __init__(self):
27
+ self.llm_tutor = None
28
+ self.chain = None
29
+ self.chat_processor = None
30
+ self.config = self._load_config()
31
+
32
+ def _load_config(self):
33
+ with open("modules/config/config.yml", "r") as f:
34
+ config = yaml.safe_load(f)
35
+ return config
36
+
37
+ async def ask_helper(func, **kwargs):
38
+ res = await func(**kwargs).send()
39
+ while not res:
40
+ res = await func(**kwargs).send()
41
+ return res
42
+
43
+ @no_type_check
44
+ async def setup_llm(self) -> None:
45
+ """From the session `llm_settings`, create new LLMConfig and LLM objects,
46
+ save them in session state."""
47
+
48
+ llm_settings = cl.user_session.get("llm_settings", {})
49
+ chat_profile = llm_settings.get("chat_model")
50
+ retriever_method = llm_settings.get("retriever_method")
51
+ memory_window = llm_settings.get("memory_window")
52
+
53
+ self._configure_llm(chat_profile)
54
+
55
+ chain = cl.user_session.get("chain")
56
+ memory = chain.memory
57
+ self.config["vectorstore"][
58
+ "db_option"
59
+ ] = retriever_method # update the retriever method in the config
60
+ memory.k = memory_window # set the memory window
61
+
62
+ self.llm_tutor = LLMTutor(self.config)
63
+ self.chain = self.llm_tutor.qa_bot(memory=memory)
64
+
65
+ tags = [chat_profile, self.config["vectorstore"]["db_option"]]
66
+ self.chat_processor = ChatProcessor(self.config, tags=tags)
67
+
68
+ cl.user_session.set("chain", self.chain)
69
+ cl.user_session.set("llm_tutor", self.llm_tutor)
70
+ cl.user_session.set("chat_processor", self.chat_processor)
71
+
72
+ @no_type_check
73
+ async def update_llm(self, new_settings: Dict[str, Any]) -> None:
74
+ """Update LLMConfig and LLM from settings, and save in session state."""
75
+ cl.user_session.set("llm_settings", new_settings)
76
+ await self.inform_llm_settings()
77
+ await self.setup_llm()
78
+
79
+ async def make_llm_settings_widgets(self, config=None):
80
+ config = config or self.config
81
+ await cl.ChatSettings(
82
+ [
83
+ cl.input_widget.Select(
84
+ id="chat_model",
85
+ label="Model Name (Default GPT-3)",
86
+ values=["llama", "gpt-3.5-turbo-1106", "gpt-4"],
87
+ initial_index=0,
88
+ ),
89
+ cl.input_widget.Select(
90
+ id="retriever_method",
91
+ label="Retriever (Default FAISS)",
92
+ values=["FAISS", "Chroma", "RAGatouille", "RAPTOR"],
93
+ initial_index=0,
94
+ ),
95
+ cl.input_widget.Slider(
96
+ id="memory_window",
97
+ label="Memory Window (Default 3)",
98
+ initial=3,
99
+ min=0,
100
+ max=10,
101
+ step=1,
102
+ ),
103
+ cl.input_widget.Switch(
104
+ id="view_sources", label="View Sources", initial=False
105
+ ),
106
+ ]
107
+ ).send() # type: ignore
108
+
109
+ @no_type_check
110
+ async def inform_llm_settings(self) -> None:
111
+ llm_settings: Dict[str, Any] = cl.user_session.get("llm_settings", {})
112
+ llm_tutor = cl.user_session.get("llm_tutor")
113
+ settings_dict = dict(
114
+ model=llm_settings.get("chat_model"),
115
+ retriever=llm_settings.get("retriever_method"),
116
+ memory_window=llm_settings.get("memory_window"),
117
+ num_docs_in_db=len(llm_tutor.vector_db),
118
+ view_sources=llm_settings.get("view_sources"),
119
+ )
120
+ await cl.Message(
121
+ author=SYSTEM,
122
+ content="LLM settings have been updated. You can continue with your Query!",
123
+ elements=[
124
+ cl.Text(
125
+ name="settings",
126
+ display="side",
127
+ content=json.dumps(settings_dict, indent=4),
128
+ language="json",
129
+ )
130
+ ],
131
+ ).send()
132
+
133
+ async def set_starters(self):
134
+ return [
135
+ cl.Starter(
136
+ label="recording on CNNs?",
137
+ message="Where can I find the recording for the lecture on Transformers?",
138
+ icon="/public/adv-screen-recorder-svgrepo-com.svg",
139
+ ),
140
+ cl.Starter(
141
+ label="where's the slides?",
142
+ message="When are the lectures? I can't find the schedule.",
143
+ icon="/public/alarmy-svgrepo-com.svg",
144
+ ),
145
+ cl.Starter(
146
+ label="Due Date?",
147
+ message="When is the final project due?",
148
+ icon="/public/calendar-samsung-17-svgrepo-com.svg",
149
+ ),
150
+ cl.Starter(
151
+ label="Explain backprop.",
152
+ message="I didn't understand the math behind backprop, could you explain it?",
153
+ icon="/public/acastusphoton-svgrepo-com.svg",
154
+ ),
155
+ ]
156
+
157
+ async def chat_profile(self):
158
+ return [
159
+ # cl.ChatProfile(
160
+ # name="gpt-3.5-turbo-1106",
161
+ # markdown_description="Use OpenAI API for **gpt-3.5-turbo-1106**.",
162
+ # ),
163
+ # cl.ChatProfile(
164
+ # name="gpt-4",
165
+ # markdown_description="Use OpenAI API for **gpt-4**.",
166
+ # ),
167
+ cl.ChatProfile(
168
+ name="Llama",
169
+ markdown_description="Use the local LLM: **Tiny Llama**.",
170
+ ),
171
+ ]
172
+
173
+ def rename(self, orig_author: str):
174
+ rename_dict = {"Chatbot": "AI Tutor"}
175
+ return rename_dict.get(orig_author, orig_author)
176
+
177
+ async def start(self):
178
+ await self.make_llm_settings_widgets(self.config)
179
+
180
+ chat_profile = cl.user_session.get("chat_profile")
181
+ if chat_profile:
182
+ self._configure_llm(chat_profile)
183
+
184
+ self.llm_tutor = LLMTutor(self.config)
185
+ self.chain = self.llm_tutor.qa_bot()
186
+ tags = [chat_profile, self.config["vectorstore"]["db_option"]]
187
+ self.chat_processor = ChatProcessor(self.config, tags=tags)
188
+
189
+ cl.user_session.set("llm_tutor", self.llm_tutor)
190
+ cl.user_session.set("chain", self.chain)
191
+ cl.user_session.set("counter", 0)
192
+ cl.user_session.set("chat_processor", self.chat_processor)
193
+
194
+ async def on_chat_end(self):
195
+ await cl.Message(content="Sorry, I have to go now. Goodbye!").send()
196
+
197
+ async def main(self, message):
198
+ user = cl.user_session.get("user")
199
+ chain = cl.user_session.get("chain")
200
+ counter = cl.user_session.get("counter")
201
+ llm_settings = cl.user_session.get("llm_settings")
202
+
203
+ counter += 1
204
+ cl.user_session.set("counter", counter)
205
+
206
+ cb = cl.AsyncLangchainCallbackHandler() # TODO: fix streaming here
207
+ cb.answer_reached = True
208
+
209
+ processor = cl.user_session.get("chat_processor")
210
+ res = await processor.rag(message.content, chain, cb)
211
+ answer = res.get("answer", res.get("result"))
212
+
213
+ answer_with_sources, source_elements, sources_dict = get_sources(
214
+ res, answer, view_sources=llm_settings.get("view_sources")
215
+ )
216
+ processor._process(message.content, answer, sources_dict)
217
+
218
+ await cl.Message(content=answer_with_sources, elements=source_elements).send()
219
+
220
+ def _configure_llm(self, chat_profile):
221
+ chat_profile = chat_profile.lower()
222
+ if chat_profile in ["gpt-3.5-turbo-1106", "gpt-4"]:
223
+ self.config["llm_params"]["llm_loader"] = "openai"
224
+ self.config["llm_params"]["openai_params"]["model"] = chat_profile
225
+ elif chat_profile == "llama":
226
+ self.config["llm_params"]["llm_loader"] = "local_llm"
227
+ self.config["llm_params"]["local_llm_params"]["model"] = LLAMA_PATH
228
+ self.config["llm_params"]["local_llm_params"]["model_type"] = "llama"
229
+ elif chat_profile == "mistral":
230
+ self.config["llm_params"]["llm_loader"] = "local_llm"
231
+ self.config["llm_params"]["local_llm_params"]["model"] = MISTRAL_PATH
232
+ self.config["llm_params"]["local_llm_params"]["model_type"] = "mistral"
233
+
234
+
235
+ chatbot = Chatbot()
236
+
237
+ # Register functions to Chainlit events
238
+ cl.set_starters(chatbot.set_starters)
239
+ cl.set_chat_profiles(chatbot.chat_profile)
240
+ cl.author_rename(chatbot.rename)
241
+ cl.on_chat_start(chatbot.start)
242
+ cl.on_chat_end(chatbot.on_chat_end)
243
+ cl.on_message(chatbot.main)
244
+ cl.on_settings_update(chatbot.update_llm)
code/modules/chat/helpers.py CHANGED
@@ -3,7 +3,7 @@ import chainlit as cl
3
  from langchain_core.prompts import PromptTemplate
4
 
5
 
6
- def get_sources(res, answer):
7
  source_elements = []
8
  source_dict = {} # Dictionary to store URL elements
9
 
@@ -40,40 +40,42 @@ def get_sources(res, answer):
40
  full_answer = "**Answer:**\n"
41
  full_answer += answer
42
 
43
- # Then, display the sources
44
- full_answer += "\n\n**Sources:**\n"
45
- for idx, (url_name, source_data) in enumerate(source_dict.items()):
46
- full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
47
 
48
- name = f"Source {idx + 1} Text\n"
49
- full_answer += name
50
- source_elements.append(
51
- cl.Text(name=name, content=source_data["text"], display="side")
52
- )
53
 
54
- # Add a PDF element if the source is a PDF file
55
- if source_data["url"].lower().endswith(".pdf"):
56
- name = f"Source {idx + 1} PDF\n"
57
  full_answer += name
58
- pdf_url = f"{source_data['url']}#page={source_data['page']+1}"
59
- source_elements.append(cl.Pdf(name=name, url=pdf_url, display="side"))
 
60
 
61
- full_answer += "\n**Metadata:**\n"
62
- for idx, (url_name, source_data) in enumerate(source_dict.items()):
63
- full_answer += f"\nSource {idx + 1} Metadata:\n"
64
- source_elements.append(
65
- cl.Text(
66
- name=f"Source {idx + 1} Metadata",
67
- content=f"Source: {source_data['url']}\n"
68
- f"Page: {source_data['page']}\n"
69
- f"Type: {source_data['source_type']}\n"
70
- f"Date: {source_data['date']}\n"
71
- f"TL;DR: {source_data['lecture_tldr']}\n"
72
- f"Lecture Recording: {source_data['lecture_recording']}\n"
73
- f"Suggested Readings: {source_data['suggested_readings']}\n",
74
- display="side",
 
 
 
 
 
 
 
 
75
  )
76
- )
77
 
78
  return full_answer, source_elements, source_dict
79
 
 
3
  from langchain_core.prompts import PromptTemplate
4
 
5
 
6
+ def get_sources(res, answer, view_sources=False):
7
  source_elements = []
8
  source_dict = {} # Dictionary to store URL elements
9
 
 
40
  full_answer = "**Answer:**\n"
41
  full_answer += answer
42
 
43
+ if view_sources:
 
 
 
44
 
45
+ # Then, display the sources
46
+ full_answer += "\n\n**Sources:**\n"
47
+ for idx, (url_name, source_data) in enumerate(source_dict.items()):
48
+ full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
 
49
 
50
+ name = f"Source {idx + 1} Text\n"
 
 
51
  full_answer += name
52
+ source_elements.append(
53
+ cl.Text(name=name, content=source_data["text"], display="side")
54
+ )
55
 
56
+ # Add a PDF element if the source is a PDF file
57
+ if source_data["url"].lower().endswith(".pdf"):
58
+ name = f"Source {idx + 1} PDF\n"
59
+ full_answer += name
60
+ pdf_url = f"{source_data['url']}#page={source_data['page']+1}"
61
+ source_elements.append(cl.Pdf(name=name, url=pdf_url, display="side"))
62
+
63
+ full_answer += "\n**Metadata:**\n"
64
+ for idx, (url_name, source_data) in enumerate(source_dict.items()):
65
+ full_answer += f"\nSource {idx + 1} Metadata:\n"
66
+ source_elements.append(
67
+ cl.Text(
68
+ name=f"Source {idx + 1} Metadata",
69
+ content=f"Source: {source_data['url']}\n"
70
+ f"Page: {source_data['page']}\n"
71
+ f"Type: {source_data['source_type']}\n"
72
+ f"Date: {source_data['date']}\n"
73
+ f"TL;DR: {source_data['lecture_tldr']}\n"
74
+ f"Lecture Recording: {source_data['lecture_recording']}\n"
75
+ f"Suggested Readings: {source_data['suggested_readings']}\n",
76
+ display="side",
77
+ )
78
  )
 
79
 
80
  return full_answer, source_elements, source_dict
81
 
code/modules/chat/llm_tutor.py CHANGED
@@ -157,18 +157,18 @@ class LLMTutor:
157
  return prompt
158
 
159
  # Retrieval QA Chain
160
- def retrieval_qa_chain(self, llm, prompt, db):
161
 
162
  retriever = Retriever(self.config)._return_retriever(db)
163
 
164
  if self.config["llm_params"]["use_history"]:
165
- memory = ConversationBufferWindowMemory(
166
- k=self.config["llm_params"]["memory_window"],
167
- memory_key="chat_history",
168
- return_messages=True,
169
- output_key="answer",
170
- max_token_limit=128,
171
- )
172
  qa_chain = CustomConversationalRetrievalChain.from_llm(
173
  llm=llm,
174
  chain_type="stuff",
@@ -195,11 +195,16 @@ class LLMTutor:
195
  return llm
196
 
197
  # QA Model Function
198
- def qa_bot(self):
199
  db = self.vector_db.load_database()
 
 
 
 
 
200
  qa_prompt = self.set_custom_prompt()
201
  qa = self.retrieval_qa_chain(
202
- self.llm, qa_prompt, db
203
  ) # TODO: PROMPT is overwritten in CustomConversationalRetrievalChain
204
 
205
  return qa
 
157
  return prompt
158
 
159
  # Retrieval QA Chain
160
+ def retrieval_qa_chain(self, llm, prompt, db, memory=None):
161
 
162
  retriever = Retriever(self.config)._return_retriever(db)
163
 
164
  if self.config["llm_params"]["use_history"]:
165
+ if memory is None:
166
+ memory = ConversationBufferWindowMemory(
167
+ k=self.config["llm_params"]["memory_window"],
168
+ memory_key="chat_history",
169
+ return_messages=True,
170
+ output_key="answer",
171
+ )
172
  qa_chain = CustomConversationalRetrievalChain.from_llm(
173
  llm=llm,
174
  chain_type="stuff",
 
195
  return llm
196
 
197
  # QA Model Function
198
+ def qa_bot(self, memory=None):
199
  db = self.vector_db.load_database()
200
+ # sanity check to see if there are any documents in the database
201
+ if len(db) == 0:
202
+ raise ValueError(
203
+ "No documents in the database. Populate the database first."
204
+ )
205
  qa_prompt = self.set_custom_prompt()
206
  qa = self.retrieval_qa_chain(
207
+ self.llm, qa_prompt, db, memory
208
  ) # TODO: PROMPT is overwritten in CustomConversationalRetrievalChain
209
 
210
  return qa
code/modules/vectorstore/base.py CHANGED
@@ -29,5 +29,8 @@ class VectorStoreBase:
29
  """
30
  raise NotImplementedError
31
 
 
 
 
32
  def __str__(self):
33
  return self.__class__.__name__
 
29
  """
30
  raise NotImplementedError
31
 
32
+ def __len__(self):
33
+ raise NotImplementedError
34
+
35
  def __str__(self):
36
  return self.__class__.__name__
code/modules/vectorstore/chroma.py CHANGED
@@ -39,3 +39,6 @@ class ChromaVectorStore(VectorStoreBase):
39
 
40
  def as_retriever(self):
41
  return self.vectorstore.as_retriever()
 
 
 
 
39
 
40
  def as_retriever(self):
41
  return self.vectorstore.as_retriever()
42
+
43
+ def __len__(self):
44
+ return len(self.vectorstore)
code/modules/vectorstore/colbert.py CHANGED
@@ -1,6 +1,67 @@
1
  from ragatouille import RAGPretrainedModel
2
  from modules.vectorstore.base import VectorStoreBase
 
 
 
 
3
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  class ColbertVectorStore(VectorStoreBase):
@@ -24,6 +85,7 @@ class ColbertVectorStore(VectorStoreBase):
24
  document_ids=document_names,
25
  document_metadatas=document_metadata,
26
  )
 
27
 
28
  def load_database(self):
29
  path = os.path.join(
@@ -33,7 +95,17 @@ class ColbertVectorStore(VectorStoreBase):
33
  self.vectorstore = RAGPretrainedModel.from_index(
34
  f"{path}/colbert/indexes/new_idx"
35
  )
 
 
 
 
 
 
 
36
  return self.vectorstore
37
 
38
  def as_retriever(self):
39
  return self.vectorstore.as_retriever()
 
 
 
 
1
  from ragatouille import RAGPretrainedModel
2
  from modules.vectorstore.base import VectorStoreBase
3
+ from langchain_core.retrievers import BaseRetriever
4
+ from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun, Callbacks
5
+ from langchain_core.documents import Document
6
+ from typing import Any, List, Optional, Sequence
7
  import os
8
+ import json
9
+
10
+
11
+ class RAGatouilleLangChainRetrieverWithScore(BaseRetriever):
12
+ model: Any
13
+ kwargs: dict = {}
14
+
15
+ def _get_relevant_documents(
16
+ self,
17
+ query: str,
18
+ *,
19
+ run_manager: CallbackManagerForRetrieverRun, # noqa
20
+ ) -> List[Document]:
21
+ """Get documents relevant to a query."""
22
+ docs = self.model.search(query, **self.kwargs)
23
+ return [
24
+ Document(
25
+ page_content=doc["content"],
26
+ metadata={**doc.get("document_metadata", {}), "score": doc["score"]},
27
+ )
28
+ for doc in docs
29
+ ]
30
+
31
+ async def _aget_relevant_documents(
32
+ self,
33
+ query: str,
34
+ *,
35
+ run_manager: CallbackManagerForRetrieverRun, # noqa
36
+ ) -> List[Document]:
37
+ """Get documents relevant to a query."""
38
+ docs = self.model.search(query, **self.kwargs)
39
+ return [
40
+ Document(
41
+ page_content=doc["content"],
42
+ metadata={**doc.get("document_metadata", {}), "score": doc["score"]},
43
+ )
44
+ for doc in docs
45
+ ]
46
+
47
+
48
+ class RAGPretrainedModel(RAGPretrainedModel):
49
+ """
50
+ Adding len property to RAGPretrainedModel
51
+ """
52
+
53
+ def __init__(self, *args, **kwargs):
54
+ super().__init__(*args, **kwargs)
55
+ self._document_count = 0
56
+
57
+ def set_document_count(self, count):
58
+ self._document_count = count
59
+
60
+ def __len__(self):
61
+ return self._document_count
62
+
63
+ def as_langchain_retriever(self, **kwargs: Any) -> BaseRetriever:
64
+ return RAGatouilleLangChainRetrieverWithScore(model=self, kwargs=kwargs)
65
 
66
 
67
  class ColbertVectorStore(VectorStoreBase):
 
85
  document_ids=document_names,
86
  document_metadatas=document_metadata,
87
  )
88
+ self.colbert.set_document_count(len(document_names))
89
 
90
  def load_database(self):
91
  path = os.path.join(
 
95
  self.vectorstore = RAGPretrainedModel.from_index(
96
  f"{path}/colbert/indexes/new_idx"
97
  )
98
+
99
+ index_metadata = json.load(
100
+ open(f"{path}/colbert/indexes/new_idx/0.metadata.json")
101
+ )
102
+ num_documents = index_metadata["num_passages"]
103
+ self.vectorstore.set_document_count(num_documents)
104
+
105
  return self.vectorstore
106
 
107
  def as_retriever(self):
108
  return self.vectorstore.as_retriever()
109
+
110
+ def __len__(self):
111
+ return len(self.vectorstore)
code/modules/vectorstore/faiss.py CHANGED
@@ -3,6 +3,13 @@ from modules.vectorstore.base import VectorStoreBase
3
  import os
4
 
5
 
 
 
 
 
 
 
 
6
  class FaissVectorStore(VectorStoreBase):
7
  def __init__(self, config):
8
  self.config = config
@@ -43,3 +50,6 @@ class FaissVectorStore(VectorStoreBase):
43
 
44
  def as_retriever(self):
45
  return self.vectorstore.as_retriever()
 
 
 
 
3
  import os
4
 
5
 
6
+ class FAISS(FAISS):
7
+ """To add length property to FAISS class"""
8
+
9
+ def __len__(self):
10
+ return self.index.ntotal
11
+
12
+
13
  class FaissVectorStore(VectorStoreBase):
14
  def __init__(self, config):
15
  self.config = config
 
50
 
51
  def as_retriever(self):
52
  return self.vectorstore.as_retriever()
53
+
54
+ def __len__(self):
55
+ return len(self.vectorstore)
code/modules/vectorstore/raptor.py CHANGED
@@ -16,6 +16,13 @@ from modules.vectorstore.base import VectorStoreBase
16
  RANDOM_SEED = 42
17
 
18
 
 
 
 
 
 
 
 
19
  class RAPTORVectoreStore(VectorStoreBase):
20
  def __init__(self, config, documents=[], text_splitter=None, embedding_model=None):
21
  self.documents = documents
 
16
  RANDOM_SEED = 42
17
 
18
 
19
+ class FAISS(FAISS):
20
+ """To add length property to FAISS class"""
21
+
22
+ def __len__(self):
23
+ return self.index.ntotal
24
+
25
+
26
  class RAPTORVectoreStore(VectorStoreBase):
27
  def __init__(self, config, documents=[], text_splitter=None, embedding_model=None):
28
  self.documents = documents
code/modules/vectorstore/store_manager.py CHANGED
@@ -138,7 +138,7 @@ class VectorStoreManager:
138
  self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
139
  end_time = time.time() # End time for loading database
140
  self.logger.info(
141
- f"Time taken to load database: {end_time - start_time} seconds"
142
  )
143
  self.logger.info("Loaded database")
144
  return self.loaded_vector_db
@@ -148,8 +148,12 @@ class VectorStoreManager:
148
  self.vector_db._load_from_HF()
149
  end_time = time.time()
150
  self.logger.info(
151
- f"Time taken to load database from Hugging Face: {end_time - start_time} seconds"
152
  )
 
 
 
 
153
 
154
 
155
  if __name__ == "__main__":
 
138
  self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
139
  end_time = time.time() # End time for loading database
140
  self.logger.info(
141
+ f"Time taken to load database {self.config['vectorstore']['db_option']} from Hugging Face: {end_time - start_time} seconds"
142
  )
143
  self.logger.info("Loaded database")
144
  return self.loaded_vector_db
 
148
  self.vector_db._load_from_HF()
149
  end_time = time.time()
150
  self.logger.info(
151
+ f"Time taken to Download database {self.config['vectorstore']['db_option']} from Hugging Face: {end_time - start_time} seconds"
152
  )
153
+ self.logger.info("Downloaded database")
154
+
155
+ def __len__(self):
156
+ return len(self.vector_db)
157
 
158
 
159
  if __name__ == "__main__":
code/modules/vectorstore/vectorstore.py CHANGED
@@ -86,3 +86,6 @@ class VectorStore:
86
 
87
  def _get_vectorstore(self):
88
  return self.vectorstore
 
 
 
 
86
 
87
  def _get_vectorstore(self):
88
  return self.vectorstore
89
+
90
+ def __len__(self):
91
+ return self.vectorstore.__len__()