XThomasBU commited on
Commit
0f566b9
·
2 Parent(s): d5cdfe3 6e49b76

Merge pull request #28 from DL4DS/dev_branch

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .github/workflows/push_to_hf_space_prototype.yml +20 -0
  2. .gitignore +9 -1
  3. Dockerfile +9 -7
  4. Dockerfile.dev +31 -0
  5. README.md +71 -22
  6. {.chainlit → code/.chainlit}/config.toml +64 -33
  7. code/__init__.py +1 -0
  8. chainlit.md → code/chainlit.md +0 -0
  9. code/main.py +91 -39
  10. code/modules/chat/__init__.py +0 -0
  11. code/modules/{chat_model_loader.py → chat/chat_model_loader.py} +2 -3
  12. code/modules/chat/helpers.py +104 -0
  13. code/modules/chat/llm_tutor.py +211 -0
  14. code/modules/chat_processor/__init__.py +0 -0
  15. code/modules/chat_processor/base.py +12 -0
  16. code/modules/chat_processor/chat_processor.py +30 -0
  17. code/modules/chat_processor/literal_ai.py +37 -0
  18. code/modules/config/__init__.py +0 -0
  19. code/{config.yml → modules/config/config.yml} +31 -12
  20. code/modules/{constants.py → config/constants.py} +4 -1
  21. code/modules/data_loader.py +0 -287
  22. code/modules/dataloader/__init__.py +0 -0
  23. code/modules/dataloader/data_loader.py +360 -0
  24. code/modules/dataloader/helpers.py +108 -0
  25. code/modules/dataloader/webpage_crawler.py +115 -0
  26. code/modules/helpers.py +0 -200
  27. code/modules/llm_tutor.py +0 -87
  28. code/modules/retriever/__init__.py +0 -0
  29. code/modules/retriever/base.py +12 -0
  30. code/modules/retriever/chroma_retriever.py +24 -0
  31. code/modules/retriever/colbert_retriever.py +10 -0
  32. code/modules/retriever/faiss_retriever.py +23 -0
  33. code/modules/retriever/helpers.py +39 -0
  34. code/modules/retriever/raptor_retriever.py +16 -0
  35. code/modules/retriever/retriever.py +26 -0
  36. code/modules/vector_db.py +0 -133
  37. code/modules/vectorstore/__init__.py +0 -0
  38. code/modules/vectorstore/base.py +33 -0
  39. code/modules/vectorstore/chroma.py +41 -0
  40. code/modules/vectorstore/colbert.py +39 -0
  41. code/modules/{embedding_model_loader.py → vectorstore/embedding_model_loader.py} +13 -10
  42. code/modules/vectorstore/faiss.py +45 -0
  43. code/modules/vectorstore/helpers.py +0 -0
  44. code/modules/vectorstore/raptor.py +438 -0
  45. code/modules/vectorstore/store_manager.py +163 -0
  46. code/modules/vectorstore/vectorstore.py +57 -0
  47. code/public/acastusphoton-svgrepo-com.svg +2 -0
  48. code/public/adv-screen-recorder-svgrepo-com.svg +2 -0
  49. code/public/alarmy-svgrepo-com.svg +2 -0
  50. code/public/avatars/ai-tutor.png +0 -0
.github/workflows/push_to_hf_space_prototype.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Push Prototype to HuggingFace
2
+
3
+ on:
4
+ pull_request:
5
+ branches:
6
+ - dev_branch
7
+
8
+
9
+ jobs:
10
+ build:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - name: Deploy Prototype to HuggingFace
14
+ uses: nateraw/[email protected]
15
+ with:
16
+ github_repo_id: DL4DS/dl4ds_tutor
17
+ huggingface_repo_id: dl4ds/tutor_dev
18
+ repo_type: space
19
+ space_sdk: static
20
+ hf_token: ${{ secrets.HF_TOKEN }}
.gitignore CHANGED
@@ -160,4 +160,12 @@ cython_debug/
160
  #.idea/
161
 
162
  # log files
163
- *.log
 
 
 
 
 
 
 
 
 
160
  #.idea/
161
 
162
  # log files
163
+ *.log
164
+
165
+ .ragatouille/*
166
+ */__pycache__/*
167
+ .chainlit/translations/
168
+ storage/logs/*
169
+ vectorstores/*
170
+
171
+ */.files/*
Dockerfile CHANGED
@@ -1,18 +1,17 @@
1
- FROM python:3.9
2
 
3
  WORKDIR /code
4
 
5
  COPY ./requirements.txt /code/requirements.txt
6
 
7
- RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
 
9
- RUN pip install --no-cache-dir transformers==4.36.2 torch==2.1.2
10
-
11
- RUN pip install --upgrade --force-reinstall --no-cache-dir llama-cpp-python==0.2.32
12
 
13
  COPY . /code
14
 
15
- RUN ls -R
 
16
 
17
  # Change permissions to allow writing to the directory
18
  RUN chmod -R 777 /code
@@ -23,7 +22,10 @@ RUN mkdir /code/logs && chmod 777 /code/logs
23
  # Create a cache directory within the application's working directory
24
  RUN mkdir /.cache && chmod -R 777 /.cache
25
 
 
 
26
  RUN --mount=type=secret,id=HUGGINGFACEHUB_API_TOKEN,mode=0444,required=true
27
  RUN --mount=type=secret,id=OPENAI_API_KEY,mode=0444,required=true
28
 
29
- CMD python code/modules/vector_db.py && chainlit run code/main.py --host 0.0.0.0 --port 7860
 
 
1
+ FROM python:3.11
2
 
3
  WORKDIR /code
4
 
5
  COPY ./requirements.txt /code/requirements.txt
6
 
7
+ RUN pip install --upgrade pip
8
 
9
+ RUN pip install --no-cache-dir -r /code/requirements.txt
 
 
10
 
11
  COPY . /code
12
 
13
+ # List the contents of the /code directory to verify files are copied correctly
14
+ RUN ls -R /code
15
 
16
  # Change permissions to allow writing to the directory
17
  RUN chmod -R 777 /code
 
22
  # Create a cache directory within the application's working directory
23
  RUN mkdir /.cache && chmod -R 777 /.cache
24
 
25
+ WORKDIR /code/code
26
+
27
  RUN --mount=type=secret,id=HUGGINGFACEHUB_API_TOKEN,mode=0444,required=true
28
  RUN --mount=type=secret,id=OPENAI_API_KEY,mode=0444,required=true
29
 
30
+ # Default command to run the application
31
+ CMD ["sh", "-c", "python -m modules.vectorstore.store_manager && chainlit run main.py --host 0.0.0.0 --port 7860"]
Dockerfile.dev ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+
7
+ RUN pip install --upgrade pip
8
+
9
+ RUN pip install --no-cache-dir -r /code/requirements.txt
10
+
11
+ COPY . /code
12
+
13
+ # List the contents of the /code directory to verify files are copied correctly
14
+ RUN ls -R /code
15
+
16
+ # Change permissions to allow writing to the directory
17
+ RUN chmod -R 777 /code
18
+
19
+ # Create a logs directory and set permissions
20
+ RUN mkdir /code/logs && chmod 777 /code/logs
21
+
22
+ # Create a cache directory within the application's working directory
23
+ RUN mkdir /.cache && chmod -R 777 /.cache
24
+
25
+ WORKDIR /code/code
26
+
27
+ # Expose the port the app runs on
28
+ EXPOSE 8051
29
+
30
+ # Default command to run the application
31
+ CMD ["sh", "-c", "python -m modules.vectorstore.store_manager && chainlit run main.py --host 0.0.0.0 --port 8051"]
README.md CHANGED
@@ -1,35 +1,84 @@
1
- ---
2
- title: Dl4ds Tutor
3
- emoji: 🏃
4
- colorFrom: green
5
- colorTo: red
6
- sdk: docker
7
- pinned: false
8
- hf_oauth: true
9
- ---
10
 
11
- DL4DS Tutor - DS598
12
- ===========
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
 
16
- You can find an implementation of the Tutor at https://dl4ds-dl4ds-tutor.hf.space/, which is hosted on Hugging Face [here](https://huggingface.co/spaces/dl4ds/dl4ds_tutor)
17
 
18
- To run locally,
 
 
 
19
 
20
- Clone the repository from: https://github.com/DL4DS/dl4ds_tutor
 
 
21
 
22
- Put your data under the `storage/data` directory. Note: You can add urls in the urls.txt file, and other pdf files in the `storage/data` directory.
 
 
 
 
23
 
24
- To create the Vector Database, run the following command:
25
- ```python code/modules/vector_db.py```
26
- (Note: You would need to run the above when you add new data to the `storage/data` directory, or if the ``storage/data/urls.txt`` file is updated. Or you can set ``["embedding_options"]["embedd_files"]`` to True in the `code/config.yaml` file, which would embed files from the storage directory everytime you run the below chainlit command.)
 
 
 
 
27
 
28
- To run the chainlit app, run the following command:
29
- ```chainlit run code/main.py```
 
 
30
 
31
  See the [docs](https://github.com/DL4DS/dl4ds_tutor/tree/main/docs) for more information.
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  ## Contributing
34
 
35
- Please create an issue if you have any suggestions or improvements, and start working on it by creating a branch and by making a pull request to the main branch.
 
1
+ # DL4DS Tutor 🏃
 
 
 
 
 
 
 
 
2
 
3
+ Check out the configuration reference at [Hugging Face Spaces Config Reference](https://huggingface.co/docs/hub/spaces-config-reference).
 
4
 
5
+ You can find an implementation of the Tutor at [DL4DS Tutor on Hugging Face](https://dl4ds-dl4ds-tutor.hf.space/), which is hosted on Hugging Face [here](https://huggingface.co/spaces/dl4ds/dl4ds_tutor).
6
 
7
+ ## Running Locally
8
 
9
+ 1. **Clone the Repository**
10
+ ```bash
11
+ git clone https://github.com/DL4DS/dl4ds_tutor
12
+ ```
13
 
14
+ 2. **Put your data under the `storage/data` directory**
15
+ - Add URLs in the `urls.txt` file.
16
+ - Add other PDF files in the `storage/data` directory.
17
 
18
+ 3. **To test Data Loading (Optional)**
19
+ ```bash
20
+ cd code
21
+ python -m modules.dataloader.data_loader
22
+ ```
23
 
24
+ 4. **Create the Vector Database**
25
+ ```bash
26
+ cd code
27
+ python -m modules.vectorstore.store_manager
28
+ ```
29
+ - Note: You need to run the above command when you add new data to the `storage/data` directory, or if the `storage/data/urls.txt` file is updated.
30
+ - Alternatively, you can set `["vectorstore"]["embedd_files"]` to `True` in the `code/modules/config/config.yaml` file, which will embed files from the storage directory every time you run the below chainlit command.
31
 
32
+ 5. **Run the Chainlit App**
33
+ ```bash
34
+ chainlit run main.py
35
+ ```
36
 
37
  See the [docs](https://github.com/DL4DS/dl4ds_tutor/tree/main/docs) for more information.
38
 
39
+ ## File Structure
40
+
41
+ ```plaintext
42
+ code/
43
+ ├── modules
44
+ │ ├── chat # Contains the chatbot implementation
45
+ │ ├── chat_processor # Contains the implementation to process and log the conversations
46
+ │ ├── config # Contains the configuration files
47
+ │ ├── dataloader # Contains the implementation to load the data from the storage directory
48
+ │ ├── retriever # Contains the implementation to create the retriever
49
+ │ └── vectorstore # Contains the implementation to create the vector database
50
+ ├── public
51
+ │ ├── logo_dark.png # Dark theme logo
52
+ │ ├── logo_light.png # Light theme logo
53
+ │ └── test.css # Custom CSS file
54
+ └── main.py
55
+
56
+
57
+ docs/ # Contains the documentation to the codebase and methods used
58
+
59
+ storage/
60
+ ├── data # Store files and URLs here
61
+ ├── logs # Logs directory, includes logs on vector DB creation, tutor logs, and chunks logged in JSON files
62
+ └── models # Local LLMs are loaded from here
63
+
64
+ vectorstores/ # Stores the created vector databases
65
+
66
+ .env # This needs to be created, store the API keys here
67
+ ```
68
+ - `code/modules/vectorstore/vectorstore.py`: Instantiates the `VectorStore` class to create the vector database.
69
+ - `code/modules/vectorstore/store_manager.py`: Instantiates the `VectorStoreManager:` class to manage the vector database, and all associated methods.
70
+ - `code/modules/retriever/retriever.py`: Instantiates the `Retriever` class to create the retriever.
71
+
72
+
73
+ ## Docker
74
+
75
+ The HuggingFace Space is built using the `Dockerfile` in the repository. To run it locally, use the `Dockerfile.dev` file.
76
+
77
+ ```bash
78
+ docker build --tag dev -f Dockerfile.dev .
79
+ docker run -it --rm -p 8051:8051 dev
80
+ ```
81
+
82
  ## Contributing
83
 
84
+ Please create an issue if you have any suggestions or improvements, and start working on it by creating a branch and by making a pull request to the main branch.
{.chainlit → code/.chainlit}/config.toml RENAMED
@@ -2,6 +2,7 @@
2
  # Whether to enable telemetry (default: true). No personal data is collected.
3
  enable_telemetry = true
4
 
 
5
  # List of environment variables to be provided by each user to use the app.
6
  user_env = []
7
 
@@ -11,74 +12,104 @@ session_timeout = 3600
11
  # Enable third parties caching (e.g LangChain cache)
12
  cache = false
13
 
 
 
 
14
  # Follow symlink for asset mount (see https://github.com/Chainlit/chainlit/issues/317)
15
  # follow_symlink = false
16
 
17
  [features]
18
- # Show the prompt playground
19
- prompt_playground = true
20
-
21
  # Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript)
22
  unsafe_allow_html = false
23
 
24
  # Process and display mathematical expressions. This can clash with "$" characters in messages.
25
  latex = false
26
 
27
- # Authorize users to upload files with messages
28
- multi_modal = true
29
-
30
- # Allows user to use speech to text
31
- [features.speech_to_text]
32
- enabled = false
33
- # See all languages here https://github.com/JamesBrill/react-speech-recognition/blob/HEAD/docs/API.md#language-string
34
- # language = "en-US"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  [UI]
37
- # Name of the app and chatbot.
38
  name = "AI Tutor"
39
 
40
- # Show the readme while the conversation is empty.
41
- show_readme_as_default = true
42
-
43
- # Description of the app and chatbot. This is used for HTML tags.
44
  # description = ""
45
 
46
  # Large size content are by default collapsed for a cleaner ui
47
  default_collapse_content = true
48
 
49
- # The default value for the expand messages settings.
50
- default_expand_messages = false
51
-
52
  # Hide the chain of thought details from the user in the UI.
53
- hide_cot = false
54
 
55
  # Link to your github repo. This will add a github button in the UI's header.
56
- # github = ""
57
 
58
  # Specify a CSS file that can be used to customize the user interface.
59
  # The CSS file can be served from the public directory or via an external link.
60
- # custom_css = "/public/test.css"
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  # Override default MUI light theme. (Check theme.ts)
63
  [UI.theme.light]
64
- background = "#FFFFFF"
65
  paper = "#FFFFFF"
66
 
67
  [UI.theme.light.primary]
68
- main = "#000000"
69
- dark = "#000000"
70
- light = "#FFE7EB"
71
-
 
 
72
  # Override default MUI dark theme. (Check theme.ts)
73
  [UI.theme.dark]
74
- #background = "#FAFAFA"
75
- #paper = "#FFFFFF"
76
 
77
  [UI.theme.dark.primary]
78
- #main = "#F80061"
79
- #dark = "#980039"
80
- #light = "#FFE7EB"
81
 
82
 
83
  [meta]
84
- generated_by = "0.7.700"
 
2
  # Whether to enable telemetry (default: true). No personal data is collected.
3
  enable_telemetry = true
4
 
5
+
6
  # List of environment variables to be provided by each user to use the app.
7
  user_env = []
8
 
 
12
  # Enable third parties caching (e.g LangChain cache)
13
  cache = false
14
 
15
+ # Authorized origins
16
+ allow_origins = ["*"]
17
+
18
  # Follow symlink for asset mount (see https://github.com/Chainlit/chainlit/issues/317)
19
  # follow_symlink = false
20
 
21
  [features]
 
 
 
22
  # Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript)
23
  unsafe_allow_html = false
24
 
25
  # Process and display mathematical expressions. This can clash with "$" characters in messages.
26
  latex = false
27
 
28
+ # Automatically tag threads with the current chat profile (if a chat profile is used)
29
+ auto_tag_thread = true
30
+
31
+ # Authorize users to spontaneously upload files with messages
32
+ [features.spontaneous_file_upload]
33
+ enabled = true
34
+ accept = ["*/*"]
35
+ max_files = 20
36
+ max_size_mb = 500
37
+
38
+ [features.audio]
39
+ # Threshold for audio recording
40
+ min_decibels = -45
41
+ # Delay for the user to start speaking in MS
42
+ initial_silence_timeout = 3000
43
+ # Delay for the user to continue speaking in MS. If the user stops speaking for this duration, the recording will stop.
44
+ silence_timeout = 1500
45
+ # Above this duration (MS), the recording will forcefully stop.
46
+ max_duration = 15000
47
+ # Duration of the audio chunks in MS
48
+ chunk_duration = 1000
49
+ # Sample rate of the audio
50
+ sample_rate = 44100
51
 
52
  [UI]
53
+ # Name of the assistant.
54
  name = "AI Tutor"
55
 
56
+ # Description of the assistant. This is used for HTML tags.
 
 
 
57
  # description = ""
58
 
59
  # Large size content are by default collapsed for a cleaner ui
60
  default_collapse_content = true
61
 
 
 
 
62
  # Hide the chain of thought details from the user in the UI.
63
+ hide_cot = true
64
 
65
  # Link to your github repo. This will add a github button in the UI's header.
66
+ # github = "https://github.com/DL4DS/dl4ds_tutor"
67
 
68
  # Specify a CSS file that can be used to customize the user interface.
69
  # The CSS file can be served from the public directory or via an external link.
70
+ custom_css = "/public/test.css"
71
 
72
+ # Specify a Javascript file that can be used to customize the user interface.
73
+ # The Javascript file can be served from the public directory.
74
+ # custom_js = "/public/test.js"
75
+
76
+ # Specify a custom font url.
77
+ # custom_font = "https://fonts.googleapis.com/css2?family=Inter:wght@400;500;700&display=swap"
78
+
79
+ # Specify a custom meta image url.
80
+ custom_meta_image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/f/f5/Boston_University_seal.svg/1200px-Boston_University_seal.svg.png"
81
+
82
+ # Specify a custom build directory for the frontend.
83
+ # This can be used to customize the frontend code.
84
+ # Be careful: If this is a relative path, it should not start with a slash.
85
+ # custom_build = "./public/build"
86
+
87
+ [UI.theme]
88
+ default = "light"
89
+ #layout = "wide"
90
+ #font_family = "Inter, sans-serif"
91
  # Override default MUI light theme. (Check theme.ts)
92
  [UI.theme.light]
93
+ background = "#FAFAFA"
94
  paper = "#FFFFFF"
95
 
96
  [UI.theme.light.primary]
97
+ main = "#b22222" # Brighter shade of red
98
+ dark = "#8b0000" # Darker shade of the brighter red
99
+ light = "#ff6347" # Lighter shade of the brighter red
100
+ [UI.theme.light.text]
101
+ primary = "#212121"
102
+ secondary = "#616161"
103
  # Override default MUI dark theme. (Check theme.ts)
104
  [UI.theme.dark]
105
+ background = "#1C1C1C" # Slightly lighter dark background color
106
+ paper = "#2A2A2A" # Slightly lighter dark paper color
107
 
108
  [UI.theme.dark.primary]
109
+ main = "#89CFF0" # Primary color
110
+ dark = "#3700B3" # Dark variant of primary color
111
+ light = "#CFBCFF" # Lighter variant of primary color
112
 
113
 
114
  [meta]
115
+ generated_by = "1.1.302"
code/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modules import *
chainlit.md → code/chainlit.md RENAMED
File without changes
code/main.py CHANGED
@@ -1,9 +1,8 @@
1
- from langchain.document_loaders import PyPDFLoader, DirectoryLoader
2
- from langchain import PromptTemplate
3
- from langchain.embeddings import HuggingFaceEmbeddings
4
- from langchain.vectorstores import FAISS
5
  from langchain.chains import RetrievalQA
6
- from langchain.llms import CTransformers
7
  import chainlit as cl
8
  from langchain_community.chat_models import ChatOpenAI
9
  from langchain_community.embeddings import OpenAIEmbeddings
@@ -11,37 +10,54 @@ import yaml
11
  import logging
12
  from dotenv import load_dotenv
13
 
14
- from modules.llm_tutor import LLMTutor
15
- from modules.constants import *
16
- from modules.helpers import get_sources
17
-
18
 
 
 
19
  logger = logging.getLogger(__name__)
20
  logger.setLevel(logging.INFO)
 
21
 
22
  # Console Handler
23
  console_handler = logging.StreamHandler()
24
  console_handler.setLevel(logging.INFO)
25
- formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
26
  console_handler.setFormatter(formatter)
27
  logger.addHandler(console_handler)
28
 
29
- # File Handler
30
- log_file_path = "log_file.log" # Change this to your desired log file path
31
- file_handler = logging.FileHandler(log_file_path)
32
- file_handler.setLevel(logging.INFO)
33
- file_handler.setFormatter(formatter)
34
- logger.addHandler(file_handler)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
  # Adding option to select the chat profile
38
  @cl.set_chat_profiles
39
  async def chat_profile():
40
  return [
41
- cl.ChatProfile(
42
- name="Llama",
43
- markdown_description="Use the local LLM: **Tiny Llama**.",
44
- ),
45
  # cl.ChatProfile(
46
  # name="Mistral",
47
  # markdown_description="Use the local LLM: **Mistral**.",
@@ -54,6 +70,10 @@ async def chat_profile():
54
  name="gpt-4",
55
  markdown_description="Use OpenAI API for **gpt-4**.",
56
  ),
 
 
 
 
57
  ]
58
 
59
 
@@ -66,12 +86,26 @@ def rename(orig_author: str):
66
  # chainlit code
67
  @cl.on_chat_start
68
  async def start():
69
- with open("code/config.yml", "r") as f:
70
  config = yaml.safe_load(f)
71
- print(config)
72
- logger.info("Config file loaded")
73
- logger.info(f"Config: {config}")
74
- logger.info("Creating llm_tutor instance")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  chat_profile = cl.user_session.get("chat_profile")
77
  if chat_profile is not None:
@@ -93,32 +127,50 @@ async def start():
93
  llm_tutor = LLMTutor(config, logger=logger)
94
 
95
  chain = llm_tutor.qa_bot()
96
- model = config["llm_params"]["local_llm_params"]["model"]
97
- msg = cl.Message(content=f"Starting the bot {model}...")
98
- await msg.send()
99
- msg.content = f"Hey, What Can I Help You With?\n\nYou can me ask me questions about the course logistics, course content, about the final project, or anything else!"
100
- await msg.update()
101
 
 
 
102
  cl.user_session.set("chain", chain)
 
 
 
 
 
 
 
103
 
104
 
105
  @cl.on_message
106
  async def main(message):
 
107
  user = cl.user_session.get("user")
108
  chain = cl.user_session.get("chain")
109
- # cb = cl.AsyncLangchainCallbackHandler(
110
- # stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
111
- # )
112
- # cb.answer_reached = True
113
- # res=await chain.acall(message, callbacks=[cb])
114
- res = await chain.acall(message.content)
115
- print(f"response: {res}")
 
 
 
 
 
 
 
 
 
116
  try:
117
  answer = res["answer"]
118
  except:
119
  answer = res["result"]
120
- print(f"answer: {answer}")
121
 
122
- answer_with_sources, source_elements = get_sources(res, answer)
 
123
 
124
  await cl.Message(content=answer_with_sources, elements=source_elements).send()
 
1
+ from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
2
+ from langchain_core.prompts import PromptTemplate
3
+ from langchain_community.embeddings import HuggingFaceEmbeddings
4
+ from langchain_community.vectorstores import FAISS
5
  from langchain.chains import RetrievalQA
 
6
  import chainlit as cl
7
  from langchain_community.chat_models import ChatOpenAI
8
  from langchain_community.embeddings import OpenAIEmbeddings
 
10
  import logging
11
  from dotenv import load_dotenv
12
 
13
+ from modules.chat.llm_tutor import LLMTutor
14
+ from modules.config.constants import *
15
+ from modules.chat.helpers import get_sources
16
+ from modules.chat_processor.chat_processor import ChatProcessor
17
 
18
+ global logger
19
+ # Initialize logger
20
  logger = logging.getLogger(__name__)
21
  logger.setLevel(logging.INFO)
22
+ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
23
 
24
  # Console Handler
25
  console_handler = logging.StreamHandler()
26
  console_handler.setLevel(logging.INFO)
 
27
  console_handler.setFormatter(formatter)
28
  logger.addHandler(console_handler)
29
 
30
+
31
+ @cl.set_starters
32
+ async def set_starters():
33
+ return [
34
+ cl.Starter(
35
+ label="recording on CNNs?",
36
+ message="Where can I find the recording for the lecture on Transfromers?",
37
+ icon="/public/adv-screen-recorder-svgrepo-com.svg",
38
+ ),
39
+ cl.Starter(
40
+ label="where's the slides?",
41
+ message="When are the lectures? I can't find the schedule.",
42
+ icon="/public/alarmy-svgrepo-com.svg",
43
+ ),
44
+ cl.Starter(
45
+ label="Due Date?",
46
+ message="When is the final project due?",
47
+ icon="/public/calendar-samsung-17-svgrepo-com.svg",
48
+ ),
49
+ cl.Starter(
50
+ label="Explain backprop.",
51
+ message="I didnt understand the math behind backprop, could you explain it?",
52
+ icon="/public/acastusphoton-svgrepo-com.svg",
53
+ ),
54
+ ]
55
 
56
 
57
  # Adding option to select the chat profile
58
  @cl.set_chat_profiles
59
  async def chat_profile():
60
  return [
 
 
 
 
61
  # cl.ChatProfile(
62
  # name="Mistral",
63
  # markdown_description="Use the local LLM: **Mistral**.",
 
70
  name="gpt-4",
71
  markdown_description="Use OpenAI API for **gpt-4**.",
72
  ),
73
+ cl.ChatProfile(
74
+ name="Llama",
75
+ markdown_description="Use the local LLM: **Tiny Llama**.",
76
+ ),
77
  ]
78
 
79
 
 
86
  # chainlit code
87
  @cl.on_chat_start
88
  async def start():
89
+ with open("modules/config/config.yml", "r") as f:
90
  config = yaml.safe_load(f)
91
+
92
+ # Ensure log directory exists
93
+ log_directory = config["log_dir"]
94
+ if not os.path.exists(log_directory):
95
+ os.makedirs(log_directory)
96
+
97
+ # File Handler
98
+ log_file_path = (
99
+ f"{log_directory}/tutor.log" # Change this to your desired log file path
100
+ )
101
+ file_handler = logging.FileHandler(log_file_path, mode="w")
102
+ file_handler.setLevel(logging.INFO)
103
+ file_handler.setFormatter(formatter)
104
+ logger.addHandler(file_handler)
105
+
106
+ logger.info("Config file loaded")
107
+ logger.info(f"Config: {config}")
108
+ logger.info("Creating llm_tutor instance")
109
 
110
  chat_profile = cl.user_session.get("chat_profile")
111
  if chat_profile is not None:
 
127
  llm_tutor = LLMTutor(config, logger=logger)
128
 
129
  chain = llm_tutor.qa_bot()
130
+ # msg = cl.Message(content=f"Starting the bot {chat_profile}...")
131
+ # await msg.send()
132
+ # msg.content = opening_message
133
+ # await msg.update()
 
134
 
135
+ tags = [chat_profile, config["vectorstore"]["db_option"]]
136
+ chat_processor = ChatProcessor(config, tags=tags)
137
  cl.user_session.set("chain", chain)
138
+ cl.user_session.set("counter", 0)
139
+ cl.user_session.set("chat_processor", chat_processor)
140
+
141
+
142
+ @cl.on_chat_end
143
+ async def on_chat_end():
144
+ await cl.Message(content="Sorry, I have to go now. Goodbye!").send()
145
 
146
 
147
  @cl.on_message
148
  async def main(message):
149
+ global logger
150
  user = cl.user_session.get("user")
151
  chain = cl.user_session.get("chain")
152
+
153
+ counter = cl.user_session.get("counter")
154
+ counter += 1
155
+ cl.user_session.set("counter", counter)
156
+
157
+ # if counter >= 3: # Ensure the counter condition is checked
158
+ # await cl.Message(content="Your credits are up!").send()
159
+ # await on_chat_end() # Call the on_chat_end function to handle the end of the chat
160
+ # return # Exit the function to stop further processing
161
+ # else:
162
+
163
+ cb = cl.AsyncLangchainCallbackHandler() # TODO: fix streaming here
164
+ cb.answer_reached = True
165
+
166
+ processor = cl.user_session.get("chat_processor")
167
+ res = await processor.rag(message.content, chain, cb)
168
  try:
169
  answer = res["answer"]
170
  except:
171
  answer = res["result"]
 
172
 
173
+ answer_with_sources, source_elements, sources_dict = get_sources(res, answer)
174
+ processor._process(message.content, answer, sources_dict)
175
 
176
  await cl.Message(content=answer_with_sources, elements=source_elements).send()
code/modules/chat/__init__.py ADDED
File without changes
code/modules/{chat_model_loader.py → chat/chat_model_loader.py} RENAMED
@@ -1,8 +1,7 @@
1
  from langchain_community.chat_models import ChatOpenAI
2
- from langchain.llms import CTransformers
3
- from langchain.llms.huggingface_pipeline import HuggingFacePipeline
4
  from transformers import AutoTokenizer, TextStreamer
5
- from langchain.llms import LlamaCpp
6
  import torch
7
  import transformers
8
  import os
 
1
  from langchain_community.chat_models import ChatOpenAI
2
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
 
3
  from transformers import AutoTokenizer, TextStreamer
4
+ from langchain_community.llms import LlamaCpp
5
  import torch
6
  import transformers
7
  import os
code/modules/chat/helpers.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.config.constants import *
2
+ import chainlit as cl
3
+ from langchain_core.prompts import PromptTemplate
4
+
5
+
6
+ def get_sources(res, answer):
7
+ source_elements = []
8
+ source_dict = {} # Dictionary to store URL elements
9
+
10
+ for idx, source in enumerate(res["source_documents"]):
11
+ source_metadata = source.metadata
12
+ url = source_metadata.get("source", "N/A")
13
+ score = source_metadata.get("score", "N/A")
14
+ page = source_metadata.get("page", 1)
15
+
16
+ lecture_tldr = source_metadata.get("tldr", "N/A")
17
+ lecture_recording = source_metadata.get("lecture_recording", "N/A")
18
+ suggested_readings = source_metadata.get("suggested_readings", "N/A")
19
+ date = source_metadata.get("date", "N/A")
20
+
21
+ source_type = source_metadata.get("source_type", "N/A")
22
+
23
+ url_name = f"{url}_{page}"
24
+ if url_name not in source_dict:
25
+ source_dict[url_name] = {
26
+ "text": source.page_content,
27
+ "url": url,
28
+ "score": score,
29
+ "page": page,
30
+ "lecture_tldr": lecture_tldr,
31
+ "lecture_recording": lecture_recording,
32
+ "suggested_readings": suggested_readings,
33
+ "date": date,
34
+ "source_type": source_type,
35
+ }
36
+ else:
37
+ source_dict[url_name]["text"] += f"\n\n{source.page_content}"
38
+
39
+ # First, display the answer
40
+ full_answer = "**Answer:**\n"
41
+ full_answer += answer
42
+
43
+ # Then, display the sources
44
+ full_answer += "\n\n**Sources:**\n"
45
+ for idx, (url_name, source_data) in enumerate(source_dict.items()):
46
+ full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
47
+
48
+ name = f"Source {idx + 1} Text\n"
49
+ full_answer += name
50
+ source_elements.append(
51
+ cl.Text(name=name, content=source_data["text"], display="side")
52
+ )
53
+
54
+ # Add a PDF element if the source is a PDF file
55
+ if source_data["url"].lower().endswith(".pdf"):
56
+ name = f"Source {idx + 1} PDF\n"
57
+ full_answer += name
58
+ pdf_url = f"{source_data['url']}#page={source_data['page']+1}"
59
+ source_elements.append(cl.Pdf(name=name, url=pdf_url, display="side"))
60
+
61
+ full_answer += "\n**Metadata:**\n"
62
+ for idx, (url_name, source_data) in enumerate(source_dict.items()):
63
+ full_answer += f"\nSource {idx + 1} Metadata:\n"
64
+ source_elements.append(
65
+ cl.Text(
66
+ name=f"Source {idx + 1} Metadata",
67
+ content=f"Source: {source_data['url']}\n"
68
+ f"Page: {source_data['page']}\n"
69
+ f"Type: {source_data['source_type']}\n"
70
+ f"Date: {source_data['date']}\n"
71
+ f"TL;DR: {source_data['lecture_tldr']}\n"
72
+ f"Lecture Recording: {source_data['lecture_recording']}\n"
73
+ f"Suggested Readings: {source_data['suggested_readings']}\n",
74
+ display="side",
75
+ )
76
+ )
77
+
78
+ return full_answer, source_elements, source_dict
79
+
80
+
81
+ def get_prompt(config):
82
+ if config["llm_params"]["use_history"]:
83
+ if config["llm_params"]["llm_loader"] == "local_llm":
84
+ custom_prompt_template = tinyllama_prompt_template_with_history
85
+ elif config["llm_params"]["llm_loader"] == "openai":
86
+ custom_prompt_template = openai_prompt_template_with_history
87
+ # else:
88
+ # custom_prompt_template = tinyllama_prompt_template_with_history # default
89
+ prompt = PromptTemplate(
90
+ template=custom_prompt_template,
91
+ input_variables=["context", "chat_history", "question"],
92
+ )
93
+ else:
94
+ if config["llm_params"]["llm_loader"] == "local_llm":
95
+ custom_prompt_template = tinyllama_prompt_template
96
+ elif config["llm_params"]["llm_loader"] == "openai":
97
+ custom_prompt_template = openai_prompt_template
98
+ # else:
99
+ # custom_prompt_template = tinyllama_prompt_template
100
+ prompt = PromptTemplate(
101
+ template=custom_prompt_template,
102
+ input_variables=["context", "question"],
103
+ )
104
+ return prompt
code/modules/chat/llm_tutor.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chains import RetrievalQA, ConversationalRetrievalChain
2
+ from langchain.memory import (
3
+ ConversationBufferWindowMemory,
4
+ ConversationSummaryBufferMemory,
5
+ )
6
+ from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
7
+ import os
8
+ from modules.config.constants import *
9
+ from modules.chat.helpers import get_prompt
10
+ from modules.chat.chat_model_loader import ChatModelLoader
11
+ from modules.vectorstore.store_manager import VectorStoreManager
12
+
13
+ from modules.retriever.retriever import Retriever
14
+
15
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
16
+ from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
17
+ import inspect
18
+ from langchain.chains.conversational_retrieval.base import _get_chat_history
19
+ from langchain_core.messages import BaseMessage
20
+
21
+ CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
22
+
23
+ from langchain_core.output_parsers import StrOutputParser
24
+ from langchain_core.prompts import ChatPromptTemplate
25
+ from langchain_community.chat_models import ChatOpenAI
26
+
27
+
28
+ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
29
+
30
+ def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
31
+ _ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
32
+ buffer = ""
33
+ for dialogue_turn in chat_history:
34
+ if isinstance(dialogue_turn, BaseMessage):
35
+ role_prefix = _ROLE_MAP.get(
36
+ dialogue_turn.type, f"{dialogue_turn.type}: "
37
+ )
38
+ buffer += f"\n{role_prefix}{dialogue_turn.content}"
39
+ elif isinstance(dialogue_turn, tuple):
40
+ human = "Student: " + dialogue_turn[0]
41
+ ai = "AI Tutor: " + dialogue_turn[1]
42
+ buffer += "\n" + "\n".join([human, ai])
43
+ else:
44
+ raise ValueError(
45
+ f"Unsupported chat history format: {type(dialogue_turn)}."
46
+ f" Full chat history: {chat_history} "
47
+ )
48
+ return buffer
49
+
50
+ async def _acall(
51
+ self,
52
+ inputs: Dict[str, Any],
53
+ run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
54
+ ) -> Dict[str, Any]:
55
+ _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
56
+ question = inputs["question"]
57
+ get_chat_history = self._get_chat_history
58
+ chat_history_str = get_chat_history(inputs["chat_history"])
59
+ if chat_history_str:
60
+ # callbacks = _run_manager.get_child()
61
+ # new_question = await self.question_generator.arun(
62
+ # question=question, chat_history=chat_history_str, callbacks=callbacks
63
+ # )
64
+ system = (
65
+ "You are someone that rephrases statements. Rephrase the student's question to add context from their chat history if relevant, ensuring it remains from the student's point of view. "
66
+ "Incorporate relevant details from the chat history to make the question clearer and more specific."
67
+ "Do not change the meaning of the original statement, and maintain the student's tone and perspective. "
68
+ "If the question is conversational and doesn't require context, do not rephrase it. "
69
+ "Example: If the student previously asked about backpropagation in the context of deep learning and now asks 'what is it', rephrase to 'What is backprogatation.'. "
70
+ "Example: Do not rephrase if the user is asking something specific like 'cool, suggest a project with transformers to use as my final project'"
71
+ "Chat history: \n{chat_history_str}\n"
72
+ "Rephrase the following question only if necessary: '{question}'"
73
+ )
74
+
75
+ prompt = ChatPromptTemplate.from_messages(
76
+ [
77
+ ("system", system),
78
+ ("human", "{question}, {chat_history_str}"),
79
+ ]
80
+ )
81
+ llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
82
+ step_back = prompt | llm | StrOutputParser()
83
+ new_question = step_back.invoke(
84
+ {"question": question, "chat_history_str": chat_history_str}
85
+ )
86
+ else:
87
+ new_question = question
88
+ accepts_run_manager = (
89
+ "run_manager" in inspect.signature(self._aget_docs).parameters
90
+ )
91
+ if accepts_run_manager:
92
+ docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager)
93
+ else:
94
+ docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
95
+
96
+ output: Dict[str, Any] = {}
97
+ output["original_question"] = question
98
+ if self.response_if_no_docs_found is not None and len(docs) == 0:
99
+ output[self.output_key] = self.response_if_no_docs_found
100
+ else:
101
+ new_inputs = inputs.copy()
102
+ if self.rephrase_question:
103
+ new_inputs["question"] = new_question
104
+ new_inputs["chat_history"] = chat_history_str
105
+
106
+ # Prepare the final prompt with metadata
107
+ context = "\n\n".join(
108
+ [
109
+ f"Context {idx+1}: \n(Document content: {doc.page_content}\nMetadata: (source_file: {doc.metadata['source'] if 'source' in doc.metadata else 'unknown'}))"
110
+ for idx, doc in enumerate(docs)
111
+ ]
112
+ )
113
+ final_prompt = (
114
+ "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance."
115
+ "If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
116
+ "Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevent."
117
+ "Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
118
+ f"Chat History:\n{chat_history_str}\n\n"
119
+ f"Context:\n{context}\n\n"
120
+ "Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
121
+ f"Student: {question}\n"
122
+ "AI Tutor:"
123
+ )
124
+
125
+ # new_inputs["input"] = final_prompt
126
+ new_inputs["question"] = final_prompt
127
+ # output["final_prompt"] = final_prompt
128
+
129
+ answer = await self.combine_docs_chain.arun(
130
+ input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
131
+ )
132
+ output[self.output_key] = answer
133
+
134
+ if self.return_source_documents:
135
+ output["source_documents"] = docs
136
+ output["rephrased_question"] = new_question
137
+ return output
138
+
139
+
140
+ class LLMTutor:
141
+ def __init__(self, config, logger=None):
142
+ self.config = config
143
+ self.llm = self.load_llm()
144
+ self.logger = logger
145
+ self.vector_db = VectorStoreManager(config, logger=self.logger)
146
+ if self.config["vectorstore"]["embedd_files"]:
147
+ self.vector_db.create_database()
148
+ self.vector_db.save_database()
149
+
150
+ def set_custom_prompt(self):
151
+ """
152
+ Prompt template for QA retrieval for each vectorstore
153
+ """
154
+ prompt = get_prompt(self.config)
155
+ # prompt = QA_PROMPT
156
+
157
+ return prompt
158
+
159
+ # Retrieval QA Chain
160
+ def retrieval_qa_chain(self, llm, prompt, db):
161
+
162
+ retriever = Retriever(self.config)._return_retriever(db)
163
+
164
+ if self.config["llm_params"]["use_history"]:
165
+ memory = ConversationBufferWindowMemory(
166
+ k=self.config["llm_params"]["memory_window"],
167
+ memory_key="chat_history",
168
+ return_messages=True,
169
+ output_key="answer",
170
+ max_token_limit=128,
171
+ )
172
+ qa_chain = CustomConversationalRetrievalChain.from_llm(
173
+ llm=llm,
174
+ chain_type="stuff",
175
+ retriever=retriever,
176
+ return_source_documents=True,
177
+ memory=memory,
178
+ combine_docs_chain_kwargs={"prompt": prompt},
179
+ response_if_no_docs_found="No context found",
180
+ )
181
+ else:
182
+ qa_chain = RetrievalQA.from_chain_type(
183
+ llm=llm,
184
+ chain_type="stuff",
185
+ retriever=retriever,
186
+ return_source_documents=True,
187
+ chain_type_kwargs={"prompt": prompt},
188
+ )
189
+ return qa_chain
190
+
191
+ # Loading the model
192
+ def load_llm(self):
193
+ chat_model_loader = ChatModelLoader(self.config)
194
+ llm = chat_model_loader.load_chat_model()
195
+ return llm
196
+
197
+ # QA Model Function
198
+ def qa_bot(self):
199
+ db = self.vector_db.load_database()
200
+ qa_prompt = self.set_custom_prompt()
201
+ qa = self.retrieval_qa_chain(
202
+ self.llm, qa_prompt, db
203
+ ) # TODO: PROMPT is overwritten in CustomConversationalRetrievalChain
204
+
205
+ return qa
206
+
207
+ # output function
208
+ def final_result(query):
209
+ qa_result = qa_bot()
210
+ response = qa_result({"query": query})
211
+ return response
code/modules/chat_processor/__init__.py ADDED
File without changes
code/modules/chat_processor/base.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Template for chat processor classes
2
+
3
+
4
+ class ChatProcessorBase:
5
+ def __init__(self, config):
6
+ self.config = config
7
+
8
+ def process(self, message):
9
+ """
10
+ Processes and Logs the message
11
+ """
12
+ raise NotImplementedError("process method not implemented")
code/modules/chat_processor/chat_processor.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.chat_processor.literal_ai import LiteralaiChatProcessor
2
+
3
+
4
+ class ChatProcessor:
5
+ def __init__(self, config, tags=None):
6
+ self.chat_processor_type = config["chat_logging"]["platform"]
7
+ self.logging = config["chat_logging"]["log_chat"]
8
+ self.tags = tags
9
+ if self.logging:
10
+ self._init_processor()
11
+
12
+ def _init_processor(self):
13
+ if self.chat_processor_type == "literalai":
14
+ self.processor = LiteralaiChatProcessor(self.tags)
15
+ else:
16
+ raise ValueError(
17
+ f"Chat processor type {self.chat_processor_type} not supported"
18
+ )
19
+
20
+ def _process(self, user_message, assistant_message, source_dict):
21
+ if self.logging:
22
+ return self.processor.process(user_message, assistant_message, source_dict)
23
+ else:
24
+ pass
25
+
26
+ async def rag(self, user_query: str, chain, cb):
27
+ if self.logging:
28
+ return await self.processor.rag(user_query, chain, cb)
29
+ else:
30
+ return await chain.acall(user_query, callbacks=[cb])
code/modules/chat_processor/literal_ai.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from literalai import LiteralClient
2
+ import os
3
+ from .base import ChatProcessorBase
4
+
5
+
6
+ class LiteralaiChatProcessor(ChatProcessorBase):
7
+ def __init__(self, tags=None):
8
+ self.literal_client = LiteralClient(api_key=os.getenv("LITERAL_API_KEY"))
9
+ self.literal_client.reset_context()
10
+ with self.literal_client.thread(name="TEST") as thread:
11
+ self.thread_id = thread.id
12
+ self.thread = thread
13
+ if tags is not None and type(tags) == list:
14
+ self.thread.tags = tags
15
+ print(f"Thread ID: {self.thread}")
16
+
17
+ def process(self, user_message, assistant_message, source_dict):
18
+ with self.literal_client.thread(thread_id=self.thread_id) as thread:
19
+ self.literal_client.message(
20
+ content=user_message,
21
+ type="user_message",
22
+ name="User",
23
+ )
24
+ self.literal_client.message(
25
+ content=assistant_message,
26
+ type="assistant_message",
27
+ name="AI_Tutor",
28
+ )
29
+
30
+ async def rag(self, user_query: str, chain, cb):
31
+ with self.literal_client.step(
32
+ type="retrieval", name="RAG", thread_id=self.thread_id
33
+ ) as step:
34
+ step.input = {"question": user_query}
35
+ res = await chain.acall(user_query, callbacks=[cb])
36
+ step.output = res
37
+ return res
code/modules/config/__init__.py ADDED
File without changes
code/{config.yml → modules/config/config.yml} RENAMED
@@ -1,23 +1,42 @@
1
- embedding_options:
 
 
 
 
2
  embedd_files: False # bool
3
- persist_directory: null # str or None
4
- data_path: 'storage/data' # str
5
- url_file_path: 'storage/data/urls.txt' # str
6
  expand_urls: True # bool
7
- db_option : 'FAISS' # str
8
- db_path : 'vectorstores' # str
9
  model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
10
  search_top_k : 3 # int
 
 
 
 
 
 
 
 
 
 
 
 
11
  llm_params:
12
- use_history: False # bool
13
  memory_window: 3 # int
14
- llm_loader: 'local_llm' # str [local_llm, openai]
15
  openai_params:
16
- model: 'gpt-4' # str [gpt-3.5-turbo-1106, gpt-4]
17
  local_llm_params:
18
- model: "storage/models/llama-2-7b-chat.Q4_0.gguf"
19
- model_type: "llama"
20
- temperature: 0.2
 
 
 
 
21
  splitter_options:
22
  use_splitter: True # bool
23
  split_by_token : True # bool
 
1
+ log_dir: '../storage/logs' # str
2
+ log_chunk_dir: '../storage/logs/chunks' # str
3
+ device: 'cpu' # str [cuda, cpu]
4
+
5
+ vectorstore:
6
  embedd_files: False # bool
7
+ data_path: '../storage/data' # str
8
+ url_file_path: '../storage/data/urls.txt' # str
 
9
  expand_urls: True # bool
10
+ db_option : 'FAISS' # str [FAISS, Chroma, RAGatouille, RAPTOR]
11
+ db_path : '../vectorstores' # str
12
  model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
13
  search_top_k : 3 # int
14
+ score_threshold : 0.2 # float
15
+
16
+ faiss_params: # Not used as of now
17
+ index_path: '../vectorstores/faiss.index' # str
18
+ index_type: 'Flat' # str [Flat, HNSW, IVF]
19
+ index_dimension: 384 # int
20
+ index_nlist: 100 # int
21
+ index_nprobe: 10 # int
22
+
23
+ colbert_params:
24
+ index_name: "new_idx" # str
25
+
26
  llm_params:
27
+ use_history: True # bool
28
  memory_window: 3 # int
29
+ llm_loader: 'openai' # str [local_llm, openai]
30
  openai_params:
31
+ model: 'gpt-3.5-turbo-1106' # str [gpt-3.5-turbo-1106, gpt-4]
32
  local_llm_params:
33
+ model: 'tiny-llama'
34
+ temperature: 0.7
35
+
36
+ chat_logging:
37
+ log_chat: False # bool
38
+ platform: 'literalai'
39
+
40
  splitter_options:
41
  use_splitter: True # bool
42
  split_by_token : True # bool
code/modules/{constants.py → config/constants.py} RENAMED
@@ -6,7 +6,10 @@ load_dotenv()
6
  # API Keys - Loaded from the .env file
7
 
8
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
 
 
9
 
 
10
 
11
  # Prompt Templates
12
 
@@ -75,5 +78,5 @@ Question: {question}
75
 
76
  # Model Paths
77
 
78
- LLAMA_PATH = "storage/models/tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
79
  MISTRAL_PATH = "storage/models/mistral-7b-v0.1.Q4_K_M.gguf"
 
6
  # API Keys - Loaded from the .env file
7
 
8
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
9
+ HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
10
+ LITERAL_API_KEY = os.getenv("LITERAL_API_KEY")
11
 
12
+ opening_message = f"Hey, What Can I Help You With?\n\nYou can me ask me questions about the course logistics, course content, about the final project, or anything else!"
13
 
14
  # Prompt Templates
15
 
 
78
 
79
  # Model Paths
80
 
81
+ LLAMA_PATH = "../storage/models/tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
82
  MISTRAL_PATH = "storage/models/mistral-7b-v0.1.Q4_K_M.gguf"
code/modules/data_loader.py DELETED
@@ -1,287 +0,0 @@
1
- import re
2
- import pysrt
3
- from langchain.text_splitter import RecursiveCharacterTextSplitter
4
- from langchain.document_loaders import (
5
- PyMuPDFLoader,
6
- Docx2txtLoader,
7
- YoutubeLoader,
8
- WebBaseLoader,
9
- TextLoader,
10
- )
11
- from langchain.schema import Document
12
- import tempfile
13
- from tempfile import NamedTemporaryFile
14
- import logging
15
- import requests
16
-
17
- logger = logging.getLogger(__name__)
18
-
19
-
20
- class DataLoader:
21
- def __init__(self, config):
22
- """
23
- Class for handling all data extraction and chunking
24
- Inputs:
25
- config - dictionary from yaml file, containing all important parameters
26
- """
27
- self.config = config
28
- self.remove_leftover_delimiters = config["splitter_options"][
29
- "remove_leftover_delimiters"
30
- ]
31
-
32
- # Main list of all documents
33
- self.document_chunks_full = []
34
- self.document_names = []
35
-
36
- if config["splitter_options"]["use_splitter"]:
37
- if config["splitter_options"]["split_by_token"]:
38
- self.splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
39
- chunk_size=config["splitter_options"]["chunk_size"],
40
- chunk_overlap=config["splitter_options"]["chunk_overlap"],
41
- separators=config["splitter_options"]["chunk_separators"],
42
- disallowed_special=()
43
- )
44
- else:
45
- self.splitter = RecursiveCharacterTextSplitter(
46
- chunk_size=config["splitter_options"]["chunk_size"],
47
- chunk_overlap=config["splitter_options"]["chunk_overlap"],
48
- separators=config["splitter_options"]["chunk_separators"],
49
- disallowed_special=()
50
- )
51
- else:
52
- self.splitter = None
53
- logger.info("InfoLoader instance created")
54
-
55
- def extract_text_from_pdf(self, pdf_path):
56
- text = ""
57
- with open(pdf_path, "rb") as file:
58
- reader = PyPDF2.PdfReader(file)
59
- num_pages = len(reader.pages)
60
- for page_num in range(num_pages):
61
- page = reader.pages[page_num]
62
- text += page.extract_text()
63
- return text
64
-
65
- def download_pdf_from_url(self, pdf_url):
66
- response = requests.get(pdf_url)
67
- if response.status_code == 200:
68
- with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
69
- temp_file.write(response.content)
70
- temp_file_path = temp_file.name
71
- return temp_file_path
72
- else:
73
- print("Failed to download PDF from URL:", pdf_url)
74
- return None
75
-
76
- def get_chunks(self, uploaded_files, weblinks):
77
- # Main list of all documents
78
- self.document_chunks_full = []
79
- self.document_names = []
80
-
81
- def remove_delimiters(document_chunks: list):
82
- """
83
- Helper function to remove remaining delimiters in document chunks
84
- """
85
- for chunk in document_chunks:
86
- for delimiter in self.config["splitter_options"][
87
- "delimiters_to_remove"
88
- ]:
89
- chunk.page_content = re.sub(delimiter, " ", chunk.page_content)
90
- return document_chunks
91
-
92
- def remove_chunks(document_chunks: list):
93
- """
94
- Helper function to remove any unwanted document chunks after splitting
95
- """
96
- front = self.config["splitter_options"]["front_chunk_to_remove"]
97
- end = self.config["splitter_options"]["last_chunks_to_remove"]
98
- # Remove pages
99
- for _ in range(front):
100
- del document_chunks[0]
101
- for _ in range(end):
102
- document_chunks.pop()
103
- logger.info(f"\tNumber of pages after skipping: {len(document_chunks)}")
104
- return document_chunks
105
-
106
- def get_pdf_from_url(pdf_url: str):
107
- temp_pdf_path = self.download_pdf_from_url(pdf_url)
108
- if temp_pdf_path:
109
- title, document_chunks = get_pdf(temp_pdf_path, pdf_url)
110
- os.remove(temp_pdf_path)
111
- return title, document_chunks
112
-
113
- def get_pdf(temp_file_path: str, title: str):
114
- """
115
- Function to process PDF files
116
- """
117
- loader = PyMuPDFLoader(
118
- temp_file_path
119
- ) # This loader preserves more metadata
120
-
121
- if self.splitter:
122
- document_chunks = self.splitter.split_documents(loader.load())
123
- else:
124
- document_chunks = loader.load()
125
-
126
- if "title" in document_chunks[0].metadata.keys():
127
- title = document_chunks[0].metadata["title"]
128
-
129
- logger.info(
130
- f"\t\tOriginal no. of pages: {document_chunks[0].metadata['total_pages']}"
131
- )
132
-
133
- return title, document_chunks
134
-
135
- def get_txt(temp_file_path: str, title: str):
136
- """
137
- Function to process TXT files
138
- """
139
- loader = TextLoader(temp_file_path, autodetect_encoding=True)
140
-
141
- if self.splitter:
142
- document_chunks = self.splitter.split_documents(loader.load())
143
- else:
144
- document_chunks = loader.load()
145
-
146
- # Update the metadata
147
- for chunk in document_chunks:
148
- chunk.metadata["source"] = title
149
- chunk.metadata["page"] = "N/A"
150
-
151
- return title, document_chunks
152
-
153
- def get_srt(temp_file_path: str, title: str):
154
- """
155
- Function to process SRT files
156
- """
157
- subs = pysrt.open(temp_file_path)
158
-
159
- text = ""
160
- for sub in subs:
161
- text += sub.text
162
- document_chunks = [Document(page_content=text)]
163
-
164
- if self.splitter:
165
- document_chunks = self.splitter.split_documents(document_chunks)
166
-
167
- # Update the metadata
168
- for chunk in document_chunks:
169
- chunk.metadata["source"] = title
170
- chunk.metadata["page"] = "N/A"
171
-
172
- return title, document_chunks
173
-
174
- def get_docx(temp_file_path: str, title: str):
175
- """
176
- Function to process DOCX files
177
- """
178
- loader = Docx2txtLoader(temp_file_path)
179
-
180
- if self.splitter:
181
- document_chunks = self.splitter.split_documents(loader.load())
182
- else:
183
- document_chunks = loader.load()
184
-
185
- # Update the metadata
186
- for chunk in document_chunks:
187
- chunk.metadata["source"] = title
188
- chunk.metadata["page"] = "N/A"
189
-
190
- return title, document_chunks
191
-
192
- def get_youtube_transcript(url: str):
193
- """
194
- Function to retrieve youtube transcript and process text
195
- """
196
- loader = YoutubeLoader.from_youtube_url(
197
- url, add_video_info=True, language=["en"], translation="en"
198
- )
199
-
200
- if self.splitter:
201
- document_chunks = self.splitter.split_documents(loader.load())
202
- else:
203
- document_chunks = loader.load_and_split()
204
-
205
- # Replace the source with title (for display in st UI later)
206
- for chunk in document_chunks:
207
- chunk.metadata["source"] = chunk.metadata["title"]
208
- logger.info(chunk.metadata["title"])
209
-
210
- return title, document_chunks
211
-
212
- def get_html(url: str):
213
- """
214
- Function to process websites via HTML files
215
- """
216
- loader = WebBaseLoader(url)
217
-
218
- if self.splitter:
219
- document_chunks = self.splitter.split_documents(loader.load())
220
- else:
221
- document_chunks = loader.load_and_split()
222
-
223
- title = document_chunks[0].metadata["title"]
224
- logger.info(document_chunks[0].metadata)
225
-
226
- return title, document_chunks
227
-
228
- # Handle file by file
229
- for file_index, file_path in enumerate(uploaded_files):
230
-
231
- file_name = file_path.split("/")[-1]
232
- file_type = file_name.split(".")[-1]
233
-
234
- # Handle different file types
235
- if file_type == "pdf":
236
- try:
237
- title, document_chunks = get_pdf(file_path, file_name)
238
- except:
239
- title, document_chunks = get_pdf_from_url(file_path)
240
- elif file_type == "txt":
241
- title, document_chunks = get_txt(file_path, file_name)
242
- elif file_type == "docx":
243
- title, document_chunks = get_docx(file_path, file_name)
244
- elif file_type == "srt":
245
- title, document_chunks = get_srt(file_path, file_name)
246
-
247
- # Additional wrangling - Remove leftover delimiters and any specified chunks
248
- if self.remove_leftover_delimiters:
249
- document_chunks = remove_delimiters(document_chunks)
250
- if self.config["splitter_options"]["remove_chunks"]:
251
- document_chunks = remove_chunks(document_chunks)
252
-
253
- logger.info(f"\t\tExtracted no. of chunks: {len(document_chunks)} from {file_name}")
254
- self.document_names.append(title)
255
- self.document_chunks_full.extend(document_chunks)
256
-
257
- # Handle youtube links:
258
- if weblinks[0] != "":
259
- logger.info(f"Splitting weblinks: total of {len(weblinks)}")
260
-
261
- # Handle link by link
262
- for link_index, link in enumerate(weblinks):
263
- try:
264
- logger.info(f"\tSplitting link {link_index+1} : {link}")
265
- if "youtube" in link:
266
- title, document_chunks = get_youtube_transcript(link)
267
- else:
268
- title, document_chunks = get_html(link)
269
-
270
- # Additional wrangling - Remove leftover delimiters and any specified chunks
271
- if self.remove_leftover_delimiters:
272
- document_chunks = remove_delimiters(document_chunks)
273
- if self.config["splitter_options"]["remove_chunks"]:
274
- document_chunks = remove_chunks(document_chunks)
275
-
276
- print(f"\t\tExtracted no. of chunks: {len(document_chunks)}")
277
- self.document_names.append(title)
278
- self.document_chunks_full.extend(document_chunks)
279
- except:
280
- logger.info(f"\t\tError splitting link {link_index+1} : {link}")
281
- exit()
282
-
283
- logger.info(
284
- f"\tNumber of document chunks extracted in total: {len(self.document_chunks_full)}\n\n"
285
- )
286
-
287
- return self.document_chunks_full, self.document_names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/modules/dataloader/__init__.py ADDED
File without changes
code/modules/dataloader/data_loader.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import requests
4
+ import pysrt
5
+ from langchain_community.document_loaders import (
6
+ PyMuPDFLoader,
7
+ Docx2txtLoader,
8
+ YoutubeLoader,
9
+ WebBaseLoader,
10
+ TextLoader,
11
+ )
12
+ from langchain_community.document_loaders import UnstructuredMarkdownLoader
13
+ from llama_parse import LlamaParse
14
+ from langchain.schema import Document
15
+ import logging
16
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
17
+ from ragatouille import RAGPretrainedModel
18
+ from langchain.chains import LLMChain
19
+ from langchain_community.llms import OpenAI
20
+ from langchain import PromptTemplate
21
+ import json
22
+ from concurrent.futures import ThreadPoolExecutor
23
+
24
+ from modules.dataloader.helpers import get_metadata
25
+
26
+
27
+ class PDFReader:
28
+ def __init__(self):
29
+ pass
30
+
31
+ def get_loader(self, pdf_path):
32
+ loader = PyMuPDFLoader(pdf_path)
33
+ return loader
34
+
35
+ def get_documents(self, loader):
36
+ return loader.load()
37
+
38
+
39
+ class FileReader:
40
+ def __init__(self, logger):
41
+ self.pdf_reader = PDFReader()
42
+ self.logger = logger
43
+
44
+ def extract_text_from_pdf(self, pdf_path):
45
+ text = ""
46
+ with open(pdf_path, "rb") as file:
47
+ reader = PyPDF2.PdfReader(file)
48
+ num_pages = len(reader.pages)
49
+ for page_num in range(num_pages):
50
+ page = reader.pages[page_num]
51
+ text += page.extract_text()
52
+ return text
53
+
54
+ def download_pdf_from_url(self, pdf_url):
55
+ response = requests.get(pdf_url)
56
+ if response.status_code == 200:
57
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
58
+ temp_file.write(response.content)
59
+ temp_file_path = temp_file.name
60
+ return temp_file_path
61
+ else:
62
+ self.logger.error(f"Failed to download PDF from URL: {pdf_url}")
63
+ return None
64
+
65
+ def read_pdf(self, temp_file_path: str):
66
+ loader = self.pdf_reader.get_loader(temp_file_path)
67
+ documents = self.pdf_reader.get_documents(loader)
68
+ return documents
69
+
70
+ def read_txt(self, temp_file_path: str):
71
+ loader = TextLoader(temp_file_path, autodetect_encoding=True)
72
+ return loader.load()
73
+
74
+ def read_docx(self, temp_file_path: str):
75
+ loader = Docx2txtLoader(temp_file_path)
76
+ return loader.load()
77
+
78
+ def read_srt(self, temp_file_path: str):
79
+ subs = pysrt.open(temp_file_path)
80
+ text = ""
81
+ for sub in subs:
82
+ text += sub.text
83
+ return [Document(page_content=text)]
84
+
85
+ def read_youtube_transcript(self, url: str):
86
+ loader = YoutubeLoader.from_youtube_url(
87
+ url, add_video_info=True, language=["en"], translation="en"
88
+ )
89
+ return loader.load()
90
+
91
+ def read_html(self, url: str):
92
+ loader = WebBaseLoader(url)
93
+ return loader.load()
94
+
95
+ def read_tex_from_url(self, tex_url):
96
+ response = requests.get(tex_url)
97
+ if response.status_code == 200:
98
+ return [Document(page_content=response.text)]
99
+ else:
100
+ self.logger.error(f"Failed to fetch .tex file from URL: {tex_url}")
101
+ return None
102
+
103
+
104
+ class ChunkProcessor:
105
+ def __init__(self, config, logger):
106
+ self.config = config
107
+ self.logger = logger
108
+
109
+ self.document_data = {}
110
+ self.document_metadata = {}
111
+ self.document_chunks_full = []
112
+
113
+ if config["splitter_options"]["use_splitter"]:
114
+ if config["splitter_options"]["split_by_token"]:
115
+ self.splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
116
+ chunk_size=config["splitter_options"]["chunk_size"],
117
+ chunk_overlap=config["splitter_options"]["chunk_overlap"],
118
+ separators=config["splitter_options"]["chunk_separators"],
119
+ disallowed_special=(),
120
+ )
121
+ else:
122
+ self.splitter = RecursiveCharacterTextSplitter(
123
+ chunk_size=config["splitter_options"]["chunk_size"],
124
+ chunk_overlap=config["splitter_options"]["chunk_overlap"],
125
+ separators=config["splitter_options"]["chunk_separators"],
126
+ disallowed_special=(),
127
+ )
128
+ else:
129
+ self.splitter = None
130
+ self.logger.info("ChunkProcessor instance created")
131
+
132
+ def remove_delimiters(self, document_chunks: list):
133
+ for chunk in document_chunks:
134
+ for delimiter in self.config["splitter_options"]["delimiters_to_remove"]:
135
+ chunk.page_content = re.sub(delimiter, " ", chunk.page_content)
136
+ return document_chunks
137
+
138
+ def remove_chunks(self, document_chunks: list):
139
+ front = self.config["splitter_options"]["front_chunk_to_remove"]
140
+ end = self.config["splitter_options"]["last_chunks_to_remove"]
141
+ for _ in range(front):
142
+ del document_chunks[0]
143
+ for _ in range(end):
144
+ document_chunks.pop()
145
+ return document_chunks
146
+
147
+ def process_chunks(
148
+ self, documents, file_type="txt", source="", page=0, metadata={}
149
+ ):
150
+ documents = [Document(page_content=documents, source=source, page=page)]
151
+ if (
152
+ file_type == "txt"
153
+ or file_type == "docx"
154
+ or file_type == "srt"
155
+ or file_type == "tex"
156
+ ):
157
+ document_chunks = self.splitter.split_documents(documents)
158
+ elif file_type == "pdf":
159
+ document_chunks = documents # Full page for now
160
+
161
+ # add the source and page number back to the metadata
162
+ for chunk in document_chunks:
163
+ chunk.metadata["source"] = source
164
+ chunk.metadata["page"] = page
165
+
166
+ # add the metadata extracted from the document
167
+ for key, value in metadata.items():
168
+ chunk.metadata[key] = value
169
+
170
+ if self.config["splitter_options"]["remove_leftover_delimiters"]:
171
+ document_chunks = self.remove_delimiters(document_chunks)
172
+ if self.config["splitter_options"]["remove_chunks"]:
173
+ document_chunks = self.remove_chunks(document_chunks)
174
+
175
+ return document_chunks
176
+
177
+ def chunk_docs(self, file_reader, uploaded_files, weblinks):
178
+ addl_metadata = get_metadata(
179
+ "https://dl4ds.github.io/sp2024/lectures/",
180
+ "https://dl4ds.github.io/sp2024/schedule/",
181
+ ) # For any additional metadata
182
+
183
+ with ThreadPoolExecutor() as executor:
184
+ executor.map(
185
+ self.process_file,
186
+ uploaded_files,
187
+ range(len(uploaded_files)),
188
+ [file_reader] * len(uploaded_files),
189
+ [addl_metadata] * len(uploaded_files),
190
+ )
191
+ executor.map(
192
+ self.process_weblink,
193
+ weblinks,
194
+ range(len(weblinks)),
195
+ [file_reader] * len(weblinks),
196
+ [addl_metadata] * len(weblinks),
197
+ )
198
+
199
+ document_names = [
200
+ f"{file_name}_{page_num}"
201
+ for file_name, pages in self.document_data.items()
202
+ for page_num in pages.keys()
203
+ ]
204
+ documents = [
205
+ page for doc in self.document_data.values() for page in doc.values()
206
+ ]
207
+ document_metadata = [
208
+ page for doc in self.document_metadata.values() for page in doc.values()
209
+ ]
210
+
211
+ self.save_document_data()
212
+
213
+ self.logger.info(
214
+ f"Total document chunks extracted: {len(self.document_chunks_full)}"
215
+ )
216
+
217
+ return self.document_chunks_full, document_names, documents, document_metadata
218
+
219
+ def process_documents(
220
+ self, documents, file_path, file_type, metadata_source, addl_metadata
221
+ ):
222
+ file_data = {}
223
+ file_metadata = {}
224
+
225
+ for doc in documents:
226
+ # if len(doc.page_content) <= 400: # better approach to filter out non-informative documents
227
+ # continue
228
+
229
+ page_num = doc.metadata.get("page", 0)
230
+ file_data[page_num] = doc.page_content
231
+ metadata = (
232
+ addl_metadata.get(file_path, {})
233
+ if metadata_source == "file"
234
+ else {"source": file_path, "page": page_num}
235
+ )
236
+ file_metadata[page_num] = metadata
237
+
238
+ if self.config["vectorstore"]["db_option"] not in ["RAGatouille"]:
239
+ document_chunks = self.process_chunks(
240
+ doc.page_content,
241
+ file_type,
242
+ source=file_path,
243
+ page=page_num,
244
+ metadata=metadata,
245
+ )
246
+ self.document_chunks_full.extend(document_chunks)
247
+
248
+ self.document_data[file_path] = file_data
249
+ self.document_metadata[file_path] = file_metadata
250
+
251
+ def process_file(self, file_path, file_index, file_reader, addl_metadata):
252
+ file_name = os.path.basename(file_path)
253
+ if file_name in self.document_data:
254
+ return
255
+
256
+ file_type = file_name.split(".")[-1].lower()
257
+ self.logger.info(f"Reading file {file_index + 1}: {file_path}")
258
+
259
+ read_methods = {
260
+ "pdf": file_reader.read_pdf,
261
+ "txt": file_reader.read_txt,
262
+ "docx": file_reader.read_docx,
263
+ "srt": file_reader.read_srt,
264
+ "tex": file_reader.read_tex_from_url,
265
+ }
266
+ if file_type not in read_methods:
267
+ self.logger.warning(f"Unsupported file type: {file_type}")
268
+ return
269
+
270
+ try:
271
+ documents = read_methods[file_type](file_path)
272
+ self.process_documents(
273
+ documents, file_path, file_type, "file", addl_metadata
274
+ )
275
+ except Exception as e:
276
+ self.logger.error(f"Error processing file {file_name}: {str(e)}")
277
+
278
+ def process_weblink(self, link, link_index, file_reader, addl_metadata):
279
+ if link in self.document_data:
280
+ return
281
+
282
+ self.logger.info(f"Reading link {link_index + 1} : {link}")
283
+
284
+ try:
285
+ if "youtube" in link:
286
+ documents = file_reader.read_youtube_transcript(link)
287
+ else:
288
+ documents = file_reader.read_html(link)
289
+
290
+ self.process_documents(documents, link, "txt", "link", addl_metadata)
291
+ except Exception as e:
292
+ self.logger.error(f"Error Reading link {link_index + 1} : {link}: {str(e)}")
293
+
294
+ def save_document_data(self):
295
+ if not os.path.exists(f"{self.config['log_chunk_dir']}/docs"):
296
+ os.makedirs(f"{self.config['log_chunk_dir']}/docs")
297
+ self.logger.info(
298
+ f"Creating directory {self.config['log_chunk_dir']}/docs for document data"
299
+ )
300
+ self.logger.info(
301
+ f"Saving document content to {self.config['log_chunk_dir']}/docs/doc_content.json"
302
+ )
303
+ if not os.path.exists(f"{self.config['log_chunk_dir']}/metadata"):
304
+ os.makedirs(f"{self.config['log_chunk_dir']}/metadata")
305
+ self.logger.info(
306
+ f"Creating directory {self.config['log_chunk_dir']}/metadata for document metadata"
307
+ )
308
+ self.logger.info(
309
+ f"Saving document metadata to {self.config['log_chunk_dir']}/metadata/doc_metadata.json"
310
+ )
311
+ with open(
312
+ f"{self.config['log_chunk_dir']}/docs/doc_content.json", "w"
313
+ ) as json_file:
314
+ json.dump(self.document_data, json_file, indent=4)
315
+ with open(
316
+ f"{self.config['log_chunk_dir']}/metadata/doc_metadata.json", "w"
317
+ ) as json_file:
318
+ json.dump(self.document_metadata, json_file, indent=4)
319
+
320
+ def load_document_data(self):
321
+ with open(
322
+ f"{self.config['log_chunk_dir']}/docs/doc_content.json", "r"
323
+ ) as json_file:
324
+ self.document_data = json.load(json_file)
325
+ with open(
326
+ f"{self.config['log_chunk_dir']}/metadata/doc_metadata.json", "r"
327
+ ) as json_file:
328
+ self.document_metadata = json.load(json_file)
329
+
330
+
331
+ class DataLoader:
332
+ def __init__(self, config, logger=None):
333
+ self.file_reader = FileReader(logger=logger)
334
+ self.chunk_processor = ChunkProcessor(config, logger=logger)
335
+
336
+ def get_chunks(self, uploaded_files, weblinks):
337
+ return self.chunk_processor.chunk_docs(
338
+ self.file_reader, uploaded_files, weblinks
339
+ )
340
+
341
+
342
+ if __name__ == "__main__":
343
+ import yaml
344
+
345
+ logger = logging.getLogger(__name__)
346
+ logger.setLevel(logging.INFO)
347
+
348
+ with open("../code/modules/config/config.yml", "r") as f:
349
+ config = yaml.safe_load(f)
350
+
351
+ data_loader = DataLoader(config, logger=logger)
352
+ document_chunks, document_names, documents, document_metadata = (
353
+ data_loader.get_chunks(
354
+ [],
355
+ ["https://dl4ds.github.io/sp2024/"],
356
+ )
357
+ )
358
+
359
+ print(document_names)
360
+ print(len(document_chunks))
code/modules/dataloader/helpers.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from bs4 import BeautifulSoup
3
+ from tqdm import tqdm
4
+
5
+
6
+ def get_urls_from_file(file_path: str):
7
+ """
8
+ Function to get urls from a file
9
+ """
10
+ with open(file_path, "r") as f:
11
+ urls = f.readlines()
12
+ urls = [url.strip() for url in urls]
13
+ return urls
14
+
15
+
16
+ def get_base_url(url):
17
+ parsed_url = urlparse(url)
18
+ base_url = f"{parsed_url.scheme}://{parsed_url.netloc}/"
19
+ return base_url
20
+
21
+
22
+ def get_metadata(lectures_url, schedule_url):
23
+ """
24
+ Function to get the lecture metadata from the lectures and schedule URLs.
25
+ """
26
+ lecture_metadata = {}
27
+
28
+ # Get the main lectures page content
29
+ r_lectures = requests.get(lectures_url)
30
+ soup_lectures = BeautifulSoup(r_lectures.text, "html.parser")
31
+
32
+ # Get the main schedule page content
33
+ r_schedule = requests.get(schedule_url)
34
+ soup_schedule = BeautifulSoup(r_schedule.text, "html.parser")
35
+
36
+ # Find all lecture blocks
37
+ lecture_blocks = soup_lectures.find_all("div", class_="lecture-container")
38
+
39
+ # Create a mapping from slides link to date
40
+ date_mapping = {}
41
+ schedule_rows = soup_schedule.find_all("li", class_="table-row-lecture")
42
+ for row in schedule_rows:
43
+ try:
44
+ date = (
45
+ row.find("div", {"data-label": "Date"}).get_text(separator=" ").strip()
46
+ )
47
+ description_div = row.find("div", {"data-label": "Description"})
48
+ slides_link_tag = description_div.find("a", title="Download slides")
49
+ slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
50
+ slides_link = (
51
+ f"https://dl4ds.github.io{slides_link}" if slides_link else None
52
+ )
53
+ if slides_link:
54
+ date_mapping[slides_link] = date
55
+ except Exception as e:
56
+ print(f"Error processing schedule row: {e}")
57
+ continue
58
+
59
+ for block in lecture_blocks:
60
+ try:
61
+ # Extract the lecture title
62
+ title = block.find("span", style="font-weight: bold;").text.strip()
63
+
64
+ # Extract the TL;DR
65
+ tldr = block.find("strong", text="tl;dr:").next_sibling.strip()
66
+
67
+ # Extract the link to the slides
68
+ slides_link_tag = block.find("a", title="Download slides")
69
+ slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
70
+ slides_link = (
71
+ f"https://dl4ds.github.io{slides_link}" if slides_link else None
72
+ )
73
+
74
+ # Extract the link to the lecture recording
75
+ recording_link_tag = block.find("a", title="Download lecture recording")
76
+ recording_link = (
77
+ recording_link_tag["href"].strip() if recording_link_tag else None
78
+ )
79
+
80
+ # Extract suggested readings or summary if available
81
+ suggested_readings_tag = block.find("p", text="Suggested Readings:")
82
+ if suggested_readings_tag:
83
+ suggested_readings = suggested_readings_tag.find_next_sibling("ul")
84
+ if suggested_readings:
85
+ suggested_readings = suggested_readings.get_text(
86
+ separator="\n"
87
+ ).strip()
88
+ else:
89
+ suggested_readings = "No specific readings provided."
90
+ else:
91
+ suggested_readings = "No specific readings provided."
92
+
93
+ # Get the date from the schedule
94
+ date = date_mapping.get(slides_link, "No date available")
95
+
96
+ # Add to the dictionary
97
+ lecture_metadata[slides_link] = {
98
+ "date": date,
99
+ "tldr": tldr,
100
+ "title": title,
101
+ "lecture_recording": recording_link,
102
+ "suggested_readings": suggested_readings,
103
+ }
104
+ except Exception as e:
105
+ print(f"Error processing block: {e}")
106
+ continue
107
+
108
+ return lecture_metadata
code/modules/dataloader/webpage_crawler.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import aiohttp
2
+ from aiohttp import ClientSession
3
+ import asyncio
4
+ import requests
5
+ from bs4 import BeautifulSoup
6
+ from urllib.parse import urlparse, urljoin, urldefrag
7
+
8
+ class WebpageCrawler:
9
+ def __init__(self):
10
+ self.dict_href_links = {}
11
+
12
+ async def fetch(self, session: ClientSession, url: str) -> str:
13
+ async with session.get(url) as response:
14
+ try:
15
+ return await response.text()
16
+ except UnicodeDecodeError:
17
+ return await response.text(encoding="latin1")
18
+
19
+ def url_exists(self, url: str) -> bool:
20
+ try:
21
+ response = requests.head(url)
22
+ return response.status_code == 200
23
+ except requests.ConnectionError:
24
+ return False
25
+
26
+ async def get_links(self, session: ClientSession, website_link: str, base_url: str):
27
+ html_data = await self.fetch(session, website_link)
28
+ soup = BeautifulSoup(html_data, "html.parser")
29
+ list_links = []
30
+ for link in soup.find_all("a", href=True):
31
+ href = link["href"].strip()
32
+ full_url = urljoin(base_url, href)
33
+ normalized_url = self.normalize_url(full_url) # sections removed
34
+ if (
35
+ normalized_url not in self.dict_href_links
36
+ and self.is_child_url(normalized_url, base_url)
37
+ and self.url_exists(normalized_url)
38
+ ):
39
+ self.dict_href_links[normalized_url] = None
40
+ list_links.append(normalized_url)
41
+
42
+ return list_links
43
+
44
+ async def get_subpage_links(
45
+ self, session: ClientSession, urls: list, base_url: str
46
+ ):
47
+ tasks = [self.get_links(session, url, base_url) for url in urls]
48
+ results = await asyncio.gather(*tasks)
49
+ all_links = [link for sublist in results for link in sublist]
50
+ return all_links
51
+
52
+ async def get_all_pages(self, url: str, base_url: str):
53
+ async with aiohttp.ClientSession() as session:
54
+ dict_links = {url: "Not-checked"}
55
+ counter = None
56
+ while counter != 0:
57
+ unchecked_links = [
58
+ link
59
+ for link, status in dict_links.items()
60
+ if status == "Not-checked"
61
+ ]
62
+ if not unchecked_links:
63
+ break
64
+ new_links = await self.get_subpage_links(
65
+ session, unchecked_links, base_url
66
+ )
67
+ for link in unchecked_links:
68
+ dict_links[link] = "Checked"
69
+ print(f"Checked: {link}")
70
+ dict_links.update(
71
+ {
72
+ link: "Not-checked"
73
+ for link in new_links
74
+ if link not in dict_links
75
+ }
76
+ )
77
+ counter = len(
78
+ [
79
+ status
80
+ for status in dict_links.values()
81
+ if status == "Not-checked"
82
+ ]
83
+ )
84
+
85
+ checked_urls = [
86
+ url for url, status in dict_links.items() if status == "Checked"
87
+ ]
88
+ return checked_urls
89
+
90
+ def is_webpage(self, url: str) -> bool:
91
+ try:
92
+ response = requests.head(url, allow_redirects=True)
93
+ content_type = response.headers.get("Content-Type", "").lower()
94
+ return "text/html" in content_type
95
+ except requests.RequestException:
96
+ return False
97
+
98
+ def clean_url_list(self, urls):
99
+ files, webpages = [], []
100
+
101
+ for url in urls:
102
+ if self.is_webpage(url):
103
+ webpages.append(url)
104
+ else:
105
+ files.append(url)
106
+
107
+ return files, webpages
108
+
109
+ def is_child_url(self, url, base_url):
110
+ return url.startswith(base_url)
111
+
112
+ def normalize_url(self, url: str):
113
+ # Strip the fragment identifier
114
+ defragged_url, _ = urldefrag(url)
115
+ return defragged_url
code/modules/helpers.py DELETED
@@ -1,200 +0,0 @@
1
- import requests
2
- from bs4 import BeautifulSoup
3
- from tqdm import tqdm
4
- from urllib.parse import urlparse
5
- import chainlit as cl
6
- from langchain import PromptTemplate
7
- try:
8
- from modules.constants import *
9
- except:
10
- from constants import *
11
-
12
- """
13
- Ref: https://python.plainenglish.io/scraping-the-subpages-on-a-website-ea2d4e3db113
14
- """
15
-
16
-
17
- class WebpageCrawler:
18
- def __init__(self):
19
- pass
20
-
21
- def getdata(self, url):
22
- r = requests.get(url)
23
- return r.text
24
-
25
- def url_exists(self, url):
26
- try:
27
- response = requests.head(url)
28
- return response.status_code == 200
29
- except requests.ConnectionError:
30
- return False
31
-
32
- def get_links(self, website_link, base_url=None):
33
- if base_url is None:
34
- base_url = website_link
35
- html_data = self.getdata(website_link)
36
- soup = BeautifulSoup(html_data, "html.parser")
37
- list_links = []
38
- for link in soup.find_all("a", href=True):
39
-
40
- # clean the link
41
- # remove empty spaces
42
- link["href"] = link["href"].strip()
43
- # Append to list if new link contains original link
44
- if str(link["href"]).startswith((str(website_link))):
45
- list_links.append(link["href"])
46
-
47
- # Include all href that do not start with website link but with "/"
48
- if str(link["href"]).startswith("/"):
49
- if link["href"] not in self.dict_href_links:
50
- print(link["href"])
51
- self.dict_href_links[link["href"]] = None
52
- link_with_www = base_url + link["href"][1:]
53
- if self.url_exists(link_with_www):
54
- print("adjusted link =", link_with_www)
55
- list_links.append(link_with_www)
56
-
57
- # Convert list of links to dictionary and define keys as the links and the values as "Not-checked"
58
- dict_links = dict.fromkeys(list_links, "Not-checked")
59
- return dict_links
60
-
61
- def get_subpage_links(self, l, base_url):
62
- for link in tqdm(l):
63
- print('checking link:', link)
64
- if not link.endswith("/"):
65
- l[link] = "Checked"
66
- dict_links_subpages = {}
67
- else:
68
- # If not crawled through this page start crawling and get links
69
- if l[link] == "Not-checked":
70
- dict_links_subpages = self.get_links(link, base_url)
71
- # Change the dictionary value of the link to "Checked"
72
- l[link] = "Checked"
73
- else:
74
- # Create an empty dictionary in case every link is checked
75
- dict_links_subpages = {}
76
- # Add new dictionary to old dictionary
77
- l = {**dict_links_subpages, **l}
78
- return l
79
-
80
- def get_all_pages(self, url, base_url):
81
- dict_links = {url: "Not-checked"}
82
- self.dict_href_links = {}
83
- counter, counter2 = None, 0
84
- while counter != 0:
85
- counter2 += 1
86
- dict_links2 = self.get_subpage_links(dict_links, base_url)
87
- # Count number of non-values and set counter to 0 if there are no values within the dictionary equal to the string "Not-checked"
88
- # https://stackoverflow.com/questions/48371856/count-the-number-of-occurrences-of-a-certain-value-in-a-dictionary-in-python
89
- counter = sum(value == "Not-checked" for value in dict_links2.values())
90
- dict_links = dict_links2
91
- checked_urls = [
92
- url for url, status in dict_links.items() if status == "Checked"
93
- ]
94
- return checked_urls
95
-
96
-
97
- def get_urls_from_file(file_path: str):
98
- """
99
- Function to get urls from a file
100
- """
101
- with open(file_path, "r") as f:
102
- urls = f.readlines()
103
- urls = [url.strip() for url in urls]
104
- return urls
105
-
106
-
107
- def get_base_url(url):
108
- parsed_url = urlparse(url)
109
- base_url = f"{parsed_url.scheme}://{parsed_url.netloc}/"
110
- return base_url
111
-
112
- def get_prompt(config):
113
- if config["llm_params"]["use_history"]:
114
- if config["llm_params"]["llm_loader"] == "local_llm":
115
- custom_prompt_template = tinyllama_prompt_template_with_history
116
- elif config["llm_params"]["llm_loader"] == "openai":
117
- custom_prompt_template = openai_prompt_template_with_history
118
- # else:
119
- # custom_prompt_template = tinyllama_prompt_template_with_history # default
120
- prompt = PromptTemplate(
121
- template=custom_prompt_template,
122
- input_variables=["context", "chat_history", "question"],
123
- )
124
- else:
125
- if config["llm_params"]["llm_loader"] == "local_llm":
126
- custom_prompt_template = tinyllama_prompt_template
127
- elif config["llm_params"]["llm_loader"] == "openai":
128
- custom_prompt_template = openai_prompt_template
129
- # else:
130
- # custom_prompt_template = tinyllama_prompt_template
131
- prompt = PromptTemplate(
132
- template=custom_prompt_template,
133
- input_variables=["context", "question"],
134
- )
135
- return prompt
136
-
137
- def get_sources(res, answer):
138
- source_elements_dict = {}
139
- source_elements = []
140
- found_sources = []
141
-
142
- source_dict = {} # Dictionary to store URL elements
143
-
144
- for idx, source in enumerate(res["source_documents"]):
145
- source_metadata = source.metadata
146
- url = source_metadata["source"]
147
-
148
- if url not in source_dict:
149
- source_dict[url] = [source.page_content]
150
- else:
151
- source_dict[url].append(source.page_content)
152
-
153
- for source_idx, (url, text_list) in enumerate(source_dict.items()):
154
- full_text = ""
155
- for url_idx, text in enumerate(text_list):
156
- full_text += f"Source {url_idx+1}:\n {text}\n\n\n"
157
- source_elements.append(cl.Text(name=url, content=full_text))
158
- found_sources.append(url)
159
-
160
- if found_sources:
161
- answer += f"\n\nSources: {', '.join(found_sources)} "
162
- else:
163
- answer += f"\n\nNo source found."
164
-
165
- # for idx, source in enumerate(res["source_documents"]):
166
- # title = source.metadata["source"]
167
-
168
- # if title not in source_elements_dict:
169
- # source_elements_dict[title] = {
170
- # "page_number": [source.metadata["page"]],
171
- # "url": source.metadata["source"],
172
- # "content": source.page_content,
173
- # }
174
-
175
- # else:
176
- # source_elements_dict[title]["page_number"].append(source.metadata["page"])
177
- # source_elements_dict[title][
178
- # "content_" + str(source.metadata["page"])
179
- # ] = source.page_content
180
- # # sort the page numbers
181
- # # source_elements_dict[title]["page_number"].sort()
182
-
183
- # for title, source in source_elements_dict.items():
184
- # # create a string for the page numbers
185
- # page_numbers = ", ".join([str(x) for x in source["page_number"]])
186
- # text_for_source = f"Page Number(s): {page_numbers}\nURL: {source['url']}"
187
- # source_elements.append(cl.Pdf(name="File", path=title))
188
- # found_sources.append("File")
189
- # # for pn in source["page_number"]:
190
- # # source_elements.append(
191
- # # cl.Text(name=str(pn), content=source["content_"+str(pn)])
192
- # # )
193
- # # found_sources.append(str(pn))
194
-
195
- # if found_sources:
196
- # answer += f"\nSource:{', '.join(found_sources)}"
197
- # else:
198
- # answer += f"\nNo source found."
199
-
200
- return answer, source_elements
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/modules/llm_tutor.py DELETED
@@ -1,87 +0,0 @@
1
- from langchain import PromptTemplate
2
- from langchain.embeddings import HuggingFaceEmbeddings
3
- from langchain_community.chat_models import ChatOpenAI
4
- from langchain_community.embeddings import OpenAIEmbeddings
5
- from langchain.vectorstores import FAISS
6
- from langchain.chains import RetrievalQA, ConversationalRetrievalChain
7
- from langchain.llms import CTransformers
8
- from langchain.memory import ConversationBufferWindowMemory
9
- from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
10
- import os
11
-
12
- from modules.constants import *
13
- from modules.helpers import get_prompt
14
- from modules.chat_model_loader import ChatModelLoader
15
- from modules.vector_db import VectorDB
16
-
17
-
18
- class LLMTutor:
19
- def __init__(self, config, logger=None):
20
- self.config = config
21
- self.vector_db = VectorDB(config, logger=logger)
22
- if self.config["embedding_options"]["embedd_files"]:
23
- self.vector_db.create_database()
24
- self.vector_db.save_database()
25
-
26
- def set_custom_prompt(self):
27
- """
28
- Prompt template for QA retrieval for each vectorstore
29
- """
30
- prompt = get_prompt(self.config)
31
- # prompt = QA_PROMPT
32
-
33
- return prompt
34
-
35
- # Retrieval QA Chain
36
- def retrieval_qa_chain(self, llm, prompt, db):
37
- if self.config["llm_params"]["use_history"]:
38
- memory = ConversationBufferWindowMemory(
39
- k = self.config["llm_params"]["memory_window"],
40
- memory_key="chat_history", return_messages=True, output_key="answer"
41
- )
42
- qa_chain = ConversationalRetrievalChain.from_llm(
43
- llm=llm,
44
- chain_type="stuff",
45
- retriever=db.as_retriever(
46
- search_kwargs={
47
- "k": self.config["embedding_options"]["search_top_k"]
48
- }
49
- ),
50
- return_source_documents=True,
51
- memory=memory,
52
- combine_docs_chain_kwargs={"prompt": prompt},
53
- )
54
- else:
55
- qa_chain = RetrievalQA.from_chain_type(
56
- llm=llm,
57
- chain_type="stuff",
58
- retriever=db.as_retriever(
59
- search_kwargs={
60
- "k": self.config["embedding_options"]["search_top_k"]
61
- }
62
- ),
63
- return_source_documents=True,
64
- chain_type_kwargs={"prompt": prompt},
65
- )
66
- return qa_chain
67
-
68
- # Loading the model
69
- def load_llm(self):
70
- chat_model_loader = ChatModelLoader(self.config)
71
- llm = chat_model_loader.load_chat_model()
72
- return llm
73
-
74
- # QA Model Function
75
- def qa_bot(self):
76
- db = self.vector_db.load_database()
77
- self.llm = self.load_llm()
78
- qa_prompt = self.set_custom_prompt()
79
- qa = self.retrieval_qa_chain(self.llm, qa_prompt, db)
80
-
81
- return qa
82
-
83
- # output function
84
- def final_result(query):
85
- qa_result = qa_bot()
86
- response = qa_result({"query": query})
87
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/modules/retriever/__init__.py ADDED
File without changes
code/modules/retriever/base.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # template for retriever classes
2
+
3
+
4
+ class BaseRetriever:
5
+ def __init__(self, config):
6
+ self.config = config
7
+
8
+ def return_retriever(self):
9
+ """
10
+ Returns the retriever object
11
+ """
12
+ raise NotImplementedError
code/modules/retriever/chroma_retriever.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .helpers import VectorStoreRetrieverScore
2
+ from .base import BaseRetriever
3
+
4
+
5
+ class ChromaRetriever(BaseRetriever):
6
+ def __init__(self):
7
+ pass
8
+
9
+ def return_retriever(self, db, config):
10
+ retriever = VectorStoreRetrieverScore(
11
+ vectorstore=db,
12
+ # search_type="similarity_score_threshold",
13
+ # search_kwargs={
14
+ # "score_threshold": self.config["vectorstore"][
15
+ # "score_threshold"
16
+ # ],
17
+ # "k": self.config["vectorstore"]["search_top_k"],
18
+ # },
19
+ search_kwargs={
20
+ "k": config["vectorstore"]["search_top_k"],
21
+ },
22
+ )
23
+
24
+ return retriever
code/modules/retriever/colbert_retriever.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import BaseRetriever
2
+
3
+
4
+ class ColbertRetriever(BaseRetriever):
5
+ def __init__(self):
6
+ pass
7
+
8
+ def return_retriever(self, db, config):
9
+ retriever = db.as_langchain_retriever(k=config["vectorstore"]["search_top_k"])
10
+ return retriever
code/modules/retriever/faiss_retriever.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .helpers import VectorStoreRetrieverScore
2
+ from .base import BaseRetriever
3
+
4
+
5
+ class FaissRetriever(BaseRetriever):
6
+ def __init__(self):
7
+ pass
8
+
9
+ def return_retriever(self, db, config):
10
+ retriever = VectorStoreRetrieverScore(
11
+ vectorstore=db,
12
+ # search_type="similarity_score_threshold",
13
+ # search_kwargs={
14
+ # "score_threshold": self.config["vectorstore"][
15
+ # "score_threshold"
16
+ # ],
17
+ # "k": self.config["vectorstore"]["search_top_k"],
18
+ # },
19
+ search_kwargs={
20
+ "k": config["vectorstore"]["search_top_k"],
21
+ },
22
+ )
23
+ return retriever
code/modules/retriever/helpers.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.schema.vectorstore import VectorStoreRetriever
2
+ from langchain.callbacks.manager import CallbackManagerForRetrieverRun
3
+ from langchain.schema.document import Document
4
+ from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun
5
+ from typing import List
6
+
7
+
8
+ class VectorStoreRetrieverScore(VectorStoreRetriever):
9
+
10
+ # See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
11
+ def _get_relevant_documents(
12
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
13
+ ) -> List[Document]:
14
+ docs_and_similarities = (
15
+ self.vectorstore.similarity_search_with_relevance_scores(
16
+ query, **self.search_kwargs
17
+ )
18
+ )
19
+ # Make the score part of the document metadata
20
+ for doc, similarity in docs_and_similarities:
21
+ doc.metadata["score"] = similarity
22
+
23
+ docs = [doc for doc, _ in docs_and_similarities]
24
+ return docs
25
+
26
+ async def _aget_relevant_documents(
27
+ self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
28
+ ) -> List[Document]:
29
+ docs_and_similarities = (
30
+ self.vectorstore.similarity_search_with_relevance_scores(
31
+ query, **self.search_kwargs
32
+ )
33
+ )
34
+ # Make the score part of the document metadata
35
+ for doc, similarity in docs_and_similarities:
36
+ doc.metadata["score"] = similarity
37
+
38
+ docs = [doc for doc, _ in docs_and_similarities]
39
+ return docs
code/modules/retriever/raptor_retriever.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .helpers import VectorStoreRetrieverScore
2
+ from .base import BaseRetriever
3
+
4
+
5
+ class RaptorRetriever(BaseRetriever):
6
+ def __init__(self):
7
+ pass
8
+
9
+ def return_retriever(self, db, config):
10
+ retriever = VectorStoreRetrieverScore(
11
+ vectorstore=db,
12
+ search_kwargs={
13
+ "k": config["vectorstore"]["search_top_k"],
14
+ },
15
+ )
16
+ return retriever
code/modules/retriever/retriever.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.retriever.faiss_retriever import FaissRetriever
2
+ from modules.retriever.chroma_retriever import ChromaRetriever
3
+ from modules.retriever.colbert_retriever import ColbertRetriever
4
+ from modules.retriever.raptor_retriever import RaptorRetriever
5
+
6
+
7
+ class Retriever:
8
+ def __init__(self, config):
9
+ self.config = config
10
+ self.retriever_classes = {
11
+ "FAISS": FaissRetriever,
12
+ "Chroma": ChromaRetriever,
13
+ "RAGatouille": ColbertRetriever,
14
+ "RAPTOR": RaptorRetriever,
15
+ }
16
+ self._create_retriever()
17
+
18
+ def _create_retriever(self):
19
+ db_option = self.config["vectorstore"]["db_option"]
20
+ retriever_class = self.retriever_classes.get(db_option)
21
+ if not retriever_class:
22
+ raise ValueError(f"Invalid db_option: {db_option}")
23
+ self.retriever = retriever_class()
24
+
25
+ def _return_retriever(self, db):
26
+ return self.retriever.return_retriever(db, self.config)
code/modules/vector_db.py DELETED
@@ -1,133 +0,0 @@
1
- import logging
2
- import os
3
- import yaml
4
- from langchain.vectorstores import FAISS
5
-
6
- try:
7
- from modules.embedding_model_loader import EmbeddingModelLoader
8
- from modules.data_loader import DataLoader
9
- from modules.constants import *
10
- from modules.helpers import *
11
- except:
12
- from embedding_model_loader import EmbeddingModelLoader
13
- from data_loader import DataLoader
14
- from constants import *
15
- from helpers import *
16
-
17
-
18
- class VectorDB:
19
- def __init__(self, config, logger=None):
20
- self.config = config
21
- self.db_option = config["embedding_options"]["db_option"]
22
- self.document_names = None
23
- self.webpage_crawler = WebpageCrawler()
24
-
25
- # Set up logging to both console and a file
26
- if logger is None:
27
- self.logger = logging.getLogger(__name__)
28
- self.logger.setLevel(logging.INFO)
29
-
30
- # Console Handler
31
- console_handler = logging.StreamHandler()
32
- console_handler.setLevel(logging.INFO)
33
- formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
34
- console_handler.setFormatter(formatter)
35
- self.logger.addHandler(console_handler)
36
-
37
- # File Handler
38
- log_file_path = "vector_db.log" # Change this to your desired log file path
39
- file_handler = logging.FileHandler(log_file_path, mode="w")
40
- file_handler.setLevel(logging.INFO)
41
- file_handler.setFormatter(formatter)
42
- self.logger.addHandler(file_handler)
43
- else:
44
- self.logger = logger
45
-
46
- self.logger.info("VectorDB instance instantiated")
47
-
48
- def load_files(self):
49
- files = os.listdir(self.config["embedding_options"]["data_path"])
50
- files = [
51
- os.path.join(self.config["embedding_options"]["data_path"], file)
52
- for file in files
53
- ]
54
- urls = get_urls_from_file(self.config["embedding_options"]["url_file_path"])
55
- if self.config["embedding_options"]["expand_urls"]:
56
- all_urls = []
57
- for url in urls:
58
- base_url = get_base_url(url)
59
- all_urls.extend(self.webpage_crawler.get_all_pages(url, base_url))
60
- urls = all_urls
61
- return files, urls
62
-
63
- def clean_url_list(self, urls):
64
- # get lecture pdf links
65
- lecture_pdfs = [link for link in urls if link.endswith(".pdf")]
66
- lecture_pdfs = [link for link in lecture_pdfs if "lecture" in link.lower()]
67
- urls = [link for link in urls if link.endswith("/")] # only keep links that end with a '/'. Extract Files Seperately
68
-
69
- return urls, lecture_pdfs
70
-
71
- def create_embedding_model(self):
72
- self.logger.info("Creating embedding function")
73
- self.embedding_model_loader = EmbeddingModelLoader(self.config)
74
- self.embedding_model = self.embedding_model_loader.load_embedding_model()
75
-
76
- def initialize_database(self, document_chunks: list, document_names: list):
77
- # Track token usage
78
- self.logger.info("Initializing vector_db")
79
- self.logger.info("\tUsing {} as db_option".format(self.db_option))
80
- if self.db_option == "FAISS":
81
- self.vector_db = FAISS.from_documents(
82
- documents=document_chunks, embedding=self.embedding_model
83
- )
84
- self.logger.info("Completed initializing vector_db")
85
-
86
- def create_database(self):
87
- data_loader = DataLoader(self.config)
88
- self.logger.info("Loading data")
89
- files, urls = self.load_files()
90
- urls, lecture_pdfs = self.clean_url_list(urls)
91
- files += lecture_pdfs
92
- files.remove('storage/data/urls.txt')
93
- document_chunks, document_names = data_loader.get_chunks(files, urls)
94
- self.logger.info("Completed loading data")
95
-
96
- self.create_embedding_model()
97
- self.initialize_database(document_chunks, document_names)
98
-
99
- def save_database(self):
100
- self.vector_db.save_local(
101
- os.path.join(
102
- self.config["embedding_options"]["db_path"],
103
- "db_"
104
- + self.config["embedding_options"]["db_option"]
105
- + "_"
106
- + self.config["embedding_options"]["model"],
107
- )
108
- )
109
- self.logger.info("Saved database")
110
-
111
- def load_database(self):
112
- self.create_embedding_model()
113
- self.vector_db = FAISS.load_local(
114
- os.path.join(
115
- self.config["embedding_options"]["db_path"],
116
- "db_"
117
- + self.config["embedding_options"]["db_option"]
118
- + "_"
119
- + self.config["embedding_options"]["model"],
120
- ),
121
- self.embedding_model,
122
- )
123
- self.logger.info("Loaded database")
124
- return self.vector_db
125
-
126
-
127
- if __name__ == "__main__":
128
- with open("code/config.yml", "r") as f:
129
- config = yaml.safe_load(f)
130
- print(config)
131
- vector_db = VectorDB(config)
132
- vector_db.create_database()
133
- vector_db.save_database()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/modules/vectorstore/__init__.py ADDED
File without changes
code/modules/vectorstore/base.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # template for vector store classes
2
+
3
+
4
+ class VectorStoreBase:
5
+ def __init__(self, config):
6
+ self.config = config
7
+
8
+ def _init_vector_db(self):
9
+ """
10
+ Creates a vector store object
11
+ """
12
+ raise NotImplementedError
13
+
14
+ def create_database(self):
15
+ """
16
+ Populates the vector store with documents
17
+ """
18
+ raise NotImplementedError
19
+
20
+ def load_database(self):
21
+ """
22
+ Loads the vector store from disk
23
+ """
24
+ raise NotImplementedError
25
+
26
+ def as_retriever(self):
27
+ """
28
+ Returns the vector store as a retriever
29
+ """
30
+ raise NotImplementedError
31
+
32
+ def __str__(self):
33
+ return self.__class__.__name__
code/modules/vectorstore/chroma.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores import Chroma
2
+ from modules.vectorstore.base import VectorStoreBase
3
+ import os
4
+
5
+
6
+ class ChromaVectorStore(VectorStoreBase):
7
+ def __init__(self, config):
8
+ self.config = config
9
+ self._init_vector_db()
10
+
11
+ def _init_vector_db(self):
12
+ self.chroma = Chroma()
13
+
14
+ def create_database(self, document_chunks, embedding_model):
15
+ self.vectorstore = self.chroma.from_documents(
16
+ documents=document_chunks,
17
+ embedding=embedding_model,
18
+ persist_directory=os.path.join(
19
+ self.config["vectorstore"]["db_path"],
20
+ "db_"
21
+ + self.config["vectorstore"]["db_option"]
22
+ + "_"
23
+ + self.config["vectorstore"]["model"],
24
+ ),
25
+ )
26
+
27
+ def load_database(self, embedding_model):
28
+ self.vectorstore = Chroma(
29
+ persist_directory=os.path.join(
30
+ self.config["vectorstore"]["db_path"],
31
+ "db_"
32
+ + self.config["vectorstore"]["db_option"]
33
+ + "_"
34
+ + self.config["vectorstore"]["model"],
35
+ ),
36
+ embedding_function=embedding_model,
37
+ )
38
+ return self.vectorstore
39
+
40
+ def as_retriever(self):
41
+ return self.vectorstore.as_retriever()
code/modules/vectorstore/colbert.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ragatouille import RAGPretrainedModel
2
+ from modules.vectorstore.base import VectorStoreBase
3
+ import os
4
+
5
+
6
+ class ColbertVectorStore(VectorStoreBase):
7
+ def __init__(self, config):
8
+ self.config = config
9
+ self._init_vector_db()
10
+
11
+ def _init_vector_db(self):
12
+ self.colbert = RAGPretrainedModel.from_pretrained(
13
+ "colbert-ir/colbertv2.0",
14
+ index_root=os.path.join(
15
+ self.config["vectorstore"]["db_path"],
16
+ "db_" + self.config["vectorstore"]["db_option"],
17
+ ),
18
+ )
19
+
20
+ def create_database(self, documents, document_names, document_metadata):
21
+ index_path = self.colbert.index(
22
+ index_name="new_idx",
23
+ collection=documents,
24
+ document_ids=document_names,
25
+ document_metadatas=document_metadata,
26
+ )
27
+
28
+ def load_database(self):
29
+ path = os.path.join(
30
+ self.config["vectorstore"]["db_path"],
31
+ "db_" + self.config["vectorstore"]["db_option"],
32
+ )
33
+ self.vectorstore = RAGPretrainedModel.from_index(
34
+ f"{path}/colbert/indexes/new_idx"
35
+ )
36
+ return self.vectorstore
37
+
38
+ def as_retriever(self):
39
+ return self.vectorstore.as_retriever()
code/modules/{embedding_model_loader.py → vectorstore/embedding_model_loader.py} RENAMED
@@ -1,10 +1,8 @@
1
  from langchain_community.embeddings import OpenAIEmbeddings
2
- from langchain.embeddings import HuggingFaceEmbeddings
3
- from langchain.embeddings import LlamaCppEmbeddings
4
- try:
5
- from modules.constants import *
6
- except:
7
- from constants import *
8
  import os
9
 
10
 
@@ -13,17 +11,22 @@ class EmbeddingModelLoader:
13
  self.config = config
14
 
15
  def load_embedding_model(self):
16
- if self.config["embedding_options"]["model"] in ["text-embedding-ada-002"]:
17
  embedding_model = OpenAIEmbeddings(
18
  deployment="SL-document_embedder",
19
- model=self.config["embedding_options"]["model"],
20
  show_progress_bar=True,
21
  openai_api_key=OPENAI_API_KEY,
 
22
  )
23
  else:
24
  embedding_model = HuggingFaceEmbeddings(
25
- model_name="sentence-transformers/all-MiniLM-L6-v2",
26
- model_kwargs={"device": "cpu"},
 
 
 
 
27
  )
28
  # embedding_model = LlamaCppEmbeddings(
29
  # model_path=os.path.abspath("storage/llama-7b.ggmlv3.q4_0.bin")
 
1
  from langchain_community.embeddings import OpenAIEmbeddings
2
+ from langchain_community.embeddings import HuggingFaceEmbeddings
3
+ from langchain_community.embeddings import LlamaCppEmbeddings
4
+
5
+ from modules.config.constants import *
 
 
6
  import os
7
 
8
 
 
11
  self.config = config
12
 
13
  def load_embedding_model(self):
14
+ if self.config["vectorstore"]["model"] in ["text-embedding-ada-002"]:
15
  embedding_model = OpenAIEmbeddings(
16
  deployment="SL-document_embedder",
17
+ model=self.config["vectorestore"]["model"],
18
  show_progress_bar=True,
19
  openai_api_key=OPENAI_API_KEY,
20
+ disallowed_special=(),
21
  )
22
  else:
23
  embedding_model = HuggingFaceEmbeddings(
24
+ model_name=self.config["vectorstore"]["model"],
25
+ model_kwargs={
26
+ "device": f"{self.config['device']}",
27
+ "token": f"{HUGGINGFACE_TOKEN}",
28
+ "trust_remote_code": True,
29
+ },
30
  )
31
  # embedding_model = LlamaCppEmbeddings(
32
  # model_path=os.path.abspath("storage/llama-7b.ggmlv3.q4_0.bin")
code/modules/vectorstore/faiss.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores import FAISS
2
+ from modules.vectorstore.base import VectorStoreBase
3
+ import os
4
+
5
+
6
+ class FaissVectorStore(VectorStoreBase):
7
+ def __init__(self, config):
8
+ self.config = config
9
+ self._init_vector_db()
10
+
11
+ def _init_vector_db(self):
12
+ self.faiss = FAISS(
13
+ embedding_function=None, index=0, index_to_docstore_id={}, docstore={}
14
+ )
15
+
16
+ def create_database(self, document_chunks, embedding_model):
17
+ self.vectorstore = self.faiss.from_documents(
18
+ documents=document_chunks, embedding=embedding_model
19
+ )
20
+ self.vectorstore.save_local(
21
+ os.path.join(
22
+ self.config["vectorstore"]["db_path"],
23
+ "db_"
24
+ + self.config["vectorstore"]["db_option"]
25
+ + "_"
26
+ + self.config["vectorstore"]["model"],
27
+ )
28
+ )
29
+
30
+ def load_database(self, embedding_model):
31
+ self.vectorstore = self.faiss.load_local(
32
+ os.path.join(
33
+ self.config["vectorstore"]["db_path"],
34
+ "db_"
35
+ + self.config["vectorstore"]["db_option"]
36
+ + "_"
37
+ + self.config["vectorstore"]["model"],
38
+ ),
39
+ embedding_model,
40
+ allow_dangerous_deserialization=True,
41
+ )
42
+ return self.vectorstore
43
+
44
+ def as_retriever(self):
45
+ return self.vectorstore.as_retriever()
code/modules/vectorstore/helpers.py ADDED
File without changes
code/modules/vectorstore/raptor.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code modified from https://github.com/langchain-ai/langchain/blob/master/cookbook/RAPTOR.ipynb
2
+
3
+ from typing import Dict, List, Optional, Tuple
4
+ import os
5
+ import numpy as np
6
+ import pandas as pd
7
+ import umap
8
+ from langchain_core.prompts import ChatPromptTemplate
9
+ from langchain_core.output_parsers import StrOutputParser
10
+ from sklearn.mixture import GaussianMixture
11
+ from langchain_community.chat_models import ChatOpenAI
12
+ from langchain_community.vectorstores import FAISS
13
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
14
+ from modules.vectorstore.base import VectorStoreBase
15
+
16
+ RANDOM_SEED = 42
17
+
18
+
19
+ class RAPTORVectoreStore(VectorStoreBase):
20
+ def __init__(self, config, documents=[], text_splitter=None, embedding_model=None):
21
+ self.documents = documents
22
+ self.config = config
23
+ self.text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
24
+ chunk_size=self.config["splitter_options"]["chunk_size"],
25
+ chunk_overlap=self.config["splitter_options"]["chunk_overlap"],
26
+ separators=self.config["splitter_options"]["chunk_separators"],
27
+ disallowed_special=(),
28
+ )
29
+ self.embd = embedding_model
30
+ self.model = ChatOpenAI(
31
+ model="gpt-3.5-turbo",
32
+ )
33
+
34
+ def concat_documents(self, documents):
35
+ d_sorted = sorted(documents, key=lambda x: x.metadata["source"])
36
+ d_reversed = list(reversed(d_sorted))
37
+ concatenated_content = "\n\n\n --- \n\n\n".join(
38
+ [doc.page_content for doc in d_reversed]
39
+ )
40
+ return concatenated_content
41
+
42
+ def split_documents(self, documents):
43
+ concatenated_content = self.concat_documents(documents)
44
+ texts_split = self.text_splitter.split_text(concatenated_content)
45
+ return texts_split
46
+
47
+ def add_documents(self, documents):
48
+ self.documents.extend(documents)
49
+
50
+ def global_cluster_embeddings(
51
+ self,
52
+ embeddings: np.ndarray,
53
+ dim: int,
54
+ n_neighbors: Optional[int] = None,
55
+ metric: str = "cosine",
56
+ ) -> np.ndarray:
57
+ """
58
+ Perform global dimensionality reduction on the embeddings using UMAP.
59
+
60
+ Parameters:
61
+ - embeddings: The input embeddings as a numpy array.
62
+ - dim: The target dimensionality for the reduced space.
63
+ - n_neighbors: Optional; the number of neighbors to consider for each point.
64
+ If not provided, it defaults to the square root of the number of embeddings.
65
+ - metric: The distance metric to use for UMAP.
66
+
67
+ Returns:
68
+ - A numpy array of the embeddings reduced to the specified dimensionality.
69
+ """
70
+ if n_neighbors is None:
71
+ n_neighbors = int((len(embeddings) - 1) ** 0.5)
72
+ return umap.UMAP(
73
+ n_neighbors=n_neighbors, n_components=dim, metric=metric
74
+ ).fit_transform(embeddings)
75
+
76
+ def local_cluster_embeddings(
77
+ self,
78
+ embeddings: np.ndarray,
79
+ dim: int,
80
+ num_neighbors: int = 10,
81
+ metric: str = "cosine",
82
+ ) -> np.ndarray:
83
+ """
84
+ Perform local dimensionality reduction on the embeddings using UMAP, typically after global clustering.
85
+
86
+ Parameters:
87
+ - embeddings: The input embeddings as a numpy array.
88
+ - dim: The target dimensionality for the reduced space.
89
+ - num_neighbors: The number of neighbors to consider for each point.
90
+ - metric: The distance metric to use for UMAP.
91
+
92
+ Returns:
93
+ - A numpy array of the embeddings reduced to the specified dimensionality.
94
+ """
95
+ return umap.UMAP(
96
+ n_neighbors=num_neighbors, n_components=dim, metric=metric
97
+ ).fit_transform(embeddings)
98
+
99
+ def get_optimal_clusters(
100
+ self,
101
+ embeddings: np.ndarray,
102
+ max_clusters: int = 50,
103
+ random_state: int = RANDOM_SEED,
104
+ ) -> int:
105
+ """
106
+ Determine the optimal number of clusters using the Bayesian Information Criterion (BIC) with a Gaussian Mixture Model.
107
+
108
+ Parameters:
109
+ - embeddings: The input embeddings as a numpy array.
110
+ - max_clusters: The maximum number of clusters to consider.
111
+ - random_state: Seed for reproducibility.
112
+
113
+ Returns:
114
+ - An integer representing the optimal number of clusters found.
115
+ """
116
+ max_clusters = min(max_clusters, len(embeddings))
117
+ n_clusters = np.arange(1, max_clusters)
118
+ bics = []
119
+ for n in n_clusters:
120
+ gm = GaussianMixture(n_components=n, random_state=random_state)
121
+ gm.fit(embeddings)
122
+ bics.append(gm.bic(embeddings))
123
+ return n_clusters[np.argmin(bics)]
124
+
125
+ def GMM_cluster(
126
+ self, embeddings: np.ndarray, threshold: float, random_state: int = 0
127
+ ):
128
+ """
129
+ Cluster embeddings using a Gaussian Mixture Model (GMM) based on a probability threshold.
130
+
131
+ Parameters:
132
+ - embeddings: The input embeddings as a numpy array.
133
+ - threshold: The probability threshold for assigning an embedding to a cluster.
134
+ - random_state: Seed for reproducibility.
135
+
136
+ Returns:
137
+ - A tuple containing the cluster labels and the number of clusters determined.
138
+ """
139
+ n_clusters = self.get_optimal_clusters(embeddings)
140
+ gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
141
+ gm.fit(embeddings)
142
+ probs = gm.predict_proba(embeddings)
143
+ labels = [np.where(prob > threshold)[0] for prob in probs]
144
+ return labels, n_clusters
145
+
146
+ def perform_clustering(
147
+ self,
148
+ embeddings: np.ndarray,
149
+ dim: int,
150
+ threshold: float,
151
+ ) -> List[np.ndarray]:
152
+ """
153
+ Perform clustering on the embeddings by first reducing their dimensionality globally, then clustering
154
+ using a Gaussian Mixture Model, and finally performing local clustering within each global cluster.
155
+
156
+ Parameters:
157
+ - embeddings: The input embeddings as a numpy array.
158
+ - dim: The target dimensionality for UMAP reduction.
159
+ - threshold: The probability threshold for assigning an embedding to a cluster in GMM.
160
+
161
+ Returns:
162
+ - A list of numpy arrays, where each array contains the cluster IDs for each embedding.
163
+ """
164
+ if len(embeddings) <= dim + 1:
165
+ # Avoid clustering when there's insufficient data
166
+ return [np.array([0]) for _ in range(len(embeddings))]
167
+
168
+ # Global dimensionality reduction
169
+ reduced_embeddings_global = self.global_cluster_embeddings(embeddings, dim)
170
+ # Global clustering
171
+ global_clusters, n_global_clusters = self.GMM_cluster(
172
+ reduced_embeddings_global, threshold
173
+ )
174
+
175
+ all_local_clusters = [np.array([]) for _ in range(len(embeddings))]
176
+ total_clusters = 0
177
+
178
+ # Iterate through each global cluster to perform local clustering
179
+ for i in range(n_global_clusters):
180
+ # Extract embeddings belonging to the current global cluster
181
+ global_cluster_embeddings_ = embeddings[
182
+ np.array([i in gc for gc in global_clusters])
183
+ ]
184
+
185
+ if len(global_cluster_embeddings_) == 0:
186
+ continue
187
+ if len(global_cluster_embeddings_) <= dim + 1:
188
+ # Handle small clusters with direct assignment
189
+ local_clusters = [np.array([0]) for _ in global_cluster_embeddings_]
190
+ n_local_clusters = 1
191
+ else:
192
+ # Local dimensionality reduction and clustering
193
+ reduced_embeddings_local = self.local_cluster_embeddings(
194
+ global_cluster_embeddings_, dim
195
+ )
196
+ local_clusters, n_local_clusters = self.GMM_cluster(
197
+ reduced_embeddings_local, threshold
198
+ )
199
+
200
+ # Assign local cluster IDs, adjusting for total clusters already processed
201
+ for j in range(n_local_clusters):
202
+ local_cluster_embeddings_ = global_cluster_embeddings_[
203
+ np.array([j in lc for lc in local_clusters])
204
+ ]
205
+ indices = np.where(
206
+ (embeddings == local_cluster_embeddings_[:, None]).all(-1)
207
+ )[1]
208
+ for idx in indices:
209
+ all_local_clusters[idx] = np.append(
210
+ all_local_clusters[idx], j + total_clusters
211
+ )
212
+
213
+ total_clusters += n_local_clusters
214
+
215
+ return all_local_clusters
216
+
217
+ def embed(self, texts):
218
+ """
219
+ Generate embeddings for a list of text documents.
220
+
221
+ This function assumes the existence of an `embd` object with a method `embed_documents`
222
+ that takes a list of texts and returns their embeddings.
223
+
224
+ Parameters:
225
+ - texts: List[str], a list of text documents to be embedded.
226
+
227
+ Returns:
228
+ - numpy.ndarray: An array of embeddings for the given text documents.
229
+ """
230
+ text_embeddings = self.embd.embed_documents(texts)
231
+ text_embeddings_np = np.array(text_embeddings)
232
+ return text_embeddings_np
233
+
234
+ def embed_cluster_texts(self, texts):
235
+ """
236
+ Embeds a list of texts and clusters them, returning a DataFrame with texts, their embeddings, and cluster labels.
237
+
238
+ This function combines embedding generation and clustering into a single step. It assumes the existence
239
+ of a previously defined `perform_clustering` function that performs clustering on the embeddings.
240
+
241
+ Parameters:
242
+ - texts: List[str], a list of text documents to be processed.
243
+
244
+ Returns:
245
+ - pandas.DataFrame: A DataFrame containing the original texts, their embeddings, and the assigned cluster labels.
246
+ """
247
+ text_embeddings_np = self.embed(texts) # Generate embeddings
248
+ cluster_labels = self.perform_clustering(
249
+ text_embeddings_np, 10, 0.1
250
+ ) # Perform clustering on the embeddings
251
+ df = pd.DataFrame() # Initialize a DataFrame to store the results
252
+ df["text"] = texts # Store original texts
253
+ df["embd"] = list(
254
+ text_embeddings_np
255
+ ) # Store embeddings as a list in the DataFrame
256
+ df["cluster"] = cluster_labels # Store cluster labels
257
+ return df
258
+
259
+ def fmt_txt(self, df: pd.DataFrame) -> str:
260
+ """
261
+ Formats the text documents in a DataFrame into a single string.
262
+
263
+ Parameters:
264
+ - df: DataFrame containing the 'text' column with text documents to format.
265
+
266
+ Returns:
267
+ - A single string where all text documents are joined by a specific delimiter.
268
+ """
269
+ unique_txt = df["text"].tolist()
270
+ return "--- --- \n --- --- ".join(unique_txt)
271
+
272
+ def embed_cluster_summarize_texts(
273
+ self, texts: List[str], level: int
274
+ ) -> Tuple[pd.DataFrame, pd.DataFrame]:
275
+ """
276
+ Embeds, clusters, and summarizes a list of texts. This function first generates embeddings for the texts,
277
+ clusters them based on similarity, expands the cluster assignments for easier processing, and then summarizes
278
+ the content within each cluster.
279
+
280
+ Parameters:
281
+ - texts: A list of text documents to be processed.
282
+ - level: An integer parameter that could define the depth or detail of processing.
283
+
284
+ Returns:
285
+ - Tuple containing two DataFrames:
286
+ 1. The first DataFrame (`df_clusters`) includes the original texts, their embeddings, and cluster assignments.
287
+ 2. The second DataFrame (`df_summary`) contains summaries for each cluster, the specified level of detail,
288
+ and the cluster identifiers.
289
+ """
290
+
291
+ # Embed and cluster the texts, resulting in a DataFrame with 'text', 'embd', and 'cluster' columns
292
+ df_clusters = self.embed_cluster_texts(texts)
293
+
294
+ # Prepare to expand the DataFrame for easier manipulation of clusters
295
+ expanded_list = []
296
+
297
+ # Expand DataFrame entries to document-cluster pairings for straightforward processing
298
+ for index, row in df_clusters.iterrows():
299
+ for cluster in row["cluster"]:
300
+ expanded_list.append(
301
+ {"text": row["text"], "embd": row["embd"], "cluster": cluster}
302
+ )
303
+
304
+ # Create a new DataFrame from the expanded list
305
+ expanded_df = pd.DataFrame(expanded_list)
306
+
307
+ # Retrieve unique cluster identifiers for processing
308
+ all_clusters = expanded_df["cluster"].unique()
309
+
310
+ print(f"--Generated {len(all_clusters)} clusters--")
311
+
312
+ # Summarization
313
+ template = """Here is content from the course DS598: Deep Learning for Data Science.
314
+
315
+ The content may be form webapge about the course, or lecture content, or any other relevant information.
316
+ If the content is in bullet points (from pdf lectre slides), you can summarize the bullet points.
317
+
318
+ Give a detailed summary of the content below.
319
+
320
+ Documentation:
321
+ {context}
322
+ """
323
+ prompt = ChatPromptTemplate.from_template(template)
324
+ chain = prompt | self.model | StrOutputParser()
325
+
326
+ # Format text within each cluster for summarization
327
+ summaries = []
328
+ for i in all_clusters:
329
+ df_cluster = expanded_df[expanded_df["cluster"] == i]
330
+ formatted_txt = self.fmt_txt(df_cluster)
331
+ summaries.append(chain.invoke({"context": formatted_txt}))
332
+
333
+ # Create a DataFrame to store summaries with their corresponding cluster and level
334
+ df_summary = pd.DataFrame(
335
+ {
336
+ "summaries": summaries,
337
+ "level": [level] * len(summaries),
338
+ "cluster": list(all_clusters),
339
+ }
340
+ )
341
+
342
+ return df_clusters, df_summary
343
+
344
+ def recursive_embed_cluster_summarize(
345
+ self, texts: List[str], level: int = 1, n_levels: int = 3
346
+ ) -> Dict[int, Tuple[pd.DataFrame, pd.DataFrame]]:
347
+ """
348
+ Recursively embeds, clusters, and summarizes texts up to a specified level or until
349
+ the number of unique clusters becomes 1, storing the results at each level.
350
+
351
+ Parameters:
352
+ - texts: List[str], texts to be processed.
353
+ - level: int, current recursion level (starts at 1).
354
+ - n_levels: int, maximum depth of recursion.
355
+
356
+ Returns:
357
+ - Dict[int, Tuple[pd.DataFrame, pd.DataFrame]], a dictionary where keys are the recursion
358
+ levels and values are tuples containing the clusters DataFrame and summaries DataFrame at that level.
359
+ """
360
+ results = {} # Dictionary to store results at each level
361
+
362
+ # Perform embedding, clustering, and summarization for the current level
363
+ df_clusters, df_summary = self.embed_cluster_summarize_texts(texts, level)
364
+
365
+ # Store the results of the current level
366
+ results[level] = (df_clusters, df_summary)
367
+
368
+ # Determine if further recursion is possible and meaningful
369
+ unique_clusters = df_summary["cluster"].nunique()
370
+ if level < n_levels and unique_clusters > 1:
371
+ # Use summaries as the input texts for the next level of recursion
372
+ new_texts = df_summary["summaries"].tolist()
373
+ next_level_results = self.recursive_embed_cluster_summarize(
374
+ new_texts, level + 1, n_levels
375
+ )
376
+
377
+ # Merge the results from the next level into the current results dictionary
378
+ results.update(next_level_results)
379
+
380
+ return results
381
+
382
+ def get_vector_db(self):
383
+ """
384
+ Generate a retriever object from a list of documents.
385
+
386
+ Parameters:
387
+ - documents: List of document objects.
388
+
389
+ Returns:
390
+ - A retriever object.
391
+ """
392
+ leaf_texts = self.split_documents(self.documents)
393
+ results = self.recursive_embed_cluster_summarize(
394
+ leaf_texts, level=1, n_levels=10
395
+ )
396
+
397
+ all_texts = leaf_texts.copy()
398
+ # Iterate through the results to extract summaries from each level and add them to all_texts
399
+ for level in sorted(results.keys()):
400
+ # Extract summaries from the current level's DataFrame
401
+ summaries = results[level][1]["summaries"].tolist()
402
+ # Extend all_texts with the summaries from the current level
403
+ all_texts.extend(summaries)
404
+
405
+ # Now, use all_texts to build the vectorstore
406
+ vectorstore = FAISS.from_texts(texts=all_texts, embedding=self.embd)
407
+ return vectorstore
408
+
409
+ def create_database(self, documents, embedding_model):
410
+ self.documents = documents
411
+ self.embd = embedding_model
412
+ self.vectorstore = self.get_vector_db()
413
+ self.vectorstore.save_local(
414
+ os.path.join(
415
+ self.config["vectorstore"]["db_path"],
416
+ "db_"
417
+ + self.config["vectorstore"]["db_option"]
418
+ + "_"
419
+ + self.config["vectorstore"]["model"],
420
+ )
421
+ )
422
+
423
+ def load_database(self, embedding_model):
424
+ self.vectorstore = FAISS.load_local(
425
+ os.path.join(
426
+ self.config["vectorstore"]["db_path"],
427
+ "db_"
428
+ + self.config["vectorstore"]["db_option"]
429
+ + "_"
430
+ + self.config["vectorstore"]["model"],
431
+ ),
432
+ embedding_model,
433
+ allow_dangerous_deserialization=True,
434
+ )
435
+ return self.vectorstore
436
+
437
+ def as_retriever(self):
438
+ return self.vectorstore.as_retriever()
code/modules/vectorstore/store_manager.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.vectorstore.vectorstore import VectorStore
2
+ from modules.vectorstore.helpers import *
3
+ from modules.dataloader.webpage_crawler import WebpageCrawler
4
+ from modules.dataloader.data_loader import DataLoader
5
+ from modules.dataloader.helpers import *
6
+ from modules.vectorstore.embedding_model_loader import EmbeddingModelLoader
7
+ import logging
8
+ import os
9
+ import time
10
+ import asyncio
11
+
12
+
13
+ class VectorStoreManager:
14
+ def __init__(self, config, logger=None):
15
+ self.config = config
16
+ self.document_names = None
17
+
18
+ # Set up logging to both console and a file
19
+ self.logger = logger or self._setup_logging()
20
+ self.webpage_crawler = WebpageCrawler()
21
+ self.vector_db = VectorStore(self.config)
22
+
23
+ self.logger.info("VectorDB instance instantiated")
24
+
25
+ def _setup_logging(self):
26
+ logger = logging.getLogger(__name__)
27
+ if not logger.hasHandlers():
28
+ logger.setLevel(logging.INFO)
29
+ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
30
+
31
+ # Console Handler
32
+ console_handler = logging.StreamHandler()
33
+ console_handler.setLevel(logging.INFO)
34
+ console_handler.setFormatter(formatter)
35
+ logger.addHandler(console_handler)
36
+
37
+ # Ensure log directory exists
38
+ log_directory = self.config["log_dir"]
39
+ os.makedirs(log_directory, exist_ok=True)
40
+
41
+ # File Handler
42
+ log_file_path = os.path.join(log_directory, "vector_db.log")
43
+ file_handler = logging.FileHandler(log_file_path, mode="w")
44
+ file_handler.setLevel(logging.INFO)
45
+ file_handler.setFormatter(formatter)
46
+ logger.addHandler(file_handler)
47
+
48
+ return logger
49
+
50
+ def load_files(self):
51
+
52
+ files = os.listdir(self.config["vectorstore"]["data_path"])
53
+ files = [
54
+ os.path.join(self.config["vectorstore"]["data_path"], file)
55
+ for file in files
56
+ ]
57
+ urls = get_urls_from_file(self.config["vectorstore"]["url_file_path"])
58
+ if self.config["vectorstore"]["expand_urls"]:
59
+ all_urls = []
60
+ for url in urls:
61
+ loop = asyncio.get_event_loop()
62
+ all_urls.extend(
63
+ loop.run_until_complete(
64
+ self.webpage_crawler.get_all_pages(
65
+ url, url
66
+ ) # only get child urls, if you want to get all urls, replace the second argument with the base url
67
+ )
68
+ )
69
+ urls = all_urls
70
+ return files, urls
71
+
72
+ def create_embedding_model(self):
73
+
74
+ self.logger.info("Creating embedding function")
75
+ embedding_model_loader = EmbeddingModelLoader(self.config)
76
+ embedding_model = embedding_model_loader.load_embedding_model()
77
+ return embedding_model
78
+
79
+ def initialize_database(
80
+ self,
81
+ document_chunks: list,
82
+ document_names: list,
83
+ documents: list,
84
+ document_metadata: list,
85
+ ):
86
+ if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma", "RAPTOR"]:
87
+ self.embedding_model = self.create_embedding_model()
88
+ else:
89
+ self.embedding_model = None
90
+
91
+ self.logger.info("Initializing vector_db")
92
+ self.logger.info(
93
+ "\tUsing {} as db_option".format(self.config["vectorstore"]["db_option"])
94
+ )
95
+ self.vector_db._create_database(
96
+ document_chunks,
97
+ document_names,
98
+ documents,
99
+ document_metadata,
100
+ self.embedding_model,
101
+ )
102
+
103
+ def create_database(self):
104
+
105
+ start_time = time.time() # Start time for creating database
106
+ data_loader = DataLoader(self.config, self.logger)
107
+ self.logger.info("Loading data")
108
+ files, urls = self.load_files()
109
+ files, webpages = self.webpage_crawler.clean_url_list(urls)
110
+ self.logger.info(f"Number of files: {len(files)}")
111
+ self.logger.info(f"Number of webpages: {len(webpages)}")
112
+ if f"{self.config['vectorstore']['url_file_path']}" in files:
113
+ files.remove(f"{self.config['vectorstores']['url_file_path']}") # cleanup
114
+ document_chunks, document_names, documents, document_metadata = (
115
+ data_loader.get_chunks(files, webpages)
116
+ )
117
+ num_documents = len(document_chunks)
118
+ self.logger.info(f"Number of documents in the DB: {num_documents}")
119
+ metadata_keys = list(document_metadata[0].keys())
120
+ self.logger.info(f"Metadata keys: {metadata_keys}")
121
+ self.logger.info("Completed loading data")
122
+ self.initialize_database(
123
+ document_chunks, document_names, documents, document_metadata
124
+ )
125
+ end_time = time.time() # End time for creating database
126
+ self.logger.info("Created database")
127
+ self.logger.info(
128
+ f"Time taken to create database: {end_time - start_time} seconds"
129
+ )
130
+
131
+ def load_database(self):
132
+
133
+ start_time = time.time() # Start time for loading database
134
+ if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma", "RAPTOR"]:
135
+ self.embedding_model = self.create_embedding_model()
136
+ else:
137
+ self.embedding_model = None
138
+ self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
139
+ end_time = time.time() # End time for loading database
140
+ self.logger.info(
141
+ f"Time taken to load database: {end_time - start_time} seconds"
142
+ )
143
+ self.logger.info("Loaded database")
144
+ return self.loaded_vector_db
145
+
146
+
147
+ if __name__ == "__main__":
148
+ import yaml
149
+
150
+ with open("modules/config/config.yml", "r") as f:
151
+ config = yaml.safe_load(f)
152
+ print(config)
153
+ print(f"Trying to create database with config: {config}")
154
+ vector_db = VectorStoreManager(config)
155
+ vector_db.create_database()
156
+ print("Created database")
157
+
158
+ print(f"Trying to load the database")
159
+ vector_db = VectorStoreManager(config)
160
+ vector_db.load_database()
161
+ print("Loaded database")
162
+
163
+ print(f"View the logs at {config['log_dir']}/vector_db.log")
code/modules/vectorstore/vectorstore.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.vectorstore.faiss import FaissVectorStore
2
+ from modules.vectorstore.chroma import ChromaVectorStore
3
+ from modules.vectorstore.colbert import ColbertVectorStore
4
+ from modules.vectorstore.raptor import RAPTORVectoreStore
5
+
6
+
7
+ class VectorStore:
8
+ def __init__(self, config):
9
+ self.config = config
10
+ self.vectorstore = None
11
+ self.vectorstore_classes = {
12
+ "FAISS": FaissVectorStore,
13
+ "Chroma": ChromaVectorStore,
14
+ "RAGatouille": ColbertVectorStore,
15
+ "RAPTOR": RAPTORVectoreStore,
16
+ }
17
+
18
+ def _create_database(
19
+ self,
20
+ document_chunks,
21
+ document_names,
22
+ documents,
23
+ document_metadata,
24
+ embedding_model,
25
+ ):
26
+ db_option = self.config["vectorstore"]["db_option"]
27
+ vectorstore_class = self.vectorstore_classes.get(db_option)
28
+ if not vectorstore_class:
29
+ raise ValueError(f"Invalid db_option: {db_option}")
30
+
31
+ self.vectorstore = vectorstore_class(self.config)
32
+
33
+ if db_option == "RAGatouille":
34
+ self.vectorstore.create_database(
35
+ documents, document_names, document_metadata
36
+ )
37
+ else:
38
+ self.vectorstore.create_database(document_chunks, embedding_model)
39
+
40
+ def _load_database(self, embedding_model):
41
+ db_option = self.config["vectorstore"]["db_option"]
42
+ vectorstore_class = self.vectorstore_classes.get(db_option)
43
+ if not vectorstore_class:
44
+ raise ValueError(f"Invalid db_option: {db_option}")
45
+
46
+ self.vectorstore = vectorstore_class(self.config)
47
+
48
+ if db_option == "RAGatouille":
49
+ return self.vectorstore.load_database()
50
+ else:
51
+ return self.vectorstore.load_database(embedding_model)
52
+
53
+ def _as_retriever(self):
54
+ return self.vectorstore.as_retriever()
55
+
56
+ def _get_vectorstore(self):
57
+ return self.vectorstore
code/public/acastusphoton-svgrepo-com.svg ADDED
code/public/adv-screen-recorder-svgrepo-com.svg ADDED
code/public/alarmy-svgrepo-com.svg ADDED
code/public/avatars/ai-tutor.png ADDED