XThomasBU
commited on
Commit
·
8f6647c
1
Parent(s):
33e5fa6
init commit for chainlit improvements
Browse files- code/main.py +238 -170
- code/modules/chat/helpers.py +32 -30
- code/modules/chat/llm_tutor.py +15 -10
- code/modules/vectorstore/base.py +3 -0
- code/modules/vectorstore/chroma.py +3 -0
- code/modules/vectorstore/colbert.py +72 -0
- code/modules/vectorstore/faiss.py +10 -0
- code/modules/vectorstore/raptor.py +7 -0
- code/modules/vectorstore/store_manager.py +6 -2
- code/modules/vectorstore/vectorstore.py +3 -0
code/main.py
CHANGED
@@ -1,176 +1,244 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
from
|
4 |
-
from langchain_community.vectorstores import FAISS
|
5 |
-
from langchain.chains import RetrievalQA
|
6 |
import chainlit as cl
|
7 |
-
from
|
8 |
-
from
|
9 |
import yaml
|
10 |
-
import
|
11 |
-
from dotenv import load_dotenv
|
12 |
|
13 |
from modules.chat.llm_tutor import LLMTutor
|
14 |
-
from modules.config.constants import *
|
15 |
-
from modules.chat.helpers import get_sources
|
16 |
from modules.chat_processor.chat_processor import ChatProcessor
|
|
|
|
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
)
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
#
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
)
|
73 |
-
cl.
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import textwrap
|
3 |
+
from typing import Any, Callable, Dict, List, Literal, Optional, no_type_check
|
|
|
|
|
4 |
import chainlit as cl
|
5 |
+
from chainlit import run_sync
|
6 |
+
from chainlit.config import config
|
7 |
import yaml
|
8 |
+
import os
|
|
|
9 |
|
10 |
from modules.chat.llm_tutor import LLMTutor
|
|
|
|
|
11 |
from modules.chat_processor.chat_processor import ChatProcessor
|
12 |
+
from modules.config.constants import LLAMA_PATH
|
13 |
+
from modules.chat.helpers import get_sources
|
14 |
|
15 |
+
from chainlit.input_widget import Select, Switch, Slider
|
16 |
+
|
17 |
+
USER_TIMEOUT = 60_000
|
18 |
+
SYSTEM = "System 🖥️"
|
19 |
+
LLM = "LLM 🧠"
|
20 |
+
AGENT = "Agent <>"
|
21 |
+
YOU = "You 😃"
|
22 |
+
ERROR = "Error 🚫"
|
23 |
+
|
24 |
+
|
25 |
+
class Chatbot:
|
26 |
+
def __init__(self):
|
27 |
+
self.llm_tutor = None
|
28 |
+
self.chain = None
|
29 |
+
self.chat_processor = None
|
30 |
+
self.config = self._load_config()
|
31 |
+
|
32 |
+
def _load_config(self):
|
33 |
+
with open("modules/config/config.yml", "r") as f:
|
34 |
+
config = yaml.safe_load(f)
|
35 |
+
return config
|
36 |
+
|
37 |
+
async def ask_helper(func, **kwargs):
|
38 |
+
res = await func(**kwargs).send()
|
39 |
+
while not res:
|
40 |
+
res = await func(**kwargs).send()
|
41 |
+
return res
|
42 |
+
|
43 |
+
@no_type_check
|
44 |
+
async def setup_llm(self) -> None:
|
45 |
+
"""From the session `llm_settings`, create new LLMConfig and LLM objects,
|
46 |
+
save them in session state."""
|
47 |
+
|
48 |
+
llm_settings = cl.user_session.get("llm_settings", {})
|
49 |
+
chat_profile = llm_settings.get("chat_model")
|
50 |
+
retriever_method = llm_settings.get("retriever_method")
|
51 |
+
memory_window = llm_settings.get("memory_window")
|
52 |
+
|
53 |
+
self._configure_llm(chat_profile)
|
54 |
+
|
55 |
+
chain = cl.user_session.get("chain")
|
56 |
+
memory = chain.memory
|
57 |
+
self.config["vectorstore"][
|
58 |
+
"db_option"
|
59 |
+
] = retriever_method # update the retriever method in the config
|
60 |
+
memory.k = memory_window # set the memory window
|
61 |
+
|
62 |
+
self.llm_tutor = LLMTutor(self.config)
|
63 |
+
self.chain = self.llm_tutor.qa_bot(memory=memory)
|
64 |
+
|
65 |
+
tags = [chat_profile, self.config["vectorstore"]["db_option"]]
|
66 |
+
self.chat_processor = ChatProcessor(self.config, tags=tags)
|
67 |
+
|
68 |
+
cl.user_session.set("chain", self.chain)
|
69 |
+
cl.user_session.set("llm_tutor", self.llm_tutor)
|
70 |
+
cl.user_session.set("chat_processor", self.chat_processor)
|
71 |
+
|
72 |
+
@no_type_check
|
73 |
+
async def update_llm(self, new_settings: Dict[str, Any]) -> None:
|
74 |
+
"""Update LLMConfig and LLM from settings, and save in session state."""
|
75 |
+
cl.user_session.set("llm_settings", new_settings)
|
76 |
+
await self.inform_llm_settings()
|
77 |
+
await self.setup_llm()
|
78 |
+
|
79 |
+
async def make_llm_settings_widgets(self, config=None):
|
80 |
+
config = config or self.config
|
81 |
+
await cl.ChatSettings(
|
82 |
+
[
|
83 |
+
cl.input_widget.Select(
|
84 |
+
id="chat_model",
|
85 |
+
label="Model Name (Default GPT-3)",
|
86 |
+
values=["llama", "gpt-3.5-turbo-1106", "gpt-4"],
|
87 |
+
initial_index=0,
|
88 |
+
),
|
89 |
+
cl.input_widget.Select(
|
90 |
+
id="retriever_method",
|
91 |
+
label="Retriever (Default FAISS)",
|
92 |
+
values=["FAISS", "Chroma", "RAGatouille", "RAPTOR"],
|
93 |
+
initial_index=0,
|
94 |
+
),
|
95 |
+
cl.input_widget.Slider(
|
96 |
+
id="memory_window",
|
97 |
+
label="Memory Window (Default 3)",
|
98 |
+
initial=3,
|
99 |
+
min=0,
|
100 |
+
max=10,
|
101 |
+
step=1,
|
102 |
+
),
|
103 |
+
cl.input_widget.Switch(
|
104 |
+
id="view_sources", label="View Sources", initial=False
|
105 |
+
),
|
106 |
+
]
|
107 |
+
).send() # type: ignore
|
108 |
+
|
109 |
+
@no_type_check
|
110 |
+
async def inform_llm_settings(self) -> None:
|
111 |
+
llm_settings: Dict[str, Any] = cl.user_session.get("llm_settings", {})
|
112 |
+
llm_tutor = cl.user_session.get("llm_tutor")
|
113 |
+
settings_dict = dict(
|
114 |
+
model=llm_settings.get("chat_model"),
|
115 |
+
retriever=llm_settings.get("retriever_method"),
|
116 |
+
memory_window=llm_settings.get("memory_window"),
|
117 |
+
num_docs_in_db=len(llm_tutor.vector_db),
|
118 |
+
view_sources=llm_settings.get("view_sources"),
|
119 |
+
)
|
120 |
+
await cl.Message(
|
121 |
+
author=SYSTEM,
|
122 |
+
content="LLM settings have been updated. You can continue with your Query!",
|
123 |
+
elements=[
|
124 |
+
cl.Text(
|
125 |
+
name="settings",
|
126 |
+
display="side",
|
127 |
+
content=json.dumps(settings_dict, indent=4),
|
128 |
+
language="json",
|
129 |
+
)
|
130 |
+
],
|
131 |
+
).send()
|
132 |
+
|
133 |
+
async def set_starters(self):
|
134 |
+
return [
|
135 |
+
cl.Starter(
|
136 |
+
label="recording on CNNs?",
|
137 |
+
message="Where can I find the recording for the lecture on Transformers?",
|
138 |
+
icon="/public/adv-screen-recorder-svgrepo-com.svg",
|
139 |
+
),
|
140 |
+
cl.Starter(
|
141 |
+
label="where's the slides?",
|
142 |
+
message="When are the lectures? I can't find the schedule.",
|
143 |
+
icon="/public/alarmy-svgrepo-com.svg",
|
144 |
+
),
|
145 |
+
cl.Starter(
|
146 |
+
label="Due Date?",
|
147 |
+
message="When is the final project due?",
|
148 |
+
icon="/public/calendar-samsung-17-svgrepo-com.svg",
|
149 |
+
),
|
150 |
+
cl.Starter(
|
151 |
+
label="Explain backprop.",
|
152 |
+
message="I didn't understand the math behind backprop, could you explain it?",
|
153 |
+
icon="/public/acastusphoton-svgrepo-com.svg",
|
154 |
+
),
|
155 |
+
]
|
156 |
+
|
157 |
+
async def chat_profile(self):
|
158 |
+
return [
|
159 |
+
# cl.ChatProfile(
|
160 |
+
# name="gpt-3.5-turbo-1106",
|
161 |
+
# markdown_description="Use OpenAI API for **gpt-3.5-turbo-1106**.",
|
162 |
+
# ),
|
163 |
+
# cl.ChatProfile(
|
164 |
+
# name="gpt-4",
|
165 |
+
# markdown_description="Use OpenAI API for **gpt-4**.",
|
166 |
+
# ),
|
167 |
+
cl.ChatProfile(
|
168 |
+
name="Llama",
|
169 |
+
markdown_description="Use the local LLM: **Tiny Llama**.",
|
170 |
+
),
|
171 |
+
]
|
172 |
+
|
173 |
+
def rename(self, orig_author: str):
|
174 |
+
rename_dict = {"Chatbot": "AI Tutor"}
|
175 |
+
return rename_dict.get(orig_author, orig_author)
|
176 |
+
|
177 |
+
async def start(self):
|
178 |
+
await self.make_llm_settings_widgets(self.config)
|
179 |
+
|
180 |
+
chat_profile = cl.user_session.get("chat_profile")
|
181 |
+
if chat_profile:
|
182 |
+
self._configure_llm(chat_profile)
|
183 |
+
|
184 |
+
self.llm_tutor = LLMTutor(self.config)
|
185 |
+
self.chain = self.llm_tutor.qa_bot()
|
186 |
+
tags = [chat_profile, self.config["vectorstore"]["db_option"]]
|
187 |
+
self.chat_processor = ChatProcessor(self.config, tags=tags)
|
188 |
+
|
189 |
+
cl.user_session.set("llm_tutor", self.llm_tutor)
|
190 |
+
cl.user_session.set("chain", self.chain)
|
191 |
+
cl.user_session.set("counter", 0)
|
192 |
+
cl.user_session.set("chat_processor", self.chat_processor)
|
193 |
+
|
194 |
+
async def on_chat_end(self):
|
195 |
+
await cl.Message(content="Sorry, I have to go now. Goodbye!").send()
|
196 |
+
|
197 |
+
async def main(self, message):
|
198 |
+
user = cl.user_session.get("user")
|
199 |
+
chain = cl.user_session.get("chain")
|
200 |
+
counter = cl.user_session.get("counter")
|
201 |
+
llm_settings = cl.user_session.get("llm_settings")
|
202 |
+
|
203 |
+
counter += 1
|
204 |
+
cl.user_session.set("counter", counter)
|
205 |
+
|
206 |
+
cb = cl.AsyncLangchainCallbackHandler() # TODO: fix streaming here
|
207 |
+
cb.answer_reached = True
|
208 |
+
|
209 |
+
processor = cl.user_session.get("chat_processor")
|
210 |
+
res = await processor.rag(message.content, chain, cb)
|
211 |
+
answer = res.get("answer", res.get("result"))
|
212 |
+
|
213 |
+
answer_with_sources, source_elements, sources_dict = get_sources(
|
214 |
+
res, answer, view_sources=llm_settings.get("view_sources")
|
215 |
+
)
|
216 |
+
processor._process(message.content, answer, sources_dict)
|
217 |
+
|
218 |
+
await cl.Message(content=answer_with_sources, elements=source_elements).send()
|
219 |
+
|
220 |
+
def _configure_llm(self, chat_profile):
|
221 |
+
chat_profile = chat_profile.lower()
|
222 |
+
if chat_profile in ["gpt-3.5-turbo-1106", "gpt-4"]:
|
223 |
+
self.config["llm_params"]["llm_loader"] = "openai"
|
224 |
+
self.config["llm_params"]["openai_params"]["model"] = chat_profile
|
225 |
+
elif chat_profile == "llama":
|
226 |
+
self.config["llm_params"]["llm_loader"] = "local_llm"
|
227 |
+
self.config["llm_params"]["local_llm_params"]["model"] = LLAMA_PATH
|
228 |
+
self.config["llm_params"]["local_llm_params"]["model_type"] = "llama"
|
229 |
+
elif chat_profile == "mistral":
|
230 |
+
self.config["llm_params"]["llm_loader"] = "local_llm"
|
231 |
+
self.config["llm_params"]["local_llm_params"]["model"] = MISTRAL_PATH
|
232 |
+
self.config["llm_params"]["local_llm_params"]["model_type"] = "mistral"
|
233 |
+
|
234 |
+
|
235 |
+
chatbot = Chatbot()
|
236 |
+
|
237 |
+
# Register functions to Chainlit events
|
238 |
+
cl.set_starters(chatbot.set_starters)
|
239 |
+
cl.set_chat_profiles(chatbot.chat_profile)
|
240 |
+
cl.author_rename(chatbot.rename)
|
241 |
+
cl.on_chat_start(chatbot.start)
|
242 |
+
cl.on_chat_end(chatbot.on_chat_end)
|
243 |
+
cl.on_message(chatbot.main)
|
244 |
+
cl.on_settings_update(chatbot.update_llm)
|
code/modules/chat/helpers.py
CHANGED
@@ -3,7 +3,7 @@ import chainlit as cl
|
|
3 |
from langchain_core.prompts import PromptTemplate
|
4 |
|
5 |
|
6 |
-
def get_sources(res, answer):
|
7 |
source_elements = []
|
8 |
source_dict = {} # Dictionary to store URL elements
|
9 |
|
@@ -40,40 +40,42 @@ def get_sources(res, answer):
|
|
40 |
full_answer = "**Answer:**\n"
|
41 |
full_answer += answer
|
42 |
|
43 |
-
|
44 |
-
full_answer += "\n\n**Sources:**\n"
|
45 |
-
for idx, (url_name, source_data) in enumerate(source_dict.items()):
|
46 |
-
full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
|
47 |
|
48 |
-
|
49 |
-
full_answer +=
|
50 |
-
|
51 |
-
|
52 |
-
)
|
53 |
|
54 |
-
|
55 |
-
if source_data["url"].lower().endswith(".pdf"):
|
56 |
-
name = f"Source {idx + 1} PDF\n"
|
57 |
full_answer += name
|
58 |
-
|
59 |
-
|
|
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
name=
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
)
|
76 |
-
)
|
77 |
|
78 |
return full_answer, source_elements, source_dict
|
79 |
|
|
|
3 |
from langchain_core.prompts import PromptTemplate
|
4 |
|
5 |
|
6 |
+
def get_sources(res, answer, view_sources=False):
|
7 |
source_elements = []
|
8 |
source_dict = {} # Dictionary to store URL elements
|
9 |
|
|
|
40 |
full_answer = "**Answer:**\n"
|
41 |
full_answer += answer
|
42 |
|
43 |
+
if view_sources:
|
|
|
|
|
|
|
44 |
|
45 |
+
# Then, display the sources
|
46 |
+
full_answer += "\n\n**Sources:**\n"
|
47 |
+
for idx, (url_name, source_data) in enumerate(source_dict.items()):
|
48 |
+
full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
|
|
|
49 |
|
50 |
+
name = f"Source {idx + 1} Text\n"
|
|
|
|
|
51 |
full_answer += name
|
52 |
+
source_elements.append(
|
53 |
+
cl.Text(name=name, content=source_data["text"], display="side")
|
54 |
+
)
|
55 |
|
56 |
+
# Add a PDF element if the source is a PDF file
|
57 |
+
if source_data["url"].lower().endswith(".pdf"):
|
58 |
+
name = f"Source {idx + 1} PDF\n"
|
59 |
+
full_answer += name
|
60 |
+
pdf_url = f"{source_data['url']}#page={source_data['page']+1}"
|
61 |
+
source_elements.append(cl.Pdf(name=name, url=pdf_url, display="side"))
|
62 |
+
|
63 |
+
full_answer += "\n**Metadata:**\n"
|
64 |
+
for idx, (url_name, source_data) in enumerate(source_dict.items()):
|
65 |
+
full_answer += f"\nSource {idx + 1} Metadata:\n"
|
66 |
+
source_elements.append(
|
67 |
+
cl.Text(
|
68 |
+
name=f"Source {idx + 1} Metadata",
|
69 |
+
content=f"Source: {source_data['url']}\n"
|
70 |
+
f"Page: {source_data['page']}\n"
|
71 |
+
f"Type: {source_data['source_type']}\n"
|
72 |
+
f"Date: {source_data['date']}\n"
|
73 |
+
f"TL;DR: {source_data['lecture_tldr']}\n"
|
74 |
+
f"Lecture Recording: {source_data['lecture_recording']}\n"
|
75 |
+
f"Suggested Readings: {source_data['suggested_readings']}\n",
|
76 |
+
display="side",
|
77 |
+
)
|
78 |
)
|
|
|
79 |
|
80 |
return full_answer, source_elements, source_dict
|
81 |
|
code/modules/chat/llm_tutor.py
CHANGED
@@ -157,18 +157,18 @@ class LLMTutor:
|
|
157 |
return prompt
|
158 |
|
159 |
# Retrieval QA Chain
|
160 |
-
def retrieval_qa_chain(self, llm, prompt, db):
|
161 |
|
162 |
retriever = Retriever(self.config)._return_retriever(db)
|
163 |
|
164 |
if self.config["llm_params"]["use_history"]:
|
165 |
-
memory
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
qa_chain = CustomConversationalRetrievalChain.from_llm(
|
173 |
llm=llm,
|
174 |
chain_type="stuff",
|
@@ -195,11 +195,16 @@ class LLMTutor:
|
|
195 |
return llm
|
196 |
|
197 |
# QA Model Function
|
198 |
-
def qa_bot(self):
|
199 |
db = self.vector_db.load_database()
|
|
|
|
|
|
|
|
|
|
|
200 |
qa_prompt = self.set_custom_prompt()
|
201 |
qa = self.retrieval_qa_chain(
|
202 |
-
self.llm, qa_prompt, db
|
203 |
) # TODO: PROMPT is overwritten in CustomConversationalRetrievalChain
|
204 |
|
205 |
return qa
|
|
|
157 |
return prompt
|
158 |
|
159 |
# Retrieval QA Chain
|
160 |
+
def retrieval_qa_chain(self, llm, prompt, db, memory=None):
|
161 |
|
162 |
retriever = Retriever(self.config)._return_retriever(db)
|
163 |
|
164 |
if self.config["llm_params"]["use_history"]:
|
165 |
+
if memory is None:
|
166 |
+
memory = ConversationBufferWindowMemory(
|
167 |
+
k=self.config["llm_params"]["memory_window"],
|
168 |
+
memory_key="chat_history",
|
169 |
+
return_messages=True,
|
170 |
+
output_key="answer",
|
171 |
+
)
|
172 |
qa_chain = CustomConversationalRetrievalChain.from_llm(
|
173 |
llm=llm,
|
174 |
chain_type="stuff",
|
|
|
195 |
return llm
|
196 |
|
197 |
# QA Model Function
|
198 |
+
def qa_bot(self, memory=None):
|
199 |
db = self.vector_db.load_database()
|
200 |
+
# sanity check to see if there are any documents in the database
|
201 |
+
if len(db) == 0:
|
202 |
+
raise ValueError(
|
203 |
+
"No documents in the database. Populate the database first."
|
204 |
+
)
|
205 |
qa_prompt = self.set_custom_prompt()
|
206 |
qa = self.retrieval_qa_chain(
|
207 |
+
self.llm, qa_prompt, db, memory
|
208 |
) # TODO: PROMPT is overwritten in CustomConversationalRetrievalChain
|
209 |
|
210 |
return qa
|
code/modules/vectorstore/base.py
CHANGED
@@ -29,5 +29,8 @@ class VectorStoreBase:
|
|
29 |
"""
|
30 |
raise NotImplementedError
|
31 |
|
|
|
|
|
|
|
32 |
def __str__(self):
|
33 |
return self.__class__.__name__
|
|
|
29 |
"""
|
30 |
raise NotImplementedError
|
31 |
|
32 |
+
def __len__(self):
|
33 |
+
raise NotImplementedError
|
34 |
+
|
35 |
def __str__(self):
|
36 |
return self.__class__.__name__
|
code/modules/vectorstore/chroma.py
CHANGED
@@ -39,3 +39,6 @@ class ChromaVectorStore(VectorStoreBase):
|
|
39 |
|
40 |
def as_retriever(self):
|
41 |
return self.vectorstore.as_retriever()
|
|
|
|
|
|
|
|
39 |
|
40 |
def as_retriever(self):
|
41 |
return self.vectorstore.as_retriever()
|
42 |
+
|
43 |
+
def __len__(self):
|
44 |
+
return len(self.vectorstore)
|
code/modules/vectorstore/colbert.py
CHANGED
@@ -1,6 +1,67 @@
|
|
1 |
from ragatouille import RAGPretrainedModel
|
2 |
from modules.vectorstore.base import VectorStoreBase
|
|
|
|
|
|
|
|
|
3 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
|
6 |
class ColbertVectorStore(VectorStoreBase):
|
@@ -24,6 +85,7 @@ class ColbertVectorStore(VectorStoreBase):
|
|
24 |
document_ids=document_names,
|
25 |
document_metadatas=document_metadata,
|
26 |
)
|
|
|
27 |
|
28 |
def load_database(self):
|
29 |
path = os.path.join(
|
@@ -33,7 +95,17 @@ class ColbertVectorStore(VectorStoreBase):
|
|
33 |
self.vectorstore = RAGPretrainedModel.from_index(
|
34 |
f"{path}/colbert/indexes/new_idx"
|
35 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
return self.vectorstore
|
37 |
|
38 |
def as_retriever(self):
|
39 |
return self.vectorstore.as_retriever()
|
|
|
|
|
|
|
|
1 |
from ragatouille import RAGPretrainedModel
|
2 |
from modules.vectorstore.base import VectorStoreBase
|
3 |
+
from langchain_core.retrievers import BaseRetriever
|
4 |
+
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun, Callbacks
|
5 |
+
from langchain_core.documents import Document
|
6 |
+
from typing import Any, List, Optional, Sequence
|
7 |
import os
|
8 |
+
import json
|
9 |
+
|
10 |
+
|
11 |
+
class RAGatouilleLangChainRetrieverWithScore(BaseRetriever):
|
12 |
+
model: Any
|
13 |
+
kwargs: dict = {}
|
14 |
+
|
15 |
+
def _get_relevant_documents(
|
16 |
+
self,
|
17 |
+
query: str,
|
18 |
+
*,
|
19 |
+
run_manager: CallbackManagerForRetrieverRun, # noqa
|
20 |
+
) -> List[Document]:
|
21 |
+
"""Get documents relevant to a query."""
|
22 |
+
docs = self.model.search(query, **self.kwargs)
|
23 |
+
return [
|
24 |
+
Document(
|
25 |
+
page_content=doc["content"],
|
26 |
+
metadata={**doc.get("document_metadata", {}), "score": doc["score"]},
|
27 |
+
)
|
28 |
+
for doc in docs
|
29 |
+
]
|
30 |
+
|
31 |
+
async def _aget_relevant_documents(
|
32 |
+
self,
|
33 |
+
query: str,
|
34 |
+
*,
|
35 |
+
run_manager: CallbackManagerForRetrieverRun, # noqa
|
36 |
+
) -> List[Document]:
|
37 |
+
"""Get documents relevant to a query."""
|
38 |
+
docs = self.model.search(query, **self.kwargs)
|
39 |
+
return [
|
40 |
+
Document(
|
41 |
+
page_content=doc["content"],
|
42 |
+
metadata={**doc.get("document_metadata", {}), "score": doc["score"]},
|
43 |
+
)
|
44 |
+
for doc in docs
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
+
class RAGPretrainedModel(RAGPretrainedModel):
|
49 |
+
"""
|
50 |
+
Adding len property to RAGPretrainedModel
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(self, *args, **kwargs):
|
54 |
+
super().__init__(*args, **kwargs)
|
55 |
+
self._document_count = 0
|
56 |
+
|
57 |
+
def set_document_count(self, count):
|
58 |
+
self._document_count = count
|
59 |
+
|
60 |
+
def __len__(self):
|
61 |
+
return self._document_count
|
62 |
+
|
63 |
+
def as_langchain_retriever(self, **kwargs: Any) -> BaseRetriever:
|
64 |
+
return RAGatouilleLangChainRetrieverWithScore(model=self, kwargs=kwargs)
|
65 |
|
66 |
|
67 |
class ColbertVectorStore(VectorStoreBase):
|
|
|
85 |
document_ids=document_names,
|
86 |
document_metadatas=document_metadata,
|
87 |
)
|
88 |
+
self.colbert.set_document_count(len(document_names))
|
89 |
|
90 |
def load_database(self):
|
91 |
path = os.path.join(
|
|
|
95 |
self.vectorstore = RAGPretrainedModel.from_index(
|
96 |
f"{path}/colbert/indexes/new_idx"
|
97 |
)
|
98 |
+
|
99 |
+
index_metadata = json.load(
|
100 |
+
open(f"{path}/colbert/indexes/new_idx/0.metadata.json")
|
101 |
+
)
|
102 |
+
num_documents = index_metadata["num_passages"]
|
103 |
+
self.vectorstore.set_document_count(num_documents)
|
104 |
+
|
105 |
return self.vectorstore
|
106 |
|
107 |
def as_retriever(self):
|
108 |
return self.vectorstore.as_retriever()
|
109 |
+
|
110 |
+
def __len__(self):
|
111 |
+
return len(self.vectorstore)
|
code/modules/vectorstore/faiss.py
CHANGED
@@ -3,6 +3,13 @@ from modules.vectorstore.base import VectorStoreBase
|
|
3 |
import os
|
4 |
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
class FaissVectorStore(VectorStoreBase):
|
7 |
def __init__(self, config):
|
8 |
self.config = config
|
@@ -43,3 +50,6 @@ class FaissVectorStore(VectorStoreBase):
|
|
43 |
|
44 |
def as_retriever(self):
|
45 |
return self.vectorstore.as_retriever()
|
|
|
|
|
|
|
|
3 |
import os
|
4 |
|
5 |
|
6 |
+
class FAISS(FAISS):
|
7 |
+
"""To add length property to FAISS class"""
|
8 |
+
|
9 |
+
def __len__(self):
|
10 |
+
return self.index.ntotal
|
11 |
+
|
12 |
+
|
13 |
class FaissVectorStore(VectorStoreBase):
|
14 |
def __init__(self, config):
|
15 |
self.config = config
|
|
|
50 |
|
51 |
def as_retriever(self):
|
52 |
return self.vectorstore.as_retriever()
|
53 |
+
|
54 |
+
def __len__(self):
|
55 |
+
return len(self.vectorstore)
|
code/modules/vectorstore/raptor.py
CHANGED
@@ -16,6 +16,13 @@ from modules.vectorstore.base import VectorStoreBase
|
|
16 |
RANDOM_SEED = 42
|
17 |
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
class RAPTORVectoreStore(VectorStoreBase):
|
20 |
def __init__(self, config, documents=[], text_splitter=None, embedding_model=None):
|
21 |
self.documents = documents
|
|
|
16 |
RANDOM_SEED = 42
|
17 |
|
18 |
|
19 |
+
class FAISS(FAISS):
|
20 |
+
"""To add length property to FAISS class"""
|
21 |
+
|
22 |
+
def __len__(self):
|
23 |
+
return self.index.ntotal
|
24 |
+
|
25 |
+
|
26 |
class RAPTORVectoreStore(VectorStoreBase):
|
27 |
def __init__(self, config, documents=[], text_splitter=None, embedding_model=None):
|
28 |
self.documents = documents
|
code/modules/vectorstore/store_manager.py
CHANGED
@@ -138,7 +138,7 @@ class VectorStoreManager:
|
|
138 |
self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
|
139 |
end_time = time.time() # End time for loading database
|
140 |
self.logger.info(
|
141 |
-
f"Time taken to load database: {end_time - start_time} seconds"
|
142 |
)
|
143 |
self.logger.info("Loaded database")
|
144 |
return self.loaded_vector_db
|
@@ -148,8 +148,12 @@ class VectorStoreManager:
|
|
148 |
self.vector_db._load_from_HF()
|
149 |
end_time = time.time()
|
150 |
self.logger.info(
|
151 |
-
f"Time taken to
|
152 |
)
|
|
|
|
|
|
|
|
|
153 |
|
154 |
|
155 |
if __name__ == "__main__":
|
|
|
138 |
self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
|
139 |
end_time = time.time() # End time for loading database
|
140 |
self.logger.info(
|
141 |
+
f"Time taken to load database {self.config['vectorstore']['db_option']} from Hugging Face: {end_time - start_time} seconds"
|
142 |
)
|
143 |
self.logger.info("Loaded database")
|
144 |
return self.loaded_vector_db
|
|
|
148 |
self.vector_db._load_from_HF()
|
149 |
end_time = time.time()
|
150 |
self.logger.info(
|
151 |
+
f"Time taken to Download database {self.config['vectorstore']['db_option']} from Hugging Face: {end_time - start_time} seconds"
|
152 |
)
|
153 |
+
self.logger.info("Downloaded database")
|
154 |
+
|
155 |
+
def __len__(self):
|
156 |
+
return len(self.vector_db)
|
157 |
|
158 |
|
159 |
if __name__ == "__main__":
|
code/modules/vectorstore/vectorstore.py
CHANGED
@@ -86,3 +86,6 @@ class VectorStore:
|
|
86 |
|
87 |
def _get_vectorstore(self):
|
88 |
return self.vectorstore
|
|
|
|
|
|
|
|
86 |
|
87 |
def _get_vectorstore(self):
|
88 |
return self.vectorstore
|
89 |
+
|
90 |
+
def __len__(self):
|
91 |
+
return self.vectorstore.__len__()
|