XThomasBU
commited on
Commit
·
f0018f2
1
Parent(s):
40de40e
Code to add metadata to the chunks
Browse files- .chainlit/config.toml +1 -1
- code/config.yml +4 -4
- code/modules/data_loader.py +115 -46
- code/modules/embedding_model_loader.py +2 -2
- code/modules/helpers.py +121 -53
- code/modules/llm_tutor.py +15 -9
- code/modules/vector_db.py +34 -9
- requirements.txt +1 -0
- storage/data/urls.txt +2 -0
.chainlit/config.toml
CHANGED
@@ -22,7 +22,7 @@ prompt_playground = true
|
|
22 |
unsafe_allow_html = false
|
23 |
|
24 |
# Process and display mathematical expressions. This can clash with "$" characters in messages.
|
25 |
-
latex =
|
26 |
|
27 |
# Authorize users to upload files with messages
|
28 |
multi_modal = true
|
|
|
22 |
unsafe_allow_html = false
|
23 |
|
24 |
# Process and display mathematical expressions. This can clash with "$" characters in messages.
|
25 |
+
latex = true
|
26 |
|
27 |
# Authorize users to upload files with messages
|
28 |
multi_modal = true
|
code/config.yml
CHANGED
@@ -2,14 +2,14 @@ embedding_options:
|
|
2 |
embedd_files: False # bool
|
3 |
data_path: 'storage/data' # str
|
4 |
url_file_path: 'storage/data/urls.txt' # str
|
5 |
-
expand_urls:
|
6 |
-
db_option : '
|
7 |
db_path : 'vectorstores' # str
|
8 |
model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
|
9 |
search_top_k : 3 # int
|
10 |
-
score_threshold : 0.
|
11 |
llm_params:
|
12 |
-
use_history:
|
13 |
memory_window: 3 # int
|
14 |
llm_loader: 'local_llm' # str [local_llm, openai]
|
15 |
openai_params:
|
|
|
2 |
embedd_files: False # bool
|
3 |
data_path: 'storage/data' # str
|
4 |
url_file_path: 'storage/data/urls.txt' # str
|
5 |
+
expand_urls: False # bool
|
6 |
+
db_option : 'RAGatouille' # str [FAISS, Chroma, RAGatouille]
|
7 |
db_path : 'vectorstores' # str
|
8 |
model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
|
9 |
search_top_k : 3 # int
|
10 |
+
score_threshold : 0.2 # float
|
11 |
llm_params:
|
12 |
+
use_history: False # bool
|
13 |
memory_window: 3 # int
|
14 |
llm_loader: 'local_llm' # str [local_llm, openai]
|
15 |
openai_params:
|
code/modules/data_loader.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2 |
import re
|
3 |
import requests
|
4 |
import pysrt
|
5 |
-
from
|
6 |
PyMuPDFLoader,
|
7 |
Docx2txtLoader,
|
8 |
YoutubeLoader,
|
@@ -16,6 +16,15 @@ import logging
|
|
16 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
17 |
from langchain_experimental.text_splitter import SemanticChunker
|
18 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
logger = logging.getLogger(__name__)
|
21 |
|
@@ -58,23 +67,6 @@ class FileReader:
|
|
58 |
return None
|
59 |
|
60 |
def read_pdf(self, temp_file_path: str):
|
61 |
-
# parser = LlamaParse(
|
62 |
-
# api_key="",
|
63 |
-
# result_type="markdown",
|
64 |
-
# num_workers=4,
|
65 |
-
# verbose=True,
|
66 |
-
# language="en",
|
67 |
-
# )
|
68 |
-
# documents = parser.load_data(temp_file_path)
|
69 |
-
|
70 |
-
# with open("temp/output.md", "a") as f:
|
71 |
-
# for doc in documents:
|
72 |
-
# f.write(doc.text + "\n")
|
73 |
-
|
74 |
-
# markdown_path = "temp/output.md"
|
75 |
-
# loader = UnstructuredMarkdownLoader(markdown_path)
|
76 |
-
# loader = PyMuPDFLoader(temp_file_path) # This loader preserves more metadata
|
77 |
-
# return loader.load()
|
78 |
loader = self.pdf_reader.get_loader(temp_file_path)
|
79 |
documents = self.pdf_reader.get_documents(loader)
|
80 |
return documents
|
@@ -108,8 +100,6 @@ class FileReader:
|
|
108 |
class ChunkProcessor:
|
109 |
def __init__(self, config):
|
110 |
self.config = config
|
111 |
-
self.document_chunks_full = []
|
112 |
-
self.document_names = []
|
113 |
|
114 |
if config["splitter_options"]["use_splitter"]:
|
115 |
if config["splitter_options"]["split_by_token"]:
|
@@ -130,6 +120,17 @@ class ChunkProcessor:
|
|
130 |
self.splitter = None
|
131 |
logger.info("ChunkProcessor instance created")
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
def remove_delimiters(self, document_chunks: list):
|
134 |
for chunk in document_chunks:
|
135 |
for delimiter in self.config["splitter_options"]["delimiters_to_remove"]:
|
@@ -146,11 +147,23 @@ class ChunkProcessor:
|
|
146 |
logger.info(f"\tNumber of pages after skipping: {len(document_chunks)}")
|
147 |
return document_chunks
|
148 |
|
149 |
-
def process_chunks(
|
150 |
-
|
|
|
|
|
|
|
151 |
document_chunks = self.splitter.split_documents(documents)
|
152 |
-
|
153 |
-
document_chunks = documents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
if self.config["splitter_options"]["remove_leftover_delimiters"]:
|
156 |
document_chunks = self.remove_delimiters(document_chunks)
|
@@ -161,38 +174,77 @@ class ChunkProcessor:
|
|
161 |
|
162 |
def get_chunks(self, file_reader, uploaded_files, weblinks):
|
163 |
self.document_chunks_full = []
|
164 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
for file_index, file_path in enumerate(uploaded_files):
|
167 |
file_name = os.path.basename(file_path)
|
168 |
file_type = file_name.split(".")[-1].lower()
|
169 |
|
170 |
-
try:
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
else:
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
self.
|
185 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
-
except Exception as e:
|
188 |
-
|
189 |
|
190 |
self.process_weblinks(file_reader, weblinks)
|
191 |
|
192 |
logger.info(
|
193 |
f"Total document chunks extracted: {len(self.document_chunks_full)}"
|
194 |
)
|
195 |
-
return
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
def process_weblinks(self, file_reader, weblinks):
|
198 |
if weblinks[0] != "":
|
@@ -206,9 +258,26 @@ class ChunkProcessor:
|
|
206 |
else:
|
207 |
documents = file_reader.read_html(link)
|
208 |
|
209 |
-
|
210 |
-
|
211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
except Exception as e:
|
213 |
logger.error(
|
214 |
f"Error splitting link {link_index+1} : {link}: {str(e)}"
|
|
|
2 |
import re
|
3 |
import requests
|
4 |
import pysrt
|
5 |
+
from langchain_community.document_loaders import (
|
6 |
PyMuPDFLoader,
|
7 |
Docx2txtLoader,
|
8 |
YoutubeLoader,
|
|
|
16 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
17 |
from langchain_experimental.text_splitter import SemanticChunker
|
18 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
19 |
+
from ragatouille import RAGPretrainedModel
|
20 |
+
from langchain.chains import LLMChain
|
21 |
+
from langchain.llms import OpenAI
|
22 |
+
from langchain import PromptTemplate
|
23 |
+
|
24 |
+
try:
|
25 |
+
from modules.helpers import get_lecture_metadata
|
26 |
+
except:
|
27 |
+
from helpers import get_lecture_metadata
|
28 |
|
29 |
logger = logging.getLogger(__name__)
|
30 |
|
|
|
67 |
return None
|
68 |
|
69 |
def read_pdf(self, temp_file_path: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
loader = self.pdf_reader.get_loader(temp_file_path)
|
71 |
documents = self.pdf_reader.get_documents(loader)
|
72 |
return documents
|
|
|
100 |
class ChunkProcessor:
|
101 |
def __init__(self, config):
|
102 |
self.config = config
|
|
|
|
|
103 |
|
104 |
if config["splitter_options"]["use_splitter"]:
|
105 |
if config["splitter_options"]["split_by_token"]:
|
|
|
120 |
self.splitter = None
|
121 |
logger.info("ChunkProcessor instance created")
|
122 |
|
123 |
+
# def extract_metadata(self, document_content):
|
124 |
+
|
125 |
+
# llm = OpenAI()
|
126 |
+
# prompt_template = PromptTemplate(
|
127 |
+
# input_variables=["document_content"],
|
128 |
+
# template="Extract metadata for this document:\n\n{document_content}\n\nMetadata:",
|
129 |
+
# )
|
130 |
+
# chain = LLMChain(llm=llm, prompt=prompt_template)
|
131 |
+
# metadata = chain.run(document_content=document_content)
|
132 |
+
# return metadata
|
133 |
+
|
134 |
def remove_delimiters(self, document_chunks: list):
|
135 |
for chunk in document_chunks:
|
136 |
for delimiter in self.config["splitter_options"]["delimiters_to_remove"]:
|
|
|
147 |
logger.info(f"\tNumber of pages after skipping: {len(document_chunks)}")
|
148 |
return document_chunks
|
149 |
|
150 |
+
def process_chunks(
|
151 |
+
self, documents, file_type="txt", source="", page=0, metadata={}
|
152 |
+
):
|
153 |
+
documents = [Document(page_content=documents, source=source, page=page)]
|
154 |
+
if file_type == "txt":
|
155 |
document_chunks = self.splitter.split_documents(documents)
|
156 |
+
elif file_type == "pdf":
|
157 |
+
document_chunks = documents # Full page for now
|
158 |
+
|
159 |
+
# add the source and page number back to the metadata
|
160 |
+
for chunk in document_chunks:
|
161 |
+
chunk.metadata["source"] = source
|
162 |
+
chunk.metadata["page"] = page
|
163 |
+
|
164 |
+
# add the metadata extracted from the document
|
165 |
+
for key, value in metadata.items():
|
166 |
+
chunk.metadata[key] = value
|
167 |
|
168 |
if self.config["splitter_options"]["remove_leftover_delimiters"]:
|
169 |
document_chunks = self.remove_delimiters(document_chunks)
|
|
|
174 |
|
175 |
def get_chunks(self, file_reader, uploaded_files, weblinks):
|
176 |
self.document_chunks_full = []
|
177 |
+
self.parent_document_names = []
|
178 |
+
self.child_document_names = []
|
179 |
+
self.documents = []
|
180 |
+
self.document_metadata = []
|
181 |
+
|
182 |
+
lecture_metadata = get_lecture_metadata(
|
183 |
+
"https://dl4ds.github.io/sp2024/lectures/"
|
184 |
+
) # TODO: Use more efficiently
|
185 |
|
186 |
for file_index, file_path in enumerate(uploaded_files):
|
187 |
file_name = os.path.basename(file_path)
|
188 |
file_type = file_name.split(".")[-1].lower()
|
189 |
|
190 |
+
# try:
|
191 |
+
if file_type == "pdf":
|
192 |
+
documents = file_reader.read_pdf(file_path)
|
193 |
+
elif file_type == "txt":
|
194 |
+
documents = file_reader.read_txt(file_path)
|
195 |
+
elif file_type == "docx":
|
196 |
+
documents = file_reader.read_docx(file_path)
|
197 |
+
elif file_type == "srt":
|
198 |
+
documents = file_reader.read_srt(file_path)
|
199 |
+
else:
|
200 |
+
logger.warning(f"Unsupported file type: {file_type}")
|
201 |
+
continue
|
202 |
+
|
203 |
+
# full_text = ""
|
204 |
+
# for doc in documents:
|
205 |
+
# full_text += doc.page_content
|
206 |
+
# break # getting only first page for now
|
207 |
+
|
208 |
+
# extracted_metadata = self.extract_metadata(full_text)
|
209 |
+
|
210 |
+
for doc in documents:
|
211 |
+
page_num = doc.metadata.get("page", 0)
|
212 |
+
self.documents.append(doc.page_content)
|
213 |
+
self.document_metadata.append({"source": file_path, "page": page_num})
|
214 |
+
if "lecture" in file_path.lower():
|
215 |
+
metadata = lecture_metadata.get(file_path, {})
|
216 |
+
metadata["source_type"] = "lecture"
|
217 |
+
self.document_metadata[-1].update(metadata)
|
218 |
else:
|
219 |
+
metadata = {"source_type": "other"}
|
220 |
+
|
221 |
+
self.child_document_names.append(f"{file_name}_{page_num}")
|
222 |
+
|
223 |
+
self.parent_document_names.append(file_name)
|
224 |
+
if self.config["embedding_options"]["db_option"] not in ["RAGatouille"]:
|
225 |
+
document_chunks = self.process_chunks(
|
226 |
+
self.documents[-1],
|
227 |
+
file_type,
|
228 |
+
source=file_path,
|
229 |
+
page=page_num,
|
230 |
+
metadata=metadata,
|
231 |
+
)
|
232 |
+
self.document_chunks_full.extend(document_chunks)
|
233 |
|
234 |
+
# except Exception as e:
|
235 |
+
# logger.error(f"Error processing file {file_name}: {str(e)}")
|
236 |
|
237 |
self.process_weblinks(file_reader, weblinks)
|
238 |
|
239 |
logger.info(
|
240 |
f"Total document chunks extracted: {len(self.document_chunks_full)}"
|
241 |
)
|
242 |
+
return (
|
243 |
+
self.document_chunks_full,
|
244 |
+
self.child_document_names,
|
245 |
+
self.documents,
|
246 |
+
self.document_metadata,
|
247 |
+
)
|
248 |
|
249 |
def process_weblinks(self, file_reader, weblinks):
|
250 |
if weblinks[0] != "":
|
|
|
258 |
else:
|
259 |
documents = file_reader.read_html(link)
|
260 |
|
261 |
+
for doc in documents:
|
262 |
+
page_num = doc.metadata.get("page", 0)
|
263 |
+
self.documents.append(doc.page_content)
|
264 |
+
self.document_metadata.append(
|
265 |
+
{"source": link, "page": page_num}
|
266 |
+
)
|
267 |
+
self.child_document_names.append(f"{link}")
|
268 |
+
|
269 |
+
self.parent_document_names.append(link)
|
270 |
+
if self.config["embedding_options"]["db_option"] not in [
|
271 |
+
"RAGatouille"
|
272 |
+
]:
|
273 |
+
document_chunks = self.process_chunks(
|
274 |
+
self.documents[-1],
|
275 |
+
"txt",
|
276 |
+
source=link,
|
277 |
+
page=0,
|
278 |
+
metadata={"source_type": "webpage"},
|
279 |
+
)
|
280 |
+
self.document_chunks_full.extend(document_chunks)
|
281 |
except Exception as e:
|
282 |
logger.error(
|
283 |
f"Error splitting link {link_index+1} : {link}: {str(e)}"
|
code/modules/embedding_model_loader.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from langchain_community.embeddings import OpenAIEmbeddings
|
2 |
-
from
|
3 |
-
from
|
4 |
|
5 |
try:
|
6 |
from modules.constants import *
|
|
|
1 |
from langchain_community.embeddings import OpenAIEmbeddings
|
2 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
3 |
+
from langchain_community.embeddings import LlamaCppEmbeddings
|
4 |
|
5 |
try:
|
6 |
from modules.constants import *
|
code/modules/helpers.py
CHANGED
@@ -4,6 +4,8 @@ from tqdm import tqdm
|
|
4 |
from urllib.parse import urlparse
|
5 |
import chainlit as cl
|
6 |
from langchain import PromptTemplate
|
|
|
|
|
7 |
|
8 |
try:
|
9 |
from modules.constants import *
|
@@ -138,67 +140,133 @@ def get_prompt(config):
|
|
138 |
|
139 |
|
140 |
def get_sources(res, answer):
|
141 |
-
source_elements_dict = {}
|
142 |
source_elements = []
|
143 |
-
found_sources = []
|
144 |
-
|
145 |
source_dict = {} # Dictionary to store URL elements
|
146 |
|
147 |
for idx, source in enumerate(res["source_documents"]):
|
148 |
source_metadata = source.metadata
|
149 |
url = source_metadata["source"]
|
150 |
score = source_metadata.get("score", "N/A")
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
else:
|
155 |
-
source_dict[
|
156 |
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
full_text += f"Source {url_idx + 1} (Score: {score}):\n{text}\n\n\n"
|
161 |
-
source_elements.append(cl.Text(name=url, content=full_text))
|
162 |
-
found_sources.append(f"{url} (Score: {score})")
|
163 |
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
#
|
181 |
-
#
|
182 |
-
#
|
183 |
-
#
|
184 |
-
#
|
185 |
-
#
|
186 |
-
|
187 |
-
#
|
188 |
-
#
|
189 |
-
#
|
190 |
-
#
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from urllib.parse import urlparse
|
5 |
import chainlit as cl
|
6 |
from langchain import PromptTemplate
|
7 |
+
import requests
|
8 |
+
from bs4 import BeautifulSoup
|
9 |
|
10 |
try:
|
11 |
from modules.constants import *
|
|
|
140 |
|
141 |
|
142 |
def get_sources(res, answer):
|
|
|
143 |
source_elements = []
|
|
|
|
|
144 |
source_dict = {} # Dictionary to store URL elements
|
145 |
|
146 |
for idx, source in enumerate(res["source_documents"]):
|
147 |
source_metadata = source.metadata
|
148 |
url = source_metadata["source"]
|
149 |
score = source_metadata.get("score", "N/A")
|
150 |
+
page = source_metadata.get("page", 1)
|
151 |
+
|
152 |
+
lecture_tldr = source_metadata.get("tldr", "N/A")
|
153 |
+
lecture_recording = source_metadata.get("lecture_recording", "N/A")
|
154 |
+
suggested_readings = source_metadata.get("suggested_readings", "N/A")
|
155 |
|
156 |
+
source_type = source_metadata.get("source_type", "N/A")
|
157 |
+
|
158 |
+
url_name = f"{url}_{page}"
|
159 |
+
if url_name not in source_dict:
|
160 |
+
source_dict[url_name] = {
|
161 |
+
"text": source.page_content,
|
162 |
+
"url": url,
|
163 |
+
"score": score,
|
164 |
+
"page": page,
|
165 |
+
"lecture_tldr": lecture_tldr,
|
166 |
+
"lecture_recording": lecture_recording,
|
167 |
+
"suggested_readings": suggested_readings,
|
168 |
+
"source_type": source_type,
|
169 |
+
}
|
170 |
else:
|
171 |
+
source_dict[url_name]["text"] += f"\n\n{source.page_content}"
|
172 |
|
173 |
+
# First, display the answer
|
174 |
+
full_answer = "**Answer:**\n"
|
175 |
+
full_answer += answer
|
|
|
|
|
|
|
176 |
|
177 |
+
# Then, display the sources
|
178 |
+
full_answer += "\n\n**Sources:**\n"
|
179 |
+
for idx, (url_name, source_data) in enumerate(source_dict.items()):
|
180 |
+
full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
|
181 |
+
|
182 |
+
name = f"Source {idx + 1} Text\n"
|
183 |
+
full_answer += name
|
184 |
+
source_elements.append(cl.Text(name=name, content=source_data["text"]))
|
185 |
+
|
186 |
+
# Add a PDF element if the source is a PDF file
|
187 |
+
if source_data["url"].lower().endswith(".pdf"):
|
188 |
+
name = f"Source {idx + 1} PDF\n"
|
189 |
+
full_answer += name
|
190 |
+
pdf_url = f"{source_data['url']}#page={source_data['page']+1}"
|
191 |
+
source_elements.append(cl.Pdf(name=name, url=pdf_url))
|
192 |
+
|
193 |
+
# Finally, include lecture metadata for each unique source
|
194 |
+
# displayed_urls = set()
|
195 |
+
# full_answer += "\n**Metadata:**\n"
|
196 |
+
# for url_name, source_data in source_dict.items():
|
197 |
+
# if source_data["url"] not in displayed_urls:
|
198 |
+
# full_answer += f"\nSource: {source_data['url']}\n"
|
199 |
+
# full_answer += f"Type: {source_data['source_type']}\n"
|
200 |
+
# full_answer += f"TL;DR: {source_data['lecture_tldr']}\n"
|
201 |
+
# full_answer += f"Lecture Recording: {source_data['lecture_recording']}\n"
|
202 |
+
# full_answer += f"Suggested Readings: {source_data['suggested_readings']}\n"
|
203 |
+
# displayed_urls.add(source_data["url"])
|
204 |
+
full_answer += "\n**Metadata:**\n"
|
205 |
+
for url_name, source_data in source_dict.items():
|
206 |
+
full_answer += f"\nSource: {source_data['url']}\n"
|
207 |
+
full_answer += f"Page: {source_data['page']}\n"
|
208 |
+
full_answer += f"Type: {source_data['source_type']}\n"
|
209 |
+
full_answer += f"TL;DR: {source_data['lecture_tldr']}\n"
|
210 |
+
full_answer += f"Lecture Recording: {source_data['lecture_recording']}\n"
|
211 |
+
full_answer += f"Suggested Readings: {source_data['suggested_readings']}\n"
|
212 |
+
|
213 |
+
return full_answer, source_elements
|
214 |
+
|
215 |
+
|
216 |
+
def get_lecture_metadata(schedule_url):
|
217 |
+
"""
|
218 |
+
Function to get the lecture metadata from the schedule URL.
|
219 |
+
"""
|
220 |
+
lecture_metadata = {}
|
221 |
+
|
222 |
+
# Get the main schedule page content
|
223 |
+
r = requests.get(schedule_url)
|
224 |
+
soup = BeautifulSoup(r.text, "html.parser")
|
225 |
+
|
226 |
+
# Find all lecture blocks
|
227 |
+
lecture_blocks = soup.find_all("div", class_="lecture-container")
|
228 |
+
|
229 |
+
for block in lecture_blocks:
|
230 |
+
try:
|
231 |
+
# Extract the lecture title
|
232 |
+
title = block.find("span", style="font-weight: bold;").text.strip()
|
233 |
+
|
234 |
+
# Extract the TL;DR
|
235 |
+
tldr = block.find("strong", text="tl;dr:").next_sibling.strip()
|
236 |
+
|
237 |
+
# Extract the link to the slides
|
238 |
+
slides_link_tag = block.find("a", title="Download slides")
|
239 |
+
slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
|
240 |
+
|
241 |
+
# Extract the link to the lecture recording
|
242 |
+
recording_link_tag = block.find("a", title="Download lecture recording")
|
243 |
+
recording_link = (
|
244 |
+
recording_link_tag["href"].strip() if recording_link_tag else None
|
245 |
+
)
|
246 |
+
|
247 |
+
# Extract suggested readings or summary if available
|
248 |
+
suggested_readings_tag = block.find("p", text="Suggested Readings:")
|
249 |
+
if suggested_readings_tag:
|
250 |
+
suggested_readings = suggested_readings_tag.find_next_sibling("ul")
|
251 |
+
if suggested_readings:
|
252 |
+
suggested_readings = suggested_readings.get_text(
|
253 |
+
separator="\n"
|
254 |
+
).strip()
|
255 |
+
else:
|
256 |
+
suggested_readings = "No specific readings provided."
|
257 |
+
else:
|
258 |
+
suggested_readings = "No specific readings provided."
|
259 |
+
|
260 |
+
# Add to the dictionary
|
261 |
+
slides_link = f"https://dl4ds.github.io{slides_link}"
|
262 |
+
lecture_metadata[slides_link] = {
|
263 |
+
"tldr": tldr,
|
264 |
+
"title": title,
|
265 |
+
"lecture_recording": recording_link,
|
266 |
+
"suggested_readings": suggested_readings,
|
267 |
+
}
|
268 |
+
except Exception as e:
|
269 |
+
print(f"Error processing block: {e}")
|
270 |
+
continue
|
271 |
+
|
272 |
+
return lecture_metadata
|
code/modules/llm_tutor.py
CHANGED
@@ -8,7 +8,6 @@ from langchain.llms import CTransformers
|
|
8 |
from langchain.memory import ConversationBufferWindowMemory
|
9 |
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
|
10 |
import os
|
11 |
-
|
12 |
from modules.constants import *
|
13 |
from modules.helpers import get_prompt
|
14 |
from modules.chat_model_loader import ChatModelLoader
|
@@ -34,14 +33,21 @@ class LLMTutor:
|
|
34 |
|
35 |
# Retrieval QA Chain
|
36 |
def retrieval_qa_chain(self, llm, prompt, db):
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
if self.config["llm_params"]["use_history"]:
|
46 |
memory = ConversationBufferWindowMemory(
|
47 |
k=self.config["llm_params"]["memory_window"],
|
|
|
8 |
from langchain.memory import ConversationBufferWindowMemory
|
9 |
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
|
10 |
import os
|
|
|
11 |
from modules.constants import *
|
12 |
from modules.helpers import get_prompt
|
13 |
from modules.chat_model_loader import ChatModelLoader
|
|
|
33 |
|
34 |
# Retrieval QA Chain
|
35 |
def retrieval_qa_chain(self, llm, prompt, db):
|
36 |
+
if self.config["embedding_options"]["db_option"] in ["FAISS", "Chroma"]:
|
37 |
+
retriever = VectorDBScore(
|
38 |
+
vectorstore=db,
|
39 |
+
search_type="similarity_score_threshold",
|
40 |
+
search_kwargs={
|
41 |
+
"score_threshold": self.config["embedding_options"][
|
42 |
+
"score_threshold"
|
43 |
+
],
|
44 |
+
"k": self.config["embedding_options"]["search_top_k"],
|
45 |
+
},
|
46 |
+
)
|
47 |
+
elif self.config["embedding_options"]["db_option"] == "RAGatouille":
|
48 |
+
retriever = db.as_langchain_retriever(
|
49 |
+
k=self.config["embedding_options"]["search_top_k"]
|
50 |
+
)
|
51 |
if self.config["llm_params"]["use_history"]:
|
52 |
memory = ConversationBufferWindowMemory(
|
53 |
k=self.config["llm_params"]["memory_window"],
|
code/modules/vector_db.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
import logging
|
2 |
import os
|
3 |
import yaml
|
4 |
-
from
|
5 |
from langchain.schema.vectorstore import VectorStoreRetriever
|
6 |
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
7 |
from langchain.schema.document import Document
|
8 |
from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun
|
|
|
9 |
|
10 |
try:
|
11 |
from modules.embedding_model_loader import EmbeddingModelLoader
|
@@ -25,7 +26,7 @@ class VectorDBScore(VectorStoreRetriever):
|
|
25 |
|
26 |
# See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
|
27 |
def _get_relevant_documents(
|
28 |
-
|
29 |
) -> List[Document]:
|
30 |
docs_and_similarities = (
|
31 |
self.vectorstore.similarity_search_with_relevance_scores(
|
@@ -55,7 +56,6 @@ class VectorDBScore(VectorStoreRetriever):
|
|
55 |
return docs
|
56 |
|
57 |
|
58 |
-
|
59 |
class VectorDB:
|
60 |
def __init__(self, config, logger=None):
|
61 |
self.config = config
|
@@ -116,7 +116,15 @@ class VectorDB:
|
|
116 |
self.embedding_model_loader = EmbeddingModelLoader(self.config)
|
117 |
self.embedding_model = self.embedding_model_loader.load_embedding_model()
|
118 |
|
119 |
-
def initialize_database(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
# Track token usage
|
121 |
self.logger.info("Initializing vector_db")
|
122 |
self.logger.info("\tUsing {} as db_option".format(self.db_option))
|
@@ -136,6 +144,14 @@ class VectorDB:
|
|
136 |
+ self.config["embedding_options"]["model"],
|
137 |
),
|
138 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
self.logger.info("Completed initializing vector_db")
|
140 |
|
141 |
def create_database(self):
|
@@ -146,11 +162,13 @@ class VectorDB:
|
|
146 |
files += lecture_pdfs
|
147 |
if "storage/data/urls.txt" in files:
|
148 |
files.remove("storage/data/urls.txt")
|
149 |
-
document_chunks, document_names =
|
|
|
|
|
150 |
self.logger.info("Completed loading data")
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
|
155 |
def save_database(self):
|
156 |
if self.db_option == "FAISS":
|
@@ -166,6 +184,9 @@ class VectorDB:
|
|
166 |
elif self.db_option == "Chroma":
|
167 |
# db is saved in the persist directory during initialization
|
168 |
pass
|
|
|
|
|
|
|
169 |
self.logger.info("Saved database")
|
170 |
|
171 |
def load_database(self):
|
@@ -180,7 +201,7 @@ class VectorDB:
|
|
180 |
+ self.config["embedding_options"]["model"],
|
181 |
),
|
182 |
self.embedding_model,
|
183 |
-
|
184 |
)
|
185 |
elif self.db_option == "Chroma":
|
186 |
self.vector_db = Chroma(
|
@@ -193,6 +214,10 @@ class VectorDB:
|
|
193 |
),
|
194 |
embedding_function=self.embedding_model,
|
195 |
)
|
|
|
|
|
|
|
|
|
196 |
self.logger.info("Loaded database")
|
197 |
return self.vector_db
|
198 |
|
|
|
1 |
import logging
|
2 |
import os
|
3 |
import yaml
|
4 |
+
from langchain_community.vectorstores import FAISS, Chroma
|
5 |
from langchain.schema.vectorstore import VectorStoreRetriever
|
6 |
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
7 |
from langchain.schema.document import Document
|
8 |
from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun
|
9 |
+
from ragatouille import RAGPretrainedModel
|
10 |
|
11 |
try:
|
12 |
from modules.embedding_model_loader import EmbeddingModelLoader
|
|
|
26 |
|
27 |
# See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
|
28 |
def _get_relevant_documents(
|
29 |
+
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
30 |
) -> List[Document]:
|
31 |
docs_and_similarities = (
|
32 |
self.vectorstore.similarity_search_with_relevance_scores(
|
|
|
56 |
return docs
|
57 |
|
58 |
|
|
|
59 |
class VectorDB:
|
60 |
def __init__(self, config, logger=None):
|
61 |
self.config = config
|
|
|
116 |
self.embedding_model_loader = EmbeddingModelLoader(self.config)
|
117 |
self.embedding_model = self.embedding_model_loader.load_embedding_model()
|
118 |
|
119 |
+
def initialize_database(
|
120 |
+
self,
|
121 |
+
document_chunks: list,
|
122 |
+
document_names: list,
|
123 |
+
documents: list,
|
124 |
+
document_metadata: list,
|
125 |
+
):
|
126 |
+
if self.db_option in ["FAISS", "Chroma"]:
|
127 |
+
self.create_embedding_model()
|
128 |
# Track token usage
|
129 |
self.logger.info("Initializing vector_db")
|
130 |
self.logger.info("\tUsing {} as db_option".format(self.db_option))
|
|
|
144 |
+ self.config["embedding_options"]["model"],
|
145 |
),
|
146 |
)
|
147 |
+
elif self.db_option == "RAGatouille":
|
148 |
+
self.RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
|
149 |
+
index_path = self.RAG.index(
|
150 |
+
index_name="new_idx",
|
151 |
+
collection=documents,
|
152 |
+
document_ids=document_names,
|
153 |
+
document_metadatas=document_metadata,
|
154 |
+
)
|
155 |
self.logger.info("Completed initializing vector_db")
|
156 |
|
157 |
def create_database(self):
|
|
|
162 |
files += lecture_pdfs
|
163 |
if "storage/data/urls.txt" in files:
|
164 |
files.remove("storage/data/urls.txt")
|
165 |
+
document_chunks, document_names, documents, document_metadata = (
|
166 |
+
data_loader.get_chunks(files, urls)
|
167 |
+
)
|
168 |
self.logger.info("Completed loading data")
|
169 |
+
self.initialize_database(
|
170 |
+
document_chunks, document_names, documents, document_metadata
|
171 |
+
)
|
172 |
|
173 |
def save_database(self):
|
174 |
if self.db_option == "FAISS":
|
|
|
184 |
elif self.db_option == "Chroma":
|
185 |
# db is saved in the persist directory during initialization
|
186 |
pass
|
187 |
+
elif self.db_option == "RAGatouille":
|
188 |
+
# index is saved during initialization
|
189 |
+
pass
|
190 |
self.logger.info("Saved database")
|
191 |
|
192 |
def load_database(self):
|
|
|
201 |
+ self.config["embedding_options"]["model"],
|
202 |
),
|
203 |
self.embedding_model,
|
204 |
+
allow_dangerous_deserialization=True,
|
205 |
)
|
206 |
elif self.db_option == "Chroma":
|
207 |
self.vector_db = Chroma(
|
|
|
214 |
),
|
215 |
embedding_function=self.embedding_model,
|
216 |
)
|
217 |
+
elif self.db_option == "RAGatouille":
|
218 |
+
self.vector_db = RAGPretrainedModel.from_index(
|
219 |
+
".ragatouille/colbert/indexes/new_idx"
|
220 |
+
)
|
221 |
self.logger.info("Loaded database")
|
222 |
return self.vector_db
|
223 |
|
requirements.txt
CHANGED
@@ -17,3 +17,4 @@ fake-useragent==1.4.0
|
|
17 |
git+https://github.com/huggingface/accelerate.git
|
18 |
llama-cpp-python
|
19 |
PyPDF2==3.0.1
|
|
|
|
17 |
git+https://github.com/huggingface/accelerate.git
|
18 |
llama-cpp-python
|
19 |
PyPDF2==3.0.1
|
20 |
+
ragatouille==0.0.8.post2
|
storage/data/urls.txt
CHANGED
@@ -1 +1,3 @@
|
|
1 |
https://dl4ds.github.io/sp2024/
|
|
|
|
|
|
1 |
https://dl4ds.github.io/sp2024/
|
2 |
+
https://dl4ds.github.io/sp2024/static_files/lectures/15_RAG_CoT.pdf
|
3 |
+
https://dl4ds.github.io/sp2024/static_files/lectures/21_RL_RLHF_v2.pdf
|