init commit
Browse files- .gitignore +3 -0
- code/.chainlit/config.toml +84 -0
- code/chainlit.md +8 -0
- code/config.yml +26 -0
- code/main.py +109 -0
- code/modules/__init__.py +0 -0
- code/modules/chat_model_loader.py +25 -0
- code/modules/constants.py +33 -0
- code/modules/data_loader.py +248 -0
- code/modules/embedding_model_loader.py +23 -0
- code/modules/llm_tutor.py +84 -0
- code/modules/vector_db.py +107 -0
- code/vectorstores/db_FAISS_sentence-transformers/all-MiniLM-L6-v2/index.faiss +0 -0
- code/vectorstores/db_FAISS_sentence-transformers/all-MiniLM-L6-v2/index.pkl +0 -0
- code/vectorstores/db_FAISS_text-embedding-ada-002/index.faiss +0 -0
- code/vectorstores/db_FAISS_text-embedding-ada-002/index.pkl +0 -0
- data/webpage.pdf +0 -0
- requirements.txt +13 -0
.gitignore
CHANGED
@@ -158,3 +158,6 @@ cython_debug/
|
|
158 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
#.idea/
|
|
|
|
|
|
|
|
158 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
#.idea/
|
161 |
+
|
162 |
+
# log files
|
163 |
+
*.log
|
code/.chainlit/config.toml
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
# Whether to enable telemetry (default: true). No personal data is collected.
|
3 |
+
enable_telemetry = true
|
4 |
+
|
5 |
+
# List of environment variables to be provided by each user to use the app.
|
6 |
+
user_env = []
|
7 |
+
|
8 |
+
# Duration (in seconds) during which the session is saved when the connection is lost
|
9 |
+
session_timeout = 3600
|
10 |
+
|
11 |
+
# Enable third parties caching (e.g LangChain cache)
|
12 |
+
cache = false
|
13 |
+
|
14 |
+
# Follow symlink for asset mount (see https://github.com/Chainlit/chainlit/issues/317)
|
15 |
+
# follow_symlink = false
|
16 |
+
|
17 |
+
[features]
|
18 |
+
# Show the prompt playground
|
19 |
+
prompt_playground = true
|
20 |
+
|
21 |
+
# Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript)
|
22 |
+
unsafe_allow_html = false
|
23 |
+
|
24 |
+
# Process and display mathematical expressions. This can clash with "$" characters in messages.
|
25 |
+
latex = false
|
26 |
+
|
27 |
+
# Authorize users to upload files with messages
|
28 |
+
multi_modal = true
|
29 |
+
|
30 |
+
# Allows user to use speech to text
|
31 |
+
[features.speech_to_text]
|
32 |
+
enabled = false
|
33 |
+
# See all languages here https://github.com/JamesBrill/react-speech-recognition/blob/HEAD/docs/API.md#language-string
|
34 |
+
# language = "en-US"
|
35 |
+
|
36 |
+
[UI]
|
37 |
+
# Name of the app and chatbot.
|
38 |
+
name = "LLM Tutor"
|
39 |
+
|
40 |
+
# Show the readme while the conversation is empty.
|
41 |
+
show_readme_as_default = true
|
42 |
+
|
43 |
+
# Description of the app and chatbot. This is used for HTML tags.
|
44 |
+
# description = ""
|
45 |
+
|
46 |
+
# Large size content are by default collapsed for a cleaner ui
|
47 |
+
default_collapse_content = true
|
48 |
+
|
49 |
+
# The default value for the expand messages settings.
|
50 |
+
default_expand_messages = false
|
51 |
+
|
52 |
+
# Hide the chain of thought details from the user in the UI.
|
53 |
+
hide_cot = false
|
54 |
+
|
55 |
+
# Link to your github repo. This will add a github button in the UI's header.
|
56 |
+
# github = "https://github.com/DL4DS/dl4ds_tutor"
|
57 |
+
|
58 |
+
# Specify a CSS file that can be used to customize the user interface.
|
59 |
+
# The CSS file can be served from the public directory or via an external link.
|
60 |
+
# custom_css = "/public/test.css"
|
61 |
+
|
62 |
+
# Override default MUI light theme. (Check theme.ts)
|
63 |
+
[UI.theme.light]
|
64 |
+
#background = "#FAFAFA"
|
65 |
+
#paper = "#FFFFFF"
|
66 |
+
|
67 |
+
[UI.theme.light.primary]
|
68 |
+
#main = "#F80061"
|
69 |
+
#dark = "#980039"
|
70 |
+
#light = "#FFE7EB"
|
71 |
+
|
72 |
+
# Override default MUI dark theme. (Check theme.ts)
|
73 |
+
[UI.theme.dark]
|
74 |
+
#background = "#FAFAFA"
|
75 |
+
#paper = "#FFFFFF"
|
76 |
+
|
77 |
+
[UI.theme.dark.primary]
|
78 |
+
#main = "#F80061"
|
79 |
+
#dark = "#980039"
|
80 |
+
#light = "#FFE7EB"
|
81 |
+
|
82 |
+
|
83 |
+
[meta]
|
84 |
+
generated_by = "0.7.700"
|
code/chainlit.md
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Welcome to DL4DS Tutor! 🚀🤖
|
2 |
+
|
3 |
+
Hi there, this is an LLM chatbot designed to help answer questions on the course content, built using Langchain and Chainlit.
|
4 |
+
This is still very much a Work in Progress.
|
5 |
+
|
6 |
+
## Useful Links 🔗
|
7 |
+
|
8 |
+
- **Documentation:** [Chainlit Documentation](https://docs.chainlit.io) 📚
|
code/config.yml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
embedding_options:
|
2 |
+
embedd_files: True # bool
|
3 |
+
persist_directory: null # str or None
|
4 |
+
data_path: '../data' # str
|
5 |
+
db_option : 'FAISS' # str
|
6 |
+
db_path : 'vectorstores' # str
|
7 |
+
model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
|
8 |
+
llm_params:
|
9 |
+
use_history: False # bool
|
10 |
+
llm_loader: 'openai' # str [ctransformers, openai]
|
11 |
+
openai_params:
|
12 |
+
model: 'gpt-4' # str [gpt-3.5-turbo-1106, gpt-4]
|
13 |
+
ctransformers_params:
|
14 |
+
model: "TheBloke/Llama-2-7B-Chat-GGML"
|
15 |
+
model_type: "llama"
|
16 |
+
splitter_options:
|
17 |
+
use_splitter: True # bool
|
18 |
+
split_by_token : True # bool
|
19 |
+
remove_leftover_delimiters: True # bool
|
20 |
+
remove_chunks: False # bool
|
21 |
+
chunk_size : 800 # int
|
22 |
+
chunk_overlap : 80 # int
|
23 |
+
chunk_separators : ["\n\n", "\n", " ", ""] # list of strings
|
24 |
+
front_chunks_to_remove : null # int or None
|
25 |
+
last_chunks_to_remove : null # int or None
|
26 |
+
delimiters_to_remove : ['\t', '\n', ' ', ' '] # list of strings
|
code/main.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.document_loaders import PyPDFLoader, DirectoryLoader
|
2 |
+
from langchain import PromptTemplate
|
3 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
4 |
+
from langchain.vectorstores import FAISS
|
5 |
+
from langchain.chains import RetrievalQA
|
6 |
+
from langchain.llms import CTransformers
|
7 |
+
import chainlit as cl
|
8 |
+
from langchain_community.chat_models import ChatOpenAI
|
9 |
+
from langchain_community.embeddings import OpenAIEmbeddings
|
10 |
+
import yaml
|
11 |
+
import logging
|
12 |
+
from dotenv import load_dotenv
|
13 |
+
|
14 |
+
from modules.llm_tutor import LLMTutor
|
15 |
+
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
logger.setLevel(logging.INFO)
|
19 |
+
|
20 |
+
# Console Handler
|
21 |
+
console_handler = logging.StreamHandler()
|
22 |
+
console_handler.setLevel(logging.INFO)
|
23 |
+
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
24 |
+
console_handler.setFormatter(formatter)
|
25 |
+
logger.addHandler(console_handler)
|
26 |
+
|
27 |
+
# File Handler
|
28 |
+
log_file_path = "log_file.log" # Change this to your desired log file path
|
29 |
+
file_handler = logging.FileHandler(log_file_path)
|
30 |
+
file_handler.setLevel(logging.INFO)
|
31 |
+
file_handler.setFormatter(formatter)
|
32 |
+
logger.addHandler(file_handler)
|
33 |
+
|
34 |
+
with open("config.yml", "r") as f:
|
35 |
+
config = yaml.safe_load(f)
|
36 |
+
print(config)
|
37 |
+
logger.info("Config file loaded")
|
38 |
+
logger.info(f"Config: {config}")
|
39 |
+
logger.info("Creating llm_tutor instance")
|
40 |
+
llm_tutor = LLMTutor(config, logger=logger)
|
41 |
+
|
42 |
+
|
43 |
+
# chainlit code
|
44 |
+
@cl.on_chat_start
|
45 |
+
async def start():
|
46 |
+
chain = llm_tutor.qa_bot()
|
47 |
+
msg = cl.Message(content="Starting the bot...")
|
48 |
+
await msg.send()
|
49 |
+
msg.content = "Hey, What Can I Help You With?"
|
50 |
+
await msg.update()
|
51 |
+
|
52 |
+
cl.user_session.set("chain", chain)
|
53 |
+
|
54 |
+
|
55 |
+
@cl.on_message
|
56 |
+
async def main(message):
|
57 |
+
chain = cl.user_session.get("chain")
|
58 |
+
cb = cl.AsyncLangchainCallbackHandler(
|
59 |
+
stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
|
60 |
+
)
|
61 |
+
cb.answer_reached = True
|
62 |
+
# res=await chain.acall(message, callbacks=[cb])
|
63 |
+
res = await chain.acall(message.content, callbacks=[cb])
|
64 |
+
# print(f"response: {res}")
|
65 |
+
try:
|
66 |
+
answer = res["answer"]
|
67 |
+
except:
|
68 |
+
answer = res["result"]
|
69 |
+
print(f"answer: {answer}")
|
70 |
+
source_elements_dict = {}
|
71 |
+
source_elements = []
|
72 |
+
found_sources = []
|
73 |
+
|
74 |
+
for idx, source in enumerate(res["source_documents"]):
|
75 |
+
title = source.metadata["source"]
|
76 |
+
|
77 |
+
if title not in source_elements_dict:
|
78 |
+
source_elements_dict[title] = {
|
79 |
+
"page_number": [source.metadata["page"]],
|
80 |
+
"url": source.metadata["source"],
|
81 |
+
"content": source.page_content,
|
82 |
+
}
|
83 |
+
|
84 |
+
else:
|
85 |
+
source_elements_dict[title]["page_number"].append(source.metadata["page"])
|
86 |
+
source_elements_dict[title][
|
87 |
+
"content_" + str(source.metadata["page"])
|
88 |
+
] = source.page_content
|
89 |
+
# sort the page numbers
|
90 |
+
# source_elements_dict[title]["page_number"].sort()
|
91 |
+
|
92 |
+
for title, source in source_elements_dict.items():
|
93 |
+
# create a string for the page numbers
|
94 |
+
page_numbers = ", ".join([str(x) for x in source["page_number"]])
|
95 |
+
text_for_source = f"Page Number(s): {page_numbers}\nURL: {source['url']}"
|
96 |
+
source_elements.append(cl.Pdf(name="File", path=title))
|
97 |
+
found_sources.append("File")
|
98 |
+
# for pn in source["page_number"]:
|
99 |
+
# source_elements.append(
|
100 |
+
# cl.Text(name=str(pn), content=source["content_"+str(pn)])
|
101 |
+
# )
|
102 |
+
# found_sources.append(str(pn))
|
103 |
+
|
104 |
+
if found_sources:
|
105 |
+
answer += f"\nSource:{', '.join(found_sources)}"
|
106 |
+
else:
|
107 |
+
answer += f"\nNo source found."
|
108 |
+
|
109 |
+
await cl.Message(content=answer, elements=source_elements).send()
|
code/modules/__init__.py
ADDED
File without changes
|
code/modules/chat_model_loader.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.chat_models import ChatOpenAI
|
2 |
+
from langchain.llms import CTransformers
|
3 |
+
|
4 |
+
|
5 |
+
class ChatModelLoader:
|
6 |
+
def __init__(self, config):
|
7 |
+
self.config = config
|
8 |
+
|
9 |
+
def load_chat_model(self):
|
10 |
+
if self.config["llm_params"]["llm_loader"] == "openai":
|
11 |
+
llm = ChatOpenAI(
|
12 |
+
model_name=self.config["llm_params"]["openai_params"]["model"]
|
13 |
+
)
|
14 |
+
elif self.config["llm_params"]["llm_loader"] == "Ctransformers":
|
15 |
+
llm = CTransformers(
|
16 |
+
model=self.config["llm_params"]["ctransformers_params"]["model"],
|
17 |
+
model_type=self.config["llm_params"]["ctransformers_params"][
|
18 |
+
"model_type"
|
19 |
+
],
|
20 |
+
max_new_tokens=512,
|
21 |
+
temperature=0.5,
|
22 |
+
)
|
23 |
+
else:
|
24 |
+
raise ValueError("Invalid LLM Loader")
|
25 |
+
return llm
|
code/modules/constants.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dotenv import load_dotenv
|
2 |
+
import os
|
3 |
+
|
4 |
+
load_dotenv()
|
5 |
+
|
6 |
+
# API Keys - Loaded from the .env file
|
7 |
+
|
8 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
9 |
+
|
10 |
+
|
11 |
+
# Prompt Templates
|
12 |
+
|
13 |
+
prompt_template = """Use the following pieces of information to answer the user's question.
|
14 |
+
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
15 |
+
|
16 |
+
Context: {context}
|
17 |
+
Question: {question}
|
18 |
+
|
19 |
+
Only return the helpful answer below and nothing else.
|
20 |
+
Helpful answer:
|
21 |
+
"""
|
22 |
+
|
23 |
+
prompt_template_with_history = """Use the following pieces of information to answer the user's question.
|
24 |
+
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
25 |
+
Use the history to answer the question if you can.
|
26 |
+
Chat History:
|
27 |
+
{chat_history}
|
28 |
+
Context: {context}
|
29 |
+
Question: {question}
|
30 |
+
|
31 |
+
Only return the helpful answer below and nothing else.
|
32 |
+
Helpful answer:
|
33 |
+
"""
|
code/modules/data_loader.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import pysrt
|
3 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
4 |
+
from langchain.document_loaders import (
|
5 |
+
PyMuPDFLoader,
|
6 |
+
Docx2txtLoader,
|
7 |
+
YoutubeLoader,
|
8 |
+
WebBaseLoader,
|
9 |
+
TextLoader,
|
10 |
+
)
|
11 |
+
from langchain.schema import Document
|
12 |
+
from tempfile import NamedTemporaryFile
|
13 |
+
import logging
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
class DataLoader:
|
19 |
+
def __init__(self, config):
|
20 |
+
"""
|
21 |
+
Class for handling all data extraction and chunking
|
22 |
+
Inputs:
|
23 |
+
config - dictionary from yaml file, containing all important parameters
|
24 |
+
"""
|
25 |
+
self.config = config
|
26 |
+
self.remove_leftover_delimiters = config["splitter_options"][
|
27 |
+
"remove_leftover_delimiters"
|
28 |
+
]
|
29 |
+
|
30 |
+
# Main list of all documents
|
31 |
+
self.document_chunks_full = []
|
32 |
+
self.document_names = []
|
33 |
+
|
34 |
+
if config["splitter_options"]["use_splitter"]:
|
35 |
+
if config["splitter_options"]["split_by_token"]:
|
36 |
+
self.splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
37 |
+
chunk_size=config["splitter_options"]["chunk_size"],
|
38 |
+
chunk_overlap=config["splitter_options"]["chunk_overlap"],
|
39 |
+
separators=config["splitter_options"]["chunk_separators"],
|
40 |
+
)
|
41 |
+
else:
|
42 |
+
self.splitter = RecursiveCharacterTextSplitter(
|
43 |
+
chunk_size=config["splitter_options"]["chunk_size"],
|
44 |
+
chunk_overlap=config["splitter_options"]["chunk_overlap"],
|
45 |
+
separators=config["splitter_options"]["chunk_separators"],
|
46 |
+
)
|
47 |
+
else:
|
48 |
+
self.splitter = None
|
49 |
+
logger.info("InfoLoader instance created")
|
50 |
+
|
51 |
+
def get_chunks(self, uploaded_files, weblinks):
|
52 |
+
# Main list of all documents
|
53 |
+
self.document_chunks_full = []
|
54 |
+
self.document_names = []
|
55 |
+
|
56 |
+
def remove_delimiters(document_chunks: list):
|
57 |
+
"""
|
58 |
+
Helper function to remove remaining delimiters in document chunks
|
59 |
+
"""
|
60 |
+
for chunk in document_chunks:
|
61 |
+
for delimiter in self.config["splitter_options"][
|
62 |
+
"delimiters_to_remove"
|
63 |
+
]:
|
64 |
+
chunk.page_content = re.sub(delimiter, " ", chunk.page_content)
|
65 |
+
return document_chunks
|
66 |
+
|
67 |
+
def remove_chunks(document_chunks: list):
|
68 |
+
"""
|
69 |
+
Helper function to remove any unwanted document chunks after splitting
|
70 |
+
"""
|
71 |
+
front = self.config["splitter_options"]["front_chunk_to_remove"]
|
72 |
+
end = self.config["splitter_options"]["last_chunks_to_remove"]
|
73 |
+
# Remove pages
|
74 |
+
for _ in range(front):
|
75 |
+
del document_chunks[0]
|
76 |
+
for _ in range(end):
|
77 |
+
document_chunks.pop()
|
78 |
+
logger.info(f"\tNumber of pages after skipping: {len(document_chunks)}")
|
79 |
+
return document_chunks
|
80 |
+
|
81 |
+
def get_pdf(temp_file_path: str, title: str):
|
82 |
+
"""
|
83 |
+
Function to process PDF files
|
84 |
+
"""
|
85 |
+
loader = PyMuPDFLoader(
|
86 |
+
temp_file_path
|
87 |
+
) # This loader preserves more metadata
|
88 |
+
|
89 |
+
if self.splitter:
|
90 |
+
document_chunks = self.splitter.split_documents(loader.load())
|
91 |
+
else:
|
92 |
+
document_chunks = loader.load()
|
93 |
+
|
94 |
+
if "title" in document_chunks[0].metadata.keys():
|
95 |
+
title = document_chunks[0].metadata["title"]
|
96 |
+
|
97 |
+
logger.info(
|
98 |
+
f"\t\tOriginal no. of pages: {document_chunks[0].metadata['total_pages']}"
|
99 |
+
)
|
100 |
+
|
101 |
+
return title, document_chunks
|
102 |
+
|
103 |
+
def get_txt(temp_file_path: str, title: str):
|
104 |
+
"""
|
105 |
+
Function to process TXT files
|
106 |
+
"""
|
107 |
+
loader = TextLoader(temp_file_path, autodetect_encoding=True)
|
108 |
+
|
109 |
+
if self.splitter:
|
110 |
+
document_chunks = self.splitter.split_documents(loader.load())
|
111 |
+
else:
|
112 |
+
document_chunks = loader.load()
|
113 |
+
|
114 |
+
# Update the metadata
|
115 |
+
for chunk in document_chunks:
|
116 |
+
chunk.metadata["source"] = title
|
117 |
+
chunk.metadata["page"] = "N/A"
|
118 |
+
|
119 |
+
return title, document_chunks
|
120 |
+
|
121 |
+
def get_srt(temp_file_path: str, title: str):
|
122 |
+
"""
|
123 |
+
Function to process SRT files
|
124 |
+
"""
|
125 |
+
subs = pysrt.open(temp_file_path)
|
126 |
+
|
127 |
+
text = ""
|
128 |
+
for sub in subs:
|
129 |
+
text += sub.text
|
130 |
+
document_chunks = [Document(page_content=text)]
|
131 |
+
|
132 |
+
if self.splitter:
|
133 |
+
document_chunks = self.splitter.split_documents(document_chunks)
|
134 |
+
|
135 |
+
# Update the metadata
|
136 |
+
for chunk in document_chunks:
|
137 |
+
chunk.metadata["source"] = title
|
138 |
+
chunk.metadata["page"] = "N/A"
|
139 |
+
|
140 |
+
return title, document_chunks
|
141 |
+
|
142 |
+
def get_docx(temp_file_path: str, title: str):
|
143 |
+
"""
|
144 |
+
Function to process DOCX files
|
145 |
+
"""
|
146 |
+
loader = Docx2txtLoader(temp_file_path)
|
147 |
+
|
148 |
+
if self.splitter:
|
149 |
+
document_chunks = self.splitter.split_documents(loader.load())
|
150 |
+
else:
|
151 |
+
document_chunks = loader.load()
|
152 |
+
|
153 |
+
# Update the metadata
|
154 |
+
for chunk in document_chunks:
|
155 |
+
chunk.metadata["source"] = title
|
156 |
+
chunk.metadata["page"] = "N/A"
|
157 |
+
|
158 |
+
return title, document_chunks
|
159 |
+
|
160 |
+
def get_youtube_transcript(url: str):
|
161 |
+
"""
|
162 |
+
Function to retrieve youtube transcript and process text
|
163 |
+
"""
|
164 |
+
loader = YoutubeLoader.from_youtube_url(
|
165 |
+
url, add_video_info=True, language=["en"], translation="en"
|
166 |
+
)
|
167 |
+
|
168 |
+
if self.splitter:
|
169 |
+
document_chunks = self.splitter.split_documents(loader.load())
|
170 |
+
else:
|
171 |
+
document_chunks = loader.load_and_split()
|
172 |
+
|
173 |
+
# Replace the source with title (for display in st UI later)
|
174 |
+
for chunk in document_chunks:
|
175 |
+
chunk.metadata["source"] = chunk.metadata["title"]
|
176 |
+
logger.info(chunk.metadata["title"])
|
177 |
+
|
178 |
+
return title, document_chunks
|
179 |
+
|
180 |
+
def get_html(url: str):
|
181 |
+
"""
|
182 |
+
Function to process websites via HTML files
|
183 |
+
"""
|
184 |
+
loader = WebBaseLoader(url)
|
185 |
+
|
186 |
+
if self.splitter:
|
187 |
+
document_chunks = self.splitter.split_documents(loader.load())
|
188 |
+
else:
|
189 |
+
document_chunks = loader.load_and_split()
|
190 |
+
|
191 |
+
title = document_chunks[0].metadata["title"]
|
192 |
+
logger.info(document_chunks[0].metadata)
|
193 |
+
|
194 |
+
return title, document_chunks
|
195 |
+
|
196 |
+
# Handle file by file
|
197 |
+
for file_index, file_path in enumerate(uploaded_files):
|
198 |
+
|
199 |
+
file_name = file_path.split("/")[-1]
|
200 |
+
file_type = file_name.split(".")[-1]
|
201 |
+
|
202 |
+
# Handle different file types
|
203 |
+
if file_type == "pdf":
|
204 |
+
title, document_chunks = get_pdf(file_path, file_name)
|
205 |
+
elif file_type == "txt":
|
206 |
+
title, document_chunks = get_txt(file_path, file_name)
|
207 |
+
elif file_type == "docx":
|
208 |
+
title, document_chunks = get_docx(file_path, file_name)
|
209 |
+
elif file_type == "srt":
|
210 |
+
title, document_chunks = get_srt(file_path, file_name)
|
211 |
+
|
212 |
+
# Additional wrangling - Remove leftover delimiters and any specified chunks
|
213 |
+
if self.remove_leftover_delimiters:
|
214 |
+
document_chunks = remove_delimiters(document_chunks)
|
215 |
+
if self.config["splitter_options"]["remove_chunks"]:
|
216 |
+
document_chunks = remove_chunks(document_chunks)
|
217 |
+
|
218 |
+
logger.info(f"\t\tExtracted no. of chunks: {len(document_chunks)}")
|
219 |
+
self.document_names.append(title)
|
220 |
+
self.document_chunks_full.extend(document_chunks)
|
221 |
+
|
222 |
+
# Handle youtube links:
|
223 |
+
if weblinks[0] != "":
|
224 |
+
logger.info(f"Splitting weblinks: total of {len(weblinks)}")
|
225 |
+
|
226 |
+
# Handle link by link
|
227 |
+
for link_index, link in enumerate(weblinks):
|
228 |
+
logger.info(f"\tSplitting link {link_index+1} : {link}")
|
229 |
+
if "youtube" in link:
|
230 |
+
title, document_chunks = get_youtube_transcript(link)
|
231 |
+
else:
|
232 |
+
title, document_chunks = get_html(link)
|
233 |
+
|
234 |
+
# Additional wrangling - Remove leftover delimiters and any specified chunks
|
235 |
+
if self.remove_leftover_delimiters:
|
236 |
+
document_chunks = remove_delimiters(document_chunks)
|
237 |
+
if self.config["splitter_options"]["remove_chunks"]:
|
238 |
+
document_chunks = remove_chunks(document_chunks)
|
239 |
+
|
240 |
+
print(f"\t\tExtracted no. of chunks: {len(document_chunks)}")
|
241 |
+
self.document_names.append(title)
|
242 |
+
self.document_chunks_full.extend(document_chunks)
|
243 |
+
|
244 |
+
logger.info(
|
245 |
+
f"\tNumber of document chunks extracted in total: {len(self.document_chunks_full)}\n\n"
|
246 |
+
)
|
247 |
+
|
248 |
+
return self.document_chunks_full, self.document_names
|
code/modules/embedding_model_loader.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.embeddings import OpenAIEmbeddings
|
2 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
3 |
+
from modules.constants import *
|
4 |
+
|
5 |
+
|
6 |
+
class EmbeddingModelLoader:
|
7 |
+
def __init__(self, config):
|
8 |
+
self.config = config
|
9 |
+
|
10 |
+
def load_embedding_model(self):
|
11 |
+
if self.config["embedding_options"]["model"] in ["text-embedding-ada-002"]:
|
12 |
+
embedding_model = OpenAIEmbeddings(
|
13 |
+
deployment="SL-document_embedder",
|
14 |
+
model=self.config["embedding_options"]["model"],
|
15 |
+
show_progress_bar=True,
|
16 |
+
openai_api_key=OPENAI_API_KEY,
|
17 |
+
)
|
18 |
+
else:
|
19 |
+
embedding_model = HuggingFaceEmbeddings(
|
20 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
21 |
+
model_kwargs={"device": "cpu"},
|
22 |
+
)
|
23 |
+
return embedding_model
|
code/modules/llm_tutor.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain import PromptTemplate
|
2 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
3 |
+
from langchain_community.chat_models import ChatOpenAI
|
4 |
+
from langchain_community.embeddings import OpenAIEmbeddings
|
5 |
+
from langchain.vectorstores import FAISS
|
6 |
+
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
|
7 |
+
from langchain.llms import CTransformers
|
8 |
+
from langchain.memory import ConversationBufferMemory
|
9 |
+
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
|
10 |
+
import os
|
11 |
+
|
12 |
+
from modules.constants import *
|
13 |
+
from modules.chat_model_loader import ChatModelLoader
|
14 |
+
from modules.vector_db import VectorDB
|
15 |
+
|
16 |
+
|
17 |
+
class LLMTutor:
|
18 |
+
def __init__(self, config, logger=None):
|
19 |
+
self.config = config
|
20 |
+
self.vector_db = VectorDB(config, logger=logger)
|
21 |
+
if self.config['embedding_options']['embedd_files']:
|
22 |
+
self.vector_db.create_database()
|
23 |
+
self.vector_db.save_database()
|
24 |
+
|
25 |
+
def set_custom_prompt(self):
|
26 |
+
"""
|
27 |
+
Prompt template for QA retrieval for each vectorstore
|
28 |
+
"""
|
29 |
+
if self.config["llm_params"]["use_history"]:
|
30 |
+
custom_prompt_template = prompt_template_with_history
|
31 |
+
else:
|
32 |
+
custom_prompt_template = prompt_template
|
33 |
+
prompt = PromptTemplate(
|
34 |
+
template=custom_prompt_template,
|
35 |
+
input_variables=["context", "chat_history", "question"],
|
36 |
+
)
|
37 |
+
# prompt = QA_PROMPT
|
38 |
+
|
39 |
+
return prompt
|
40 |
+
|
41 |
+
# Retrieval QA Chain
|
42 |
+
def retrieval_qa_chain(self, llm, prompt, db):
|
43 |
+
if self.config["llm_params"]["use_history"]:
|
44 |
+
memory = ConversationBufferMemory(
|
45 |
+
memory_key="chat_history", return_messages=True, output_key="answer"
|
46 |
+
)
|
47 |
+
qa_chain = ConversationalRetrievalChain.from_llm(
|
48 |
+
llm=llm,
|
49 |
+
chain_type="stuff",
|
50 |
+
retriever=db.as_retriever(search_kwargs={"k": 3}),
|
51 |
+
return_source_documents=True,
|
52 |
+
memory=memory,
|
53 |
+
combine_docs_chain_kwargs={"prompt": prompt},
|
54 |
+
)
|
55 |
+
else:
|
56 |
+
qa_chain = RetrievalQA.from_chain_type(
|
57 |
+
llm=llm,
|
58 |
+
chain_type="stuff",
|
59 |
+
retriever=db.as_retriever(search_kwargs={"k": 3}),
|
60 |
+
return_source_documents=True,
|
61 |
+
chain_type_kwargs={"prompt": prompt},
|
62 |
+
)
|
63 |
+
return qa_chain
|
64 |
+
|
65 |
+
# Loading the model
|
66 |
+
def load_llm(self):
|
67 |
+
chat_model_loader = ChatModelLoader(self.config)
|
68 |
+
llm = chat_model_loader.load_chat_model()
|
69 |
+
return llm
|
70 |
+
|
71 |
+
# QA Model Function
|
72 |
+
def qa_bot(self):
|
73 |
+
db = self.vector_db.load_database()
|
74 |
+
self.llm = self.load_llm()
|
75 |
+
qa_prompt = self.set_custom_prompt()
|
76 |
+
qa = self.retrieval_qa_chain(self.llm, qa_prompt, db)
|
77 |
+
|
78 |
+
return qa
|
79 |
+
|
80 |
+
# output function
|
81 |
+
def final_result(query):
|
82 |
+
qa_result = qa_bot()
|
83 |
+
response = qa_result({"query": query})
|
84 |
+
return response
|
code/modules/vector_db.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import yaml
|
4 |
+
|
5 |
+
from modules.embedding_model_loader import EmbeddingModelLoader
|
6 |
+
from langchain.vectorstores import FAISS
|
7 |
+
from modules.data_loader import DataLoader
|
8 |
+
from modules.constants import *
|
9 |
+
|
10 |
+
|
11 |
+
class VectorDB:
|
12 |
+
def __init__(self, config, logger=None):
|
13 |
+
self.config = config
|
14 |
+
self.db_option = config["embedding_options"]["db_option"]
|
15 |
+
self.document_names = None
|
16 |
+
|
17 |
+
# Set up logging to both console and a file
|
18 |
+
if logger is None:
|
19 |
+
self.logger = logging.getLogger(__name__)
|
20 |
+
self.logger.setLevel(logging.INFO)
|
21 |
+
|
22 |
+
# Console Handler
|
23 |
+
console_handler = logging.StreamHandler()
|
24 |
+
console_handler.setLevel(logging.INFO)
|
25 |
+
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
26 |
+
console_handler.setFormatter(formatter)
|
27 |
+
self.logger.addHandler(console_handler)
|
28 |
+
|
29 |
+
# File Handler
|
30 |
+
log_file_path = "vector_db.log" # Change this to your desired log file path
|
31 |
+
file_handler = logging.FileHandler(log_file_path, mode="w")
|
32 |
+
file_handler.setLevel(logging.INFO)
|
33 |
+
file_handler.setFormatter(formatter)
|
34 |
+
self.logger.addHandler(file_handler)
|
35 |
+
else:
|
36 |
+
self.logger = logger
|
37 |
+
|
38 |
+
self.logger.info("VectorDB instance instantiated")
|
39 |
+
|
40 |
+
def load_files(self):
|
41 |
+
files = os.listdir(self.config["embedding_options"]["data_path"])
|
42 |
+
files = [
|
43 |
+
os.path.join(self.config["embedding_options"]["data_path"], file)
|
44 |
+
for file in files
|
45 |
+
]
|
46 |
+
return files
|
47 |
+
|
48 |
+
def create_embedding_model(self):
|
49 |
+
self.logger.info("Creating embedding function")
|
50 |
+
self.embedding_model_loader = EmbeddingModelLoader(self.config)
|
51 |
+
self.embedding_model = self.embedding_model_loader.load_embedding_model()
|
52 |
+
|
53 |
+
def initialize_database(self, document_chunks: list, document_names: list):
|
54 |
+
# Track token usage
|
55 |
+
self.logger.info("Initializing vector_db")
|
56 |
+
self.logger.info("\tUsing {} as db_option".format(self.db_option))
|
57 |
+
if self.db_option == "FAISS":
|
58 |
+
self.vector_db = FAISS.from_documents(
|
59 |
+
documents=document_chunks, embedding=self.embedding_model
|
60 |
+
)
|
61 |
+
self.logger.info("Completed initializing vector_db")
|
62 |
+
|
63 |
+
def create_database(self):
|
64 |
+
data_loader = DataLoader(self.config)
|
65 |
+
self.logger.info("Loading data")
|
66 |
+
files = self.load_files()
|
67 |
+
document_chunks, document_names = data_loader.get_chunks(files, [""])
|
68 |
+
self.logger.info("Completed loading data")
|
69 |
+
|
70 |
+
self.create_embedding_model()
|
71 |
+
self.initialize_database(document_chunks, document_names)
|
72 |
+
|
73 |
+
def save_database(self):
|
74 |
+
self.vector_db.save_local(
|
75 |
+
os.path.join(
|
76 |
+
self.config["embedding_options"]["db_path"],
|
77 |
+
"db_"
|
78 |
+
+ self.config["embedding_options"]["db_option"]
|
79 |
+
+ "_"
|
80 |
+
+ self.config["embedding_options"]["model"],
|
81 |
+
)
|
82 |
+
)
|
83 |
+
self.logger.info("Saved database")
|
84 |
+
|
85 |
+
def load_database(self):
|
86 |
+
self.create_embedding_model()
|
87 |
+
self.vector_db = FAISS.load_local(
|
88 |
+
os.path.join(
|
89 |
+
self.config["embedding_options"]["db_path"],
|
90 |
+
"db_"
|
91 |
+
+ self.config["embedding_options"]["db_option"]
|
92 |
+
+ "_"
|
93 |
+
+ self.config["embedding_options"]["model"],
|
94 |
+
),
|
95 |
+
self.embedding_model,
|
96 |
+
)
|
97 |
+
self.logger.info("Loaded database")
|
98 |
+
return self.vector_db
|
99 |
+
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
with open("config.yml", "r") as f:
|
103 |
+
config = yaml.safe_load(f)
|
104 |
+
print(config)
|
105 |
+
vector_db = VectorDB(config)
|
106 |
+
vector_db.create_database()
|
107 |
+
vector_db.save_database()
|
code/vectorstores/db_FAISS_sentence-transformers/all-MiniLM-L6-v2/index.faiss
ADDED
Binary file (6.19 kB). View file
|
|
code/vectorstores/db_FAISS_sentence-transformers/all-MiniLM-L6-v2/index.pkl
ADDED
Binary file (9.21 kB). View file
|
|
code/vectorstores/db_FAISS_text-embedding-ada-002/index.faiss
ADDED
Binary file (24.6 kB). View file
|
|
code/vectorstores/db_FAISS_text-embedding-ada-002/index.pkl
ADDED
Binary file (9.21 kB). View file
|
|
data/webpage.pdf
ADDED
Binary file (51.3 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit==1.29.0
|
2 |
+
PyYAML==6.0.1
|
3 |
+
pysrt==1.1.2
|
4 |
+
langchain==0.0.353
|
5 |
+
tiktoken==0.5.2
|
6 |
+
streamlit-chat==0.1.1
|
7 |
+
pypdf==3.17.4
|
8 |
+
sentence-transformers==2.2.2
|
9 |
+
faiss-cpu==1.7.4
|
10 |
+
ctransformers==0.2.27
|
11 |
+
python-dotenv==1.0.0
|
12 |
+
openai==1.6.1
|
13 |
+
pymupdf==1.23.8
|