XThomasBU commited on
Commit
6158da4
·
1 Parent(s): b5be549

init commit

Browse files
.gitignore CHANGED
@@ -158,3 +158,6 @@ cython_debug/
158
  # and can be added to the global gitignore or merged into this file. For a more nuclear
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
 
 
 
 
158
  # and can be added to the global gitignore or merged into this file. For a more nuclear
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
161
+
162
+ # log files
163
+ *.log
code/.chainlit/config.toml ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ # Whether to enable telemetry (default: true). No personal data is collected.
3
+ enable_telemetry = true
4
+
5
+ # List of environment variables to be provided by each user to use the app.
6
+ user_env = []
7
+
8
+ # Duration (in seconds) during which the session is saved when the connection is lost
9
+ session_timeout = 3600
10
+
11
+ # Enable third parties caching (e.g LangChain cache)
12
+ cache = false
13
+
14
+ # Follow symlink for asset mount (see https://github.com/Chainlit/chainlit/issues/317)
15
+ # follow_symlink = false
16
+
17
+ [features]
18
+ # Show the prompt playground
19
+ prompt_playground = true
20
+
21
+ # Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript)
22
+ unsafe_allow_html = false
23
+
24
+ # Process and display mathematical expressions. This can clash with "$" characters in messages.
25
+ latex = false
26
+
27
+ # Authorize users to upload files with messages
28
+ multi_modal = true
29
+
30
+ # Allows user to use speech to text
31
+ [features.speech_to_text]
32
+ enabled = false
33
+ # See all languages here https://github.com/JamesBrill/react-speech-recognition/blob/HEAD/docs/API.md#language-string
34
+ # language = "en-US"
35
+
36
+ [UI]
37
+ # Name of the app and chatbot.
38
+ name = "LLM Tutor"
39
+
40
+ # Show the readme while the conversation is empty.
41
+ show_readme_as_default = true
42
+
43
+ # Description of the app and chatbot. This is used for HTML tags.
44
+ # description = ""
45
+
46
+ # Large size content are by default collapsed for a cleaner ui
47
+ default_collapse_content = true
48
+
49
+ # The default value for the expand messages settings.
50
+ default_expand_messages = false
51
+
52
+ # Hide the chain of thought details from the user in the UI.
53
+ hide_cot = false
54
+
55
+ # Link to your github repo. This will add a github button in the UI's header.
56
+ # github = "https://github.com/DL4DS/dl4ds_tutor"
57
+
58
+ # Specify a CSS file that can be used to customize the user interface.
59
+ # The CSS file can be served from the public directory or via an external link.
60
+ # custom_css = "/public/test.css"
61
+
62
+ # Override default MUI light theme. (Check theme.ts)
63
+ [UI.theme.light]
64
+ #background = "#FAFAFA"
65
+ #paper = "#FFFFFF"
66
+
67
+ [UI.theme.light.primary]
68
+ #main = "#F80061"
69
+ #dark = "#980039"
70
+ #light = "#FFE7EB"
71
+
72
+ # Override default MUI dark theme. (Check theme.ts)
73
+ [UI.theme.dark]
74
+ #background = "#FAFAFA"
75
+ #paper = "#FFFFFF"
76
+
77
+ [UI.theme.dark.primary]
78
+ #main = "#F80061"
79
+ #dark = "#980039"
80
+ #light = "#FFE7EB"
81
+
82
+
83
+ [meta]
84
+ generated_by = "0.7.700"
code/chainlit.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Welcome to DL4DS Tutor! 🚀🤖
2
+
3
+ Hi there, this is an LLM chatbot designed to help answer questions on the course content, built using Langchain and Chainlit.
4
+ This is still very much a Work in Progress.
5
+
6
+ ## Useful Links 🔗
7
+
8
+ - **Documentation:** [Chainlit Documentation](https://docs.chainlit.io) 📚
code/config.yml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ embedding_options:
2
+ embedd_files: True # bool
3
+ persist_directory: null # str or None
4
+ data_path: '../data' # str
5
+ db_option : 'FAISS' # str
6
+ db_path : 'vectorstores' # str
7
+ model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
8
+ llm_params:
9
+ use_history: False # bool
10
+ llm_loader: 'openai' # str [ctransformers, openai]
11
+ openai_params:
12
+ model: 'gpt-4' # str [gpt-3.5-turbo-1106, gpt-4]
13
+ ctransformers_params:
14
+ model: "TheBloke/Llama-2-7B-Chat-GGML"
15
+ model_type: "llama"
16
+ splitter_options:
17
+ use_splitter: True # bool
18
+ split_by_token : True # bool
19
+ remove_leftover_delimiters: True # bool
20
+ remove_chunks: False # bool
21
+ chunk_size : 800 # int
22
+ chunk_overlap : 80 # int
23
+ chunk_separators : ["\n\n", "\n", " ", ""] # list of strings
24
+ front_chunks_to_remove : null # int or None
25
+ last_chunks_to_remove : null # int or None
26
+ delimiters_to_remove : ['\t', '\n', ' ', ' '] # list of strings
code/main.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.document_loaders import PyPDFLoader, DirectoryLoader
2
+ from langchain import PromptTemplate
3
+ from langchain.embeddings import HuggingFaceEmbeddings
4
+ from langchain.vectorstores import FAISS
5
+ from langchain.chains import RetrievalQA
6
+ from langchain.llms import CTransformers
7
+ import chainlit as cl
8
+ from langchain_community.chat_models import ChatOpenAI
9
+ from langchain_community.embeddings import OpenAIEmbeddings
10
+ import yaml
11
+ import logging
12
+ from dotenv import load_dotenv
13
+
14
+ from modules.llm_tutor import LLMTutor
15
+
16
+
17
+ logger = logging.getLogger(__name__)
18
+ logger.setLevel(logging.INFO)
19
+
20
+ # Console Handler
21
+ console_handler = logging.StreamHandler()
22
+ console_handler.setLevel(logging.INFO)
23
+ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
24
+ console_handler.setFormatter(formatter)
25
+ logger.addHandler(console_handler)
26
+
27
+ # File Handler
28
+ log_file_path = "log_file.log" # Change this to your desired log file path
29
+ file_handler = logging.FileHandler(log_file_path)
30
+ file_handler.setLevel(logging.INFO)
31
+ file_handler.setFormatter(formatter)
32
+ logger.addHandler(file_handler)
33
+
34
+ with open("config.yml", "r") as f:
35
+ config = yaml.safe_load(f)
36
+ print(config)
37
+ logger.info("Config file loaded")
38
+ logger.info(f"Config: {config}")
39
+ logger.info("Creating llm_tutor instance")
40
+ llm_tutor = LLMTutor(config, logger=logger)
41
+
42
+
43
+ # chainlit code
44
+ @cl.on_chat_start
45
+ async def start():
46
+ chain = llm_tutor.qa_bot()
47
+ msg = cl.Message(content="Starting the bot...")
48
+ await msg.send()
49
+ msg.content = "Hey, What Can I Help You With?"
50
+ await msg.update()
51
+
52
+ cl.user_session.set("chain", chain)
53
+
54
+
55
+ @cl.on_message
56
+ async def main(message):
57
+ chain = cl.user_session.get("chain")
58
+ cb = cl.AsyncLangchainCallbackHandler(
59
+ stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
60
+ )
61
+ cb.answer_reached = True
62
+ # res=await chain.acall(message, callbacks=[cb])
63
+ res = await chain.acall(message.content, callbacks=[cb])
64
+ # print(f"response: {res}")
65
+ try:
66
+ answer = res["answer"]
67
+ except:
68
+ answer = res["result"]
69
+ print(f"answer: {answer}")
70
+ source_elements_dict = {}
71
+ source_elements = []
72
+ found_sources = []
73
+
74
+ for idx, source in enumerate(res["source_documents"]):
75
+ title = source.metadata["source"]
76
+
77
+ if title not in source_elements_dict:
78
+ source_elements_dict[title] = {
79
+ "page_number": [source.metadata["page"]],
80
+ "url": source.metadata["source"],
81
+ "content": source.page_content,
82
+ }
83
+
84
+ else:
85
+ source_elements_dict[title]["page_number"].append(source.metadata["page"])
86
+ source_elements_dict[title][
87
+ "content_" + str(source.metadata["page"])
88
+ ] = source.page_content
89
+ # sort the page numbers
90
+ # source_elements_dict[title]["page_number"].sort()
91
+
92
+ for title, source in source_elements_dict.items():
93
+ # create a string for the page numbers
94
+ page_numbers = ", ".join([str(x) for x in source["page_number"]])
95
+ text_for_source = f"Page Number(s): {page_numbers}\nURL: {source['url']}"
96
+ source_elements.append(cl.Pdf(name="File", path=title))
97
+ found_sources.append("File")
98
+ # for pn in source["page_number"]:
99
+ # source_elements.append(
100
+ # cl.Text(name=str(pn), content=source["content_"+str(pn)])
101
+ # )
102
+ # found_sources.append(str(pn))
103
+
104
+ if found_sources:
105
+ answer += f"\nSource:{', '.join(found_sources)}"
106
+ else:
107
+ answer += f"\nNo source found."
108
+
109
+ await cl.Message(content=answer, elements=source_elements).send()
code/modules/__init__.py ADDED
File without changes
code/modules/chat_model_loader.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.chat_models import ChatOpenAI
2
+ from langchain.llms import CTransformers
3
+
4
+
5
+ class ChatModelLoader:
6
+ def __init__(self, config):
7
+ self.config = config
8
+
9
+ def load_chat_model(self):
10
+ if self.config["llm_params"]["llm_loader"] == "openai":
11
+ llm = ChatOpenAI(
12
+ model_name=self.config["llm_params"]["openai_params"]["model"]
13
+ )
14
+ elif self.config["llm_params"]["llm_loader"] == "Ctransformers":
15
+ llm = CTransformers(
16
+ model=self.config["llm_params"]["ctransformers_params"]["model"],
17
+ model_type=self.config["llm_params"]["ctransformers_params"][
18
+ "model_type"
19
+ ],
20
+ max_new_tokens=512,
21
+ temperature=0.5,
22
+ )
23
+ else:
24
+ raise ValueError("Invalid LLM Loader")
25
+ return llm
code/modules/constants.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import os
3
+
4
+ load_dotenv()
5
+
6
+ # API Keys - Loaded from the .env file
7
+
8
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
9
+
10
+
11
+ # Prompt Templates
12
+
13
+ prompt_template = """Use the following pieces of information to answer the user's question.
14
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
15
+
16
+ Context: {context}
17
+ Question: {question}
18
+
19
+ Only return the helpful answer below and nothing else.
20
+ Helpful answer:
21
+ """
22
+
23
+ prompt_template_with_history = """Use the following pieces of information to answer the user's question.
24
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
25
+ Use the history to answer the question if you can.
26
+ Chat History:
27
+ {chat_history}
28
+ Context: {context}
29
+ Question: {question}
30
+
31
+ Only return the helpful answer below and nothing else.
32
+ Helpful answer:
33
+ """
code/modules/data_loader.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import pysrt
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain.document_loaders import (
5
+ PyMuPDFLoader,
6
+ Docx2txtLoader,
7
+ YoutubeLoader,
8
+ WebBaseLoader,
9
+ TextLoader,
10
+ )
11
+ from langchain.schema import Document
12
+ from tempfile import NamedTemporaryFile
13
+ import logging
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class DataLoader:
19
+ def __init__(self, config):
20
+ """
21
+ Class for handling all data extraction and chunking
22
+ Inputs:
23
+ config - dictionary from yaml file, containing all important parameters
24
+ """
25
+ self.config = config
26
+ self.remove_leftover_delimiters = config["splitter_options"][
27
+ "remove_leftover_delimiters"
28
+ ]
29
+
30
+ # Main list of all documents
31
+ self.document_chunks_full = []
32
+ self.document_names = []
33
+
34
+ if config["splitter_options"]["use_splitter"]:
35
+ if config["splitter_options"]["split_by_token"]:
36
+ self.splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
37
+ chunk_size=config["splitter_options"]["chunk_size"],
38
+ chunk_overlap=config["splitter_options"]["chunk_overlap"],
39
+ separators=config["splitter_options"]["chunk_separators"],
40
+ )
41
+ else:
42
+ self.splitter = RecursiveCharacterTextSplitter(
43
+ chunk_size=config["splitter_options"]["chunk_size"],
44
+ chunk_overlap=config["splitter_options"]["chunk_overlap"],
45
+ separators=config["splitter_options"]["chunk_separators"],
46
+ )
47
+ else:
48
+ self.splitter = None
49
+ logger.info("InfoLoader instance created")
50
+
51
+ def get_chunks(self, uploaded_files, weblinks):
52
+ # Main list of all documents
53
+ self.document_chunks_full = []
54
+ self.document_names = []
55
+
56
+ def remove_delimiters(document_chunks: list):
57
+ """
58
+ Helper function to remove remaining delimiters in document chunks
59
+ """
60
+ for chunk in document_chunks:
61
+ for delimiter in self.config["splitter_options"][
62
+ "delimiters_to_remove"
63
+ ]:
64
+ chunk.page_content = re.sub(delimiter, " ", chunk.page_content)
65
+ return document_chunks
66
+
67
+ def remove_chunks(document_chunks: list):
68
+ """
69
+ Helper function to remove any unwanted document chunks after splitting
70
+ """
71
+ front = self.config["splitter_options"]["front_chunk_to_remove"]
72
+ end = self.config["splitter_options"]["last_chunks_to_remove"]
73
+ # Remove pages
74
+ for _ in range(front):
75
+ del document_chunks[0]
76
+ for _ in range(end):
77
+ document_chunks.pop()
78
+ logger.info(f"\tNumber of pages after skipping: {len(document_chunks)}")
79
+ return document_chunks
80
+
81
+ def get_pdf(temp_file_path: str, title: str):
82
+ """
83
+ Function to process PDF files
84
+ """
85
+ loader = PyMuPDFLoader(
86
+ temp_file_path
87
+ ) # This loader preserves more metadata
88
+
89
+ if self.splitter:
90
+ document_chunks = self.splitter.split_documents(loader.load())
91
+ else:
92
+ document_chunks = loader.load()
93
+
94
+ if "title" in document_chunks[0].metadata.keys():
95
+ title = document_chunks[0].metadata["title"]
96
+
97
+ logger.info(
98
+ f"\t\tOriginal no. of pages: {document_chunks[0].metadata['total_pages']}"
99
+ )
100
+
101
+ return title, document_chunks
102
+
103
+ def get_txt(temp_file_path: str, title: str):
104
+ """
105
+ Function to process TXT files
106
+ """
107
+ loader = TextLoader(temp_file_path, autodetect_encoding=True)
108
+
109
+ if self.splitter:
110
+ document_chunks = self.splitter.split_documents(loader.load())
111
+ else:
112
+ document_chunks = loader.load()
113
+
114
+ # Update the metadata
115
+ for chunk in document_chunks:
116
+ chunk.metadata["source"] = title
117
+ chunk.metadata["page"] = "N/A"
118
+
119
+ return title, document_chunks
120
+
121
+ def get_srt(temp_file_path: str, title: str):
122
+ """
123
+ Function to process SRT files
124
+ """
125
+ subs = pysrt.open(temp_file_path)
126
+
127
+ text = ""
128
+ for sub in subs:
129
+ text += sub.text
130
+ document_chunks = [Document(page_content=text)]
131
+
132
+ if self.splitter:
133
+ document_chunks = self.splitter.split_documents(document_chunks)
134
+
135
+ # Update the metadata
136
+ for chunk in document_chunks:
137
+ chunk.metadata["source"] = title
138
+ chunk.metadata["page"] = "N/A"
139
+
140
+ return title, document_chunks
141
+
142
+ def get_docx(temp_file_path: str, title: str):
143
+ """
144
+ Function to process DOCX files
145
+ """
146
+ loader = Docx2txtLoader(temp_file_path)
147
+
148
+ if self.splitter:
149
+ document_chunks = self.splitter.split_documents(loader.load())
150
+ else:
151
+ document_chunks = loader.load()
152
+
153
+ # Update the metadata
154
+ for chunk in document_chunks:
155
+ chunk.metadata["source"] = title
156
+ chunk.metadata["page"] = "N/A"
157
+
158
+ return title, document_chunks
159
+
160
+ def get_youtube_transcript(url: str):
161
+ """
162
+ Function to retrieve youtube transcript and process text
163
+ """
164
+ loader = YoutubeLoader.from_youtube_url(
165
+ url, add_video_info=True, language=["en"], translation="en"
166
+ )
167
+
168
+ if self.splitter:
169
+ document_chunks = self.splitter.split_documents(loader.load())
170
+ else:
171
+ document_chunks = loader.load_and_split()
172
+
173
+ # Replace the source with title (for display in st UI later)
174
+ for chunk in document_chunks:
175
+ chunk.metadata["source"] = chunk.metadata["title"]
176
+ logger.info(chunk.metadata["title"])
177
+
178
+ return title, document_chunks
179
+
180
+ def get_html(url: str):
181
+ """
182
+ Function to process websites via HTML files
183
+ """
184
+ loader = WebBaseLoader(url)
185
+
186
+ if self.splitter:
187
+ document_chunks = self.splitter.split_documents(loader.load())
188
+ else:
189
+ document_chunks = loader.load_and_split()
190
+
191
+ title = document_chunks[0].metadata["title"]
192
+ logger.info(document_chunks[0].metadata)
193
+
194
+ return title, document_chunks
195
+
196
+ # Handle file by file
197
+ for file_index, file_path in enumerate(uploaded_files):
198
+
199
+ file_name = file_path.split("/")[-1]
200
+ file_type = file_name.split(".")[-1]
201
+
202
+ # Handle different file types
203
+ if file_type == "pdf":
204
+ title, document_chunks = get_pdf(file_path, file_name)
205
+ elif file_type == "txt":
206
+ title, document_chunks = get_txt(file_path, file_name)
207
+ elif file_type == "docx":
208
+ title, document_chunks = get_docx(file_path, file_name)
209
+ elif file_type == "srt":
210
+ title, document_chunks = get_srt(file_path, file_name)
211
+
212
+ # Additional wrangling - Remove leftover delimiters and any specified chunks
213
+ if self.remove_leftover_delimiters:
214
+ document_chunks = remove_delimiters(document_chunks)
215
+ if self.config["splitter_options"]["remove_chunks"]:
216
+ document_chunks = remove_chunks(document_chunks)
217
+
218
+ logger.info(f"\t\tExtracted no. of chunks: {len(document_chunks)}")
219
+ self.document_names.append(title)
220
+ self.document_chunks_full.extend(document_chunks)
221
+
222
+ # Handle youtube links:
223
+ if weblinks[0] != "":
224
+ logger.info(f"Splitting weblinks: total of {len(weblinks)}")
225
+
226
+ # Handle link by link
227
+ for link_index, link in enumerate(weblinks):
228
+ logger.info(f"\tSplitting link {link_index+1} : {link}")
229
+ if "youtube" in link:
230
+ title, document_chunks = get_youtube_transcript(link)
231
+ else:
232
+ title, document_chunks = get_html(link)
233
+
234
+ # Additional wrangling - Remove leftover delimiters and any specified chunks
235
+ if self.remove_leftover_delimiters:
236
+ document_chunks = remove_delimiters(document_chunks)
237
+ if self.config["splitter_options"]["remove_chunks"]:
238
+ document_chunks = remove_chunks(document_chunks)
239
+
240
+ print(f"\t\tExtracted no. of chunks: {len(document_chunks)}")
241
+ self.document_names.append(title)
242
+ self.document_chunks_full.extend(document_chunks)
243
+
244
+ logger.info(
245
+ f"\tNumber of document chunks extracted in total: {len(self.document_chunks_full)}\n\n"
246
+ )
247
+
248
+ return self.document_chunks_full, self.document_names
code/modules/embedding_model_loader.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.embeddings import OpenAIEmbeddings
2
+ from langchain.embeddings import HuggingFaceEmbeddings
3
+ from modules.constants import *
4
+
5
+
6
+ class EmbeddingModelLoader:
7
+ def __init__(self, config):
8
+ self.config = config
9
+
10
+ def load_embedding_model(self):
11
+ if self.config["embedding_options"]["model"] in ["text-embedding-ada-002"]:
12
+ embedding_model = OpenAIEmbeddings(
13
+ deployment="SL-document_embedder",
14
+ model=self.config["embedding_options"]["model"],
15
+ show_progress_bar=True,
16
+ openai_api_key=OPENAI_API_KEY,
17
+ )
18
+ else:
19
+ embedding_model = HuggingFaceEmbeddings(
20
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
21
+ model_kwargs={"device": "cpu"},
22
+ )
23
+ return embedding_model
code/modules/llm_tutor.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain import PromptTemplate
2
+ from langchain.embeddings import HuggingFaceEmbeddings
3
+ from langchain_community.chat_models import ChatOpenAI
4
+ from langchain_community.embeddings import OpenAIEmbeddings
5
+ from langchain.vectorstores import FAISS
6
+ from langchain.chains import RetrievalQA, ConversationalRetrievalChain
7
+ from langchain.llms import CTransformers
8
+ from langchain.memory import ConversationBufferMemory
9
+ from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
10
+ import os
11
+
12
+ from modules.constants import *
13
+ from modules.chat_model_loader import ChatModelLoader
14
+ from modules.vector_db import VectorDB
15
+
16
+
17
+ class LLMTutor:
18
+ def __init__(self, config, logger=None):
19
+ self.config = config
20
+ self.vector_db = VectorDB(config, logger=logger)
21
+ if self.config['embedding_options']['embedd_files']:
22
+ self.vector_db.create_database()
23
+ self.vector_db.save_database()
24
+
25
+ def set_custom_prompt(self):
26
+ """
27
+ Prompt template for QA retrieval for each vectorstore
28
+ """
29
+ if self.config["llm_params"]["use_history"]:
30
+ custom_prompt_template = prompt_template_with_history
31
+ else:
32
+ custom_prompt_template = prompt_template
33
+ prompt = PromptTemplate(
34
+ template=custom_prompt_template,
35
+ input_variables=["context", "chat_history", "question"],
36
+ )
37
+ # prompt = QA_PROMPT
38
+
39
+ return prompt
40
+
41
+ # Retrieval QA Chain
42
+ def retrieval_qa_chain(self, llm, prompt, db):
43
+ if self.config["llm_params"]["use_history"]:
44
+ memory = ConversationBufferMemory(
45
+ memory_key="chat_history", return_messages=True, output_key="answer"
46
+ )
47
+ qa_chain = ConversationalRetrievalChain.from_llm(
48
+ llm=llm,
49
+ chain_type="stuff",
50
+ retriever=db.as_retriever(search_kwargs={"k": 3}),
51
+ return_source_documents=True,
52
+ memory=memory,
53
+ combine_docs_chain_kwargs={"prompt": prompt},
54
+ )
55
+ else:
56
+ qa_chain = RetrievalQA.from_chain_type(
57
+ llm=llm,
58
+ chain_type="stuff",
59
+ retriever=db.as_retriever(search_kwargs={"k": 3}),
60
+ return_source_documents=True,
61
+ chain_type_kwargs={"prompt": prompt},
62
+ )
63
+ return qa_chain
64
+
65
+ # Loading the model
66
+ def load_llm(self):
67
+ chat_model_loader = ChatModelLoader(self.config)
68
+ llm = chat_model_loader.load_chat_model()
69
+ return llm
70
+
71
+ # QA Model Function
72
+ def qa_bot(self):
73
+ db = self.vector_db.load_database()
74
+ self.llm = self.load_llm()
75
+ qa_prompt = self.set_custom_prompt()
76
+ qa = self.retrieval_qa_chain(self.llm, qa_prompt, db)
77
+
78
+ return qa
79
+
80
+ # output function
81
+ def final_result(query):
82
+ qa_result = qa_bot()
83
+ response = qa_result({"query": query})
84
+ return response
code/modules/vector_db.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import yaml
4
+
5
+ from modules.embedding_model_loader import EmbeddingModelLoader
6
+ from langchain.vectorstores import FAISS
7
+ from modules.data_loader import DataLoader
8
+ from modules.constants import *
9
+
10
+
11
+ class VectorDB:
12
+ def __init__(self, config, logger=None):
13
+ self.config = config
14
+ self.db_option = config["embedding_options"]["db_option"]
15
+ self.document_names = None
16
+
17
+ # Set up logging to both console and a file
18
+ if logger is None:
19
+ self.logger = logging.getLogger(__name__)
20
+ self.logger.setLevel(logging.INFO)
21
+
22
+ # Console Handler
23
+ console_handler = logging.StreamHandler()
24
+ console_handler.setLevel(logging.INFO)
25
+ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
26
+ console_handler.setFormatter(formatter)
27
+ self.logger.addHandler(console_handler)
28
+
29
+ # File Handler
30
+ log_file_path = "vector_db.log" # Change this to your desired log file path
31
+ file_handler = logging.FileHandler(log_file_path, mode="w")
32
+ file_handler.setLevel(logging.INFO)
33
+ file_handler.setFormatter(formatter)
34
+ self.logger.addHandler(file_handler)
35
+ else:
36
+ self.logger = logger
37
+
38
+ self.logger.info("VectorDB instance instantiated")
39
+
40
+ def load_files(self):
41
+ files = os.listdir(self.config["embedding_options"]["data_path"])
42
+ files = [
43
+ os.path.join(self.config["embedding_options"]["data_path"], file)
44
+ for file in files
45
+ ]
46
+ return files
47
+
48
+ def create_embedding_model(self):
49
+ self.logger.info("Creating embedding function")
50
+ self.embedding_model_loader = EmbeddingModelLoader(self.config)
51
+ self.embedding_model = self.embedding_model_loader.load_embedding_model()
52
+
53
+ def initialize_database(self, document_chunks: list, document_names: list):
54
+ # Track token usage
55
+ self.logger.info("Initializing vector_db")
56
+ self.logger.info("\tUsing {} as db_option".format(self.db_option))
57
+ if self.db_option == "FAISS":
58
+ self.vector_db = FAISS.from_documents(
59
+ documents=document_chunks, embedding=self.embedding_model
60
+ )
61
+ self.logger.info("Completed initializing vector_db")
62
+
63
+ def create_database(self):
64
+ data_loader = DataLoader(self.config)
65
+ self.logger.info("Loading data")
66
+ files = self.load_files()
67
+ document_chunks, document_names = data_loader.get_chunks(files, [""])
68
+ self.logger.info("Completed loading data")
69
+
70
+ self.create_embedding_model()
71
+ self.initialize_database(document_chunks, document_names)
72
+
73
+ def save_database(self):
74
+ self.vector_db.save_local(
75
+ os.path.join(
76
+ self.config["embedding_options"]["db_path"],
77
+ "db_"
78
+ + self.config["embedding_options"]["db_option"]
79
+ + "_"
80
+ + self.config["embedding_options"]["model"],
81
+ )
82
+ )
83
+ self.logger.info("Saved database")
84
+
85
+ def load_database(self):
86
+ self.create_embedding_model()
87
+ self.vector_db = FAISS.load_local(
88
+ os.path.join(
89
+ self.config["embedding_options"]["db_path"],
90
+ "db_"
91
+ + self.config["embedding_options"]["db_option"]
92
+ + "_"
93
+ + self.config["embedding_options"]["model"],
94
+ ),
95
+ self.embedding_model,
96
+ )
97
+ self.logger.info("Loaded database")
98
+ return self.vector_db
99
+
100
+
101
+ if __name__ == "__main__":
102
+ with open("config.yml", "r") as f:
103
+ config = yaml.safe_load(f)
104
+ print(config)
105
+ vector_db = VectorDB(config)
106
+ vector_db.create_database()
107
+ vector_db.save_database()
code/vectorstores/db_FAISS_sentence-transformers/all-MiniLM-L6-v2/index.faiss ADDED
Binary file (6.19 kB). View file
 
code/vectorstores/db_FAISS_sentence-transformers/all-MiniLM-L6-v2/index.pkl ADDED
Binary file (9.21 kB). View file
 
code/vectorstores/db_FAISS_text-embedding-ada-002/index.faiss ADDED
Binary file (24.6 kB). View file
 
code/vectorstores/db_FAISS_text-embedding-ada-002/index.pkl ADDED
Binary file (9.21 kB). View file
 
data/webpage.pdf ADDED
Binary file (51.3 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.29.0
2
+ PyYAML==6.0.1
3
+ pysrt==1.1.2
4
+ langchain==0.0.353
5
+ tiktoken==0.5.2
6
+ streamlit-chat==0.1.1
7
+ pypdf==3.17.4
8
+ sentence-transformers==2.2.2
9
+ faiss-cpu==1.7.4
10
+ ctransformers==0.2.27
11
+ python-dotenv==1.0.0
12
+ openai==1.6.1
13
+ pymupdf==1.23.8