Merge pull request #28 from DL4DS/dev_branch
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .github/workflows/push_to_hf_space_prototype.yml +20 -0
- .gitignore +9 -1
- Dockerfile +9 -7
- Dockerfile.dev +31 -0
- README.md +71 -22
- {.chainlit → code/.chainlit}/config.toml +64 -33
- code/__init__.py +1 -0
- chainlit.md → code/chainlit.md +0 -0
- code/main.py +91 -39
- code/modules/chat/__init__.py +0 -0
- code/modules/{chat_model_loader.py → chat/chat_model_loader.py} +2 -3
- code/modules/chat/helpers.py +104 -0
- code/modules/chat/llm_tutor.py +211 -0
- code/modules/chat_processor/__init__.py +0 -0
- code/modules/chat_processor/base.py +12 -0
- code/modules/chat_processor/chat_processor.py +30 -0
- code/modules/chat_processor/literal_ai.py +37 -0
- code/modules/config/__init__.py +0 -0
- code/{config.yml → modules/config/config.yml} +31 -12
- code/modules/{constants.py → config/constants.py} +4 -1
- code/modules/data_loader.py +0 -287
- code/modules/dataloader/__init__.py +0 -0
- code/modules/dataloader/data_loader.py +360 -0
- code/modules/dataloader/helpers.py +108 -0
- code/modules/dataloader/webpage_crawler.py +115 -0
- code/modules/helpers.py +0 -200
- code/modules/llm_tutor.py +0 -87
- code/modules/retriever/__init__.py +0 -0
- code/modules/retriever/base.py +12 -0
- code/modules/retriever/chroma_retriever.py +24 -0
- code/modules/retriever/colbert_retriever.py +10 -0
- code/modules/retriever/faiss_retriever.py +23 -0
- code/modules/retriever/helpers.py +39 -0
- code/modules/retriever/raptor_retriever.py +16 -0
- code/modules/retriever/retriever.py +26 -0
- code/modules/vector_db.py +0 -133
- code/modules/vectorstore/__init__.py +0 -0
- code/modules/vectorstore/base.py +33 -0
- code/modules/vectorstore/chroma.py +41 -0
- code/modules/vectorstore/colbert.py +39 -0
- code/modules/{embedding_model_loader.py → vectorstore/embedding_model_loader.py} +13 -10
- code/modules/vectorstore/faiss.py +45 -0
- code/modules/vectorstore/helpers.py +0 -0
- code/modules/vectorstore/raptor.py +438 -0
- code/modules/vectorstore/store_manager.py +163 -0
- code/modules/vectorstore/vectorstore.py +57 -0
- code/public/acastusphoton-svgrepo-com.svg +2 -0
- code/public/adv-screen-recorder-svgrepo-com.svg +2 -0
- code/public/alarmy-svgrepo-com.svg +2 -0
- code/public/avatars/ai-tutor.png +0 -0
.github/workflows/push_to_hf_space_prototype.yml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Push Prototype to HuggingFace
|
2 |
+
|
3 |
+
on:
|
4 |
+
pull_request:
|
5 |
+
branches:
|
6 |
+
- dev_branch
|
7 |
+
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
build:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- name: Deploy Prototype to HuggingFace
|
14 |
+
uses: nateraw/[email protected]
|
15 |
+
with:
|
16 |
+
github_repo_id: DL4DS/dl4ds_tutor
|
17 |
+
huggingface_repo_id: dl4ds/tutor_dev
|
18 |
+
repo_type: space
|
19 |
+
space_sdk: static
|
20 |
+
hf_token: ${{ secrets.HF_TOKEN }}
|
.gitignore
CHANGED
@@ -160,4 +160,12 @@ cython_debug/
|
|
160 |
#.idea/
|
161 |
|
162 |
# log files
|
163 |
-
*.log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
#.idea/
|
161 |
|
162 |
# log files
|
163 |
+
*.log
|
164 |
+
|
165 |
+
.ragatouille/*
|
166 |
+
*/__pycache__/*
|
167 |
+
.chainlit/translations/
|
168 |
+
storage/logs/*
|
169 |
+
vectorstores/*
|
170 |
+
|
171 |
+
*/.files/*
|
Dockerfile
CHANGED
@@ -1,18 +1,17 @@
|
|
1 |
-
FROM python:3.
|
2 |
|
3 |
WORKDIR /code
|
4 |
|
5 |
COPY ./requirements.txt /code/requirements.txt
|
6 |
|
7 |
-
RUN pip install --
|
8 |
|
9 |
-
RUN pip install --no-cache-dir
|
10 |
-
|
11 |
-
RUN pip install --upgrade --force-reinstall --no-cache-dir llama-cpp-python==0.2.32
|
12 |
|
13 |
COPY . /code
|
14 |
|
15 |
-
|
|
|
16 |
|
17 |
# Change permissions to allow writing to the directory
|
18 |
RUN chmod -R 777 /code
|
@@ -23,7 +22,10 @@ RUN mkdir /code/logs && chmod 777 /code/logs
|
|
23 |
# Create a cache directory within the application's working directory
|
24 |
RUN mkdir /.cache && chmod -R 777 /.cache
|
25 |
|
|
|
|
|
26 |
RUN --mount=type=secret,id=HUGGINGFACEHUB_API_TOKEN,mode=0444,required=true
|
27 |
RUN --mount=type=secret,id=OPENAI_API_KEY,mode=0444,required=true
|
28 |
|
29 |
-
|
|
|
|
1 |
+
FROM python:3.11
|
2 |
|
3 |
WORKDIR /code
|
4 |
|
5 |
COPY ./requirements.txt /code/requirements.txt
|
6 |
|
7 |
+
RUN pip install --upgrade pip
|
8 |
|
9 |
+
RUN pip install --no-cache-dir -r /code/requirements.txt
|
|
|
|
|
10 |
|
11 |
COPY . /code
|
12 |
|
13 |
+
# List the contents of the /code directory to verify files are copied correctly
|
14 |
+
RUN ls -R /code
|
15 |
|
16 |
# Change permissions to allow writing to the directory
|
17 |
RUN chmod -R 777 /code
|
|
|
22 |
# Create a cache directory within the application's working directory
|
23 |
RUN mkdir /.cache && chmod -R 777 /.cache
|
24 |
|
25 |
+
WORKDIR /code/code
|
26 |
+
|
27 |
RUN --mount=type=secret,id=HUGGINGFACEHUB_API_TOKEN,mode=0444,required=true
|
28 |
RUN --mount=type=secret,id=OPENAI_API_KEY,mode=0444,required=true
|
29 |
|
30 |
+
# Default command to run the application
|
31 |
+
CMD ["sh", "-c", "python -m modules.vectorstore.store_manager && chainlit run main.py --host 0.0.0.0 --port 7860"]
|
Dockerfile.dev
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.11
|
2 |
+
|
3 |
+
WORKDIR /code
|
4 |
+
|
5 |
+
COPY ./requirements.txt /code/requirements.txt
|
6 |
+
|
7 |
+
RUN pip install --upgrade pip
|
8 |
+
|
9 |
+
RUN pip install --no-cache-dir -r /code/requirements.txt
|
10 |
+
|
11 |
+
COPY . /code
|
12 |
+
|
13 |
+
# List the contents of the /code directory to verify files are copied correctly
|
14 |
+
RUN ls -R /code
|
15 |
+
|
16 |
+
# Change permissions to allow writing to the directory
|
17 |
+
RUN chmod -R 777 /code
|
18 |
+
|
19 |
+
# Create a logs directory and set permissions
|
20 |
+
RUN mkdir /code/logs && chmod 777 /code/logs
|
21 |
+
|
22 |
+
# Create a cache directory within the application's working directory
|
23 |
+
RUN mkdir /.cache && chmod -R 777 /.cache
|
24 |
+
|
25 |
+
WORKDIR /code/code
|
26 |
+
|
27 |
+
# Expose the port the app runs on
|
28 |
+
EXPOSE 8051
|
29 |
+
|
30 |
+
# Default command to run the application
|
31 |
+
CMD ["sh", "-c", "python -m modules.vectorstore.store_manager && chainlit run main.py --host 0.0.0.0 --port 8051"]
|
README.md
CHANGED
@@ -1,35 +1,84 @@
|
|
1 |
-
|
2 |
-
title: Dl4ds Tutor
|
3 |
-
emoji: 🏃
|
4 |
-
colorFrom: green
|
5 |
-
colorTo: red
|
6 |
-
sdk: docker
|
7 |
-
pinned: false
|
8 |
-
hf_oauth: true
|
9 |
-
---
|
10 |
|
11 |
-
|
12 |
-
===========
|
13 |
|
14 |
-
|
15 |
|
16 |
-
|
17 |
|
18 |
-
|
|
|
|
|
|
|
19 |
|
20 |
-
|
|
|
|
|
21 |
|
22 |
-
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
|
25 |
-
```
|
26 |
-
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
|
29 |
-
```
|
|
|
|
|
30 |
|
31 |
See the [docs](https://github.com/DL4DS/dl4ds_tutor/tree/main/docs) for more information.
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
## Contributing
|
34 |
|
35 |
-
Please create an issue if you have any suggestions or improvements, and start working on it by creating a branch and by making a pull request to the main branch.
|
|
|
1 |
+
# DL4DS Tutor 🏃
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
Check out the configuration reference at [Hugging Face Spaces Config Reference](https://huggingface.co/docs/hub/spaces-config-reference).
|
|
|
4 |
|
5 |
+
You can find an implementation of the Tutor at [DL4DS Tutor on Hugging Face](https://dl4ds-dl4ds-tutor.hf.space/), which is hosted on Hugging Face [here](https://huggingface.co/spaces/dl4ds/dl4ds_tutor).
|
6 |
|
7 |
+
## Running Locally
|
8 |
|
9 |
+
1. **Clone the Repository**
|
10 |
+
```bash
|
11 |
+
git clone https://github.com/DL4DS/dl4ds_tutor
|
12 |
+
```
|
13 |
|
14 |
+
2. **Put your data under the `storage/data` directory**
|
15 |
+
- Add URLs in the `urls.txt` file.
|
16 |
+
- Add other PDF files in the `storage/data` directory.
|
17 |
|
18 |
+
3. **To test Data Loading (Optional)**
|
19 |
+
```bash
|
20 |
+
cd code
|
21 |
+
python -m modules.dataloader.data_loader
|
22 |
+
```
|
23 |
|
24 |
+
4. **Create the Vector Database**
|
25 |
+
```bash
|
26 |
+
cd code
|
27 |
+
python -m modules.vectorstore.store_manager
|
28 |
+
```
|
29 |
+
- Note: You need to run the above command when you add new data to the `storage/data` directory, or if the `storage/data/urls.txt` file is updated.
|
30 |
+
- Alternatively, you can set `["vectorstore"]["embedd_files"]` to `True` in the `code/modules/config/config.yaml` file, which will embed files from the storage directory every time you run the below chainlit command.
|
31 |
|
32 |
+
5. **Run the Chainlit App**
|
33 |
+
```bash
|
34 |
+
chainlit run main.py
|
35 |
+
```
|
36 |
|
37 |
See the [docs](https://github.com/DL4DS/dl4ds_tutor/tree/main/docs) for more information.
|
38 |
|
39 |
+
## File Structure
|
40 |
+
|
41 |
+
```plaintext
|
42 |
+
code/
|
43 |
+
├── modules
|
44 |
+
│ ├── chat # Contains the chatbot implementation
|
45 |
+
│ ├── chat_processor # Contains the implementation to process and log the conversations
|
46 |
+
│ ├── config # Contains the configuration files
|
47 |
+
│ ├── dataloader # Contains the implementation to load the data from the storage directory
|
48 |
+
│ ├── retriever # Contains the implementation to create the retriever
|
49 |
+
│ └── vectorstore # Contains the implementation to create the vector database
|
50 |
+
├── public
|
51 |
+
│ ├── logo_dark.png # Dark theme logo
|
52 |
+
│ ├── logo_light.png # Light theme logo
|
53 |
+
│ └── test.css # Custom CSS file
|
54 |
+
└── main.py
|
55 |
+
|
56 |
+
|
57 |
+
docs/ # Contains the documentation to the codebase and methods used
|
58 |
+
|
59 |
+
storage/
|
60 |
+
├── data # Store files and URLs here
|
61 |
+
├── logs # Logs directory, includes logs on vector DB creation, tutor logs, and chunks logged in JSON files
|
62 |
+
└── models # Local LLMs are loaded from here
|
63 |
+
|
64 |
+
vectorstores/ # Stores the created vector databases
|
65 |
+
|
66 |
+
.env # This needs to be created, store the API keys here
|
67 |
+
```
|
68 |
+
- `code/modules/vectorstore/vectorstore.py`: Instantiates the `VectorStore` class to create the vector database.
|
69 |
+
- `code/modules/vectorstore/store_manager.py`: Instantiates the `VectorStoreManager:` class to manage the vector database, and all associated methods.
|
70 |
+
- `code/modules/retriever/retriever.py`: Instantiates the `Retriever` class to create the retriever.
|
71 |
+
|
72 |
+
|
73 |
+
## Docker
|
74 |
+
|
75 |
+
The HuggingFace Space is built using the `Dockerfile` in the repository. To run it locally, use the `Dockerfile.dev` file.
|
76 |
+
|
77 |
+
```bash
|
78 |
+
docker build --tag dev -f Dockerfile.dev .
|
79 |
+
docker run -it --rm -p 8051:8051 dev
|
80 |
+
```
|
81 |
+
|
82 |
## Contributing
|
83 |
|
84 |
+
Please create an issue if you have any suggestions or improvements, and start working on it by creating a branch and by making a pull request to the main branch.
|
{.chainlit → code/.chainlit}/config.toml
RENAMED
@@ -2,6 +2,7 @@
|
|
2 |
# Whether to enable telemetry (default: true). No personal data is collected.
|
3 |
enable_telemetry = true
|
4 |
|
|
|
5 |
# List of environment variables to be provided by each user to use the app.
|
6 |
user_env = []
|
7 |
|
@@ -11,74 +12,104 @@ session_timeout = 3600
|
|
11 |
# Enable third parties caching (e.g LangChain cache)
|
12 |
cache = false
|
13 |
|
|
|
|
|
|
|
14 |
# Follow symlink for asset mount (see https://github.com/Chainlit/chainlit/issues/317)
|
15 |
# follow_symlink = false
|
16 |
|
17 |
[features]
|
18 |
-
# Show the prompt playground
|
19 |
-
prompt_playground = true
|
20 |
-
|
21 |
# Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript)
|
22 |
unsafe_allow_html = false
|
23 |
|
24 |
# Process and display mathematical expressions. This can clash with "$" characters in messages.
|
25 |
latex = false
|
26 |
|
27 |
-
#
|
28 |
-
|
29 |
-
|
30 |
-
#
|
31 |
-
[features.
|
32 |
-
enabled =
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
[UI]
|
37 |
-
# Name of the
|
38 |
name = "AI Tutor"
|
39 |
|
40 |
-
#
|
41 |
-
show_readme_as_default = true
|
42 |
-
|
43 |
-
# Description of the app and chatbot. This is used for HTML tags.
|
44 |
# description = ""
|
45 |
|
46 |
# Large size content are by default collapsed for a cleaner ui
|
47 |
default_collapse_content = true
|
48 |
|
49 |
-
# The default value for the expand messages settings.
|
50 |
-
default_expand_messages = false
|
51 |
-
|
52 |
# Hide the chain of thought details from the user in the UI.
|
53 |
-
hide_cot =
|
54 |
|
55 |
# Link to your github repo. This will add a github button in the UI's header.
|
56 |
-
# github = ""
|
57 |
|
58 |
# Specify a CSS file that can be used to customize the user interface.
|
59 |
# The CSS file can be served from the public directory or via an external link.
|
60 |
-
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
# Override default MUI light theme. (Check theme.ts)
|
63 |
[UI.theme.light]
|
64 |
-
background = "#
|
65 |
paper = "#FFFFFF"
|
66 |
|
67 |
[UI.theme.light.primary]
|
68 |
-
main = "#
|
69 |
-
dark = "#
|
70 |
-
light = "#
|
71 |
-
|
|
|
|
|
72 |
# Override default MUI dark theme. (Check theme.ts)
|
73 |
[UI.theme.dark]
|
74 |
-
|
75 |
-
|
76 |
|
77 |
[UI.theme.dark.primary]
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
|
82 |
|
83 |
[meta]
|
84 |
-
generated_by = "
|
|
|
2 |
# Whether to enable telemetry (default: true). No personal data is collected.
|
3 |
enable_telemetry = true
|
4 |
|
5 |
+
|
6 |
# List of environment variables to be provided by each user to use the app.
|
7 |
user_env = []
|
8 |
|
|
|
12 |
# Enable third parties caching (e.g LangChain cache)
|
13 |
cache = false
|
14 |
|
15 |
+
# Authorized origins
|
16 |
+
allow_origins = ["*"]
|
17 |
+
|
18 |
# Follow symlink for asset mount (see https://github.com/Chainlit/chainlit/issues/317)
|
19 |
# follow_symlink = false
|
20 |
|
21 |
[features]
|
|
|
|
|
|
|
22 |
# Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript)
|
23 |
unsafe_allow_html = false
|
24 |
|
25 |
# Process and display mathematical expressions. This can clash with "$" characters in messages.
|
26 |
latex = false
|
27 |
|
28 |
+
# Automatically tag threads with the current chat profile (if a chat profile is used)
|
29 |
+
auto_tag_thread = true
|
30 |
+
|
31 |
+
# Authorize users to spontaneously upload files with messages
|
32 |
+
[features.spontaneous_file_upload]
|
33 |
+
enabled = true
|
34 |
+
accept = ["*/*"]
|
35 |
+
max_files = 20
|
36 |
+
max_size_mb = 500
|
37 |
+
|
38 |
+
[features.audio]
|
39 |
+
# Threshold for audio recording
|
40 |
+
min_decibels = -45
|
41 |
+
# Delay for the user to start speaking in MS
|
42 |
+
initial_silence_timeout = 3000
|
43 |
+
# Delay for the user to continue speaking in MS. If the user stops speaking for this duration, the recording will stop.
|
44 |
+
silence_timeout = 1500
|
45 |
+
# Above this duration (MS), the recording will forcefully stop.
|
46 |
+
max_duration = 15000
|
47 |
+
# Duration of the audio chunks in MS
|
48 |
+
chunk_duration = 1000
|
49 |
+
# Sample rate of the audio
|
50 |
+
sample_rate = 44100
|
51 |
|
52 |
[UI]
|
53 |
+
# Name of the assistant.
|
54 |
name = "AI Tutor"
|
55 |
|
56 |
+
# Description of the assistant. This is used for HTML tags.
|
|
|
|
|
|
|
57 |
# description = ""
|
58 |
|
59 |
# Large size content are by default collapsed for a cleaner ui
|
60 |
default_collapse_content = true
|
61 |
|
|
|
|
|
|
|
62 |
# Hide the chain of thought details from the user in the UI.
|
63 |
+
hide_cot = true
|
64 |
|
65 |
# Link to your github repo. This will add a github button in the UI's header.
|
66 |
+
# github = "https://github.com/DL4DS/dl4ds_tutor"
|
67 |
|
68 |
# Specify a CSS file that can be used to customize the user interface.
|
69 |
# The CSS file can be served from the public directory or via an external link.
|
70 |
+
custom_css = "/public/test.css"
|
71 |
|
72 |
+
# Specify a Javascript file that can be used to customize the user interface.
|
73 |
+
# The Javascript file can be served from the public directory.
|
74 |
+
# custom_js = "/public/test.js"
|
75 |
+
|
76 |
+
# Specify a custom font url.
|
77 |
+
# custom_font = "https://fonts.googleapis.com/css2?family=Inter:wght@400;500;700&display=swap"
|
78 |
+
|
79 |
+
# Specify a custom meta image url.
|
80 |
+
custom_meta_image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/f/f5/Boston_University_seal.svg/1200px-Boston_University_seal.svg.png"
|
81 |
+
|
82 |
+
# Specify a custom build directory for the frontend.
|
83 |
+
# This can be used to customize the frontend code.
|
84 |
+
# Be careful: If this is a relative path, it should not start with a slash.
|
85 |
+
# custom_build = "./public/build"
|
86 |
+
|
87 |
+
[UI.theme]
|
88 |
+
default = "light"
|
89 |
+
#layout = "wide"
|
90 |
+
#font_family = "Inter, sans-serif"
|
91 |
# Override default MUI light theme. (Check theme.ts)
|
92 |
[UI.theme.light]
|
93 |
+
background = "#FAFAFA"
|
94 |
paper = "#FFFFFF"
|
95 |
|
96 |
[UI.theme.light.primary]
|
97 |
+
main = "#b22222" # Brighter shade of red
|
98 |
+
dark = "#8b0000" # Darker shade of the brighter red
|
99 |
+
light = "#ff6347" # Lighter shade of the brighter red
|
100 |
+
[UI.theme.light.text]
|
101 |
+
primary = "#212121"
|
102 |
+
secondary = "#616161"
|
103 |
# Override default MUI dark theme. (Check theme.ts)
|
104 |
[UI.theme.dark]
|
105 |
+
background = "#1C1C1C" # Slightly lighter dark background color
|
106 |
+
paper = "#2A2A2A" # Slightly lighter dark paper color
|
107 |
|
108 |
[UI.theme.dark.primary]
|
109 |
+
main = "#89CFF0" # Primary color
|
110 |
+
dark = "#3700B3" # Dark variant of primary color
|
111 |
+
light = "#CFBCFF" # Lighter variant of primary color
|
112 |
|
113 |
|
114 |
[meta]
|
115 |
+
generated_by = "1.1.302"
|
code/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .modules import *
|
chainlit.md → code/chainlit.md
RENAMED
File without changes
|
code/main.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
-
from
|
2 |
-
from
|
3 |
-
from
|
4 |
-
from
|
5 |
from langchain.chains import RetrievalQA
|
6 |
-
from langchain.llms import CTransformers
|
7 |
import chainlit as cl
|
8 |
from langchain_community.chat_models import ChatOpenAI
|
9 |
from langchain_community.embeddings import OpenAIEmbeddings
|
@@ -11,37 +10,54 @@ import yaml
|
|
11 |
import logging
|
12 |
from dotenv import load_dotenv
|
13 |
|
14 |
-
from modules.llm_tutor import LLMTutor
|
15 |
-
from modules.constants import *
|
16 |
-
from modules.helpers import get_sources
|
17 |
-
|
18 |
|
|
|
|
|
19 |
logger = logging.getLogger(__name__)
|
20 |
logger.setLevel(logging.INFO)
|
|
|
21 |
|
22 |
# Console Handler
|
23 |
console_handler = logging.StreamHandler()
|
24 |
console_handler.setLevel(logging.INFO)
|
25 |
-
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
26 |
console_handler.setFormatter(formatter)
|
27 |
logger.addHandler(console_handler)
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
|
37 |
# Adding option to select the chat profile
|
38 |
@cl.set_chat_profiles
|
39 |
async def chat_profile():
|
40 |
return [
|
41 |
-
cl.ChatProfile(
|
42 |
-
name="Llama",
|
43 |
-
markdown_description="Use the local LLM: **Tiny Llama**.",
|
44 |
-
),
|
45 |
# cl.ChatProfile(
|
46 |
# name="Mistral",
|
47 |
# markdown_description="Use the local LLM: **Mistral**.",
|
@@ -54,6 +70,10 @@ async def chat_profile():
|
|
54 |
name="gpt-4",
|
55 |
markdown_description="Use OpenAI API for **gpt-4**.",
|
56 |
),
|
|
|
|
|
|
|
|
|
57 |
]
|
58 |
|
59 |
|
@@ -66,12 +86,26 @@ def rename(orig_author: str):
|
|
66 |
# chainlit code
|
67 |
@cl.on_chat_start
|
68 |
async def start():
|
69 |
-
with open("
|
70 |
config = yaml.safe_load(f)
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
chat_profile = cl.user_session.get("chat_profile")
|
77 |
if chat_profile is not None:
|
@@ -93,32 +127,50 @@ async def start():
|
|
93 |
llm_tutor = LLMTutor(config, logger=logger)
|
94 |
|
95 |
chain = llm_tutor.qa_bot()
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
msg.
|
100 |
-
await msg.update()
|
101 |
|
|
|
|
|
102 |
cl.user_session.set("chain", chain)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
|
105 |
@cl.on_message
|
106 |
async def main(message):
|
|
|
107 |
user = cl.user_session.get("user")
|
108 |
chain = cl.user_session.get("chain")
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
try:
|
117 |
answer = res["answer"]
|
118 |
except:
|
119 |
answer = res["result"]
|
120 |
-
print(f"answer: {answer}")
|
121 |
|
122 |
-
answer_with_sources, source_elements = get_sources(res, answer)
|
|
|
123 |
|
124 |
await cl.Message(content=answer_with_sources, elements=source_elements).send()
|
|
|
1 |
+
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
|
2 |
+
from langchain_core.prompts import PromptTemplate
|
3 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
4 |
+
from langchain_community.vectorstores import FAISS
|
5 |
from langchain.chains import RetrievalQA
|
|
|
6 |
import chainlit as cl
|
7 |
from langchain_community.chat_models import ChatOpenAI
|
8 |
from langchain_community.embeddings import OpenAIEmbeddings
|
|
|
10 |
import logging
|
11 |
from dotenv import load_dotenv
|
12 |
|
13 |
+
from modules.chat.llm_tutor import LLMTutor
|
14 |
+
from modules.config.constants import *
|
15 |
+
from modules.chat.helpers import get_sources
|
16 |
+
from modules.chat_processor.chat_processor import ChatProcessor
|
17 |
|
18 |
+
global logger
|
19 |
+
# Initialize logger
|
20 |
logger = logging.getLogger(__name__)
|
21 |
logger.setLevel(logging.INFO)
|
22 |
+
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
23 |
|
24 |
# Console Handler
|
25 |
console_handler = logging.StreamHandler()
|
26 |
console_handler.setLevel(logging.INFO)
|
|
|
27 |
console_handler.setFormatter(formatter)
|
28 |
logger.addHandler(console_handler)
|
29 |
|
30 |
+
|
31 |
+
@cl.set_starters
|
32 |
+
async def set_starters():
|
33 |
+
return [
|
34 |
+
cl.Starter(
|
35 |
+
label="recording on CNNs?",
|
36 |
+
message="Where can I find the recording for the lecture on Transfromers?",
|
37 |
+
icon="/public/adv-screen-recorder-svgrepo-com.svg",
|
38 |
+
),
|
39 |
+
cl.Starter(
|
40 |
+
label="where's the slides?",
|
41 |
+
message="When are the lectures? I can't find the schedule.",
|
42 |
+
icon="/public/alarmy-svgrepo-com.svg",
|
43 |
+
),
|
44 |
+
cl.Starter(
|
45 |
+
label="Due Date?",
|
46 |
+
message="When is the final project due?",
|
47 |
+
icon="/public/calendar-samsung-17-svgrepo-com.svg",
|
48 |
+
),
|
49 |
+
cl.Starter(
|
50 |
+
label="Explain backprop.",
|
51 |
+
message="I didnt understand the math behind backprop, could you explain it?",
|
52 |
+
icon="/public/acastusphoton-svgrepo-com.svg",
|
53 |
+
),
|
54 |
+
]
|
55 |
|
56 |
|
57 |
# Adding option to select the chat profile
|
58 |
@cl.set_chat_profiles
|
59 |
async def chat_profile():
|
60 |
return [
|
|
|
|
|
|
|
|
|
61 |
# cl.ChatProfile(
|
62 |
# name="Mistral",
|
63 |
# markdown_description="Use the local LLM: **Mistral**.",
|
|
|
70 |
name="gpt-4",
|
71 |
markdown_description="Use OpenAI API for **gpt-4**.",
|
72 |
),
|
73 |
+
cl.ChatProfile(
|
74 |
+
name="Llama",
|
75 |
+
markdown_description="Use the local LLM: **Tiny Llama**.",
|
76 |
+
),
|
77 |
]
|
78 |
|
79 |
|
|
|
86 |
# chainlit code
|
87 |
@cl.on_chat_start
|
88 |
async def start():
|
89 |
+
with open("modules/config/config.yml", "r") as f:
|
90 |
config = yaml.safe_load(f)
|
91 |
+
|
92 |
+
# Ensure log directory exists
|
93 |
+
log_directory = config["log_dir"]
|
94 |
+
if not os.path.exists(log_directory):
|
95 |
+
os.makedirs(log_directory)
|
96 |
+
|
97 |
+
# File Handler
|
98 |
+
log_file_path = (
|
99 |
+
f"{log_directory}/tutor.log" # Change this to your desired log file path
|
100 |
+
)
|
101 |
+
file_handler = logging.FileHandler(log_file_path, mode="w")
|
102 |
+
file_handler.setLevel(logging.INFO)
|
103 |
+
file_handler.setFormatter(formatter)
|
104 |
+
logger.addHandler(file_handler)
|
105 |
+
|
106 |
+
logger.info("Config file loaded")
|
107 |
+
logger.info(f"Config: {config}")
|
108 |
+
logger.info("Creating llm_tutor instance")
|
109 |
|
110 |
chat_profile = cl.user_session.get("chat_profile")
|
111 |
if chat_profile is not None:
|
|
|
127 |
llm_tutor = LLMTutor(config, logger=logger)
|
128 |
|
129 |
chain = llm_tutor.qa_bot()
|
130 |
+
# msg = cl.Message(content=f"Starting the bot {chat_profile}...")
|
131 |
+
# await msg.send()
|
132 |
+
# msg.content = opening_message
|
133 |
+
# await msg.update()
|
|
|
134 |
|
135 |
+
tags = [chat_profile, config["vectorstore"]["db_option"]]
|
136 |
+
chat_processor = ChatProcessor(config, tags=tags)
|
137 |
cl.user_session.set("chain", chain)
|
138 |
+
cl.user_session.set("counter", 0)
|
139 |
+
cl.user_session.set("chat_processor", chat_processor)
|
140 |
+
|
141 |
+
|
142 |
+
@cl.on_chat_end
|
143 |
+
async def on_chat_end():
|
144 |
+
await cl.Message(content="Sorry, I have to go now. Goodbye!").send()
|
145 |
|
146 |
|
147 |
@cl.on_message
|
148 |
async def main(message):
|
149 |
+
global logger
|
150 |
user = cl.user_session.get("user")
|
151 |
chain = cl.user_session.get("chain")
|
152 |
+
|
153 |
+
counter = cl.user_session.get("counter")
|
154 |
+
counter += 1
|
155 |
+
cl.user_session.set("counter", counter)
|
156 |
+
|
157 |
+
# if counter >= 3: # Ensure the counter condition is checked
|
158 |
+
# await cl.Message(content="Your credits are up!").send()
|
159 |
+
# await on_chat_end() # Call the on_chat_end function to handle the end of the chat
|
160 |
+
# return # Exit the function to stop further processing
|
161 |
+
# else:
|
162 |
+
|
163 |
+
cb = cl.AsyncLangchainCallbackHandler() # TODO: fix streaming here
|
164 |
+
cb.answer_reached = True
|
165 |
+
|
166 |
+
processor = cl.user_session.get("chat_processor")
|
167 |
+
res = await processor.rag(message.content, chain, cb)
|
168 |
try:
|
169 |
answer = res["answer"]
|
170 |
except:
|
171 |
answer = res["result"]
|
|
|
172 |
|
173 |
+
answer_with_sources, source_elements, sources_dict = get_sources(res, answer)
|
174 |
+
processor._process(message.content, answer, sources_dict)
|
175 |
|
176 |
await cl.Message(content=answer_with_sources, elements=source_elements).send()
|
code/modules/chat/__init__.py
ADDED
File without changes
|
code/modules/{chat_model_loader.py → chat/chat_model_loader.py}
RENAMED
@@ -1,8 +1,7 @@
|
|
1 |
from langchain_community.chat_models import ChatOpenAI
|
2 |
-
from
|
3 |
-
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
4 |
from transformers import AutoTokenizer, TextStreamer
|
5 |
-
from
|
6 |
import torch
|
7 |
import transformers
|
8 |
import os
|
|
|
1 |
from langchain_community.chat_models import ChatOpenAI
|
2 |
+
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
|
|
3 |
from transformers import AutoTokenizer, TextStreamer
|
4 |
+
from langchain_community.llms import LlamaCpp
|
5 |
import torch
|
6 |
import transformers
|
7 |
import os
|
code/modules/chat/helpers.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.config.constants import *
|
2 |
+
import chainlit as cl
|
3 |
+
from langchain_core.prompts import PromptTemplate
|
4 |
+
|
5 |
+
|
6 |
+
def get_sources(res, answer):
|
7 |
+
source_elements = []
|
8 |
+
source_dict = {} # Dictionary to store URL elements
|
9 |
+
|
10 |
+
for idx, source in enumerate(res["source_documents"]):
|
11 |
+
source_metadata = source.metadata
|
12 |
+
url = source_metadata.get("source", "N/A")
|
13 |
+
score = source_metadata.get("score", "N/A")
|
14 |
+
page = source_metadata.get("page", 1)
|
15 |
+
|
16 |
+
lecture_tldr = source_metadata.get("tldr", "N/A")
|
17 |
+
lecture_recording = source_metadata.get("lecture_recording", "N/A")
|
18 |
+
suggested_readings = source_metadata.get("suggested_readings", "N/A")
|
19 |
+
date = source_metadata.get("date", "N/A")
|
20 |
+
|
21 |
+
source_type = source_metadata.get("source_type", "N/A")
|
22 |
+
|
23 |
+
url_name = f"{url}_{page}"
|
24 |
+
if url_name not in source_dict:
|
25 |
+
source_dict[url_name] = {
|
26 |
+
"text": source.page_content,
|
27 |
+
"url": url,
|
28 |
+
"score": score,
|
29 |
+
"page": page,
|
30 |
+
"lecture_tldr": lecture_tldr,
|
31 |
+
"lecture_recording": lecture_recording,
|
32 |
+
"suggested_readings": suggested_readings,
|
33 |
+
"date": date,
|
34 |
+
"source_type": source_type,
|
35 |
+
}
|
36 |
+
else:
|
37 |
+
source_dict[url_name]["text"] += f"\n\n{source.page_content}"
|
38 |
+
|
39 |
+
# First, display the answer
|
40 |
+
full_answer = "**Answer:**\n"
|
41 |
+
full_answer += answer
|
42 |
+
|
43 |
+
# Then, display the sources
|
44 |
+
full_answer += "\n\n**Sources:**\n"
|
45 |
+
for idx, (url_name, source_data) in enumerate(source_dict.items()):
|
46 |
+
full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
|
47 |
+
|
48 |
+
name = f"Source {idx + 1} Text\n"
|
49 |
+
full_answer += name
|
50 |
+
source_elements.append(
|
51 |
+
cl.Text(name=name, content=source_data["text"], display="side")
|
52 |
+
)
|
53 |
+
|
54 |
+
# Add a PDF element if the source is a PDF file
|
55 |
+
if source_data["url"].lower().endswith(".pdf"):
|
56 |
+
name = f"Source {idx + 1} PDF\n"
|
57 |
+
full_answer += name
|
58 |
+
pdf_url = f"{source_data['url']}#page={source_data['page']+1}"
|
59 |
+
source_elements.append(cl.Pdf(name=name, url=pdf_url, display="side"))
|
60 |
+
|
61 |
+
full_answer += "\n**Metadata:**\n"
|
62 |
+
for idx, (url_name, source_data) in enumerate(source_dict.items()):
|
63 |
+
full_answer += f"\nSource {idx + 1} Metadata:\n"
|
64 |
+
source_elements.append(
|
65 |
+
cl.Text(
|
66 |
+
name=f"Source {idx + 1} Metadata",
|
67 |
+
content=f"Source: {source_data['url']}\n"
|
68 |
+
f"Page: {source_data['page']}\n"
|
69 |
+
f"Type: {source_data['source_type']}\n"
|
70 |
+
f"Date: {source_data['date']}\n"
|
71 |
+
f"TL;DR: {source_data['lecture_tldr']}\n"
|
72 |
+
f"Lecture Recording: {source_data['lecture_recording']}\n"
|
73 |
+
f"Suggested Readings: {source_data['suggested_readings']}\n",
|
74 |
+
display="side",
|
75 |
+
)
|
76 |
+
)
|
77 |
+
|
78 |
+
return full_answer, source_elements, source_dict
|
79 |
+
|
80 |
+
|
81 |
+
def get_prompt(config):
|
82 |
+
if config["llm_params"]["use_history"]:
|
83 |
+
if config["llm_params"]["llm_loader"] == "local_llm":
|
84 |
+
custom_prompt_template = tinyllama_prompt_template_with_history
|
85 |
+
elif config["llm_params"]["llm_loader"] == "openai":
|
86 |
+
custom_prompt_template = openai_prompt_template_with_history
|
87 |
+
# else:
|
88 |
+
# custom_prompt_template = tinyllama_prompt_template_with_history # default
|
89 |
+
prompt = PromptTemplate(
|
90 |
+
template=custom_prompt_template,
|
91 |
+
input_variables=["context", "chat_history", "question"],
|
92 |
+
)
|
93 |
+
else:
|
94 |
+
if config["llm_params"]["llm_loader"] == "local_llm":
|
95 |
+
custom_prompt_template = tinyllama_prompt_template
|
96 |
+
elif config["llm_params"]["llm_loader"] == "openai":
|
97 |
+
custom_prompt_template = openai_prompt_template
|
98 |
+
# else:
|
99 |
+
# custom_prompt_template = tinyllama_prompt_template
|
100 |
+
prompt = PromptTemplate(
|
101 |
+
template=custom_prompt_template,
|
102 |
+
input_variables=["context", "question"],
|
103 |
+
)
|
104 |
+
return prompt
|
code/modules/chat/llm_tutor.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
|
2 |
+
from langchain.memory import (
|
3 |
+
ConversationBufferWindowMemory,
|
4 |
+
ConversationSummaryBufferMemory,
|
5 |
+
)
|
6 |
+
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
|
7 |
+
import os
|
8 |
+
from modules.config.constants import *
|
9 |
+
from modules.chat.helpers import get_prompt
|
10 |
+
from modules.chat.chat_model_loader import ChatModelLoader
|
11 |
+
from modules.vectorstore.store_manager import VectorStoreManager
|
12 |
+
|
13 |
+
from modules.retriever.retriever import Retriever
|
14 |
+
|
15 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
16 |
+
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
|
17 |
+
import inspect
|
18 |
+
from langchain.chains.conversational_retrieval.base import _get_chat_history
|
19 |
+
from langchain_core.messages import BaseMessage
|
20 |
+
|
21 |
+
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
|
22 |
+
|
23 |
+
from langchain_core.output_parsers import StrOutputParser
|
24 |
+
from langchain_core.prompts import ChatPromptTemplate
|
25 |
+
from langchain_community.chat_models import ChatOpenAI
|
26 |
+
|
27 |
+
|
28 |
+
class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
29 |
+
|
30 |
+
def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
|
31 |
+
_ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
|
32 |
+
buffer = ""
|
33 |
+
for dialogue_turn in chat_history:
|
34 |
+
if isinstance(dialogue_turn, BaseMessage):
|
35 |
+
role_prefix = _ROLE_MAP.get(
|
36 |
+
dialogue_turn.type, f"{dialogue_turn.type}: "
|
37 |
+
)
|
38 |
+
buffer += f"\n{role_prefix}{dialogue_turn.content}"
|
39 |
+
elif isinstance(dialogue_turn, tuple):
|
40 |
+
human = "Student: " + dialogue_turn[0]
|
41 |
+
ai = "AI Tutor: " + dialogue_turn[1]
|
42 |
+
buffer += "\n" + "\n".join([human, ai])
|
43 |
+
else:
|
44 |
+
raise ValueError(
|
45 |
+
f"Unsupported chat history format: {type(dialogue_turn)}."
|
46 |
+
f" Full chat history: {chat_history} "
|
47 |
+
)
|
48 |
+
return buffer
|
49 |
+
|
50 |
+
async def _acall(
|
51 |
+
self,
|
52 |
+
inputs: Dict[str, Any],
|
53 |
+
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
54 |
+
) -> Dict[str, Any]:
|
55 |
+
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
56 |
+
question = inputs["question"]
|
57 |
+
get_chat_history = self._get_chat_history
|
58 |
+
chat_history_str = get_chat_history(inputs["chat_history"])
|
59 |
+
if chat_history_str:
|
60 |
+
# callbacks = _run_manager.get_child()
|
61 |
+
# new_question = await self.question_generator.arun(
|
62 |
+
# question=question, chat_history=chat_history_str, callbacks=callbacks
|
63 |
+
# )
|
64 |
+
system = (
|
65 |
+
"You are someone that rephrases statements. Rephrase the student's question to add context from their chat history if relevant, ensuring it remains from the student's point of view. "
|
66 |
+
"Incorporate relevant details from the chat history to make the question clearer and more specific."
|
67 |
+
"Do not change the meaning of the original statement, and maintain the student's tone and perspective. "
|
68 |
+
"If the question is conversational and doesn't require context, do not rephrase it. "
|
69 |
+
"Example: If the student previously asked about backpropagation in the context of deep learning and now asks 'what is it', rephrase to 'What is backprogatation.'. "
|
70 |
+
"Example: Do not rephrase if the user is asking something specific like 'cool, suggest a project with transformers to use as my final project'"
|
71 |
+
"Chat history: \n{chat_history_str}\n"
|
72 |
+
"Rephrase the following question only if necessary: '{question}'"
|
73 |
+
)
|
74 |
+
|
75 |
+
prompt = ChatPromptTemplate.from_messages(
|
76 |
+
[
|
77 |
+
("system", system),
|
78 |
+
("human", "{question}, {chat_history_str}"),
|
79 |
+
]
|
80 |
+
)
|
81 |
+
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
82 |
+
step_back = prompt | llm | StrOutputParser()
|
83 |
+
new_question = step_back.invoke(
|
84 |
+
{"question": question, "chat_history_str": chat_history_str}
|
85 |
+
)
|
86 |
+
else:
|
87 |
+
new_question = question
|
88 |
+
accepts_run_manager = (
|
89 |
+
"run_manager" in inspect.signature(self._aget_docs).parameters
|
90 |
+
)
|
91 |
+
if accepts_run_manager:
|
92 |
+
docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager)
|
93 |
+
else:
|
94 |
+
docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
|
95 |
+
|
96 |
+
output: Dict[str, Any] = {}
|
97 |
+
output["original_question"] = question
|
98 |
+
if self.response_if_no_docs_found is not None and len(docs) == 0:
|
99 |
+
output[self.output_key] = self.response_if_no_docs_found
|
100 |
+
else:
|
101 |
+
new_inputs = inputs.copy()
|
102 |
+
if self.rephrase_question:
|
103 |
+
new_inputs["question"] = new_question
|
104 |
+
new_inputs["chat_history"] = chat_history_str
|
105 |
+
|
106 |
+
# Prepare the final prompt with metadata
|
107 |
+
context = "\n\n".join(
|
108 |
+
[
|
109 |
+
f"Context {idx+1}: \n(Document content: {doc.page_content}\nMetadata: (source_file: {doc.metadata['source'] if 'source' in doc.metadata else 'unknown'}))"
|
110 |
+
for idx, doc in enumerate(docs)
|
111 |
+
]
|
112 |
+
)
|
113 |
+
final_prompt = (
|
114 |
+
"You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance."
|
115 |
+
"If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
|
116 |
+
"Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevent."
|
117 |
+
"Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
|
118 |
+
f"Chat History:\n{chat_history_str}\n\n"
|
119 |
+
f"Context:\n{context}\n\n"
|
120 |
+
"Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
|
121 |
+
f"Student: {question}\n"
|
122 |
+
"AI Tutor:"
|
123 |
+
)
|
124 |
+
|
125 |
+
# new_inputs["input"] = final_prompt
|
126 |
+
new_inputs["question"] = final_prompt
|
127 |
+
# output["final_prompt"] = final_prompt
|
128 |
+
|
129 |
+
answer = await self.combine_docs_chain.arun(
|
130 |
+
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
|
131 |
+
)
|
132 |
+
output[self.output_key] = answer
|
133 |
+
|
134 |
+
if self.return_source_documents:
|
135 |
+
output["source_documents"] = docs
|
136 |
+
output["rephrased_question"] = new_question
|
137 |
+
return output
|
138 |
+
|
139 |
+
|
140 |
+
class LLMTutor:
|
141 |
+
def __init__(self, config, logger=None):
|
142 |
+
self.config = config
|
143 |
+
self.llm = self.load_llm()
|
144 |
+
self.logger = logger
|
145 |
+
self.vector_db = VectorStoreManager(config, logger=self.logger)
|
146 |
+
if self.config["vectorstore"]["embedd_files"]:
|
147 |
+
self.vector_db.create_database()
|
148 |
+
self.vector_db.save_database()
|
149 |
+
|
150 |
+
def set_custom_prompt(self):
|
151 |
+
"""
|
152 |
+
Prompt template for QA retrieval for each vectorstore
|
153 |
+
"""
|
154 |
+
prompt = get_prompt(self.config)
|
155 |
+
# prompt = QA_PROMPT
|
156 |
+
|
157 |
+
return prompt
|
158 |
+
|
159 |
+
# Retrieval QA Chain
|
160 |
+
def retrieval_qa_chain(self, llm, prompt, db):
|
161 |
+
|
162 |
+
retriever = Retriever(self.config)._return_retriever(db)
|
163 |
+
|
164 |
+
if self.config["llm_params"]["use_history"]:
|
165 |
+
memory = ConversationBufferWindowMemory(
|
166 |
+
k=self.config["llm_params"]["memory_window"],
|
167 |
+
memory_key="chat_history",
|
168 |
+
return_messages=True,
|
169 |
+
output_key="answer",
|
170 |
+
max_token_limit=128,
|
171 |
+
)
|
172 |
+
qa_chain = CustomConversationalRetrievalChain.from_llm(
|
173 |
+
llm=llm,
|
174 |
+
chain_type="stuff",
|
175 |
+
retriever=retriever,
|
176 |
+
return_source_documents=True,
|
177 |
+
memory=memory,
|
178 |
+
combine_docs_chain_kwargs={"prompt": prompt},
|
179 |
+
response_if_no_docs_found="No context found",
|
180 |
+
)
|
181 |
+
else:
|
182 |
+
qa_chain = RetrievalQA.from_chain_type(
|
183 |
+
llm=llm,
|
184 |
+
chain_type="stuff",
|
185 |
+
retriever=retriever,
|
186 |
+
return_source_documents=True,
|
187 |
+
chain_type_kwargs={"prompt": prompt},
|
188 |
+
)
|
189 |
+
return qa_chain
|
190 |
+
|
191 |
+
# Loading the model
|
192 |
+
def load_llm(self):
|
193 |
+
chat_model_loader = ChatModelLoader(self.config)
|
194 |
+
llm = chat_model_loader.load_chat_model()
|
195 |
+
return llm
|
196 |
+
|
197 |
+
# QA Model Function
|
198 |
+
def qa_bot(self):
|
199 |
+
db = self.vector_db.load_database()
|
200 |
+
qa_prompt = self.set_custom_prompt()
|
201 |
+
qa = self.retrieval_qa_chain(
|
202 |
+
self.llm, qa_prompt, db
|
203 |
+
) # TODO: PROMPT is overwritten in CustomConversationalRetrievalChain
|
204 |
+
|
205 |
+
return qa
|
206 |
+
|
207 |
+
# output function
|
208 |
+
def final_result(query):
|
209 |
+
qa_result = qa_bot()
|
210 |
+
response = qa_result({"query": query})
|
211 |
+
return response
|
code/modules/chat_processor/__init__.py
ADDED
File without changes
|
code/modules/chat_processor/base.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Template for chat processor classes
|
2 |
+
|
3 |
+
|
4 |
+
class ChatProcessorBase:
|
5 |
+
def __init__(self, config):
|
6 |
+
self.config = config
|
7 |
+
|
8 |
+
def process(self, message):
|
9 |
+
"""
|
10 |
+
Processes and Logs the message
|
11 |
+
"""
|
12 |
+
raise NotImplementedError("process method not implemented")
|
code/modules/chat_processor/chat_processor.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.chat_processor.literal_ai import LiteralaiChatProcessor
|
2 |
+
|
3 |
+
|
4 |
+
class ChatProcessor:
|
5 |
+
def __init__(self, config, tags=None):
|
6 |
+
self.chat_processor_type = config["chat_logging"]["platform"]
|
7 |
+
self.logging = config["chat_logging"]["log_chat"]
|
8 |
+
self.tags = tags
|
9 |
+
if self.logging:
|
10 |
+
self._init_processor()
|
11 |
+
|
12 |
+
def _init_processor(self):
|
13 |
+
if self.chat_processor_type == "literalai":
|
14 |
+
self.processor = LiteralaiChatProcessor(self.tags)
|
15 |
+
else:
|
16 |
+
raise ValueError(
|
17 |
+
f"Chat processor type {self.chat_processor_type} not supported"
|
18 |
+
)
|
19 |
+
|
20 |
+
def _process(self, user_message, assistant_message, source_dict):
|
21 |
+
if self.logging:
|
22 |
+
return self.processor.process(user_message, assistant_message, source_dict)
|
23 |
+
else:
|
24 |
+
pass
|
25 |
+
|
26 |
+
async def rag(self, user_query: str, chain, cb):
|
27 |
+
if self.logging:
|
28 |
+
return await self.processor.rag(user_query, chain, cb)
|
29 |
+
else:
|
30 |
+
return await chain.acall(user_query, callbacks=[cb])
|
code/modules/chat_processor/literal_ai.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from literalai import LiteralClient
|
2 |
+
import os
|
3 |
+
from .base import ChatProcessorBase
|
4 |
+
|
5 |
+
|
6 |
+
class LiteralaiChatProcessor(ChatProcessorBase):
|
7 |
+
def __init__(self, tags=None):
|
8 |
+
self.literal_client = LiteralClient(api_key=os.getenv("LITERAL_API_KEY"))
|
9 |
+
self.literal_client.reset_context()
|
10 |
+
with self.literal_client.thread(name="TEST") as thread:
|
11 |
+
self.thread_id = thread.id
|
12 |
+
self.thread = thread
|
13 |
+
if tags is not None and type(tags) == list:
|
14 |
+
self.thread.tags = tags
|
15 |
+
print(f"Thread ID: {self.thread}")
|
16 |
+
|
17 |
+
def process(self, user_message, assistant_message, source_dict):
|
18 |
+
with self.literal_client.thread(thread_id=self.thread_id) as thread:
|
19 |
+
self.literal_client.message(
|
20 |
+
content=user_message,
|
21 |
+
type="user_message",
|
22 |
+
name="User",
|
23 |
+
)
|
24 |
+
self.literal_client.message(
|
25 |
+
content=assistant_message,
|
26 |
+
type="assistant_message",
|
27 |
+
name="AI_Tutor",
|
28 |
+
)
|
29 |
+
|
30 |
+
async def rag(self, user_query: str, chain, cb):
|
31 |
+
with self.literal_client.step(
|
32 |
+
type="retrieval", name="RAG", thread_id=self.thread_id
|
33 |
+
) as step:
|
34 |
+
step.input = {"question": user_query}
|
35 |
+
res = await chain.acall(user_query, callbacks=[cb])
|
36 |
+
step.output = res
|
37 |
+
return res
|
code/modules/config/__init__.py
ADDED
File without changes
|
code/{config.yml → modules/config/config.yml}
RENAMED
@@ -1,23 +1,42 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
2 |
embedd_files: False # bool
|
3 |
-
|
4 |
-
|
5 |
-
url_file_path: 'storage/data/urls.txt' # str
|
6 |
expand_urls: True # bool
|
7 |
-
db_option : 'FAISS' # str
|
8 |
-
db_path : 'vectorstores' # str
|
9 |
model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
|
10 |
search_top_k : 3 # int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
llm_params:
|
12 |
-
use_history:
|
13 |
memory_window: 3 # int
|
14 |
-
llm_loader: '
|
15 |
openai_params:
|
16 |
-
model: 'gpt-
|
17 |
local_llm_params:
|
18 |
-
model:
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
21 |
splitter_options:
|
22 |
use_splitter: True # bool
|
23 |
split_by_token : True # bool
|
|
|
1 |
+
log_dir: '../storage/logs' # str
|
2 |
+
log_chunk_dir: '../storage/logs/chunks' # str
|
3 |
+
device: 'cpu' # str [cuda, cpu]
|
4 |
+
|
5 |
+
vectorstore:
|
6 |
embedd_files: False # bool
|
7 |
+
data_path: '../storage/data' # str
|
8 |
+
url_file_path: '../storage/data/urls.txt' # str
|
|
|
9 |
expand_urls: True # bool
|
10 |
+
db_option : 'FAISS' # str [FAISS, Chroma, RAGatouille, RAPTOR]
|
11 |
+
db_path : '../vectorstores' # str
|
12 |
model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
|
13 |
search_top_k : 3 # int
|
14 |
+
score_threshold : 0.2 # float
|
15 |
+
|
16 |
+
faiss_params: # Not used as of now
|
17 |
+
index_path: '../vectorstores/faiss.index' # str
|
18 |
+
index_type: 'Flat' # str [Flat, HNSW, IVF]
|
19 |
+
index_dimension: 384 # int
|
20 |
+
index_nlist: 100 # int
|
21 |
+
index_nprobe: 10 # int
|
22 |
+
|
23 |
+
colbert_params:
|
24 |
+
index_name: "new_idx" # str
|
25 |
+
|
26 |
llm_params:
|
27 |
+
use_history: True # bool
|
28 |
memory_window: 3 # int
|
29 |
+
llm_loader: 'openai' # str [local_llm, openai]
|
30 |
openai_params:
|
31 |
+
model: 'gpt-3.5-turbo-1106' # str [gpt-3.5-turbo-1106, gpt-4]
|
32 |
local_llm_params:
|
33 |
+
model: 'tiny-llama'
|
34 |
+
temperature: 0.7
|
35 |
+
|
36 |
+
chat_logging:
|
37 |
+
log_chat: False # bool
|
38 |
+
platform: 'literalai'
|
39 |
+
|
40 |
splitter_options:
|
41 |
use_splitter: True # bool
|
42 |
split_by_token : True # bool
|
code/modules/{constants.py → config/constants.py}
RENAMED
@@ -6,7 +6,10 @@ load_dotenv()
|
|
6 |
# API Keys - Loaded from the .env file
|
7 |
|
8 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
|
|
|
9 |
|
|
|
10 |
|
11 |
# Prompt Templates
|
12 |
|
@@ -75,5 +78,5 @@ Question: {question}
|
|
75 |
|
76 |
# Model Paths
|
77 |
|
78 |
-
LLAMA_PATH = "storage/models/tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
|
79 |
MISTRAL_PATH = "storage/models/mistral-7b-v0.1.Q4_K_M.gguf"
|
|
|
6 |
# API Keys - Loaded from the .env file
|
7 |
|
8 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
9 |
+
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
|
10 |
+
LITERAL_API_KEY = os.getenv("LITERAL_API_KEY")
|
11 |
|
12 |
+
opening_message = f"Hey, What Can I Help You With?\n\nYou can me ask me questions about the course logistics, course content, about the final project, or anything else!"
|
13 |
|
14 |
# Prompt Templates
|
15 |
|
|
|
78 |
|
79 |
# Model Paths
|
80 |
|
81 |
+
LLAMA_PATH = "../storage/models/tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
|
82 |
MISTRAL_PATH = "storage/models/mistral-7b-v0.1.Q4_K_M.gguf"
|
code/modules/data_loader.py
DELETED
@@ -1,287 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
import pysrt
|
3 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
4 |
-
from langchain.document_loaders import (
|
5 |
-
PyMuPDFLoader,
|
6 |
-
Docx2txtLoader,
|
7 |
-
YoutubeLoader,
|
8 |
-
WebBaseLoader,
|
9 |
-
TextLoader,
|
10 |
-
)
|
11 |
-
from langchain.schema import Document
|
12 |
-
import tempfile
|
13 |
-
from tempfile import NamedTemporaryFile
|
14 |
-
import logging
|
15 |
-
import requests
|
16 |
-
|
17 |
-
logger = logging.getLogger(__name__)
|
18 |
-
|
19 |
-
|
20 |
-
class DataLoader:
|
21 |
-
def __init__(self, config):
|
22 |
-
"""
|
23 |
-
Class for handling all data extraction and chunking
|
24 |
-
Inputs:
|
25 |
-
config - dictionary from yaml file, containing all important parameters
|
26 |
-
"""
|
27 |
-
self.config = config
|
28 |
-
self.remove_leftover_delimiters = config["splitter_options"][
|
29 |
-
"remove_leftover_delimiters"
|
30 |
-
]
|
31 |
-
|
32 |
-
# Main list of all documents
|
33 |
-
self.document_chunks_full = []
|
34 |
-
self.document_names = []
|
35 |
-
|
36 |
-
if config["splitter_options"]["use_splitter"]:
|
37 |
-
if config["splitter_options"]["split_by_token"]:
|
38 |
-
self.splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
39 |
-
chunk_size=config["splitter_options"]["chunk_size"],
|
40 |
-
chunk_overlap=config["splitter_options"]["chunk_overlap"],
|
41 |
-
separators=config["splitter_options"]["chunk_separators"],
|
42 |
-
disallowed_special=()
|
43 |
-
)
|
44 |
-
else:
|
45 |
-
self.splitter = RecursiveCharacterTextSplitter(
|
46 |
-
chunk_size=config["splitter_options"]["chunk_size"],
|
47 |
-
chunk_overlap=config["splitter_options"]["chunk_overlap"],
|
48 |
-
separators=config["splitter_options"]["chunk_separators"],
|
49 |
-
disallowed_special=()
|
50 |
-
)
|
51 |
-
else:
|
52 |
-
self.splitter = None
|
53 |
-
logger.info("InfoLoader instance created")
|
54 |
-
|
55 |
-
def extract_text_from_pdf(self, pdf_path):
|
56 |
-
text = ""
|
57 |
-
with open(pdf_path, "rb") as file:
|
58 |
-
reader = PyPDF2.PdfReader(file)
|
59 |
-
num_pages = len(reader.pages)
|
60 |
-
for page_num in range(num_pages):
|
61 |
-
page = reader.pages[page_num]
|
62 |
-
text += page.extract_text()
|
63 |
-
return text
|
64 |
-
|
65 |
-
def download_pdf_from_url(self, pdf_url):
|
66 |
-
response = requests.get(pdf_url)
|
67 |
-
if response.status_code == 200:
|
68 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
|
69 |
-
temp_file.write(response.content)
|
70 |
-
temp_file_path = temp_file.name
|
71 |
-
return temp_file_path
|
72 |
-
else:
|
73 |
-
print("Failed to download PDF from URL:", pdf_url)
|
74 |
-
return None
|
75 |
-
|
76 |
-
def get_chunks(self, uploaded_files, weblinks):
|
77 |
-
# Main list of all documents
|
78 |
-
self.document_chunks_full = []
|
79 |
-
self.document_names = []
|
80 |
-
|
81 |
-
def remove_delimiters(document_chunks: list):
|
82 |
-
"""
|
83 |
-
Helper function to remove remaining delimiters in document chunks
|
84 |
-
"""
|
85 |
-
for chunk in document_chunks:
|
86 |
-
for delimiter in self.config["splitter_options"][
|
87 |
-
"delimiters_to_remove"
|
88 |
-
]:
|
89 |
-
chunk.page_content = re.sub(delimiter, " ", chunk.page_content)
|
90 |
-
return document_chunks
|
91 |
-
|
92 |
-
def remove_chunks(document_chunks: list):
|
93 |
-
"""
|
94 |
-
Helper function to remove any unwanted document chunks after splitting
|
95 |
-
"""
|
96 |
-
front = self.config["splitter_options"]["front_chunk_to_remove"]
|
97 |
-
end = self.config["splitter_options"]["last_chunks_to_remove"]
|
98 |
-
# Remove pages
|
99 |
-
for _ in range(front):
|
100 |
-
del document_chunks[0]
|
101 |
-
for _ in range(end):
|
102 |
-
document_chunks.pop()
|
103 |
-
logger.info(f"\tNumber of pages after skipping: {len(document_chunks)}")
|
104 |
-
return document_chunks
|
105 |
-
|
106 |
-
def get_pdf_from_url(pdf_url: str):
|
107 |
-
temp_pdf_path = self.download_pdf_from_url(pdf_url)
|
108 |
-
if temp_pdf_path:
|
109 |
-
title, document_chunks = get_pdf(temp_pdf_path, pdf_url)
|
110 |
-
os.remove(temp_pdf_path)
|
111 |
-
return title, document_chunks
|
112 |
-
|
113 |
-
def get_pdf(temp_file_path: str, title: str):
|
114 |
-
"""
|
115 |
-
Function to process PDF files
|
116 |
-
"""
|
117 |
-
loader = PyMuPDFLoader(
|
118 |
-
temp_file_path
|
119 |
-
) # This loader preserves more metadata
|
120 |
-
|
121 |
-
if self.splitter:
|
122 |
-
document_chunks = self.splitter.split_documents(loader.load())
|
123 |
-
else:
|
124 |
-
document_chunks = loader.load()
|
125 |
-
|
126 |
-
if "title" in document_chunks[0].metadata.keys():
|
127 |
-
title = document_chunks[0].metadata["title"]
|
128 |
-
|
129 |
-
logger.info(
|
130 |
-
f"\t\tOriginal no. of pages: {document_chunks[0].metadata['total_pages']}"
|
131 |
-
)
|
132 |
-
|
133 |
-
return title, document_chunks
|
134 |
-
|
135 |
-
def get_txt(temp_file_path: str, title: str):
|
136 |
-
"""
|
137 |
-
Function to process TXT files
|
138 |
-
"""
|
139 |
-
loader = TextLoader(temp_file_path, autodetect_encoding=True)
|
140 |
-
|
141 |
-
if self.splitter:
|
142 |
-
document_chunks = self.splitter.split_documents(loader.load())
|
143 |
-
else:
|
144 |
-
document_chunks = loader.load()
|
145 |
-
|
146 |
-
# Update the metadata
|
147 |
-
for chunk in document_chunks:
|
148 |
-
chunk.metadata["source"] = title
|
149 |
-
chunk.metadata["page"] = "N/A"
|
150 |
-
|
151 |
-
return title, document_chunks
|
152 |
-
|
153 |
-
def get_srt(temp_file_path: str, title: str):
|
154 |
-
"""
|
155 |
-
Function to process SRT files
|
156 |
-
"""
|
157 |
-
subs = pysrt.open(temp_file_path)
|
158 |
-
|
159 |
-
text = ""
|
160 |
-
for sub in subs:
|
161 |
-
text += sub.text
|
162 |
-
document_chunks = [Document(page_content=text)]
|
163 |
-
|
164 |
-
if self.splitter:
|
165 |
-
document_chunks = self.splitter.split_documents(document_chunks)
|
166 |
-
|
167 |
-
# Update the metadata
|
168 |
-
for chunk in document_chunks:
|
169 |
-
chunk.metadata["source"] = title
|
170 |
-
chunk.metadata["page"] = "N/A"
|
171 |
-
|
172 |
-
return title, document_chunks
|
173 |
-
|
174 |
-
def get_docx(temp_file_path: str, title: str):
|
175 |
-
"""
|
176 |
-
Function to process DOCX files
|
177 |
-
"""
|
178 |
-
loader = Docx2txtLoader(temp_file_path)
|
179 |
-
|
180 |
-
if self.splitter:
|
181 |
-
document_chunks = self.splitter.split_documents(loader.load())
|
182 |
-
else:
|
183 |
-
document_chunks = loader.load()
|
184 |
-
|
185 |
-
# Update the metadata
|
186 |
-
for chunk in document_chunks:
|
187 |
-
chunk.metadata["source"] = title
|
188 |
-
chunk.metadata["page"] = "N/A"
|
189 |
-
|
190 |
-
return title, document_chunks
|
191 |
-
|
192 |
-
def get_youtube_transcript(url: str):
|
193 |
-
"""
|
194 |
-
Function to retrieve youtube transcript and process text
|
195 |
-
"""
|
196 |
-
loader = YoutubeLoader.from_youtube_url(
|
197 |
-
url, add_video_info=True, language=["en"], translation="en"
|
198 |
-
)
|
199 |
-
|
200 |
-
if self.splitter:
|
201 |
-
document_chunks = self.splitter.split_documents(loader.load())
|
202 |
-
else:
|
203 |
-
document_chunks = loader.load_and_split()
|
204 |
-
|
205 |
-
# Replace the source with title (for display in st UI later)
|
206 |
-
for chunk in document_chunks:
|
207 |
-
chunk.metadata["source"] = chunk.metadata["title"]
|
208 |
-
logger.info(chunk.metadata["title"])
|
209 |
-
|
210 |
-
return title, document_chunks
|
211 |
-
|
212 |
-
def get_html(url: str):
|
213 |
-
"""
|
214 |
-
Function to process websites via HTML files
|
215 |
-
"""
|
216 |
-
loader = WebBaseLoader(url)
|
217 |
-
|
218 |
-
if self.splitter:
|
219 |
-
document_chunks = self.splitter.split_documents(loader.load())
|
220 |
-
else:
|
221 |
-
document_chunks = loader.load_and_split()
|
222 |
-
|
223 |
-
title = document_chunks[0].metadata["title"]
|
224 |
-
logger.info(document_chunks[0].metadata)
|
225 |
-
|
226 |
-
return title, document_chunks
|
227 |
-
|
228 |
-
# Handle file by file
|
229 |
-
for file_index, file_path in enumerate(uploaded_files):
|
230 |
-
|
231 |
-
file_name = file_path.split("/")[-1]
|
232 |
-
file_type = file_name.split(".")[-1]
|
233 |
-
|
234 |
-
# Handle different file types
|
235 |
-
if file_type == "pdf":
|
236 |
-
try:
|
237 |
-
title, document_chunks = get_pdf(file_path, file_name)
|
238 |
-
except:
|
239 |
-
title, document_chunks = get_pdf_from_url(file_path)
|
240 |
-
elif file_type == "txt":
|
241 |
-
title, document_chunks = get_txt(file_path, file_name)
|
242 |
-
elif file_type == "docx":
|
243 |
-
title, document_chunks = get_docx(file_path, file_name)
|
244 |
-
elif file_type == "srt":
|
245 |
-
title, document_chunks = get_srt(file_path, file_name)
|
246 |
-
|
247 |
-
# Additional wrangling - Remove leftover delimiters and any specified chunks
|
248 |
-
if self.remove_leftover_delimiters:
|
249 |
-
document_chunks = remove_delimiters(document_chunks)
|
250 |
-
if self.config["splitter_options"]["remove_chunks"]:
|
251 |
-
document_chunks = remove_chunks(document_chunks)
|
252 |
-
|
253 |
-
logger.info(f"\t\tExtracted no. of chunks: {len(document_chunks)} from {file_name}")
|
254 |
-
self.document_names.append(title)
|
255 |
-
self.document_chunks_full.extend(document_chunks)
|
256 |
-
|
257 |
-
# Handle youtube links:
|
258 |
-
if weblinks[0] != "":
|
259 |
-
logger.info(f"Splitting weblinks: total of {len(weblinks)}")
|
260 |
-
|
261 |
-
# Handle link by link
|
262 |
-
for link_index, link in enumerate(weblinks):
|
263 |
-
try:
|
264 |
-
logger.info(f"\tSplitting link {link_index+1} : {link}")
|
265 |
-
if "youtube" in link:
|
266 |
-
title, document_chunks = get_youtube_transcript(link)
|
267 |
-
else:
|
268 |
-
title, document_chunks = get_html(link)
|
269 |
-
|
270 |
-
# Additional wrangling - Remove leftover delimiters and any specified chunks
|
271 |
-
if self.remove_leftover_delimiters:
|
272 |
-
document_chunks = remove_delimiters(document_chunks)
|
273 |
-
if self.config["splitter_options"]["remove_chunks"]:
|
274 |
-
document_chunks = remove_chunks(document_chunks)
|
275 |
-
|
276 |
-
print(f"\t\tExtracted no. of chunks: {len(document_chunks)}")
|
277 |
-
self.document_names.append(title)
|
278 |
-
self.document_chunks_full.extend(document_chunks)
|
279 |
-
except:
|
280 |
-
logger.info(f"\t\tError splitting link {link_index+1} : {link}")
|
281 |
-
exit()
|
282 |
-
|
283 |
-
logger.info(
|
284 |
-
f"\tNumber of document chunks extracted in total: {len(self.document_chunks_full)}\n\n"
|
285 |
-
)
|
286 |
-
|
287 |
-
return self.document_chunks_full, self.document_names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/modules/dataloader/__init__.py
ADDED
File without changes
|
code/modules/dataloader/data_loader.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import requests
|
4 |
+
import pysrt
|
5 |
+
from langchain_community.document_loaders import (
|
6 |
+
PyMuPDFLoader,
|
7 |
+
Docx2txtLoader,
|
8 |
+
YoutubeLoader,
|
9 |
+
WebBaseLoader,
|
10 |
+
TextLoader,
|
11 |
+
)
|
12 |
+
from langchain_community.document_loaders import UnstructuredMarkdownLoader
|
13 |
+
from llama_parse import LlamaParse
|
14 |
+
from langchain.schema import Document
|
15 |
+
import logging
|
16 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
17 |
+
from ragatouille import RAGPretrainedModel
|
18 |
+
from langchain.chains import LLMChain
|
19 |
+
from langchain_community.llms import OpenAI
|
20 |
+
from langchain import PromptTemplate
|
21 |
+
import json
|
22 |
+
from concurrent.futures import ThreadPoolExecutor
|
23 |
+
|
24 |
+
from modules.dataloader.helpers import get_metadata
|
25 |
+
|
26 |
+
|
27 |
+
class PDFReader:
|
28 |
+
def __init__(self):
|
29 |
+
pass
|
30 |
+
|
31 |
+
def get_loader(self, pdf_path):
|
32 |
+
loader = PyMuPDFLoader(pdf_path)
|
33 |
+
return loader
|
34 |
+
|
35 |
+
def get_documents(self, loader):
|
36 |
+
return loader.load()
|
37 |
+
|
38 |
+
|
39 |
+
class FileReader:
|
40 |
+
def __init__(self, logger):
|
41 |
+
self.pdf_reader = PDFReader()
|
42 |
+
self.logger = logger
|
43 |
+
|
44 |
+
def extract_text_from_pdf(self, pdf_path):
|
45 |
+
text = ""
|
46 |
+
with open(pdf_path, "rb") as file:
|
47 |
+
reader = PyPDF2.PdfReader(file)
|
48 |
+
num_pages = len(reader.pages)
|
49 |
+
for page_num in range(num_pages):
|
50 |
+
page = reader.pages[page_num]
|
51 |
+
text += page.extract_text()
|
52 |
+
return text
|
53 |
+
|
54 |
+
def download_pdf_from_url(self, pdf_url):
|
55 |
+
response = requests.get(pdf_url)
|
56 |
+
if response.status_code == 200:
|
57 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
|
58 |
+
temp_file.write(response.content)
|
59 |
+
temp_file_path = temp_file.name
|
60 |
+
return temp_file_path
|
61 |
+
else:
|
62 |
+
self.logger.error(f"Failed to download PDF from URL: {pdf_url}")
|
63 |
+
return None
|
64 |
+
|
65 |
+
def read_pdf(self, temp_file_path: str):
|
66 |
+
loader = self.pdf_reader.get_loader(temp_file_path)
|
67 |
+
documents = self.pdf_reader.get_documents(loader)
|
68 |
+
return documents
|
69 |
+
|
70 |
+
def read_txt(self, temp_file_path: str):
|
71 |
+
loader = TextLoader(temp_file_path, autodetect_encoding=True)
|
72 |
+
return loader.load()
|
73 |
+
|
74 |
+
def read_docx(self, temp_file_path: str):
|
75 |
+
loader = Docx2txtLoader(temp_file_path)
|
76 |
+
return loader.load()
|
77 |
+
|
78 |
+
def read_srt(self, temp_file_path: str):
|
79 |
+
subs = pysrt.open(temp_file_path)
|
80 |
+
text = ""
|
81 |
+
for sub in subs:
|
82 |
+
text += sub.text
|
83 |
+
return [Document(page_content=text)]
|
84 |
+
|
85 |
+
def read_youtube_transcript(self, url: str):
|
86 |
+
loader = YoutubeLoader.from_youtube_url(
|
87 |
+
url, add_video_info=True, language=["en"], translation="en"
|
88 |
+
)
|
89 |
+
return loader.load()
|
90 |
+
|
91 |
+
def read_html(self, url: str):
|
92 |
+
loader = WebBaseLoader(url)
|
93 |
+
return loader.load()
|
94 |
+
|
95 |
+
def read_tex_from_url(self, tex_url):
|
96 |
+
response = requests.get(tex_url)
|
97 |
+
if response.status_code == 200:
|
98 |
+
return [Document(page_content=response.text)]
|
99 |
+
else:
|
100 |
+
self.logger.error(f"Failed to fetch .tex file from URL: {tex_url}")
|
101 |
+
return None
|
102 |
+
|
103 |
+
|
104 |
+
class ChunkProcessor:
|
105 |
+
def __init__(self, config, logger):
|
106 |
+
self.config = config
|
107 |
+
self.logger = logger
|
108 |
+
|
109 |
+
self.document_data = {}
|
110 |
+
self.document_metadata = {}
|
111 |
+
self.document_chunks_full = []
|
112 |
+
|
113 |
+
if config["splitter_options"]["use_splitter"]:
|
114 |
+
if config["splitter_options"]["split_by_token"]:
|
115 |
+
self.splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
116 |
+
chunk_size=config["splitter_options"]["chunk_size"],
|
117 |
+
chunk_overlap=config["splitter_options"]["chunk_overlap"],
|
118 |
+
separators=config["splitter_options"]["chunk_separators"],
|
119 |
+
disallowed_special=(),
|
120 |
+
)
|
121 |
+
else:
|
122 |
+
self.splitter = RecursiveCharacterTextSplitter(
|
123 |
+
chunk_size=config["splitter_options"]["chunk_size"],
|
124 |
+
chunk_overlap=config["splitter_options"]["chunk_overlap"],
|
125 |
+
separators=config["splitter_options"]["chunk_separators"],
|
126 |
+
disallowed_special=(),
|
127 |
+
)
|
128 |
+
else:
|
129 |
+
self.splitter = None
|
130 |
+
self.logger.info("ChunkProcessor instance created")
|
131 |
+
|
132 |
+
def remove_delimiters(self, document_chunks: list):
|
133 |
+
for chunk in document_chunks:
|
134 |
+
for delimiter in self.config["splitter_options"]["delimiters_to_remove"]:
|
135 |
+
chunk.page_content = re.sub(delimiter, " ", chunk.page_content)
|
136 |
+
return document_chunks
|
137 |
+
|
138 |
+
def remove_chunks(self, document_chunks: list):
|
139 |
+
front = self.config["splitter_options"]["front_chunk_to_remove"]
|
140 |
+
end = self.config["splitter_options"]["last_chunks_to_remove"]
|
141 |
+
for _ in range(front):
|
142 |
+
del document_chunks[0]
|
143 |
+
for _ in range(end):
|
144 |
+
document_chunks.pop()
|
145 |
+
return document_chunks
|
146 |
+
|
147 |
+
def process_chunks(
|
148 |
+
self, documents, file_type="txt", source="", page=0, metadata={}
|
149 |
+
):
|
150 |
+
documents = [Document(page_content=documents, source=source, page=page)]
|
151 |
+
if (
|
152 |
+
file_type == "txt"
|
153 |
+
or file_type == "docx"
|
154 |
+
or file_type == "srt"
|
155 |
+
or file_type == "tex"
|
156 |
+
):
|
157 |
+
document_chunks = self.splitter.split_documents(documents)
|
158 |
+
elif file_type == "pdf":
|
159 |
+
document_chunks = documents # Full page for now
|
160 |
+
|
161 |
+
# add the source and page number back to the metadata
|
162 |
+
for chunk in document_chunks:
|
163 |
+
chunk.metadata["source"] = source
|
164 |
+
chunk.metadata["page"] = page
|
165 |
+
|
166 |
+
# add the metadata extracted from the document
|
167 |
+
for key, value in metadata.items():
|
168 |
+
chunk.metadata[key] = value
|
169 |
+
|
170 |
+
if self.config["splitter_options"]["remove_leftover_delimiters"]:
|
171 |
+
document_chunks = self.remove_delimiters(document_chunks)
|
172 |
+
if self.config["splitter_options"]["remove_chunks"]:
|
173 |
+
document_chunks = self.remove_chunks(document_chunks)
|
174 |
+
|
175 |
+
return document_chunks
|
176 |
+
|
177 |
+
def chunk_docs(self, file_reader, uploaded_files, weblinks):
|
178 |
+
addl_metadata = get_metadata(
|
179 |
+
"https://dl4ds.github.io/sp2024/lectures/",
|
180 |
+
"https://dl4ds.github.io/sp2024/schedule/",
|
181 |
+
) # For any additional metadata
|
182 |
+
|
183 |
+
with ThreadPoolExecutor() as executor:
|
184 |
+
executor.map(
|
185 |
+
self.process_file,
|
186 |
+
uploaded_files,
|
187 |
+
range(len(uploaded_files)),
|
188 |
+
[file_reader] * len(uploaded_files),
|
189 |
+
[addl_metadata] * len(uploaded_files),
|
190 |
+
)
|
191 |
+
executor.map(
|
192 |
+
self.process_weblink,
|
193 |
+
weblinks,
|
194 |
+
range(len(weblinks)),
|
195 |
+
[file_reader] * len(weblinks),
|
196 |
+
[addl_metadata] * len(weblinks),
|
197 |
+
)
|
198 |
+
|
199 |
+
document_names = [
|
200 |
+
f"{file_name}_{page_num}"
|
201 |
+
for file_name, pages in self.document_data.items()
|
202 |
+
for page_num in pages.keys()
|
203 |
+
]
|
204 |
+
documents = [
|
205 |
+
page for doc in self.document_data.values() for page in doc.values()
|
206 |
+
]
|
207 |
+
document_metadata = [
|
208 |
+
page for doc in self.document_metadata.values() for page in doc.values()
|
209 |
+
]
|
210 |
+
|
211 |
+
self.save_document_data()
|
212 |
+
|
213 |
+
self.logger.info(
|
214 |
+
f"Total document chunks extracted: {len(self.document_chunks_full)}"
|
215 |
+
)
|
216 |
+
|
217 |
+
return self.document_chunks_full, document_names, documents, document_metadata
|
218 |
+
|
219 |
+
def process_documents(
|
220 |
+
self, documents, file_path, file_type, metadata_source, addl_metadata
|
221 |
+
):
|
222 |
+
file_data = {}
|
223 |
+
file_metadata = {}
|
224 |
+
|
225 |
+
for doc in documents:
|
226 |
+
# if len(doc.page_content) <= 400: # better approach to filter out non-informative documents
|
227 |
+
# continue
|
228 |
+
|
229 |
+
page_num = doc.metadata.get("page", 0)
|
230 |
+
file_data[page_num] = doc.page_content
|
231 |
+
metadata = (
|
232 |
+
addl_metadata.get(file_path, {})
|
233 |
+
if metadata_source == "file"
|
234 |
+
else {"source": file_path, "page": page_num}
|
235 |
+
)
|
236 |
+
file_metadata[page_num] = metadata
|
237 |
+
|
238 |
+
if self.config["vectorstore"]["db_option"] not in ["RAGatouille"]:
|
239 |
+
document_chunks = self.process_chunks(
|
240 |
+
doc.page_content,
|
241 |
+
file_type,
|
242 |
+
source=file_path,
|
243 |
+
page=page_num,
|
244 |
+
metadata=metadata,
|
245 |
+
)
|
246 |
+
self.document_chunks_full.extend(document_chunks)
|
247 |
+
|
248 |
+
self.document_data[file_path] = file_data
|
249 |
+
self.document_metadata[file_path] = file_metadata
|
250 |
+
|
251 |
+
def process_file(self, file_path, file_index, file_reader, addl_metadata):
|
252 |
+
file_name = os.path.basename(file_path)
|
253 |
+
if file_name in self.document_data:
|
254 |
+
return
|
255 |
+
|
256 |
+
file_type = file_name.split(".")[-1].lower()
|
257 |
+
self.logger.info(f"Reading file {file_index + 1}: {file_path}")
|
258 |
+
|
259 |
+
read_methods = {
|
260 |
+
"pdf": file_reader.read_pdf,
|
261 |
+
"txt": file_reader.read_txt,
|
262 |
+
"docx": file_reader.read_docx,
|
263 |
+
"srt": file_reader.read_srt,
|
264 |
+
"tex": file_reader.read_tex_from_url,
|
265 |
+
}
|
266 |
+
if file_type not in read_methods:
|
267 |
+
self.logger.warning(f"Unsupported file type: {file_type}")
|
268 |
+
return
|
269 |
+
|
270 |
+
try:
|
271 |
+
documents = read_methods[file_type](file_path)
|
272 |
+
self.process_documents(
|
273 |
+
documents, file_path, file_type, "file", addl_metadata
|
274 |
+
)
|
275 |
+
except Exception as e:
|
276 |
+
self.logger.error(f"Error processing file {file_name}: {str(e)}")
|
277 |
+
|
278 |
+
def process_weblink(self, link, link_index, file_reader, addl_metadata):
|
279 |
+
if link in self.document_data:
|
280 |
+
return
|
281 |
+
|
282 |
+
self.logger.info(f"Reading link {link_index + 1} : {link}")
|
283 |
+
|
284 |
+
try:
|
285 |
+
if "youtube" in link:
|
286 |
+
documents = file_reader.read_youtube_transcript(link)
|
287 |
+
else:
|
288 |
+
documents = file_reader.read_html(link)
|
289 |
+
|
290 |
+
self.process_documents(documents, link, "txt", "link", addl_metadata)
|
291 |
+
except Exception as e:
|
292 |
+
self.logger.error(f"Error Reading link {link_index + 1} : {link}: {str(e)}")
|
293 |
+
|
294 |
+
def save_document_data(self):
|
295 |
+
if not os.path.exists(f"{self.config['log_chunk_dir']}/docs"):
|
296 |
+
os.makedirs(f"{self.config['log_chunk_dir']}/docs")
|
297 |
+
self.logger.info(
|
298 |
+
f"Creating directory {self.config['log_chunk_dir']}/docs for document data"
|
299 |
+
)
|
300 |
+
self.logger.info(
|
301 |
+
f"Saving document content to {self.config['log_chunk_dir']}/docs/doc_content.json"
|
302 |
+
)
|
303 |
+
if not os.path.exists(f"{self.config['log_chunk_dir']}/metadata"):
|
304 |
+
os.makedirs(f"{self.config['log_chunk_dir']}/metadata")
|
305 |
+
self.logger.info(
|
306 |
+
f"Creating directory {self.config['log_chunk_dir']}/metadata for document metadata"
|
307 |
+
)
|
308 |
+
self.logger.info(
|
309 |
+
f"Saving document metadata to {self.config['log_chunk_dir']}/metadata/doc_metadata.json"
|
310 |
+
)
|
311 |
+
with open(
|
312 |
+
f"{self.config['log_chunk_dir']}/docs/doc_content.json", "w"
|
313 |
+
) as json_file:
|
314 |
+
json.dump(self.document_data, json_file, indent=4)
|
315 |
+
with open(
|
316 |
+
f"{self.config['log_chunk_dir']}/metadata/doc_metadata.json", "w"
|
317 |
+
) as json_file:
|
318 |
+
json.dump(self.document_metadata, json_file, indent=4)
|
319 |
+
|
320 |
+
def load_document_data(self):
|
321 |
+
with open(
|
322 |
+
f"{self.config['log_chunk_dir']}/docs/doc_content.json", "r"
|
323 |
+
) as json_file:
|
324 |
+
self.document_data = json.load(json_file)
|
325 |
+
with open(
|
326 |
+
f"{self.config['log_chunk_dir']}/metadata/doc_metadata.json", "r"
|
327 |
+
) as json_file:
|
328 |
+
self.document_metadata = json.load(json_file)
|
329 |
+
|
330 |
+
|
331 |
+
class DataLoader:
|
332 |
+
def __init__(self, config, logger=None):
|
333 |
+
self.file_reader = FileReader(logger=logger)
|
334 |
+
self.chunk_processor = ChunkProcessor(config, logger=logger)
|
335 |
+
|
336 |
+
def get_chunks(self, uploaded_files, weblinks):
|
337 |
+
return self.chunk_processor.chunk_docs(
|
338 |
+
self.file_reader, uploaded_files, weblinks
|
339 |
+
)
|
340 |
+
|
341 |
+
|
342 |
+
if __name__ == "__main__":
|
343 |
+
import yaml
|
344 |
+
|
345 |
+
logger = logging.getLogger(__name__)
|
346 |
+
logger.setLevel(logging.INFO)
|
347 |
+
|
348 |
+
with open("../code/modules/config/config.yml", "r") as f:
|
349 |
+
config = yaml.safe_load(f)
|
350 |
+
|
351 |
+
data_loader = DataLoader(config, logger=logger)
|
352 |
+
document_chunks, document_names, documents, document_metadata = (
|
353 |
+
data_loader.get_chunks(
|
354 |
+
[],
|
355 |
+
["https://dl4ds.github.io/sp2024/"],
|
356 |
+
)
|
357 |
+
)
|
358 |
+
|
359 |
+
print(document_names)
|
360 |
+
print(len(document_chunks))
|
code/modules/dataloader/helpers.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
|
6 |
+
def get_urls_from_file(file_path: str):
|
7 |
+
"""
|
8 |
+
Function to get urls from a file
|
9 |
+
"""
|
10 |
+
with open(file_path, "r") as f:
|
11 |
+
urls = f.readlines()
|
12 |
+
urls = [url.strip() for url in urls]
|
13 |
+
return urls
|
14 |
+
|
15 |
+
|
16 |
+
def get_base_url(url):
|
17 |
+
parsed_url = urlparse(url)
|
18 |
+
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}/"
|
19 |
+
return base_url
|
20 |
+
|
21 |
+
|
22 |
+
def get_metadata(lectures_url, schedule_url):
|
23 |
+
"""
|
24 |
+
Function to get the lecture metadata from the lectures and schedule URLs.
|
25 |
+
"""
|
26 |
+
lecture_metadata = {}
|
27 |
+
|
28 |
+
# Get the main lectures page content
|
29 |
+
r_lectures = requests.get(lectures_url)
|
30 |
+
soup_lectures = BeautifulSoup(r_lectures.text, "html.parser")
|
31 |
+
|
32 |
+
# Get the main schedule page content
|
33 |
+
r_schedule = requests.get(schedule_url)
|
34 |
+
soup_schedule = BeautifulSoup(r_schedule.text, "html.parser")
|
35 |
+
|
36 |
+
# Find all lecture blocks
|
37 |
+
lecture_blocks = soup_lectures.find_all("div", class_="lecture-container")
|
38 |
+
|
39 |
+
# Create a mapping from slides link to date
|
40 |
+
date_mapping = {}
|
41 |
+
schedule_rows = soup_schedule.find_all("li", class_="table-row-lecture")
|
42 |
+
for row in schedule_rows:
|
43 |
+
try:
|
44 |
+
date = (
|
45 |
+
row.find("div", {"data-label": "Date"}).get_text(separator=" ").strip()
|
46 |
+
)
|
47 |
+
description_div = row.find("div", {"data-label": "Description"})
|
48 |
+
slides_link_tag = description_div.find("a", title="Download slides")
|
49 |
+
slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
|
50 |
+
slides_link = (
|
51 |
+
f"https://dl4ds.github.io{slides_link}" if slides_link else None
|
52 |
+
)
|
53 |
+
if slides_link:
|
54 |
+
date_mapping[slides_link] = date
|
55 |
+
except Exception as e:
|
56 |
+
print(f"Error processing schedule row: {e}")
|
57 |
+
continue
|
58 |
+
|
59 |
+
for block in lecture_blocks:
|
60 |
+
try:
|
61 |
+
# Extract the lecture title
|
62 |
+
title = block.find("span", style="font-weight: bold;").text.strip()
|
63 |
+
|
64 |
+
# Extract the TL;DR
|
65 |
+
tldr = block.find("strong", text="tl;dr:").next_sibling.strip()
|
66 |
+
|
67 |
+
# Extract the link to the slides
|
68 |
+
slides_link_tag = block.find("a", title="Download slides")
|
69 |
+
slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
|
70 |
+
slides_link = (
|
71 |
+
f"https://dl4ds.github.io{slides_link}" if slides_link else None
|
72 |
+
)
|
73 |
+
|
74 |
+
# Extract the link to the lecture recording
|
75 |
+
recording_link_tag = block.find("a", title="Download lecture recording")
|
76 |
+
recording_link = (
|
77 |
+
recording_link_tag["href"].strip() if recording_link_tag else None
|
78 |
+
)
|
79 |
+
|
80 |
+
# Extract suggested readings or summary if available
|
81 |
+
suggested_readings_tag = block.find("p", text="Suggested Readings:")
|
82 |
+
if suggested_readings_tag:
|
83 |
+
suggested_readings = suggested_readings_tag.find_next_sibling("ul")
|
84 |
+
if suggested_readings:
|
85 |
+
suggested_readings = suggested_readings.get_text(
|
86 |
+
separator="\n"
|
87 |
+
).strip()
|
88 |
+
else:
|
89 |
+
suggested_readings = "No specific readings provided."
|
90 |
+
else:
|
91 |
+
suggested_readings = "No specific readings provided."
|
92 |
+
|
93 |
+
# Get the date from the schedule
|
94 |
+
date = date_mapping.get(slides_link, "No date available")
|
95 |
+
|
96 |
+
# Add to the dictionary
|
97 |
+
lecture_metadata[slides_link] = {
|
98 |
+
"date": date,
|
99 |
+
"tldr": tldr,
|
100 |
+
"title": title,
|
101 |
+
"lecture_recording": recording_link,
|
102 |
+
"suggested_readings": suggested_readings,
|
103 |
+
}
|
104 |
+
except Exception as e:
|
105 |
+
print(f"Error processing block: {e}")
|
106 |
+
continue
|
107 |
+
|
108 |
+
return lecture_metadata
|
code/modules/dataloader/webpage_crawler.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import aiohttp
|
2 |
+
from aiohttp import ClientSession
|
3 |
+
import asyncio
|
4 |
+
import requests
|
5 |
+
from bs4 import BeautifulSoup
|
6 |
+
from urllib.parse import urlparse, urljoin, urldefrag
|
7 |
+
|
8 |
+
class WebpageCrawler:
|
9 |
+
def __init__(self):
|
10 |
+
self.dict_href_links = {}
|
11 |
+
|
12 |
+
async def fetch(self, session: ClientSession, url: str) -> str:
|
13 |
+
async with session.get(url) as response:
|
14 |
+
try:
|
15 |
+
return await response.text()
|
16 |
+
except UnicodeDecodeError:
|
17 |
+
return await response.text(encoding="latin1")
|
18 |
+
|
19 |
+
def url_exists(self, url: str) -> bool:
|
20 |
+
try:
|
21 |
+
response = requests.head(url)
|
22 |
+
return response.status_code == 200
|
23 |
+
except requests.ConnectionError:
|
24 |
+
return False
|
25 |
+
|
26 |
+
async def get_links(self, session: ClientSession, website_link: str, base_url: str):
|
27 |
+
html_data = await self.fetch(session, website_link)
|
28 |
+
soup = BeautifulSoup(html_data, "html.parser")
|
29 |
+
list_links = []
|
30 |
+
for link in soup.find_all("a", href=True):
|
31 |
+
href = link["href"].strip()
|
32 |
+
full_url = urljoin(base_url, href)
|
33 |
+
normalized_url = self.normalize_url(full_url) # sections removed
|
34 |
+
if (
|
35 |
+
normalized_url not in self.dict_href_links
|
36 |
+
and self.is_child_url(normalized_url, base_url)
|
37 |
+
and self.url_exists(normalized_url)
|
38 |
+
):
|
39 |
+
self.dict_href_links[normalized_url] = None
|
40 |
+
list_links.append(normalized_url)
|
41 |
+
|
42 |
+
return list_links
|
43 |
+
|
44 |
+
async def get_subpage_links(
|
45 |
+
self, session: ClientSession, urls: list, base_url: str
|
46 |
+
):
|
47 |
+
tasks = [self.get_links(session, url, base_url) for url in urls]
|
48 |
+
results = await asyncio.gather(*tasks)
|
49 |
+
all_links = [link for sublist in results for link in sublist]
|
50 |
+
return all_links
|
51 |
+
|
52 |
+
async def get_all_pages(self, url: str, base_url: str):
|
53 |
+
async with aiohttp.ClientSession() as session:
|
54 |
+
dict_links = {url: "Not-checked"}
|
55 |
+
counter = None
|
56 |
+
while counter != 0:
|
57 |
+
unchecked_links = [
|
58 |
+
link
|
59 |
+
for link, status in dict_links.items()
|
60 |
+
if status == "Not-checked"
|
61 |
+
]
|
62 |
+
if not unchecked_links:
|
63 |
+
break
|
64 |
+
new_links = await self.get_subpage_links(
|
65 |
+
session, unchecked_links, base_url
|
66 |
+
)
|
67 |
+
for link in unchecked_links:
|
68 |
+
dict_links[link] = "Checked"
|
69 |
+
print(f"Checked: {link}")
|
70 |
+
dict_links.update(
|
71 |
+
{
|
72 |
+
link: "Not-checked"
|
73 |
+
for link in new_links
|
74 |
+
if link not in dict_links
|
75 |
+
}
|
76 |
+
)
|
77 |
+
counter = len(
|
78 |
+
[
|
79 |
+
status
|
80 |
+
for status in dict_links.values()
|
81 |
+
if status == "Not-checked"
|
82 |
+
]
|
83 |
+
)
|
84 |
+
|
85 |
+
checked_urls = [
|
86 |
+
url for url, status in dict_links.items() if status == "Checked"
|
87 |
+
]
|
88 |
+
return checked_urls
|
89 |
+
|
90 |
+
def is_webpage(self, url: str) -> bool:
|
91 |
+
try:
|
92 |
+
response = requests.head(url, allow_redirects=True)
|
93 |
+
content_type = response.headers.get("Content-Type", "").lower()
|
94 |
+
return "text/html" in content_type
|
95 |
+
except requests.RequestException:
|
96 |
+
return False
|
97 |
+
|
98 |
+
def clean_url_list(self, urls):
|
99 |
+
files, webpages = [], []
|
100 |
+
|
101 |
+
for url in urls:
|
102 |
+
if self.is_webpage(url):
|
103 |
+
webpages.append(url)
|
104 |
+
else:
|
105 |
+
files.append(url)
|
106 |
+
|
107 |
+
return files, webpages
|
108 |
+
|
109 |
+
def is_child_url(self, url, base_url):
|
110 |
+
return url.startswith(base_url)
|
111 |
+
|
112 |
+
def normalize_url(self, url: str):
|
113 |
+
# Strip the fragment identifier
|
114 |
+
defragged_url, _ = urldefrag(url)
|
115 |
+
return defragged_url
|
code/modules/helpers.py
DELETED
@@ -1,200 +0,0 @@
|
|
1 |
-
import requests
|
2 |
-
from bs4 import BeautifulSoup
|
3 |
-
from tqdm import tqdm
|
4 |
-
from urllib.parse import urlparse
|
5 |
-
import chainlit as cl
|
6 |
-
from langchain import PromptTemplate
|
7 |
-
try:
|
8 |
-
from modules.constants import *
|
9 |
-
except:
|
10 |
-
from constants import *
|
11 |
-
|
12 |
-
"""
|
13 |
-
Ref: https://python.plainenglish.io/scraping-the-subpages-on-a-website-ea2d4e3db113
|
14 |
-
"""
|
15 |
-
|
16 |
-
|
17 |
-
class WebpageCrawler:
|
18 |
-
def __init__(self):
|
19 |
-
pass
|
20 |
-
|
21 |
-
def getdata(self, url):
|
22 |
-
r = requests.get(url)
|
23 |
-
return r.text
|
24 |
-
|
25 |
-
def url_exists(self, url):
|
26 |
-
try:
|
27 |
-
response = requests.head(url)
|
28 |
-
return response.status_code == 200
|
29 |
-
except requests.ConnectionError:
|
30 |
-
return False
|
31 |
-
|
32 |
-
def get_links(self, website_link, base_url=None):
|
33 |
-
if base_url is None:
|
34 |
-
base_url = website_link
|
35 |
-
html_data = self.getdata(website_link)
|
36 |
-
soup = BeautifulSoup(html_data, "html.parser")
|
37 |
-
list_links = []
|
38 |
-
for link in soup.find_all("a", href=True):
|
39 |
-
|
40 |
-
# clean the link
|
41 |
-
# remove empty spaces
|
42 |
-
link["href"] = link["href"].strip()
|
43 |
-
# Append to list if new link contains original link
|
44 |
-
if str(link["href"]).startswith((str(website_link))):
|
45 |
-
list_links.append(link["href"])
|
46 |
-
|
47 |
-
# Include all href that do not start with website link but with "/"
|
48 |
-
if str(link["href"]).startswith("/"):
|
49 |
-
if link["href"] not in self.dict_href_links:
|
50 |
-
print(link["href"])
|
51 |
-
self.dict_href_links[link["href"]] = None
|
52 |
-
link_with_www = base_url + link["href"][1:]
|
53 |
-
if self.url_exists(link_with_www):
|
54 |
-
print("adjusted link =", link_with_www)
|
55 |
-
list_links.append(link_with_www)
|
56 |
-
|
57 |
-
# Convert list of links to dictionary and define keys as the links and the values as "Not-checked"
|
58 |
-
dict_links = dict.fromkeys(list_links, "Not-checked")
|
59 |
-
return dict_links
|
60 |
-
|
61 |
-
def get_subpage_links(self, l, base_url):
|
62 |
-
for link in tqdm(l):
|
63 |
-
print('checking link:', link)
|
64 |
-
if not link.endswith("/"):
|
65 |
-
l[link] = "Checked"
|
66 |
-
dict_links_subpages = {}
|
67 |
-
else:
|
68 |
-
# If not crawled through this page start crawling and get links
|
69 |
-
if l[link] == "Not-checked":
|
70 |
-
dict_links_subpages = self.get_links(link, base_url)
|
71 |
-
# Change the dictionary value of the link to "Checked"
|
72 |
-
l[link] = "Checked"
|
73 |
-
else:
|
74 |
-
# Create an empty dictionary in case every link is checked
|
75 |
-
dict_links_subpages = {}
|
76 |
-
# Add new dictionary to old dictionary
|
77 |
-
l = {**dict_links_subpages, **l}
|
78 |
-
return l
|
79 |
-
|
80 |
-
def get_all_pages(self, url, base_url):
|
81 |
-
dict_links = {url: "Not-checked"}
|
82 |
-
self.dict_href_links = {}
|
83 |
-
counter, counter2 = None, 0
|
84 |
-
while counter != 0:
|
85 |
-
counter2 += 1
|
86 |
-
dict_links2 = self.get_subpage_links(dict_links, base_url)
|
87 |
-
# Count number of non-values and set counter to 0 if there are no values within the dictionary equal to the string "Not-checked"
|
88 |
-
# https://stackoverflow.com/questions/48371856/count-the-number-of-occurrences-of-a-certain-value-in-a-dictionary-in-python
|
89 |
-
counter = sum(value == "Not-checked" for value in dict_links2.values())
|
90 |
-
dict_links = dict_links2
|
91 |
-
checked_urls = [
|
92 |
-
url for url, status in dict_links.items() if status == "Checked"
|
93 |
-
]
|
94 |
-
return checked_urls
|
95 |
-
|
96 |
-
|
97 |
-
def get_urls_from_file(file_path: str):
|
98 |
-
"""
|
99 |
-
Function to get urls from a file
|
100 |
-
"""
|
101 |
-
with open(file_path, "r") as f:
|
102 |
-
urls = f.readlines()
|
103 |
-
urls = [url.strip() for url in urls]
|
104 |
-
return urls
|
105 |
-
|
106 |
-
|
107 |
-
def get_base_url(url):
|
108 |
-
parsed_url = urlparse(url)
|
109 |
-
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}/"
|
110 |
-
return base_url
|
111 |
-
|
112 |
-
def get_prompt(config):
|
113 |
-
if config["llm_params"]["use_history"]:
|
114 |
-
if config["llm_params"]["llm_loader"] == "local_llm":
|
115 |
-
custom_prompt_template = tinyllama_prompt_template_with_history
|
116 |
-
elif config["llm_params"]["llm_loader"] == "openai":
|
117 |
-
custom_prompt_template = openai_prompt_template_with_history
|
118 |
-
# else:
|
119 |
-
# custom_prompt_template = tinyllama_prompt_template_with_history # default
|
120 |
-
prompt = PromptTemplate(
|
121 |
-
template=custom_prompt_template,
|
122 |
-
input_variables=["context", "chat_history", "question"],
|
123 |
-
)
|
124 |
-
else:
|
125 |
-
if config["llm_params"]["llm_loader"] == "local_llm":
|
126 |
-
custom_prompt_template = tinyllama_prompt_template
|
127 |
-
elif config["llm_params"]["llm_loader"] == "openai":
|
128 |
-
custom_prompt_template = openai_prompt_template
|
129 |
-
# else:
|
130 |
-
# custom_prompt_template = tinyllama_prompt_template
|
131 |
-
prompt = PromptTemplate(
|
132 |
-
template=custom_prompt_template,
|
133 |
-
input_variables=["context", "question"],
|
134 |
-
)
|
135 |
-
return prompt
|
136 |
-
|
137 |
-
def get_sources(res, answer):
|
138 |
-
source_elements_dict = {}
|
139 |
-
source_elements = []
|
140 |
-
found_sources = []
|
141 |
-
|
142 |
-
source_dict = {} # Dictionary to store URL elements
|
143 |
-
|
144 |
-
for idx, source in enumerate(res["source_documents"]):
|
145 |
-
source_metadata = source.metadata
|
146 |
-
url = source_metadata["source"]
|
147 |
-
|
148 |
-
if url not in source_dict:
|
149 |
-
source_dict[url] = [source.page_content]
|
150 |
-
else:
|
151 |
-
source_dict[url].append(source.page_content)
|
152 |
-
|
153 |
-
for source_idx, (url, text_list) in enumerate(source_dict.items()):
|
154 |
-
full_text = ""
|
155 |
-
for url_idx, text in enumerate(text_list):
|
156 |
-
full_text += f"Source {url_idx+1}:\n {text}\n\n\n"
|
157 |
-
source_elements.append(cl.Text(name=url, content=full_text))
|
158 |
-
found_sources.append(url)
|
159 |
-
|
160 |
-
if found_sources:
|
161 |
-
answer += f"\n\nSources: {', '.join(found_sources)} "
|
162 |
-
else:
|
163 |
-
answer += f"\n\nNo source found."
|
164 |
-
|
165 |
-
# for idx, source in enumerate(res["source_documents"]):
|
166 |
-
# title = source.metadata["source"]
|
167 |
-
|
168 |
-
# if title not in source_elements_dict:
|
169 |
-
# source_elements_dict[title] = {
|
170 |
-
# "page_number": [source.metadata["page"]],
|
171 |
-
# "url": source.metadata["source"],
|
172 |
-
# "content": source.page_content,
|
173 |
-
# }
|
174 |
-
|
175 |
-
# else:
|
176 |
-
# source_elements_dict[title]["page_number"].append(source.metadata["page"])
|
177 |
-
# source_elements_dict[title][
|
178 |
-
# "content_" + str(source.metadata["page"])
|
179 |
-
# ] = source.page_content
|
180 |
-
# # sort the page numbers
|
181 |
-
# # source_elements_dict[title]["page_number"].sort()
|
182 |
-
|
183 |
-
# for title, source in source_elements_dict.items():
|
184 |
-
# # create a string for the page numbers
|
185 |
-
# page_numbers = ", ".join([str(x) for x in source["page_number"]])
|
186 |
-
# text_for_source = f"Page Number(s): {page_numbers}\nURL: {source['url']}"
|
187 |
-
# source_elements.append(cl.Pdf(name="File", path=title))
|
188 |
-
# found_sources.append("File")
|
189 |
-
# # for pn in source["page_number"]:
|
190 |
-
# # source_elements.append(
|
191 |
-
# # cl.Text(name=str(pn), content=source["content_"+str(pn)])
|
192 |
-
# # )
|
193 |
-
# # found_sources.append(str(pn))
|
194 |
-
|
195 |
-
# if found_sources:
|
196 |
-
# answer += f"\nSource:{', '.join(found_sources)}"
|
197 |
-
# else:
|
198 |
-
# answer += f"\nNo source found."
|
199 |
-
|
200 |
-
return answer, source_elements
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/modules/llm_tutor.py
DELETED
@@ -1,87 +0,0 @@
|
|
1 |
-
from langchain import PromptTemplate
|
2 |
-
from langchain.embeddings import HuggingFaceEmbeddings
|
3 |
-
from langchain_community.chat_models import ChatOpenAI
|
4 |
-
from langchain_community.embeddings import OpenAIEmbeddings
|
5 |
-
from langchain.vectorstores import FAISS
|
6 |
-
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
|
7 |
-
from langchain.llms import CTransformers
|
8 |
-
from langchain.memory import ConversationBufferWindowMemory
|
9 |
-
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
|
10 |
-
import os
|
11 |
-
|
12 |
-
from modules.constants import *
|
13 |
-
from modules.helpers import get_prompt
|
14 |
-
from modules.chat_model_loader import ChatModelLoader
|
15 |
-
from modules.vector_db import VectorDB
|
16 |
-
|
17 |
-
|
18 |
-
class LLMTutor:
|
19 |
-
def __init__(self, config, logger=None):
|
20 |
-
self.config = config
|
21 |
-
self.vector_db = VectorDB(config, logger=logger)
|
22 |
-
if self.config["embedding_options"]["embedd_files"]:
|
23 |
-
self.vector_db.create_database()
|
24 |
-
self.vector_db.save_database()
|
25 |
-
|
26 |
-
def set_custom_prompt(self):
|
27 |
-
"""
|
28 |
-
Prompt template for QA retrieval for each vectorstore
|
29 |
-
"""
|
30 |
-
prompt = get_prompt(self.config)
|
31 |
-
# prompt = QA_PROMPT
|
32 |
-
|
33 |
-
return prompt
|
34 |
-
|
35 |
-
# Retrieval QA Chain
|
36 |
-
def retrieval_qa_chain(self, llm, prompt, db):
|
37 |
-
if self.config["llm_params"]["use_history"]:
|
38 |
-
memory = ConversationBufferWindowMemory(
|
39 |
-
k = self.config["llm_params"]["memory_window"],
|
40 |
-
memory_key="chat_history", return_messages=True, output_key="answer"
|
41 |
-
)
|
42 |
-
qa_chain = ConversationalRetrievalChain.from_llm(
|
43 |
-
llm=llm,
|
44 |
-
chain_type="stuff",
|
45 |
-
retriever=db.as_retriever(
|
46 |
-
search_kwargs={
|
47 |
-
"k": self.config["embedding_options"]["search_top_k"]
|
48 |
-
}
|
49 |
-
),
|
50 |
-
return_source_documents=True,
|
51 |
-
memory=memory,
|
52 |
-
combine_docs_chain_kwargs={"prompt": prompt},
|
53 |
-
)
|
54 |
-
else:
|
55 |
-
qa_chain = RetrievalQA.from_chain_type(
|
56 |
-
llm=llm,
|
57 |
-
chain_type="stuff",
|
58 |
-
retriever=db.as_retriever(
|
59 |
-
search_kwargs={
|
60 |
-
"k": self.config["embedding_options"]["search_top_k"]
|
61 |
-
}
|
62 |
-
),
|
63 |
-
return_source_documents=True,
|
64 |
-
chain_type_kwargs={"prompt": prompt},
|
65 |
-
)
|
66 |
-
return qa_chain
|
67 |
-
|
68 |
-
# Loading the model
|
69 |
-
def load_llm(self):
|
70 |
-
chat_model_loader = ChatModelLoader(self.config)
|
71 |
-
llm = chat_model_loader.load_chat_model()
|
72 |
-
return llm
|
73 |
-
|
74 |
-
# QA Model Function
|
75 |
-
def qa_bot(self):
|
76 |
-
db = self.vector_db.load_database()
|
77 |
-
self.llm = self.load_llm()
|
78 |
-
qa_prompt = self.set_custom_prompt()
|
79 |
-
qa = self.retrieval_qa_chain(self.llm, qa_prompt, db)
|
80 |
-
|
81 |
-
return qa
|
82 |
-
|
83 |
-
# output function
|
84 |
-
def final_result(query):
|
85 |
-
qa_result = qa_bot()
|
86 |
-
response = qa_result({"query": query})
|
87 |
-
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/modules/retriever/__init__.py
ADDED
File without changes
|
code/modules/retriever/base.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# template for retriever classes
|
2 |
+
|
3 |
+
|
4 |
+
class BaseRetriever:
|
5 |
+
def __init__(self, config):
|
6 |
+
self.config = config
|
7 |
+
|
8 |
+
def return_retriever(self):
|
9 |
+
"""
|
10 |
+
Returns the retriever object
|
11 |
+
"""
|
12 |
+
raise NotImplementedError
|
code/modules/retriever/chroma_retriever.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .helpers import VectorStoreRetrieverScore
|
2 |
+
from .base import BaseRetriever
|
3 |
+
|
4 |
+
|
5 |
+
class ChromaRetriever(BaseRetriever):
|
6 |
+
def __init__(self):
|
7 |
+
pass
|
8 |
+
|
9 |
+
def return_retriever(self, db, config):
|
10 |
+
retriever = VectorStoreRetrieverScore(
|
11 |
+
vectorstore=db,
|
12 |
+
# search_type="similarity_score_threshold",
|
13 |
+
# search_kwargs={
|
14 |
+
# "score_threshold": self.config["vectorstore"][
|
15 |
+
# "score_threshold"
|
16 |
+
# ],
|
17 |
+
# "k": self.config["vectorstore"]["search_top_k"],
|
18 |
+
# },
|
19 |
+
search_kwargs={
|
20 |
+
"k": config["vectorstore"]["search_top_k"],
|
21 |
+
},
|
22 |
+
)
|
23 |
+
|
24 |
+
return retriever
|
code/modules/retriever/colbert_retriever.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import BaseRetriever
|
2 |
+
|
3 |
+
|
4 |
+
class ColbertRetriever(BaseRetriever):
|
5 |
+
def __init__(self):
|
6 |
+
pass
|
7 |
+
|
8 |
+
def return_retriever(self, db, config):
|
9 |
+
retriever = db.as_langchain_retriever(k=config["vectorstore"]["search_top_k"])
|
10 |
+
return retriever
|
code/modules/retriever/faiss_retriever.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .helpers import VectorStoreRetrieverScore
|
2 |
+
from .base import BaseRetriever
|
3 |
+
|
4 |
+
|
5 |
+
class FaissRetriever(BaseRetriever):
|
6 |
+
def __init__(self):
|
7 |
+
pass
|
8 |
+
|
9 |
+
def return_retriever(self, db, config):
|
10 |
+
retriever = VectorStoreRetrieverScore(
|
11 |
+
vectorstore=db,
|
12 |
+
# search_type="similarity_score_threshold",
|
13 |
+
# search_kwargs={
|
14 |
+
# "score_threshold": self.config["vectorstore"][
|
15 |
+
# "score_threshold"
|
16 |
+
# ],
|
17 |
+
# "k": self.config["vectorstore"]["search_top_k"],
|
18 |
+
# },
|
19 |
+
search_kwargs={
|
20 |
+
"k": config["vectorstore"]["search_top_k"],
|
21 |
+
},
|
22 |
+
)
|
23 |
+
return retriever
|
code/modules/retriever/helpers.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.schema.vectorstore import VectorStoreRetriever
|
2 |
+
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
3 |
+
from langchain.schema.document import Document
|
4 |
+
from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
|
8 |
+
class VectorStoreRetrieverScore(VectorStoreRetriever):
|
9 |
+
|
10 |
+
# See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
|
11 |
+
def _get_relevant_documents(
|
12 |
+
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
13 |
+
) -> List[Document]:
|
14 |
+
docs_and_similarities = (
|
15 |
+
self.vectorstore.similarity_search_with_relevance_scores(
|
16 |
+
query, **self.search_kwargs
|
17 |
+
)
|
18 |
+
)
|
19 |
+
# Make the score part of the document metadata
|
20 |
+
for doc, similarity in docs_and_similarities:
|
21 |
+
doc.metadata["score"] = similarity
|
22 |
+
|
23 |
+
docs = [doc for doc, _ in docs_and_similarities]
|
24 |
+
return docs
|
25 |
+
|
26 |
+
async def _aget_relevant_documents(
|
27 |
+
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
28 |
+
) -> List[Document]:
|
29 |
+
docs_and_similarities = (
|
30 |
+
self.vectorstore.similarity_search_with_relevance_scores(
|
31 |
+
query, **self.search_kwargs
|
32 |
+
)
|
33 |
+
)
|
34 |
+
# Make the score part of the document metadata
|
35 |
+
for doc, similarity in docs_and_similarities:
|
36 |
+
doc.metadata["score"] = similarity
|
37 |
+
|
38 |
+
docs = [doc for doc, _ in docs_and_similarities]
|
39 |
+
return docs
|
code/modules/retriever/raptor_retriever.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .helpers import VectorStoreRetrieverScore
|
2 |
+
from .base import BaseRetriever
|
3 |
+
|
4 |
+
|
5 |
+
class RaptorRetriever(BaseRetriever):
|
6 |
+
def __init__(self):
|
7 |
+
pass
|
8 |
+
|
9 |
+
def return_retriever(self, db, config):
|
10 |
+
retriever = VectorStoreRetrieverScore(
|
11 |
+
vectorstore=db,
|
12 |
+
search_kwargs={
|
13 |
+
"k": config["vectorstore"]["search_top_k"],
|
14 |
+
},
|
15 |
+
)
|
16 |
+
return retriever
|
code/modules/retriever/retriever.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.retriever.faiss_retriever import FaissRetriever
|
2 |
+
from modules.retriever.chroma_retriever import ChromaRetriever
|
3 |
+
from modules.retriever.colbert_retriever import ColbertRetriever
|
4 |
+
from modules.retriever.raptor_retriever import RaptorRetriever
|
5 |
+
|
6 |
+
|
7 |
+
class Retriever:
|
8 |
+
def __init__(self, config):
|
9 |
+
self.config = config
|
10 |
+
self.retriever_classes = {
|
11 |
+
"FAISS": FaissRetriever,
|
12 |
+
"Chroma": ChromaRetriever,
|
13 |
+
"RAGatouille": ColbertRetriever,
|
14 |
+
"RAPTOR": RaptorRetriever,
|
15 |
+
}
|
16 |
+
self._create_retriever()
|
17 |
+
|
18 |
+
def _create_retriever(self):
|
19 |
+
db_option = self.config["vectorstore"]["db_option"]
|
20 |
+
retriever_class = self.retriever_classes.get(db_option)
|
21 |
+
if not retriever_class:
|
22 |
+
raise ValueError(f"Invalid db_option: {db_option}")
|
23 |
+
self.retriever = retriever_class()
|
24 |
+
|
25 |
+
def _return_retriever(self, db):
|
26 |
+
return self.retriever.return_retriever(db, self.config)
|
code/modules/vector_db.py
DELETED
@@ -1,133 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import os
|
3 |
-
import yaml
|
4 |
-
from langchain.vectorstores import FAISS
|
5 |
-
|
6 |
-
try:
|
7 |
-
from modules.embedding_model_loader import EmbeddingModelLoader
|
8 |
-
from modules.data_loader import DataLoader
|
9 |
-
from modules.constants import *
|
10 |
-
from modules.helpers import *
|
11 |
-
except:
|
12 |
-
from embedding_model_loader import EmbeddingModelLoader
|
13 |
-
from data_loader import DataLoader
|
14 |
-
from constants import *
|
15 |
-
from helpers import *
|
16 |
-
|
17 |
-
|
18 |
-
class VectorDB:
|
19 |
-
def __init__(self, config, logger=None):
|
20 |
-
self.config = config
|
21 |
-
self.db_option = config["embedding_options"]["db_option"]
|
22 |
-
self.document_names = None
|
23 |
-
self.webpage_crawler = WebpageCrawler()
|
24 |
-
|
25 |
-
# Set up logging to both console and a file
|
26 |
-
if logger is None:
|
27 |
-
self.logger = logging.getLogger(__name__)
|
28 |
-
self.logger.setLevel(logging.INFO)
|
29 |
-
|
30 |
-
# Console Handler
|
31 |
-
console_handler = logging.StreamHandler()
|
32 |
-
console_handler.setLevel(logging.INFO)
|
33 |
-
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
34 |
-
console_handler.setFormatter(formatter)
|
35 |
-
self.logger.addHandler(console_handler)
|
36 |
-
|
37 |
-
# File Handler
|
38 |
-
log_file_path = "vector_db.log" # Change this to your desired log file path
|
39 |
-
file_handler = logging.FileHandler(log_file_path, mode="w")
|
40 |
-
file_handler.setLevel(logging.INFO)
|
41 |
-
file_handler.setFormatter(formatter)
|
42 |
-
self.logger.addHandler(file_handler)
|
43 |
-
else:
|
44 |
-
self.logger = logger
|
45 |
-
|
46 |
-
self.logger.info("VectorDB instance instantiated")
|
47 |
-
|
48 |
-
def load_files(self):
|
49 |
-
files = os.listdir(self.config["embedding_options"]["data_path"])
|
50 |
-
files = [
|
51 |
-
os.path.join(self.config["embedding_options"]["data_path"], file)
|
52 |
-
for file in files
|
53 |
-
]
|
54 |
-
urls = get_urls_from_file(self.config["embedding_options"]["url_file_path"])
|
55 |
-
if self.config["embedding_options"]["expand_urls"]:
|
56 |
-
all_urls = []
|
57 |
-
for url in urls:
|
58 |
-
base_url = get_base_url(url)
|
59 |
-
all_urls.extend(self.webpage_crawler.get_all_pages(url, base_url))
|
60 |
-
urls = all_urls
|
61 |
-
return files, urls
|
62 |
-
|
63 |
-
def clean_url_list(self, urls):
|
64 |
-
# get lecture pdf links
|
65 |
-
lecture_pdfs = [link for link in urls if link.endswith(".pdf")]
|
66 |
-
lecture_pdfs = [link for link in lecture_pdfs if "lecture" in link.lower()]
|
67 |
-
urls = [link for link in urls if link.endswith("/")] # only keep links that end with a '/'. Extract Files Seperately
|
68 |
-
|
69 |
-
return urls, lecture_pdfs
|
70 |
-
|
71 |
-
def create_embedding_model(self):
|
72 |
-
self.logger.info("Creating embedding function")
|
73 |
-
self.embedding_model_loader = EmbeddingModelLoader(self.config)
|
74 |
-
self.embedding_model = self.embedding_model_loader.load_embedding_model()
|
75 |
-
|
76 |
-
def initialize_database(self, document_chunks: list, document_names: list):
|
77 |
-
# Track token usage
|
78 |
-
self.logger.info("Initializing vector_db")
|
79 |
-
self.logger.info("\tUsing {} as db_option".format(self.db_option))
|
80 |
-
if self.db_option == "FAISS":
|
81 |
-
self.vector_db = FAISS.from_documents(
|
82 |
-
documents=document_chunks, embedding=self.embedding_model
|
83 |
-
)
|
84 |
-
self.logger.info("Completed initializing vector_db")
|
85 |
-
|
86 |
-
def create_database(self):
|
87 |
-
data_loader = DataLoader(self.config)
|
88 |
-
self.logger.info("Loading data")
|
89 |
-
files, urls = self.load_files()
|
90 |
-
urls, lecture_pdfs = self.clean_url_list(urls)
|
91 |
-
files += lecture_pdfs
|
92 |
-
files.remove('storage/data/urls.txt')
|
93 |
-
document_chunks, document_names = data_loader.get_chunks(files, urls)
|
94 |
-
self.logger.info("Completed loading data")
|
95 |
-
|
96 |
-
self.create_embedding_model()
|
97 |
-
self.initialize_database(document_chunks, document_names)
|
98 |
-
|
99 |
-
def save_database(self):
|
100 |
-
self.vector_db.save_local(
|
101 |
-
os.path.join(
|
102 |
-
self.config["embedding_options"]["db_path"],
|
103 |
-
"db_"
|
104 |
-
+ self.config["embedding_options"]["db_option"]
|
105 |
-
+ "_"
|
106 |
-
+ self.config["embedding_options"]["model"],
|
107 |
-
)
|
108 |
-
)
|
109 |
-
self.logger.info("Saved database")
|
110 |
-
|
111 |
-
def load_database(self):
|
112 |
-
self.create_embedding_model()
|
113 |
-
self.vector_db = FAISS.load_local(
|
114 |
-
os.path.join(
|
115 |
-
self.config["embedding_options"]["db_path"],
|
116 |
-
"db_"
|
117 |
-
+ self.config["embedding_options"]["db_option"]
|
118 |
-
+ "_"
|
119 |
-
+ self.config["embedding_options"]["model"],
|
120 |
-
),
|
121 |
-
self.embedding_model,
|
122 |
-
)
|
123 |
-
self.logger.info("Loaded database")
|
124 |
-
return self.vector_db
|
125 |
-
|
126 |
-
|
127 |
-
if __name__ == "__main__":
|
128 |
-
with open("code/config.yml", "r") as f:
|
129 |
-
config = yaml.safe_load(f)
|
130 |
-
print(config)
|
131 |
-
vector_db = VectorDB(config)
|
132 |
-
vector_db.create_database()
|
133 |
-
vector_db.save_database()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/modules/vectorstore/__init__.py
ADDED
File without changes
|
code/modules/vectorstore/base.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# template for vector store classes
|
2 |
+
|
3 |
+
|
4 |
+
class VectorStoreBase:
|
5 |
+
def __init__(self, config):
|
6 |
+
self.config = config
|
7 |
+
|
8 |
+
def _init_vector_db(self):
|
9 |
+
"""
|
10 |
+
Creates a vector store object
|
11 |
+
"""
|
12 |
+
raise NotImplementedError
|
13 |
+
|
14 |
+
def create_database(self):
|
15 |
+
"""
|
16 |
+
Populates the vector store with documents
|
17 |
+
"""
|
18 |
+
raise NotImplementedError
|
19 |
+
|
20 |
+
def load_database(self):
|
21 |
+
"""
|
22 |
+
Loads the vector store from disk
|
23 |
+
"""
|
24 |
+
raise NotImplementedError
|
25 |
+
|
26 |
+
def as_retriever(self):
|
27 |
+
"""
|
28 |
+
Returns the vector store as a retriever
|
29 |
+
"""
|
30 |
+
raise NotImplementedError
|
31 |
+
|
32 |
+
def __str__(self):
|
33 |
+
return self.__class__.__name__
|
code/modules/vectorstore/chroma.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.vectorstores import Chroma
|
2 |
+
from modules.vectorstore.base import VectorStoreBase
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class ChromaVectorStore(VectorStoreBase):
|
7 |
+
def __init__(self, config):
|
8 |
+
self.config = config
|
9 |
+
self._init_vector_db()
|
10 |
+
|
11 |
+
def _init_vector_db(self):
|
12 |
+
self.chroma = Chroma()
|
13 |
+
|
14 |
+
def create_database(self, document_chunks, embedding_model):
|
15 |
+
self.vectorstore = self.chroma.from_documents(
|
16 |
+
documents=document_chunks,
|
17 |
+
embedding=embedding_model,
|
18 |
+
persist_directory=os.path.join(
|
19 |
+
self.config["vectorstore"]["db_path"],
|
20 |
+
"db_"
|
21 |
+
+ self.config["vectorstore"]["db_option"]
|
22 |
+
+ "_"
|
23 |
+
+ self.config["vectorstore"]["model"],
|
24 |
+
),
|
25 |
+
)
|
26 |
+
|
27 |
+
def load_database(self, embedding_model):
|
28 |
+
self.vectorstore = Chroma(
|
29 |
+
persist_directory=os.path.join(
|
30 |
+
self.config["vectorstore"]["db_path"],
|
31 |
+
"db_"
|
32 |
+
+ self.config["vectorstore"]["db_option"]
|
33 |
+
+ "_"
|
34 |
+
+ self.config["vectorstore"]["model"],
|
35 |
+
),
|
36 |
+
embedding_function=embedding_model,
|
37 |
+
)
|
38 |
+
return self.vectorstore
|
39 |
+
|
40 |
+
def as_retriever(self):
|
41 |
+
return self.vectorstore.as_retriever()
|
code/modules/vectorstore/colbert.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ragatouille import RAGPretrainedModel
|
2 |
+
from modules.vectorstore.base import VectorStoreBase
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class ColbertVectorStore(VectorStoreBase):
|
7 |
+
def __init__(self, config):
|
8 |
+
self.config = config
|
9 |
+
self._init_vector_db()
|
10 |
+
|
11 |
+
def _init_vector_db(self):
|
12 |
+
self.colbert = RAGPretrainedModel.from_pretrained(
|
13 |
+
"colbert-ir/colbertv2.0",
|
14 |
+
index_root=os.path.join(
|
15 |
+
self.config["vectorstore"]["db_path"],
|
16 |
+
"db_" + self.config["vectorstore"]["db_option"],
|
17 |
+
),
|
18 |
+
)
|
19 |
+
|
20 |
+
def create_database(self, documents, document_names, document_metadata):
|
21 |
+
index_path = self.colbert.index(
|
22 |
+
index_name="new_idx",
|
23 |
+
collection=documents,
|
24 |
+
document_ids=document_names,
|
25 |
+
document_metadatas=document_metadata,
|
26 |
+
)
|
27 |
+
|
28 |
+
def load_database(self):
|
29 |
+
path = os.path.join(
|
30 |
+
self.config["vectorstore"]["db_path"],
|
31 |
+
"db_" + self.config["vectorstore"]["db_option"],
|
32 |
+
)
|
33 |
+
self.vectorstore = RAGPretrainedModel.from_index(
|
34 |
+
f"{path}/colbert/indexes/new_idx"
|
35 |
+
)
|
36 |
+
return self.vectorstore
|
37 |
+
|
38 |
+
def as_retriever(self):
|
39 |
+
return self.vectorstore.as_retriever()
|
code/modules/{embedding_model_loader.py → vectorstore/embedding_model_loader.py}
RENAMED
@@ -1,10 +1,8 @@
|
|
1 |
from langchain_community.embeddings import OpenAIEmbeddings
|
2 |
-
from
|
3 |
-
from
|
4 |
-
|
5 |
-
|
6 |
-
except:
|
7 |
-
from constants import *
|
8 |
import os
|
9 |
|
10 |
|
@@ -13,17 +11,22 @@ class EmbeddingModelLoader:
|
|
13 |
self.config = config
|
14 |
|
15 |
def load_embedding_model(self):
|
16 |
-
if self.config["
|
17 |
embedding_model = OpenAIEmbeddings(
|
18 |
deployment="SL-document_embedder",
|
19 |
-
model=self.config["
|
20 |
show_progress_bar=True,
|
21 |
openai_api_key=OPENAI_API_KEY,
|
|
|
22 |
)
|
23 |
else:
|
24 |
embedding_model = HuggingFaceEmbeddings(
|
25 |
-
model_name="
|
26 |
-
model_kwargs={
|
|
|
|
|
|
|
|
|
27 |
)
|
28 |
# embedding_model = LlamaCppEmbeddings(
|
29 |
# model_path=os.path.abspath("storage/llama-7b.ggmlv3.q4_0.bin")
|
|
|
1 |
from langchain_community.embeddings import OpenAIEmbeddings
|
2 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
3 |
+
from langchain_community.embeddings import LlamaCppEmbeddings
|
4 |
+
|
5 |
+
from modules.config.constants import *
|
|
|
|
|
6 |
import os
|
7 |
|
8 |
|
|
|
11 |
self.config = config
|
12 |
|
13 |
def load_embedding_model(self):
|
14 |
+
if self.config["vectorstore"]["model"] in ["text-embedding-ada-002"]:
|
15 |
embedding_model = OpenAIEmbeddings(
|
16 |
deployment="SL-document_embedder",
|
17 |
+
model=self.config["vectorestore"]["model"],
|
18 |
show_progress_bar=True,
|
19 |
openai_api_key=OPENAI_API_KEY,
|
20 |
+
disallowed_special=(),
|
21 |
)
|
22 |
else:
|
23 |
embedding_model = HuggingFaceEmbeddings(
|
24 |
+
model_name=self.config["vectorstore"]["model"],
|
25 |
+
model_kwargs={
|
26 |
+
"device": f"{self.config['device']}",
|
27 |
+
"token": f"{HUGGINGFACE_TOKEN}",
|
28 |
+
"trust_remote_code": True,
|
29 |
+
},
|
30 |
)
|
31 |
# embedding_model = LlamaCppEmbeddings(
|
32 |
# model_path=os.path.abspath("storage/llama-7b.ggmlv3.q4_0.bin")
|
code/modules/vectorstore/faiss.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.vectorstores import FAISS
|
2 |
+
from modules.vectorstore.base import VectorStoreBase
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class FaissVectorStore(VectorStoreBase):
|
7 |
+
def __init__(self, config):
|
8 |
+
self.config = config
|
9 |
+
self._init_vector_db()
|
10 |
+
|
11 |
+
def _init_vector_db(self):
|
12 |
+
self.faiss = FAISS(
|
13 |
+
embedding_function=None, index=0, index_to_docstore_id={}, docstore={}
|
14 |
+
)
|
15 |
+
|
16 |
+
def create_database(self, document_chunks, embedding_model):
|
17 |
+
self.vectorstore = self.faiss.from_documents(
|
18 |
+
documents=document_chunks, embedding=embedding_model
|
19 |
+
)
|
20 |
+
self.vectorstore.save_local(
|
21 |
+
os.path.join(
|
22 |
+
self.config["vectorstore"]["db_path"],
|
23 |
+
"db_"
|
24 |
+
+ self.config["vectorstore"]["db_option"]
|
25 |
+
+ "_"
|
26 |
+
+ self.config["vectorstore"]["model"],
|
27 |
+
)
|
28 |
+
)
|
29 |
+
|
30 |
+
def load_database(self, embedding_model):
|
31 |
+
self.vectorstore = self.faiss.load_local(
|
32 |
+
os.path.join(
|
33 |
+
self.config["vectorstore"]["db_path"],
|
34 |
+
"db_"
|
35 |
+
+ self.config["vectorstore"]["db_option"]
|
36 |
+
+ "_"
|
37 |
+
+ self.config["vectorstore"]["model"],
|
38 |
+
),
|
39 |
+
embedding_model,
|
40 |
+
allow_dangerous_deserialization=True,
|
41 |
+
)
|
42 |
+
return self.vectorstore
|
43 |
+
|
44 |
+
def as_retriever(self):
|
45 |
+
return self.vectorstore.as_retriever()
|
code/modules/vectorstore/helpers.py
ADDED
File without changes
|
code/modules/vectorstore/raptor.py
ADDED
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# code modified from https://github.com/langchain-ai/langchain/blob/master/cookbook/RAPTOR.ipynb
|
2 |
+
|
3 |
+
from typing import Dict, List, Optional, Tuple
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import umap
|
8 |
+
from langchain_core.prompts import ChatPromptTemplate
|
9 |
+
from langchain_core.output_parsers import StrOutputParser
|
10 |
+
from sklearn.mixture import GaussianMixture
|
11 |
+
from langchain_community.chat_models import ChatOpenAI
|
12 |
+
from langchain_community.vectorstores import FAISS
|
13 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
14 |
+
from modules.vectorstore.base import VectorStoreBase
|
15 |
+
|
16 |
+
RANDOM_SEED = 42
|
17 |
+
|
18 |
+
|
19 |
+
class RAPTORVectoreStore(VectorStoreBase):
|
20 |
+
def __init__(self, config, documents=[], text_splitter=None, embedding_model=None):
|
21 |
+
self.documents = documents
|
22 |
+
self.config = config
|
23 |
+
self.text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
24 |
+
chunk_size=self.config["splitter_options"]["chunk_size"],
|
25 |
+
chunk_overlap=self.config["splitter_options"]["chunk_overlap"],
|
26 |
+
separators=self.config["splitter_options"]["chunk_separators"],
|
27 |
+
disallowed_special=(),
|
28 |
+
)
|
29 |
+
self.embd = embedding_model
|
30 |
+
self.model = ChatOpenAI(
|
31 |
+
model="gpt-3.5-turbo",
|
32 |
+
)
|
33 |
+
|
34 |
+
def concat_documents(self, documents):
|
35 |
+
d_sorted = sorted(documents, key=lambda x: x.metadata["source"])
|
36 |
+
d_reversed = list(reversed(d_sorted))
|
37 |
+
concatenated_content = "\n\n\n --- \n\n\n".join(
|
38 |
+
[doc.page_content for doc in d_reversed]
|
39 |
+
)
|
40 |
+
return concatenated_content
|
41 |
+
|
42 |
+
def split_documents(self, documents):
|
43 |
+
concatenated_content = self.concat_documents(documents)
|
44 |
+
texts_split = self.text_splitter.split_text(concatenated_content)
|
45 |
+
return texts_split
|
46 |
+
|
47 |
+
def add_documents(self, documents):
|
48 |
+
self.documents.extend(documents)
|
49 |
+
|
50 |
+
def global_cluster_embeddings(
|
51 |
+
self,
|
52 |
+
embeddings: np.ndarray,
|
53 |
+
dim: int,
|
54 |
+
n_neighbors: Optional[int] = None,
|
55 |
+
metric: str = "cosine",
|
56 |
+
) -> np.ndarray:
|
57 |
+
"""
|
58 |
+
Perform global dimensionality reduction on the embeddings using UMAP.
|
59 |
+
|
60 |
+
Parameters:
|
61 |
+
- embeddings: The input embeddings as a numpy array.
|
62 |
+
- dim: The target dimensionality for the reduced space.
|
63 |
+
- n_neighbors: Optional; the number of neighbors to consider for each point.
|
64 |
+
If not provided, it defaults to the square root of the number of embeddings.
|
65 |
+
- metric: The distance metric to use for UMAP.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
- A numpy array of the embeddings reduced to the specified dimensionality.
|
69 |
+
"""
|
70 |
+
if n_neighbors is None:
|
71 |
+
n_neighbors = int((len(embeddings) - 1) ** 0.5)
|
72 |
+
return umap.UMAP(
|
73 |
+
n_neighbors=n_neighbors, n_components=dim, metric=metric
|
74 |
+
).fit_transform(embeddings)
|
75 |
+
|
76 |
+
def local_cluster_embeddings(
|
77 |
+
self,
|
78 |
+
embeddings: np.ndarray,
|
79 |
+
dim: int,
|
80 |
+
num_neighbors: int = 10,
|
81 |
+
metric: str = "cosine",
|
82 |
+
) -> np.ndarray:
|
83 |
+
"""
|
84 |
+
Perform local dimensionality reduction on the embeddings using UMAP, typically after global clustering.
|
85 |
+
|
86 |
+
Parameters:
|
87 |
+
- embeddings: The input embeddings as a numpy array.
|
88 |
+
- dim: The target dimensionality for the reduced space.
|
89 |
+
- num_neighbors: The number of neighbors to consider for each point.
|
90 |
+
- metric: The distance metric to use for UMAP.
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
- A numpy array of the embeddings reduced to the specified dimensionality.
|
94 |
+
"""
|
95 |
+
return umap.UMAP(
|
96 |
+
n_neighbors=num_neighbors, n_components=dim, metric=metric
|
97 |
+
).fit_transform(embeddings)
|
98 |
+
|
99 |
+
def get_optimal_clusters(
|
100 |
+
self,
|
101 |
+
embeddings: np.ndarray,
|
102 |
+
max_clusters: int = 50,
|
103 |
+
random_state: int = RANDOM_SEED,
|
104 |
+
) -> int:
|
105 |
+
"""
|
106 |
+
Determine the optimal number of clusters using the Bayesian Information Criterion (BIC) with a Gaussian Mixture Model.
|
107 |
+
|
108 |
+
Parameters:
|
109 |
+
- embeddings: The input embeddings as a numpy array.
|
110 |
+
- max_clusters: The maximum number of clusters to consider.
|
111 |
+
- random_state: Seed for reproducibility.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
- An integer representing the optimal number of clusters found.
|
115 |
+
"""
|
116 |
+
max_clusters = min(max_clusters, len(embeddings))
|
117 |
+
n_clusters = np.arange(1, max_clusters)
|
118 |
+
bics = []
|
119 |
+
for n in n_clusters:
|
120 |
+
gm = GaussianMixture(n_components=n, random_state=random_state)
|
121 |
+
gm.fit(embeddings)
|
122 |
+
bics.append(gm.bic(embeddings))
|
123 |
+
return n_clusters[np.argmin(bics)]
|
124 |
+
|
125 |
+
def GMM_cluster(
|
126 |
+
self, embeddings: np.ndarray, threshold: float, random_state: int = 0
|
127 |
+
):
|
128 |
+
"""
|
129 |
+
Cluster embeddings using a Gaussian Mixture Model (GMM) based on a probability threshold.
|
130 |
+
|
131 |
+
Parameters:
|
132 |
+
- embeddings: The input embeddings as a numpy array.
|
133 |
+
- threshold: The probability threshold for assigning an embedding to a cluster.
|
134 |
+
- random_state: Seed for reproducibility.
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
- A tuple containing the cluster labels and the number of clusters determined.
|
138 |
+
"""
|
139 |
+
n_clusters = self.get_optimal_clusters(embeddings)
|
140 |
+
gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
|
141 |
+
gm.fit(embeddings)
|
142 |
+
probs = gm.predict_proba(embeddings)
|
143 |
+
labels = [np.where(prob > threshold)[0] for prob in probs]
|
144 |
+
return labels, n_clusters
|
145 |
+
|
146 |
+
def perform_clustering(
|
147 |
+
self,
|
148 |
+
embeddings: np.ndarray,
|
149 |
+
dim: int,
|
150 |
+
threshold: float,
|
151 |
+
) -> List[np.ndarray]:
|
152 |
+
"""
|
153 |
+
Perform clustering on the embeddings by first reducing their dimensionality globally, then clustering
|
154 |
+
using a Gaussian Mixture Model, and finally performing local clustering within each global cluster.
|
155 |
+
|
156 |
+
Parameters:
|
157 |
+
- embeddings: The input embeddings as a numpy array.
|
158 |
+
- dim: The target dimensionality for UMAP reduction.
|
159 |
+
- threshold: The probability threshold for assigning an embedding to a cluster in GMM.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
- A list of numpy arrays, where each array contains the cluster IDs for each embedding.
|
163 |
+
"""
|
164 |
+
if len(embeddings) <= dim + 1:
|
165 |
+
# Avoid clustering when there's insufficient data
|
166 |
+
return [np.array([0]) for _ in range(len(embeddings))]
|
167 |
+
|
168 |
+
# Global dimensionality reduction
|
169 |
+
reduced_embeddings_global = self.global_cluster_embeddings(embeddings, dim)
|
170 |
+
# Global clustering
|
171 |
+
global_clusters, n_global_clusters = self.GMM_cluster(
|
172 |
+
reduced_embeddings_global, threshold
|
173 |
+
)
|
174 |
+
|
175 |
+
all_local_clusters = [np.array([]) for _ in range(len(embeddings))]
|
176 |
+
total_clusters = 0
|
177 |
+
|
178 |
+
# Iterate through each global cluster to perform local clustering
|
179 |
+
for i in range(n_global_clusters):
|
180 |
+
# Extract embeddings belonging to the current global cluster
|
181 |
+
global_cluster_embeddings_ = embeddings[
|
182 |
+
np.array([i in gc for gc in global_clusters])
|
183 |
+
]
|
184 |
+
|
185 |
+
if len(global_cluster_embeddings_) == 0:
|
186 |
+
continue
|
187 |
+
if len(global_cluster_embeddings_) <= dim + 1:
|
188 |
+
# Handle small clusters with direct assignment
|
189 |
+
local_clusters = [np.array([0]) for _ in global_cluster_embeddings_]
|
190 |
+
n_local_clusters = 1
|
191 |
+
else:
|
192 |
+
# Local dimensionality reduction and clustering
|
193 |
+
reduced_embeddings_local = self.local_cluster_embeddings(
|
194 |
+
global_cluster_embeddings_, dim
|
195 |
+
)
|
196 |
+
local_clusters, n_local_clusters = self.GMM_cluster(
|
197 |
+
reduced_embeddings_local, threshold
|
198 |
+
)
|
199 |
+
|
200 |
+
# Assign local cluster IDs, adjusting for total clusters already processed
|
201 |
+
for j in range(n_local_clusters):
|
202 |
+
local_cluster_embeddings_ = global_cluster_embeddings_[
|
203 |
+
np.array([j in lc for lc in local_clusters])
|
204 |
+
]
|
205 |
+
indices = np.where(
|
206 |
+
(embeddings == local_cluster_embeddings_[:, None]).all(-1)
|
207 |
+
)[1]
|
208 |
+
for idx in indices:
|
209 |
+
all_local_clusters[idx] = np.append(
|
210 |
+
all_local_clusters[idx], j + total_clusters
|
211 |
+
)
|
212 |
+
|
213 |
+
total_clusters += n_local_clusters
|
214 |
+
|
215 |
+
return all_local_clusters
|
216 |
+
|
217 |
+
def embed(self, texts):
|
218 |
+
"""
|
219 |
+
Generate embeddings for a list of text documents.
|
220 |
+
|
221 |
+
This function assumes the existence of an `embd` object with a method `embed_documents`
|
222 |
+
that takes a list of texts and returns their embeddings.
|
223 |
+
|
224 |
+
Parameters:
|
225 |
+
- texts: List[str], a list of text documents to be embedded.
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
- numpy.ndarray: An array of embeddings for the given text documents.
|
229 |
+
"""
|
230 |
+
text_embeddings = self.embd.embed_documents(texts)
|
231 |
+
text_embeddings_np = np.array(text_embeddings)
|
232 |
+
return text_embeddings_np
|
233 |
+
|
234 |
+
def embed_cluster_texts(self, texts):
|
235 |
+
"""
|
236 |
+
Embeds a list of texts and clusters them, returning a DataFrame with texts, their embeddings, and cluster labels.
|
237 |
+
|
238 |
+
This function combines embedding generation and clustering into a single step. It assumes the existence
|
239 |
+
of a previously defined `perform_clustering` function that performs clustering on the embeddings.
|
240 |
+
|
241 |
+
Parameters:
|
242 |
+
- texts: List[str], a list of text documents to be processed.
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
- pandas.DataFrame: A DataFrame containing the original texts, their embeddings, and the assigned cluster labels.
|
246 |
+
"""
|
247 |
+
text_embeddings_np = self.embed(texts) # Generate embeddings
|
248 |
+
cluster_labels = self.perform_clustering(
|
249 |
+
text_embeddings_np, 10, 0.1
|
250 |
+
) # Perform clustering on the embeddings
|
251 |
+
df = pd.DataFrame() # Initialize a DataFrame to store the results
|
252 |
+
df["text"] = texts # Store original texts
|
253 |
+
df["embd"] = list(
|
254 |
+
text_embeddings_np
|
255 |
+
) # Store embeddings as a list in the DataFrame
|
256 |
+
df["cluster"] = cluster_labels # Store cluster labels
|
257 |
+
return df
|
258 |
+
|
259 |
+
def fmt_txt(self, df: pd.DataFrame) -> str:
|
260 |
+
"""
|
261 |
+
Formats the text documents in a DataFrame into a single string.
|
262 |
+
|
263 |
+
Parameters:
|
264 |
+
- df: DataFrame containing the 'text' column with text documents to format.
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
- A single string where all text documents are joined by a specific delimiter.
|
268 |
+
"""
|
269 |
+
unique_txt = df["text"].tolist()
|
270 |
+
return "--- --- \n --- --- ".join(unique_txt)
|
271 |
+
|
272 |
+
def embed_cluster_summarize_texts(
|
273 |
+
self, texts: List[str], level: int
|
274 |
+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
275 |
+
"""
|
276 |
+
Embeds, clusters, and summarizes a list of texts. This function first generates embeddings for the texts,
|
277 |
+
clusters them based on similarity, expands the cluster assignments for easier processing, and then summarizes
|
278 |
+
the content within each cluster.
|
279 |
+
|
280 |
+
Parameters:
|
281 |
+
- texts: A list of text documents to be processed.
|
282 |
+
- level: An integer parameter that could define the depth or detail of processing.
|
283 |
+
|
284 |
+
Returns:
|
285 |
+
- Tuple containing two DataFrames:
|
286 |
+
1. The first DataFrame (`df_clusters`) includes the original texts, their embeddings, and cluster assignments.
|
287 |
+
2. The second DataFrame (`df_summary`) contains summaries for each cluster, the specified level of detail,
|
288 |
+
and the cluster identifiers.
|
289 |
+
"""
|
290 |
+
|
291 |
+
# Embed and cluster the texts, resulting in a DataFrame with 'text', 'embd', and 'cluster' columns
|
292 |
+
df_clusters = self.embed_cluster_texts(texts)
|
293 |
+
|
294 |
+
# Prepare to expand the DataFrame for easier manipulation of clusters
|
295 |
+
expanded_list = []
|
296 |
+
|
297 |
+
# Expand DataFrame entries to document-cluster pairings for straightforward processing
|
298 |
+
for index, row in df_clusters.iterrows():
|
299 |
+
for cluster in row["cluster"]:
|
300 |
+
expanded_list.append(
|
301 |
+
{"text": row["text"], "embd": row["embd"], "cluster": cluster}
|
302 |
+
)
|
303 |
+
|
304 |
+
# Create a new DataFrame from the expanded list
|
305 |
+
expanded_df = pd.DataFrame(expanded_list)
|
306 |
+
|
307 |
+
# Retrieve unique cluster identifiers for processing
|
308 |
+
all_clusters = expanded_df["cluster"].unique()
|
309 |
+
|
310 |
+
print(f"--Generated {len(all_clusters)} clusters--")
|
311 |
+
|
312 |
+
# Summarization
|
313 |
+
template = """Here is content from the course DS598: Deep Learning for Data Science.
|
314 |
+
|
315 |
+
The content may be form webapge about the course, or lecture content, or any other relevant information.
|
316 |
+
If the content is in bullet points (from pdf lectre slides), you can summarize the bullet points.
|
317 |
+
|
318 |
+
Give a detailed summary of the content below.
|
319 |
+
|
320 |
+
Documentation:
|
321 |
+
{context}
|
322 |
+
"""
|
323 |
+
prompt = ChatPromptTemplate.from_template(template)
|
324 |
+
chain = prompt | self.model | StrOutputParser()
|
325 |
+
|
326 |
+
# Format text within each cluster for summarization
|
327 |
+
summaries = []
|
328 |
+
for i in all_clusters:
|
329 |
+
df_cluster = expanded_df[expanded_df["cluster"] == i]
|
330 |
+
formatted_txt = self.fmt_txt(df_cluster)
|
331 |
+
summaries.append(chain.invoke({"context": formatted_txt}))
|
332 |
+
|
333 |
+
# Create a DataFrame to store summaries with their corresponding cluster and level
|
334 |
+
df_summary = pd.DataFrame(
|
335 |
+
{
|
336 |
+
"summaries": summaries,
|
337 |
+
"level": [level] * len(summaries),
|
338 |
+
"cluster": list(all_clusters),
|
339 |
+
}
|
340 |
+
)
|
341 |
+
|
342 |
+
return df_clusters, df_summary
|
343 |
+
|
344 |
+
def recursive_embed_cluster_summarize(
|
345 |
+
self, texts: List[str], level: int = 1, n_levels: int = 3
|
346 |
+
) -> Dict[int, Tuple[pd.DataFrame, pd.DataFrame]]:
|
347 |
+
"""
|
348 |
+
Recursively embeds, clusters, and summarizes texts up to a specified level or until
|
349 |
+
the number of unique clusters becomes 1, storing the results at each level.
|
350 |
+
|
351 |
+
Parameters:
|
352 |
+
- texts: List[str], texts to be processed.
|
353 |
+
- level: int, current recursion level (starts at 1).
|
354 |
+
- n_levels: int, maximum depth of recursion.
|
355 |
+
|
356 |
+
Returns:
|
357 |
+
- Dict[int, Tuple[pd.DataFrame, pd.DataFrame]], a dictionary where keys are the recursion
|
358 |
+
levels and values are tuples containing the clusters DataFrame and summaries DataFrame at that level.
|
359 |
+
"""
|
360 |
+
results = {} # Dictionary to store results at each level
|
361 |
+
|
362 |
+
# Perform embedding, clustering, and summarization for the current level
|
363 |
+
df_clusters, df_summary = self.embed_cluster_summarize_texts(texts, level)
|
364 |
+
|
365 |
+
# Store the results of the current level
|
366 |
+
results[level] = (df_clusters, df_summary)
|
367 |
+
|
368 |
+
# Determine if further recursion is possible and meaningful
|
369 |
+
unique_clusters = df_summary["cluster"].nunique()
|
370 |
+
if level < n_levels and unique_clusters > 1:
|
371 |
+
# Use summaries as the input texts for the next level of recursion
|
372 |
+
new_texts = df_summary["summaries"].tolist()
|
373 |
+
next_level_results = self.recursive_embed_cluster_summarize(
|
374 |
+
new_texts, level + 1, n_levels
|
375 |
+
)
|
376 |
+
|
377 |
+
# Merge the results from the next level into the current results dictionary
|
378 |
+
results.update(next_level_results)
|
379 |
+
|
380 |
+
return results
|
381 |
+
|
382 |
+
def get_vector_db(self):
|
383 |
+
"""
|
384 |
+
Generate a retriever object from a list of documents.
|
385 |
+
|
386 |
+
Parameters:
|
387 |
+
- documents: List of document objects.
|
388 |
+
|
389 |
+
Returns:
|
390 |
+
- A retriever object.
|
391 |
+
"""
|
392 |
+
leaf_texts = self.split_documents(self.documents)
|
393 |
+
results = self.recursive_embed_cluster_summarize(
|
394 |
+
leaf_texts, level=1, n_levels=10
|
395 |
+
)
|
396 |
+
|
397 |
+
all_texts = leaf_texts.copy()
|
398 |
+
# Iterate through the results to extract summaries from each level and add them to all_texts
|
399 |
+
for level in sorted(results.keys()):
|
400 |
+
# Extract summaries from the current level's DataFrame
|
401 |
+
summaries = results[level][1]["summaries"].tolist()
|
402 |
+
# Extend all_texts with the summaries from the current level
|
403 |
+
all_texts.extend(summaries)
|
404 |
+
|
405 |
+
# Now, use all_texts to build the vectorstore
|
406 |
+
vectorstore = FAISS.from_texts(texts=all_texts, embedding=self.embd)
|
407 |
+
return vectorstore
|
408 |
+
|
409 |
+
def create_database(self, documents, embedding_model):
|
410 |
+
self.documents = documents
|
411 |
+
self.embd = embedding_model
|
412 |
+
self.vectorstore = self.get_vector_db()
|
413 |
+
self.vectorstore.save_local(
|
414 |
+
os.path.join(
|
415 |
+
self.config["vectorstore"]["db_path"],
|
416 |
+
"db_"
|
417 |
+
+ self.config["vectorstore"]["db_option"]
|
418 |
+
+ "_"
|
419 |
+
+ self.config["vectorstore"]["model"],
|
420 |
+
)
|
421 |
+
)
|
422 |
+
|
423 |
+
def load_database(self, embedding_model):
|
424 |
+
self.vectorstore = FAISS.load_local(
|
425 |
+
os.path.join(
|
426 |
+
self.config["vectorstore"]["db_path"],
|
427 |
+
"db_"
|
428 |
+
+ self.config["vectorstore"]["db_option"]
|
429 |
+
+ "_"
|
430 |
+
+ self.config["vectorstore"]["model"],
|
431 |
+
),
|
432 |
+
embedding_model,
|
433 |
+
allow_dangerous_deserialization=True,
|
434 |
+
)
|
435 |
+
return self.vectorstore
|
436 |
+
|
437 |
+
def as_retriever(self):
|
438 |
+
return self.vectorstore.as_retriever()
|
code/modules/vectorstore/store_manager.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.vectorstore.vectorstore import VectorStore
|
2 |
+
from modules.vectorstore.helpers import *
|
3 |
+
from modules.dataloader.webpage_crawler import WebpageCrawler
|
4 |
+
from modules.dataloader.data_loader import DataLoader
|
5 |
+
from modules.dataloader.helpers import *
|
6 |
+
from modules.vectorstore.embedding_model_loader import EmbeddingModelLoader
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
import asyncio
|
11 |
+
|
12 |
+
|
13 |
+
class VectorStoreManager:
|
14 |
+
def __init__(self, config, logger=None):
|
15 |
+
self.config = config
|
16 |
+
self.document_names = None
|
17 |
+
|
18 |
+
# Set up logging to both console and a file
|
19 |
+
self.logger = logger or self._setup_logging()
|
20 |
+
self.webpage_crawler = WebpageCrawler()
|
21 |
+
self.vector_db = VectorStore(self.config)
|
22 |
+
|
23 |
+
self.logger.info("VectorDB instance instantiated")
|
24 |
+
|
25 |
+
def _setup_logging(self):
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
if not logger.hasHandlers():
|
28 |
+
logger.setLevel(logging.INFO)
|
29 |
+
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
30 |
+
|
31 |
+
# Console Handler
|
32 |
+
console_handler = logging.StreamHandler()
|
33 |
+
console_handler.setLevel(logging.INFO)
|
34 |
+
console_handler.setFormatter(formatter)
|
35 |
+
logger.addHandler(console_handler)
|
36 |
+
|
37 |
+
# Ensure log directory exists
|
38 |
+
log_directory = self.config["log_dir"]
|
39 |
+
os.makedirs(log_directory, exist_ok=True)
|
40 |
+
|
41 |
+
# File Handler
|
42 |
+
log_file_path = os.path.join(log_directory, "vector_db.log")
|
43 |
+
file_handler = logging.FileHandler(log_file_path, mode="w")
|
44 |
+
file_handler.setLevel(logging.INFO)
|
45 |
+
file_handler.setFormatter(formatter)
|
46 |
+
logger.addHandler(file_handler)
|
47 |
+
|
48 |
+
return logger
|
49 |
+
|
50 |
+
def load_files(self):
|
51 |
+
|
52 |
+
files = os.listdir(self.config["vectorstore"]["data_path"])
|
53 |
+
files = [
|
54 |
+
os.path.join(self.config["vectorstore"]["data_path"], file)
|
55 |
+
for file in files
|
56 |
+
]
|
57 |
+
urls = get_urls_from_file(self.config["vectorstore"]["url_file_path"])
|
58 |
+
if self.config["vectorstore"]["expand_urls"]:
|
59 |
+
all_urls = []
|
60 |
+
for url in urls:
|
61 |
+
loop = asyncio.get_event_loop()
|
62 |
+
all_urls.extend(
|
63 |
+
loop.run_until_complete(
|
64 |
+
self.webpage_crawler.get_all_pages(
|
65 |
+
url, url
|
66 |
+
) # only get child urls, if you want to get all urls, replace the second argument with the base url
|
67 |
+
)
|
68 |
+
)
|
69 |
+
urls = all_urls
|
70 |
+
return files, urls
|
71 |
+
|
72 |
+
def create_embedding_model(self):
|
73 |
+
|
74 |
+
self.logger.info("Creating embedding function")
|
75 |
+
embedding_model_loader = EmbeddingModelLoader(self.config)
|
76 |
+
embedding_model = embedding_model_loader.load_embedding_model()
|
77 |
+
return embedding_model
|
78 |
+
|
79 |
+
def initialize_database(
|
80 |
+
self,
|
81 |
+
document_chunks: list,
|
82 |
+
document_names: list,
|
83 |
+
documents: list,
|
84 |
+
document_metadata: list,
|
85 |
+
):
|
86 |
+
if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma", "RAPTOR"]:
|
87 |
+
self.embedding_model = self.create_embedding_model()
|
88 |
+
else:
|
89 |
+
self.embedding_model = None
|
90 |
+
|
91 |
+
self.logger.info("Initializing vector_db")
|
92 |
+
self.logger.info(
|
93 |
+
"\tUsing {} as db_option".format(self.config["vectorstore"]["db_option"])
|
94 |
+
)
|
95 |
+
self.vector_db._create_database(
|
96 |
+
document_chunks,
|
97 |
+
document_names,
|
98 |
+
documents,
|
99 |
+
document_metadata,
|
100 |
+
self.embedding_model,
|
101 |
+
)
|
102 |
+
|
103 |
+
def create_database(self):
|
104 |
+
|
105 |
+
start_time = time.time() # Start time for creating database
|
106 |
+
data_loader = DataLoader(self.config, self.logger)
|
107 |
+
self.logger.info("Loading data")
|
108 |
+
files, urls = self.load_files()
|
109 |
+
files, webpages = self.webpage_crawler.clean_url_list(urls)
|
110 |
+
self.logger.info(f"Number of files: {len(files)}")
|
111 |
+
self.logger.info(f"Number of webpages: {len(webpages)}")
|
112 |
+
if f"{self.config['vectorstore']['url_file_path']}" in files:
|
113 |
+
files.remove(f"{self.config['vectorstores']['url_file_path']}") # cleanup
|
114 |
+
document_chunks, document_names, documents, document_metadata = (
|
115 |
+
data_loader.get_chunks(files, webpages)
|
116 |
+
)
|
117 |
+
num_documents = len(document_chunks)
|
118 |
+
self.logger.info(f"Number of documents in the DB: {num_documents}")
|
119 |
+
metadata_keys = list(document_metadata[0].keys())
|
120 |
+
self.logger.info(f"Metadata keys: {metadata_keys}")
|
121 |
+
self.logger.info("Completed loading data")
|
122 |
+
self.initialize_database(
|
123 |
+
document_chunks, document_names, documents, document_metadata
|
124 |
+
)
|
125 |
+
end_time = time.time() # End time for creating database
|
126 |
+
self.logger.info("Created database")
|
127 |
+
self.logger.info(
|
128 |
+
f"Time taken to create database: {end_time - start_time} seconds"
|
129 |
+
)
|
130 |
+
|
131 |
+
def load_database(self):
|
132 |
+
|
133 |
+
start_time = time.time() # Start time for loading database
|
134 |
+
if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma", "RAPTOR"]:
|
135 |
+
self.embedding_model = self.create_embedding_model()
|
136 |
+
else:
|
137 |
+
self.embedding_model = None
|
138 |
+
self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
|
139 |
+
end_time = time.time() # End time for loading database
|
140 |
+
self.logger.info(
|
141 |
+
f"Time taken to load database: {end_time - start_time} seconds"
|
142 |
+
)
|
143 |
+
self.logger.info("Loaded database")
|
144 |
+
return self.loaded_vector_db
|
145 |
+
|
146 |
+
|
147 |
+
if __name__ == "__main__":
|
148 |
+
import yaml
|
149 |
+
|
150 |
+
with open("modules/config/config.yml", "r") as f:
|
151 |
+
config = yaml.safe_load(f)
|
152 |
+
print(config)
|
153 |
+
print(f"Trying to create database with config: {config}")
|
154 |
+
vector_db = VectorStoreManager(config)
|
155 |
+
vector_db.create_database()
|
156 |
+
print("Created database")
|
157 |
+
|
158 |
+
print(f"Trying to load the database")
|
159 |
+
vector_db = VectorStoreManager(config)
|
160 |
+
vector_db.load_database()
|
161 |
+
print("Loaded database")
|
162 |
+
|
163 |
+
print(f"View the logs at {config['log_dir']}/vector_db.log")
|
code/modules/vectorstore/vectorstore.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.vectorstore.faiss import FaissVectorStore
|
2 |
+
from modules.vectorstore.chroma import ChromaVectorStore
|
3 |
+
from modules.vectorstore.colbert import ColbertVectorStore
|
4 |
+
from modules.vectorstore.raptor import RAPTORVectoreStore
|
5 |
+
|
6 |
+
|
7 |
+
class VectorStore:
|
8 |
+
def __init__(self, config):
|
9 |
+
self.config = config
|
10 |
+
self.vectorstore = None
|
11 |
+
self.vectorstore_classes = {
|
12 |
+
"FAISS": FaissVectorStore,
|
13 |
+
"Chroma": ChromaVectorStore,
|
14 |
+
"RAGatouille": ColbertVectorStore,
|
15 |
+
"RAPTOR": RAPTORVectoreStore,
|
16 |
+
}
|
17 |
+
|
18 |
+
def _create_database(
|
19 |
+
self,
|
20 |
+
document_chunks,
|
21 |
+
document_names,
|
22 |
+
documents,
|
23 |
+
document_metadata,
|
24 |
+
embedding_model,
|
25 |
+
):
|
26 |
+
db_option = self.config["vectorstore"]["db_option"]
|
27 |
+
vectorstore_class = self.vectorstore_classes.get(db_option)
|
28 |
+
if not vectorstore_class:
|
29 |
+
raise ValueError(f"Invalid db_option: {db_option}")
|
30 |
+
|
31 |
+
self.vectorstore = vectorstore_class(self.config)
|
32 |
+
|
33 |
+
if db_option == "RAGatouille":
|
34 |
+
self.vectorstore.create_database(
|
35 |
+
documents, document_names, document_metadata
|
36 |
+
)
|
37 |
+
else:
|
38 |
+
self.vectorstore.create_database(document_chunks, embedding_model)
|
39 |
+
|
40 |
+
def _load_database(self, embedding_model):
|
41 |
+
db_option = self.config["vectorstore"]["db_option"]
|
42 |
+
vectorstore_class = self.vectorstore_classes.get(db_option)
|
43 |
+
if not vectorstore_class:
|
44 |
+
raise ValueError(f"Invalid db_option: {db_option}")
|
45 |
+
|
46 |
+
self.vectorstore = vectorstore_class(self.config)
|
47 |
+
|
48 |
+
if db_option == "RAGatouille":
|
49 |
+
return self.vectorstore.load_database()
|
50 |
+
else:
|
51 |
+
return self.vectorstore.load_database(embedding_model)
|
52 |
+
|
53 |
+
def _as_retriever(self):
|
54 |
+
return self.vectorstore.as_retriever()
|
55 |
+
|
56 |
+
def _get_vectorstore(self):
|
57 |
+
return self.vectorstore
|
code/public/acastusphoton-svgrepo-com.svg
ADDED
code/public/adv-screen-recorder-svgrepo-com.svg
ADDED
code/public/alarmy-svgrepo-com.svg
ADDED
code/public/avatars/ai-tutor.png
ADDED