Spaces:
Build error
Build error
Merge pull request #90 from DL4DS/dev2main
Browse files- .flake8 +3 -0
- .gitattributes +1 -0
- .github/workflows/code_quality_check.yml +33 -0
- .gitignore +3 -1
- Dockerfile +8 -1
- README.md +18 -39
- code/.chainlit/config.toml +8 -6
- code/__init__.py +0 -1
- code/app.py +351 -0
- code/chainlit.md +1 -6
- code/chainlit_base.py +484 -0
- code/main.py +212 -67
- code/modules/chat/chat_model_loader.py +2 -9
- code/modules/chat/helpers.py +6 -4
- code/modules/chat/langchain/__init__.py +0 -0
- code/modules/chat/langchain/langchain_rag.py +16 -12
- code/modules/chat/langchain/utils.py +12 -34
- code/modules/chat/llm_tutor.py +10 -7
- code/modules/chat_processor/helpers.py +245 -0
- code/modules/chat_processor/literal_ai.py +1 -38
- code/modules/config/config.yml +3 -3
- code/modules/config/constants.py +14 -3
- code/modules/config/project_config.yml +7 -0
- code/modules/dataloader/data_loader.py +96 -55
- code/modules/dataloader/helpers.py +13 -6
- code/modules/dataloader/pdf_readers/gpt.py +27 -19
- code/modules/dataloader/pdf_readers/llama.py +24 -23
- code/modules/dataloader/webpage_crawler.py +5 -3
- code/modules/retriever/helpers.py +0 -1
- code/modules/vectorstore/colbert.py +3 -2
- code/modules/vectorstore/embedding_model_loader.py +1 -7
- code/modules/vectorstore/faiss.py +10 -7
- code/modules/vectorstore/raptor.py +1 -4
- code/modules/vectorstore/store_manager.py +21 -14
- code/public/avatars/{ai-tutor.png → ai_tutor.png} +0 -0
- code/public/space.jpg +3 -0
- code/public/test.css +0 -19
- code/templates/cooldown.html +181 -0
- code/templates/dashboard.html +145 -0
- code/templates/error.html +95 -0
- code/templates/error_404.html +80 -0
- code/templates/login.html +132 -0
- code/templates/logout.html +21 -0
- docs/README.md +0 -51
- docs/contribute.md +33 -0
- docs/setup.md +127 -0
- pyproject.toml +2 -0
- requirements.txt +12 -1
.flake8
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[flake8]
|
2 |
+
max-line-length = 88
|
3 |
+
extend-ignore = E203, E266, E501, W503
|
.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/code_quality_check.yml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Code Quality and Security Checks
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches: [ main, dev_branch ]
|
6 |
+
pull_request:
|
7 |
+
branches: [ main, dev_branch ]
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
code-quality:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- uses: actions/checkout@v3
|
14 |
+
|
15 |
+
- name: Set up Python
|
16 |
+
uses: actions/setup-python@v4
|
17 |
+
with:
|
18 |
+
python-version: '3.11'
|
19 |
+
|
20 |
+
- name: Install dependencies
|
21 |
+
run: |
|
22 |
+
python -m pip install --upgrade pip
|
23 |
+
pip install flake8 black bandit
|
24 |
+
|
25 |
+
- name: Run Black
|
26 |
+
run: black --check .
|
27 |
+
|
28 |
+
- name: Run Flake8
|
29 |
+
run: flake8 .
|
30 |
+
|
31 |
+
- name: Run Bandit
|
32 |
+
run: |
|
33 |
+
bandit -r .
|
.gitignore
CHANGED
@@ -165,7 +165,9 @@ cython_debug/
|
|
165 |
.ragatouille/*
|
166 |
*/__pycache__/*
|
167 |
.chainlit/translations/
|
|
|
168 |
storage/logs/*
|
169 |
vectorstores/*
|
170 |
|
171 |
-
*/.files/*
|
|
|
|
165 |
.ragatouille/*
|
166 |
*/__pycache__/*
|
167 |
.chainlit/translations/
|
168 |
+
code/.chainlit/translations/
|
169 |
storage/logs/*
|
170 |
vectorstores/*
|
171 |
|
172 |
+
*/.files/*
|
173 |
+
code/storage/models/
|
Dockerfile
CHANGED
@@ -26,6 +26,13 @@ WORKDIR /code/code
|
|
26 |
|
27 |
RUN --mount=type=secret,id=HUGGINGFACEHUB_API_TOKEN,mode=0444,required=true
|
28 |
RUN --mount=type=secret,id=OPENAI_API_KEY,mode=0444,required=true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
# Default command to run the application
|
31 |
-
CMD ["sh", "-c", "python -m modules.vectorstore.store_manager &&
|
|
|
26 |
|
27 |
RUN --mount=type=secret,id=HUGGINGFACEHUB_API_TOKEN,mode=0444,required=true
|
28 |
RUN --mount=type=secret,id=OPENAI_API_KEY,mode=0444,required=true
|
29 |
+
RUN --mount=type=secret,id=CHAINLIT_URL,mode=0444,required=true
|
30 |
+
RUN --mount=type=secret,id=LITERAL_API_URL,mode=0444,required=true
|
31 |
+
RUN --mount=type=secret,id=LLAMA_CLOUD_API_KEY,mode=0444,required=true
|
32 |
+
RUN --mount=type=secret,id=OAUTH_GOOGLE_CLIENT_ID,mode=0444,required=true
|
33 |
+
RUN --mount=type=secret,id=OAUTH_GOOGLE_CLIENT_SECRET,mode=0444,required=true
|
34 |
+
RUN --mount=type=secret,id=LITERAL_API_KEY_LOGGING,mode=0444,required=true
|
35 |
+
RUN --mount=type=secret,id=CHAINLIT_AUTH_SECRET,mode=0444,required=true
|
36 |
|
37 |
# Default command to run the application
|
38 |
+
CMD ["sh", "-c", "python -m modules.vectorstore.store_manager && uvicorn app:app --host 0.0.0.0 --port 7860"]
|
README.md
CHANGED
@@ -15,10 +15,14 @@ You can find a "production" implementation of the Tutor running live at [DL4DS T
|
|
15 |
Hugging Face [Space](https://huggingface.co/spaces/dl4ds/dl4ds_tutor). It is pushed automatically from the `main` branch of this repo by this
|
16 |
[Actions Workflow](https://github.com/DL4DS/dl4ds_tutor/blob/main/.github/workflows/push_to_hf_space.yml) upon a push to `main`.
|
17 |
|
18 |
-
|
|
|
19 |
[Space](https://huggingface.co/spaces/dl4ds/tutor_dev). It is pushed automatically from the `dev_branch` branch of this repo by this
|
20 |
[Actions Workflow](https://github.com/DL4DS/dl4ds_tutor/blob/dev_branch/.github/workflows/push_to_hf_space_prototype.yml) upon a push to `dev_branch`.
|
21 |
|
|
|
|
|
|
|
22 |
|
23 |
## Running Locally
|
24 |
|
@@ -34,7 +38,7 @@ A "development" version of the Tutor is running live at [DL4DS Tutor -- Dev](htt
|
|
34 |
3. **To test Data Loading (Optional)**
|
35 |
```bash
|
36 |
cd code
|
37 |
-
python -m modules.dataloader.data_loader
|
38 |
```
|
39 |
|
40 |
4. **Create the Vector Database**
|
@@ -43,47 +47,16 @@ A "development" version of the Tutor is running live at [DL4DS Tutor -- Dev](htt
|
|
43 |
python -m modules.vectorstore.store_manager
|
44 |
```
|
45 |
- Note: You need to run the above command when you add new data to the `storage/data` directory, or if the `storage/data/urls.txt` file is updated.
|
46 |
-
- Alternatively, you can set `["vectorstore"]["embedd_files"]` to `True` in the `code/modules/config/config.yaml` file, which will embed files from the storage directory every time you run the below chainlit command.
|
47 |
|
48 |
-
|
49 |
```bash
|
50 |
-
|
|
|
51 |
```
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
## File Structure
|
56 |
-
|
57 |
-
```plaintext
|
58 |
-
code/
|
59 |
-
├── modules
|
60 |
-
│ ├── chat # Contains the chatbot implementation
|
61 |
-
│ ├── chat_processor # Contains the implementation to process and log the conversations
|
62 |
-
│ ├── config # Contains the configuration files
|
63 |
-
│ ├── dataloader # Contains the implementation to load the data from the storage directory
|
64 |
-
│ ├── retriever # Contains the implementation to create the retriever
|
65 |
-
│ └── vectorstore # Contains the implementation to create the vector database
|
66 |
-
├── public
|
67 |
-
│ ├── logo_dark.png # Dark theme logo
|
68 |
-
│ ├── logo_light.png # Light theme logo
|
69 |
-
│ └── test.css # Custom CSS file
|
70 |
-
└── main.py
|
71 |
-
|
72 |
-
|
73 |
-
docs/ # Contains the documentation to the codebase and methods used
|
74 |
|
75 |
-
|
76 |
-
├── data # Store files and URLs here
|
77 |
-
├── logs # Logs directory, includes logs on vector DB creation, tutor logs, and chunks logged in JSON files
|
78 |
-
└── models # Local LLMs are loaded from here
|
79 |
-
|
80 |
-
vectorstores/ # Stores the created vector databases
|
81 |
-
|
82 |
-
.env # This needs to be created, store the API keys here
|
83 |
-
```
|
84 |
-
- `code/modules/vectorstore/vectorstore.py`: Instantiates the `VectorStore` class to create the vector database.
|
85 |
-
- `code/modules/vectorstore/store_manager.py`: Instantiates the `VectorStoreManager:` class to manage the vector database, and all associated methods.
|
86 |
-
- `code/modules/retriever/retriever.py`: Instantiates the `Retriever` class to create the retriever.
|
87 |
|
88 |
|
89 |
## Docker
|
@@ -97,4 +70,10 @@ docker run -it --rm -p 8000:8000 dev
|
|
97 |
|
98 |
## Contributing
|
99 |
|
100 |
-
Please create an issue if you have any suggestions or improvements, and start working on it by creating a branch and by making a pull request to the
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
Hugging Face [Space](https://huggingface.co/spaces/dl4ds/dl4ds_tutor). It is pushed automatically from the `main` branch of this repo by this
|
16 |
[Actions Workflow](https://github.com/DL4DS/dl4ds_tutor/blob/main/.github/workflows/push_to_hf_space.yml) upon a push to `main`.
|
17 |
|
18 |
+
|
19 |
+
A "development" version of the Tutor is running live at [DL4DS Tutor -- Dev](https://dl4ds-tutor-dev.hf.space/) from this Hugging Face
|
20 |
[Space](https://huggingface.co/spaces/dl4ds/tutor_dev). It is pushed automatically from the `dev_branch` branch of this repo by this
|
21 |
[Actions Workflow](https://github.com/DL4DS/dl4ds_tutor/blob/dev_branch/.github/workflows/push_to_hf_space_prototype.yml) upon a push to `dev_branch`.
|
22 |
|
23 |
+
## Setup
|
24 |
+
|
25 |
+
Please visit [setup](https://dl4ds.github.io/dl4ds_tutor/guide/setup/) for more information on setting up the project.
|
26 |
|
27 |
## Running Locally
|
28 |
|
|
|
38 |
3. **To test Data Loading (Optional)**
|
39 |
```bash
|
40 |
cd code
|
41 |
+
python -m modules.dataloader.data_loader --links "your_pdf_link"
|
42 |
```
|
43 |
|
44 |
4. **Create the Vector Database**
|
|
|
47 |
python -m modules.vectorstore.store_manager
|
48 |
```
|
49 |
- Note: You need to run the above command when you add new data to the `storage/data` directory, or if the `storage/data/urls.txt` file is updated.
|
|
|
50 |
|
51 |
+
6. **Run the FastAPI App**
|
52 |
```bash
|
53 |
+
cd code
|
54 |
+
uvicorn app:app --port 7860
|
55 |
```
|
56 |
|
57 |
+
## Documentation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
+
Please visit the [docs](https://dl4ds.github.io/dl4ds_tutor/) for more information.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
|
62 |
## Docker
|
|
|
70 |
|
71 |
## Contributing
|
72 |
|
73 |
+
Please create an issue if you have any suggestions or improvements, and start working on it by creating a branch and by making a pull request to the `dev_branch`.
|
74 |
+
|
75 |
+
Please visit [contribute](https://dl4ds.github.io/dl4ds_tutor/guide/contribute/) for more information on contributing.
|
76 |
+
|
77 |
+
## Future Work
|
78 |
+
|
79 |
+
For more information on future work, please visit [roadmap](https://dl4ds.github.io/dl4ds_tutor/guide/readmap/).
|
code/.chainlit/config.toml
CHANGED
@@ -20,7 +20,7 @@ allow_origins = ["*"]
|
|
20 |
|
21 |
[features]
|
22 |
# Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript)
|
23 |
-
unsafe_allow_html =
|
24 |
|
25 |
# Process and display mathematical expressions. This can clash with "$" characters in messages.
|
26 |
latex = true
|
@@ -49,6 +49,8 @@ auto_tag_thread = true
|
|
49 |
# Sample rate of the audio
|
50 |
sample_rate = 44100
|
51 |
|
|
|
|
|
52 |
[UI]
|
53 |
# Name of the assistant.
|
54 |
name = "AI Tutor"
|
@@ -59,11 +61,11 @@ name = "AI Tutor"
|
|
59 |
# Large size content are by default collapsed for a cleaner ui
|
60 |
default_collapse_content = true
|
61 |
|
62 |
-
#
|
63 |
-
|
64 |
|
65 |
# Link to your github repo. This will add a github button in the UI's header.
|
66 |
-
|
67 |
|
68 |
# Specify a CSS file that can be used to customize the user interface.
|
69 |
# The CSS file can be served from the public directory or via an external link.
|
@@ -85,7 +87,7 @@ custom_meta_image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/f/
|
|
85 |
# custom_build = "./public/build"
|
86 |
|
87 |
[UI.theme]
|
88 |
-
default = "
|
89 |
#layout = "wide"
|
90 |
#font_family = "Inter, sans-serif"
|
91 |
# Override default MUI light theme. (Check theme.ts)
|
@@ -115,4 +117,4 @@ custom_meta_image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/f/
|
|
115 |
#secondary = "#BDBDBD"
|
116 |
|
117 |
[meta]
|
118 |
-
generated_by = "1.1.
|
|
|
20 |
|
21 |
[features]
|
22 |
# Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript)
|
23 |
+
unsafe_allow_html = true
|
24 |
|
25 |
# Process and display mathematical expressions. This can clash with "$" characters in messages.
|
26 |
latex = true
|
|
|
49 |
# Sample rate of the audio
|
50 |
sample_rate = 44100
|
51 |
|
52 |
+
edit_message = true
|
53 |
+
|
54 |
[UI]
|
55 |
# Name of the assistant.
|
56 |
name = "AI Tutor"
|
|
|
61 |
# Large size content are by default collapsed for a cleaner ui
|
62 |
default_collapse_content = true
|
63 |
|
64 |
+
# Chain of Thought (CoT) display mode. Can be "hidden", "tool_call" or "full".
|
65 |
+
cot = "hidden"
|
66 |
|
67 |
# Link to your github repo. This will add a github button in the UI's header.
|
68 |
+
github = "https://github.com/DL4DS/dl4ds_tutor"
|
69 |
|
70 |
# Specify a CSS file that can be used to customize the user interface.
|
71 |
# The CSS file can be served from the public directory or via an external link.
|
|
|
87 |
# custom_build = "./public/build"
|
88 |
|
89 |
[UI.theme]
|
90 |
+
default = "light"
|
91 |
#layout = "wide"
|
92 |
#font_family = "Inter, sans-serif"
|
93 |
# Override default MUI light theme. (Check theme.ts)
|
|
|
117 |
#secondary = "#BDBDBD"
|
118 |
|
119 |
[meta]
|
120 |
+
generated_by = "1.1.402"
|
code/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .modules import *
|
|
|
|
code/app.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, Request, Response, HTTPException
|
2 |
+
from fastapi.responses import HTMLResponse, RedirectResponse
|
3 |
+
from fastapi.templating import Jinja2Templates
|
4 |
+
from google.oauth2 import id_token
|
5 |
+
from google.auth.transport import requests as google_requests
|
6 |
+
from google_auth_oauthlib.flow import Flow
|
7 |
+
from chainlit.utils import mount_chainlit
|
8 |
+
import secrets
|
9 |
+
import json
|
10 |
+
import base64
|
11 |
+
from modules.config.constants import (
|
12 |
+
OAUTH_GOOGLE_CLIENT_ID,
|
13 |
+
OAUTH_GOOGLE_CLIENT_SECRET,
|
14 |
+
CHAINLIT_URL,
|
15 |
+
GITHUB_REPO,
|
16 |
+
DOCS_WEBSITE,
|
17 |
+
ALL_TIME_TOKENS_ALLOCATED,
|
18 |
+
TOKENS_LEFT,
|
19 |
+
)
|
20 |
+
from fastapi.middleware.cors import CORSMiddleware
|
21 |
+
from fastapi.staticfiles import StaticFiles
|
22 |
+
from modules.chat_processor.helpers import (
|
23 |
+
get_user_details,
|
24 |
+
get_time,
|
25 |
+
reset_tokens_for_user,
|
26 |
+
check_user_cooldown,
|
27 |
+
update_user_info,
|
28 |
+
)
|
29 |
+
|
30 |
+
GOOGLE_CLIENT_ID = OAUTH_GOOGLE_CLIENT_ID
|
31 |
+
GOOGLE_CLIENT_SECRET = OAUTH_GOOGLE_CLIENT_SECRET
|
32 |
+
GOOGLE_REDIRECT_URI = f"{CHAINLIT_URL}/auth/oauth/google/callback"
|
33 |
+
|
34 |
+
app = FastAPI()
|
35 |
+
app.mount("/public", StaticFiles(directory="public"), name="public")
|
36 |
+
app.add_middleware(
|
37 |
+
CORSMiddleware,
|
38 |
+
allow_origins=["*"], # Update with appropriate origins
|
39 |
+
allow_methods=["*"],
|
40 |
+
allow_headers=["*"], # or specify the headers you want to allow
|
41 |
+
expose_headers=["X-User-Info"], # Expose the custom header
|
42 |
+
)
|
43 |
+
|
44 |
+
templates = Jinja2Templates(directory="templates")
|
45 |
+
session_store = {}
|
46 |
+
CHAINLIT_PATH = "/chainlit_tutor"
|
47 |
+
|
48 |
+
# only admin is given any additional permissions for now -- no limits on tokens
|
49 |
+
USER_ROLES = {
|
50 |
+
"[email protected]": ["instructor", "bu"],
|
51 |
+
"[email protected]": ["admin", "instructor", "bu"],
|
52 |
+
"[email protected]": ["instructor", "bu"],
|
53 |
+
"[email protected]": ["guest"],
|
54 |
+
# Add more users and roles as needed
|
55 |
+
}
|
56 |
+
|
57 |
+
# Create a Google OAuth flow
|
58 |
+
flow = Flow.from_client_config(
|
59 |
+
{
|
60 |
+
"web": {
|
61 |
+
"client_id": GOOGLE_CLIENT_ID,
|
62 |
+
"client_secret": GOOGLE_CLIENT_SECRET,
|
63 |
+
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
64 |
+
"token_uri": "https://oauth2.googleapis.com/token",
|
65 |
+
"redirect_uris": [GOOGLE_REDIRECT_URI],
|
66 |
+
"scopes": [
|
67 |
+
"openid",
|
68 |
+
# "https://www.googleapis.com/auth/userinfo.email",
|
69 |
+
# "https://www.googleapis.com/auth/userinfo.profile",
|
70 |
+
],
|
71 |
+
}
|
72 |
+
},
|
73 |
+
scopes=[
|
74 |
+
"openid",
|
75 |
+
"https://www.googleapis.com/auth/userinfo.email",
|
76 |
+
"https://www.googleapis.com/auth/userinfo.profile",
|
77 |
+
],
|
78 |
+
redirect_uri=GOOGLE_REDIRECT_URI,
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
def get_user_role(username: str):
|
83 |
+
return USER_ROLES.get(username, ["guest"]) # Default to "guest" role
|
84 |
+
|
85 |
+
|
86 |
+
async def get_user_info_from_cookie(request: Request):
|
87 |
+
user_info_encoded = request.cookies.get("X-User-Info")
|
88 |
+
if user_info_encoded:
|
89 |
+
try:
|
90 |
+
user_info_json = base64.b64decode(user_info_encoded).decode()
|
91 |
+
return json.loads(user_info_json)
|
92 |
+
except Exception as e:
|
93 |
+
print(f"Error decoding user info: {e}")
|
94 |
+
return None
|
95 |
+
return None
|
96 |
+
|
97 |
+
|
98 |
+
async def del_user_info_from_cookie(request: Request, response: Response):
|
99 |
+
# Delete cookies from the response
|
100 |
+
response.delete_cookie("X-User-Info")
|
101 |
+
response.delete_cookie("session_token")
|
102 |
+
# Get the session token from the request cookies
|
103 |
+
session_token = request.cookies.get("session_token")
|
104 |
+
# Check if the session token exists in the session_store before deleting
|
105 |
+
if session_token and session_token in session_store:
|
106 |
+
del session_store[session_token]
|
107 |
+
|
108 |
+
|
109 |
+
def get_user_info(request: Request):
|
110 |
+
session_token = request.cookies.get("session_token")
|
111 |
+
if session_token and session_token in session_store:
|
112 |
+
return session_store[session_token]
|
113 |
+
return None
|
114 |
+
|
115 |
+
|
116 |
+
@app.get("/", response_class=HTMLResponse)
|
117 |
+
async def login_page(request: Request):
|
118 |
+
user_info = await get_user_info_from_cookie(request)
|
119 |
+
if user_info and user_info.get("google_signed_in"):
|
120 |
+
return RedirectResponse("/post-signin")
|
121 |
+
return templates.TemplateResponse(
|
122 |
+
"login.html",
|
123 |
+
{"request": request, "GITHUB_REPO": GITHUB_REPO, "DOCS_WEBSITE": DOCS_WEBSITE},
|
124 |
+
)
|
125 |
+
|
126 |
+
|
127 |
+
# @app.get("/login/guest")
|
128 |
+
# async def login_guest():
|
129 |
+
# username = "guest"
|
130 |
+
# session_token = secrets.token_hex(16)
|
131 |
+
# unique_session_id = secrets.token_hex(8)
|
132 |
+
# username = f"{username}_{unique_session_id}"
|
133 |
+
# session_store[session_token] = {
|
134 |
+
# "email": username,
|
135 |
+
# "name": "Guest",
|
136 |
+
# "profile_image": "",
|
137 |
+
# "google_signed_in": False, # Ensure guest users do not have this flag
|
138 |
+
# }
|
139 |
+
# user_info_json = json.dumps(session_store[session_token])
|
140 |
+
# user_info_encoded = base64.b64encode(user_info_json.encode()).decode()
|
141 |
+
|
142 |
+
# # Set cookies
|
143 |
+
# response = RedirectResponse(url="/post-signin", status_code=303)
|
144 |
+
# response.set_cookie(key="session_token", value=session_token)
|
145 |
+
# response.set_cookie(key="X-User-Info", value=user_info_encoded, httponly=True)
|
146 |
+
# return response
|
147 |
+
|
148 |
+
|
149 |
+
@app.get("/login/google")
|
150 |
+
async def login_google(request: Request):
|
151 |
+
# Clear any existing session cookies to avoid conflicts with guest sessions
|
152 |
+
response = RedirectResponse(url="/post-signin")
|
153 |
+
response.delete_cookie(key="session_token")
|
154 |
+
response.delete_cookie(key="X-User-Info")
|
155 |
+
|
156 |
+
user_info = await get_user_info_from_cookie(request)
|
157 |
+
# Check if user is already signed in using Google
|
158 |
+
if user_info and user_info.get("google_signed_in"):
|
159 |
+
return RedirectResponse("/post-signin")
|
160 |
+
else:
|
161 |
+
authorization_url, _ = flow.authorization_url(prompt="consent")
|
162 |
+
return RedirectResponse(authorization_url, headers=response.headers)
|
163 |
+
|
164 |
+
|
165 |
+
@app.get("/auth/oauth/google/callback")
|
166 |
+
async def auth_google(request: Request):
|
167 |
+
try:
|
168 |
+
flow.fetch_token(code=request.query_params.get("code"))
|
169 |
+
credentials = flow.credentials
|
170 |
+
user_info = id_token.verify_oauth2_token(
|
171 |
+
credentials.id_token, google_requests.Request(), GOOGLE_CLIENT_ID
|
172 |
+
)
|
173 |
+
|
174 |
+
email = user_info["email"]
|
175 |
+
name = user_info.get("name", "")
|
176 |
+
profile_image = user_info.get("picture", "")
|
177 |
+
role = get_user_role(email)
|
178 |
+
|
179 |
+
session_token = secrets.token_hex(16)
|
180 |
+
session_store[session_token] = {
|
181 |
+
"email": email,
|
182 |
+
"name": name,
|
183 |
+
"profile_image": profile_image,
|
184 |
+
"google_signed_in": True, # Set this flag to True for Google-signed users
|
185 |
+
}
|
186 |
+
|
187 |
+
# add literalai user info to session store to be sent to chainlit
|
188 |
+
literalai_user = await get_user_details(email)
|
189 |
+
session_store[session_token]["literalai_info"] = literalai_user.to_dict()
|
190 |
+
session_store[session_token]["literalai_info"]["metadata"]["role"] = role
|
191 |
+
|
192 |
+
user_info_json = json.dumps(session_store[session_token])
|
193 |
+
user_info_encoded = base64.b64encode(user_info_json.encode()).decode()
|
194 |
+
|
195 |
+
# Set cookies
|
196 |
+
response = RedirectResponse(url="/post-signin", status_code=303)
|
197 |
+
response.set_cookie(key="session_token", value=session_token)
|
198 |
+
response.set_cookie(
|
199 |
+
key="X-User-Info", value=user_info_encoded, httponly=True
|
200 |
+
) # TODO: is the flag httponly=True necessary?
|
201 |
+
return response
|
202 |
+
except Exception as e:
|
203 |
+
print(f"Error during Google OAuth callback: {e}")
|
204 |
+
return RedirectResponse(url="/", status_code=302)
|
205 |
+
|
206 |
+
|
207 |
+
@app.get("/cooldown")
|
208 |
+
async def cooldown(request: Request):
|
209 |
+
user_info = await get_user_info_from_cookie(request)
|
210 |
+
user_details = await get_user_details(user_info["email"])
|
211 |
+
current_datetime = get_time()
|
212 |
+
cooldown, cooldown_end_time = await check_user_cooldown(
|
213 |
+
user_details, current_datetime
|
214 |
+
)
|
215 |
+
print(f"User in cooldown: {cooldown}")
|
216 |
+
print(f"Cooldown end time: {cooldown_end_time}")
|
217 |
+
if cooldown and "admin" not in get_user_role(user_info["email"]):
|
218 |
+
return templates.TemplateResponse(
|
219 |
+
"cooldown.html",
|
220 |
+
{
|
221 |
+
"request": request,
|
222 |
+
"username": user_info["email"],
|
223 |
+
"role": get_user_role(user_info["email"]),
|
224 |
+
"cooldown_end_time": cooldown_end_time,
|
225 |
+
"tokens_left": user_details.metadata["tokens_left"],
|
226 |
+
},
|
227 |
+
)
|
228 |
+
else:
|
229 |
+
user_details.metadata["in_cooldown"] = False
|
230 |
+
await update_user_info(user_details)
|
231 |
+
await reset_tokens_for_user(user_details)
|
232 |
+
return RedirectResponse("/post-signin")
|
233 |
+
|
234 |
+
|
235 |
+
@app.get("/post-signin", response_class=HTMLResponse)
|
236 |
+
async def post_signin(request: Request):
|
237 |
+
user_info = await get_user_info_from_cookie(request)
|
238 |
+
if not user_info:
|
239 |
+
user_info = get_user_info(request)
|
240 |
+
user_details = await get_user_details(user_info["email"])
|
241 |
+
current_datetime = get_time()
|
242 |
+
user_details.metadata["last_login"] = current_datetime
|
243 |
+
# if new user, set the number of tries
|
244 |
+
if "tokens_left" not in user_details.metadata:
|
245 |
+
user_details.metadata["tokens_left"] = (
|
246 |
+
TOKENS_LEFT # set the number of tokens left for the new user
|
247 |
+
)
|
248 |
+
if "last_message_time" not in user_details.metadata:
|
249 |
+
user_details.metadata["last_message_time"] = current_datetime
|
250 |
+
if "all_time_tokens_allocated" not in user_details.metadata:
|
251 |
+
user_details.metadata["all_time_tokens_allocated"] = ALL_TIME_TOKENS_ALLOCATED
|
252 |
+
if "in_cooldown" not in user_details.metadata:
|
253 |
+
user_details.metadata["in_cooldown"] = False
|
254 |
+
await update_user_info(user_details)
|
255 |
+
|
256 |
+
if "last_message_time" in user_details.metadata and "admin" not in get_user_role(
|
257 |
+
user_info["email"]
|
258 |
+
):
|
259 |
+
cooldown, _ = await check_user_cooldown(user_details, current_datetime)
|
260 |
+
if cooldown:
|
261 |
+
user_details.metadata["in_cooldown"] = True
|
262 |
+
return RedirectResponse("/cooldown")
|
263 |
+
else:
|
264 |
+
user_details.metadata["in_cooldown"] = False
|
265 |
+
await reset_tokens_for_user(user_details)
|
266 |
+
|
267 |
+
if user_info:
|
268 |
+
username = user_info["email"]
|
269 |
+
role = get_user_role(username)
|
270 |
+
jwt_token = request.cookies.get("X-User-Info")
|
271 |
+
return templates.TemplateResponse(
|
272 |
+
"dashboard.html",
|
273 |
+
{
|
274 |
+
"request": request,
|
275 |
+
"username": username,
|
276 |
+
"role": role,
|
277 |
+
"jwt_token": jwt_token,
|
278 |
+
"tokens_left": user_details.metadata["tokens_left"],
|
279 |
+
"all_time_tokens_allocated": user_details.metadata[
|
280 |
+
"all_time_tokens_allocated"
|
281 |
+
],
|
282 |
+
"total_tokens_allocated": ALL_TIME_TOKENS_ALLOCATED,
|
283 |
+
},
|
284 |
+
)
|
285 |
+
return RedirectResponse("/")
|
286 |
+
|
287 |
+
|
288 |
+
@app.get("/start-tutor")
|
289 |
+
@app.post("/start-tutor")
|
290 |
+
async def start_tutor(request: Request):
|
291 |
+
user_info = await get_user_info_from_cookie(request)
|
292 |
+
if user_info:
|
293 |
+
user_info_json = json.dumps(user_info)
|
294 |
+
user_info_encoded = base64.b64encode(user_info_json.encode()).decode()
|
295 |
+
|
296 |
+
response = RedirectResponse(CHAINLIT_PATH, status_code=303)
|
297 |
+
response.set_cookie(key="X-User-Info", value=user_info_encoded, httponly=True)
|
298 |
+
return response
|
299 |
+
|
300 |
+
return RedirectResponse(url="/")
|
301 |
+
|
302 |
+
|
303 |
+
@app.exception_handler(HTTPException)
|
304 |
+
async def http_exception_handler(request: Request, exc: HTTPException):
|
305 |
+
if exc.status_code == 404:
|
306 |
+
return templates.TemplateResponse(
|
307 |
+
"error_404.html", {"request": request}, status_code=404
|
308 |
+
)
|
309 |
+
return templates.TemplateResponse(
|
310 |
+
"error.html",
|
311 |
+
{"request": request, "error": str(exc)},
|
312 |
+
status_code=exc.status_code,
|
313 |
+
)
|
314 |
+
|
315 |
+
|
316 |
+
@app.exception_handler(Exception)
|
317 |
+
async def exception_handler(request: Request, exc: Exception):
|
318 |
+
return templates.TemplateResponse(
|
319 |
+
"error.html", {"request": request, "error": str(exc)}, status_code=500
|
320 |
+
)
|
321 |
+
|
322 |
+
|
323 |
+
@app.get("/logout", response_class=HTMLResponse)
|
324 |
+
async def logout(request: Request, response: Response):
|
325 |
+
await del_user_info_from_cookie(request=request, response=response)
|
326 |
+
response = RedirectResponse(url="/", status_code=302)
|
327 |
+
# Set cookies to empty values and expire them immediately
|
328 |
+
response.set_cookie(key="session_token", value="", expires=0)
|
329 |
+
response.set_cookie(key="X-User-Info", value="", expires=0)
|
330 |
+
return response
|
331 |
+
|
332 |
+
|
333 |
+
@app.get("/get-tokens-left")
|
334 |
+
async def get_tokens_left(request: Request):
|
335 |
+
try:
|
336 |
+
user_info = await get_user_info_from_cookie(request)
|
337 |
+
user_details = await get_user_details(user_info["email"])
|
338 |
+
await reset_tokens_for_user(user_details)
|
339 |
+
tokens_left = user_details.metadata["tokens_left"]
|
340 |
+
return {"tokens_left": tokens_left}
|
341 |
+
except Exception as e:
|
342 |
+
print(f"Error getting tokens left: {e}")
|
343 |
+
return {"tokens_left": 0}
|
344 |
+
|
345 |
+
|
346 |
+
mount_chainlit(app=app, target="main.py", path=CHAINLIT_PATH)
|
347 |
+
|
348 |
+
if __name__ == "__main__":
|
349 |
+
import uvicorn
|
350 |
+
|
351 |
+
uvicorn.run(app, host="127.0.0.1", port=7860)
|
code/chainlit.md
CHANGED
@@ -1,10 +1,5 @@
|
|
1 |
# Welcome to DL4DS Tutor! 🚀🤖
|
2 |
|
3 |
-
Hi there, this is an LLM chatbot designed to help answer questions on the course content
|
4 |
-
This is still very much a Work in Progress.
|
5 |
|
6 |
### --- Please wait while the Tutor loads... ---
|
7 |
-
|
8 |
-
## Useful Links 🔗
|
9 |
-
|
10 |
-
- **Documentation:** [Chainlit Documentation](https://docs.chainlit.io) 📚
|
|
|
1 |
# Welcome to DL4DS Tutor! 🚀🤖
|
2 |
|
3 |
+
Hi there, this is an LLM chatbot designed to help answer questions on the course content.
|
|
|
4 |
|
5 |
### --- Please wait while the Tutor loads... ---
|
|
|
|
|
|
|
|
code/chainlit_base.py
ADDED
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import chainlit.data as cl_data
|
2 |
+
import asyncio
|
3 |
+
import yaml
|
4 |
+
from typing import Any, Dict, no_type_check
|
5 |
+
import chainlit as cl
|
6 |
+
from modules.chat.llm_tutor import LLMTutor
|
7 |
+
from modules.chat.helpers import (
|
8 |
+
get_sources,
|
9 |
+
get_history_chat_resume,
|
10 |
+
get_history_setup_llm,
|
11 |
+
get_last_config,
|
12 |
+
)
|
13 |
+
import copy
|
14 |
+
from chainlit.types import ThreadDict
|
15 |
+
import time
|
16 |
+
from langchain_community.callbacks import get_openai_callback
|
17 |
+
|
18 |
+
USER_TIMEOUT = 60_000
|
19 |
+
SYSTEM = "System"
|
20 |
+
LLM = "AI Tutor"
|
21 |
+
AGENT = "Agent"
|
22 |
+
YOU = "User"
|
23 |
+
ERROR = "Error"
|
24 |
+
|
25 |
+
with open("modules/config/config.yml", "r") as f:
|
26 |
+
config = yaml.safe_load(f)
|
27 |
+
|
28 |
+
|
29 |
+
# async def setup_data_layer():
|
30 |
+
# """
|
31 |
+
# Set up the data layer for chat logging.
|
32 |
+
# """
|
33 |
+
# if config["chat_logging"]["log_chat"]:
|
34 |
+
# data_layer = CustomLiteralDataLayer(
|
35 |
+
# api_key=LITERAL_API_KEY_LOGGING, server=LITERAL_API_URL
|
36 |
+
# )
|
37 |
+
# else:
|
38 |
+
# data_layer = None
|
39 |
+
|
40 |
+
# return data_layer
|
41 |
+
|
42 |
+
|
43 |
+
class Chatbot:
|
44 |
+
def __init__(self, config):
|
45 |
+
"""
|
46 |
+
Initialize the Chatbot class.
|
47 |
+
"""
|
48 |
+
self.config = config
|
49 |
+
|
50 |
+
async def _load_config(self):
|
51 |
+
"""
|
52 |
+
Load the configuration from a YAML file.
|
53 |
+
"""
|
54 |
+
with open("modules/config/config.yml", "r") as f:
|
55 |
+
return yaml.safe_load(f)
|
56 |
+
|
57 |
+
@no_type_check
|
58 |
+
async def setup_llm(self):
|
59 |
+
"""
|
60 |
+
Set up the LLM with the provided settings. Update the configuration and initialize the LLM tutor.
|
61 |
+
|
62 |
+
#TODO: Clean this up.
|
63 |
+
"""
|
64 |
+
start_time = time.time()
|
65 |
+
|
66 |
+
llm_settings = cl.user_session.get("llm_settings", {})
|
67 |
+
(
|
68 |
+
chat_profile,
|
69 |
+
retriever_method,
|
70 |
+
memory_window,
|
71 |
+
llm_style,
|
72 |
+
generate_follow_up,
|
73 |
+
chunking_mode,
|
74 |
+
) = (
|
75 |
+
llm_settings.get("chat_model"),
|
76 |
+
llm_settings.get("retriever_method"),
|
77 |
+
llm_settings.get("memory_window"),
|
78 |
+
llm_settings.get("llm_style"),
|
79 |
+
llm_settings.get("follow_up_questions"),
|
80 |
+
llm_settings.get("chunking_mode"),
|
81 |
+
)
|
82 |
+
|
83 |
+
chain = cl.user_session.get("chain")
|
84 |
+
memory_list = cl.user_session.get(
|
85 |
+
"memory",
|
86 |
+
(
|
87 |
+
list(chain.store.values())[0].messages
|
88 |
+
if len(chain.store.values()) > 0
|
89 |
+
else []
|
90 |
+
),
|
91 |
+
)
|
92 |
+
conversation_list = get_history_setup_llm(memory_list)
|
93 |
+
|
94 |
+
old_config = copy.deepcopy(self.config)
|
95 |
+
self.config["vectorstore"]["db_option"] = retriever_method
|
96 |
+
self.config["llm_params"]["memory_window"] = memory_window
|
97 |
+
self.config["llm_params"]["llm_style"] = llm_style
|
98 |
+
self.config["llm_params"]["llm_loader"] = chat_profile
|
99 |
+
self.config["llm_params"]["generate_follow_up"] = generate_follow_up
|
100 |
+
self.config["splitter_options"]["chunking_mode"] = chunking_mode
|
101 |
+
|
102 |
+
self.llm_tutor.update_llm(
|
103 |
+
old_config, self.config
|
104 |
+
) # update only llm attributes that are changed
|
105 |
+
self.chain = self.llm_tutor.qa_bot(
|
106 |
+
memory=conversation_list,
|
107 |
+
)
|
108 |
+
|
109 |
+
cl.user_session.set("chain", self.chain)
|
110 |
+
cl.user_session.set("llm_tutor", self.llm_tutor)
|
111 |
+
|
112 |
+
print("Time taken to setup LLM: ", time.time() - start_time)
|
113 |
+
|
114 |
+
@no_type_check
|
115 |
+
async def update_llm(self, new_settings: Dict[str, Any]):
|
116 |
+
"""
|
117 |
+
Update the LLM settings and reinitialize the LLM with the new settings.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
new_settings (Dict[str, Any]): The new settings to update.
|
121 |
+
"""
|
122 |
+
cl.user_session.set("llm_settings", new_settings)
|
123 |
+
await self.inform_llm_settings()
|
124 |
+
await self.setup_llm()
|
125 |
+
|
126 |
+
async def make_llm_settings_widgets(self, config=None):
|
127 |
+
"""
|
128 |
+
Create and send the widgets for LLM settings configuration.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
config: The configuration to use for setting up the widgets.
|
132 |
+
"""
|
133 |
+
config = config or self.config
|
134 |
+
await cl.ChatSettings(
|
135 |
+
[
|
136 |
+
cl.input_widget.Select(
|
137 |
+
id="chat_model",
|
138 |
+
label="Model Name (Default GPT-3)",
|
139 |
+
values=["local_llm", "gpt-3.5-turbo-1106", "gpt-4", "gpt-4o-mini"],
|
140 |
+
initial_index=[
|
141 |
+
"local_llm",
|
142 |
+
"gpt-3.5-turbo-1106",
|
143 |
+
"gpt-4",
|
144 |
+
"gpt-4o-mini",
|
145 |
+
].index(config["llm_params"]["llm_loader"]),
|
146 |
+
),
|
147 |
+
cl.input_widget.Select(
|
148 |
+
id="retriever_method",
|
149 |
+
label="Retriever (Default FAISS)",
|
150 |
+
values=["FAISS", "Chroma", "RAGatouille", "RAPTOR"],
|
151 |
+
initial_index=["FAISS", "Chroma", "RAGatouille", "RAPTOR"].index(
|
152 |
+
config["vectorstore"]["db_option"]
|
153 |
+
),
|
154 |
+
),
|
155 |
+
cl.input_widget.Slider(
|
156 |
+
id="memory_window",
|
157 |
+
label="Memory Window (Default 3)",
|
158 |
+
initial=3,
|
159 |
+
min=0,
|
160 |
+
max=10,
|
161 |
+
step=1,
|
162 |
+
),
|
163 |
+
cl.input_widget.Switch(
|
164 |
+
id="view_sources", label="View Sources", initial=False
|
165 |
+
),
|
166 |
+
cl.input_widget.Switch(
|
167 |
+
id="stream_response",
|
168 |
+
label="Stream response",
|
169 |
+
initial=config["llm_params"]["stream"],
|
170 |
+
),
|
171 |
+
cl.input_widget.Select(
|
172 |
+
id="chunking_mode",
|
173 |
+
label="Chunking mode",
|
174 |
+
values=["fixed", "semantic"],
|
175 |
+
initial_index=1,
|
176 |
+
),
|
177 |
+
cl.input_widget.Switch(
|
178 |
+
id="follow_up_questions",
|
179 |
+
label="Generate follow up questions",
|
180 |
+
initial=False,
|
181 |
+
),
|
182 |
+
cl.input_widget.Select(
|
183 |
+
id="llm_style",
|
184 |
+
label="Type of Conversation (Default Normal)",
|
185 |
+
values=["Normal", "ELI5"],
|
186 |
+
initial_index=0,
|
187 |
+
),
|
188 |
+
]
|
189 |
+
).send()
|
190 |
+
|
191 |
+
@no_type_check
|
192 |
+
async def inform_llm_settings(self):
|
193 |
+
"""
|
194 |
+
Inform the user about the updated LLM settings and display them as a message.
|
195 |
+
"""
|
196 |
+
llm_settings: Dict[str, Any] = cl.user_session.get("llm_settings", {})
|
197 |
+
llm_tutor = cl.user_session.get("llm_tutor")
|
198 |
+
settings_dict = {
|
199 |
+
"model": llm_settings.get("chat_model"),
|
200 |
+
"retriever": llm_settings.get("retriever_method"),
|
201 |
+
"memory_window": llm_settings.get("memory_window"),
|
202 |
+
"num_docs_in_db": (
|
203 |
+
len(llm_tutor.vector_db)
|
204 |
+
if llm_tutor and hasattr(llm_tutor, "vector_db")
|
205 |
+
else 0
|
206 |
+
),
|
207 |
+
"view_sources": llm_settings.get("view_sources"),
|
208 |
+
"follow_up_questions": llm_settings.get("follow_up_questions"),
|
209 |
+
}
|
210 |
+
print("Settings Dict: ", settings_dict)
|
211 |
+
await cl.Message(
|
212 |
+
author=SYSTEM,
|
213 |
+
content="LLM settings have been updated. You can continue with your Query!",
|
214 |
+
# elements=[
|
215 |
+
# cl.Text(
|
216 |
+
# name="settings",
|
217 |
+
# display="side",
|
218 |
+
# content=json.dumps(settings_dict, indent=4),
|
219 |
+
# language="json",
|
220 |
+
# ),
|
221 |
+
# ],
|
222 |
+
).send()
|
223 |
+
|
224 |
+
async def set_starters(self):
|
225 |
+
"""
|
226 |
+
Set starter messages for the chatbot.
|
227 |
+
"""
|
228 |
+
# Return Starters only if the chat is new
|
229 |
+
|
230 |
+
try:
|
231 |
+
thread = cl_data._data_layer.get_thread(
|
232 |
+
cl.context.session.thread_id
|
233 |
+
) # see if the thread has any steps
|
234 |
+
if thread.steps or len(thread.steps) > 0:
|
235 |
+
return None
|
236 |
+
except Exception as e:
|
237 |
+
print(e)
|
238 |
+
return [
|
239 |
+
cl.Starter(
|
240 |
+
label="recording on CNNs?",
|
241 |
+
message="Where can I find the recording for the lecture on Transformers?",
|
242 |
+
icon="/public/adv-screen-recorder-svgrepo-com.svg",
|
243 |
+
),
|
244 |
+
cl.Starter(
|
245 |
+
label="where's the slides?",
|
246 |
+
message="When are the lectures? I can't find the schedule.",
|
247 |
+
icon="/public/alarmy-svgrepo-com.svg",
|
248 |
+
),
|
249 |
+
cl.Starter(
|
250 |
+
label="Due Date?",
|
251 |
+
message="When is the final project due?",
|
252 |
+
icon="/public/calendar-samsung-17-svgrepo-com.svg",
|
253 |
+
),
|
254 |
+
cl.Starter(
|
255 |
+
label="Explain backprop.",
|
256 |
+
message="I didn't understand the math behind backprop, could you explain it?",
|
257 |
+
icon="/public/acastusphoton-svgrepo-com.svg",
|
258 |
+
),
|
259 |
+
]
|
260 |
+
|
261 |
+
def rename(self, orig_author: str):
|
262 |
+
"""
|
263 |
+
Rename the original author to a more user-friendly name.
|
264 |
+
|
265 |
+
Args:
|
266 |
+
orig_author (str): The original author's name.
|
267 |
+
|
268 |
+
Returns:
|
269 |
+
str: The renamed author.
|
270 |
+
"""
|
271 |
+
rename_dict = {"Chatbot": LLM}
|
272 |
+
return rename_dict.get(orig_author, orig_author)
|
273 |
+
|
274 |
+
async def start(self, config=None):
|
275 |
+
"""
|
276 |
+
Start the chatbot, initialize settings widgets,
|
277 |
+
and display and load previous conversation if chat logging is enabled.
|
278 |
+
"""
|
279 |
+
|
280 |
+
start_time = time.time()
|
281 |
+
|
282 |
+
self.config = (
|
283 |
+
await self._load_config() if config is None else config
|
284 |
+
) # Reload the configuration on chat resume
|
285 |
+
|
286 |
+
await self.make_llm_settings_widgets(self.config) # Reload the settings widgets
|
287 |
+
|
288 |
+
user = cl.user_session.get("user")
|
289 |
+
|
290 |
+
# TODO: remove self.user with cl.user_session.get("user")
|
291 |
+
try:
|
292 |
+
self.user = {
|
293 |
+
"user_id": user.identifier,
|
294 |
+
"session_id": cl.context.session.thread_id,
|
295 |
+
}
|
296 |
+
except Exception as e:
|
297 |
+
print(e)
|
298 |
+
self.user = {
|
299 |
+
"user_id": "guest",
|
300 |
+
"session_id": cl.context.session.thread_id,
|
301 |
+
}
|
302 |
+
|
303 |
+
memory = cl.user_session.get("memory", [])
|
304 |
+
self.llm_tutor = LLMTutor(self.config, user=self.user)
|
305 |
+
|
306 |
+
self.chain = self.llm_tutor.qa_bot(
|
307 |
+
memory=memory,
|
308 |
+
)
|
309 |
+
self.question_generator = self.llm_tutor.question_generator
|
310 |
+
cl.user_session.set("llm_tutor", self.llm_tutor)
|
311 |
+
cl.user_session.set("chain", self.chain)
|
312 |
+
|
313 |
+
print("Time taken to start LLM: ", time.time() - start_time)
|
314 |
+
|
315 |
+
async def stream_response(self, response):
|
316 |
+
"""
|
317 |
+
Stream the response from the LLM.
|
318 |
+
|
319 |
+
Args:
|
320 |
+
response: The response from the LLM.
|
321 |
+
"""
|
322 |
+
msg = cl.Message(content="")
|
323 |
+
await msg.send()
|
324 |
+
|
325 |
+
output = {}
|
326 |
+
for chunk in response:
|
327 |
+
if "answer" in chunk:
|
328 |
+
await msg.stream_token(chunk["answer"])
|
329 |
+
|
330 |
+
for key in chunk:
|
331 |
+
if key not in output:
|
332 |
+
output[key] = chunk[key]
|
333 |
+
else:
|
334 |
+
output[key] += chunk[key]
|
335 |
+
return output
|
336 |
+
|
337 |
+
async def main(self, message):
|
338 |
+
"""
|
339 |
+
Process and Display the Conversation.
|
340 |
+
|
341 |
+
Args:
|
342 |
+
message: The incoming chat message.
|
343 |
+
"""
|
344 |
+
|
345 |
+
start_time = time.time()
|
346 |
+
|
347 |
+
chain = cl.user_session.get("chain")
|
348 |
+
token_count = 0 # initialize token count
|
349 |
+
if not chain:
|
350 |
+
await self.start() # start the chatbot if the chain is not present
|
351 |
+
chain = cl.user_session.get("chain")
|
352 |
+
|
353 |
+
# update user info with last message time
|
354 |
+
llm_settings = cl.user_session.get("llm_settings", {})
|
355 |
+
view_sources = llm_settings.get("view_sources", False)
|
356 |
+
stream = llm_settings.get("stream_response", False)
|
357 |
+
stream = False # Fix streaming
|
358 |
+
user_query_dict = {"input": message.content}
|
359 |
+
# Define the base configuration
|
360 |
+
cb = cl.AsyncLangchainCallbackHandler()
|
361 |
+
chain_config = {
|
362 |
+
"configurable": {
|
363 |
+
"user_id": self.user["user_id"],
|
364 |
+
"conversation_id": self.user["session_id"],
|
365 |
+
"memory_window": self.config["llm_params"]["memory_window"],
|
366 |
+
},
|
367 |
+
"callbacks": (
|
368 |
+
[cb]
|
369 |
+
if cl_data._data_layer and self.config["chat_logging"]["callbacks"]
|
370 |
+
else None
|
371 |
+
),
|
372 |
+
}
|
373 |
+
|
374 |
+
with get_openai_callback() as token_count_cb:
|
375 |
+
if stream:
|
376 |
+
res = chain.stream(user_query=user_query_dict, config=chain_config)
|
377 |
+
res = await self.stream_response(res)
|
378 |
+
else:
|
379 |
+
res = await chain.invoke(
|
380 |
+
user_query=user_query_dict,
|
381 |
+
config=chain_config,
|
382 |
+
)
|
383 |
+
token_count += token_count_cb.total_tokens
|
384 |
+
|
385 |
+
answer = res.get("answer", res.get("result"))
|
386 |
+
|
387 |
+
answer_with_sources, source_elements, sources_dict = get_sources(
|
388 |
+
res, answer, stream=stream, view_sources=view_sources
|
389 |
+
)
|
390 |
+
answer_with_sources = answer_with_sources.replace("$$", "$")
|
391 |
+
|
392 |
+
print("Time taken to process the message: ", time.time() - start_time)
|
393 |
+
|
394 |
+
actions = []
|
395 |
+
|
396 |
+
if self.config["llm_params"]["generate_follow_up"]:
|
397 |
+
start_time = time.time()
|
398 |
+
cb_follow_up = cl.AsyncLangchainCallbackHandler()
|
399 |
+
config = {
|
400 |
+
"callbacks": (
|
401 |
+
[cb_follow_up]
|
402 |
+
if cl_data._data_layer and self.config["chat_logging"]["callbacks"]
|
403 |
+
else None
|
404 |
+
)
|
405 |
+
}
|
406 |
+
with get_openai_callback() as token_count_cb:
|
407 |
+
list_of_questions = await self.question_generator.generate_questions(
|
408 |
+
query=user_query_dict["input"],
|
409 |
+
response=answer,
|
410 |
+
chat_history=res.get("chat_history"),
|
411 |
+
context=res.get("context"),
|
412 |
+
config=config,
|
413 |
+
)
|
414 |
+
|
415 |
+
token_count += token_count_cb.total_tokens
|
416 |
+
|
417 |
+
for question in list_of_questions:
|
418 |
+
actions.append(
|
419 |
+
cl.Action(
|
420 |
+
name="follow up question",
|
421 |
+
value="example_value",
|
422 |
+
description=question,
|
423 |
+
label=question,
|
424 |
+
)
|
425 |
+
)
|
426 |
+
|
427 |
+
print("Time taken to generate questions: ", time.time() - start_time)
|
428 |
+
print("Total Tokens Used: ", token_count)
|
429 |
+
|
430 |
+
await cl.Message(
|
431 |
+
content=answer_with_sources,
|
432 |
+
elements=source_elements,
|
433 |
+
author=LLM,
|
434 |
+
actions=actions,
|
435 |
+
metadata=self.config,
|
436 |
+
).send()
|
437 |
+
|
438 |
+
async def on_chat_resume(self, thread: ThreadDict):
|
439 |
+
thread_config = None
|
440 |
+
steps = thread["steps"]
|
441 |
+
k = self.config["llm_params"][
|
442 |
+
"memory_window"
|
443 |
+
] # on resume, alwyas use the default memory window
|
444 |
+
conversation_list = get_history_chat_resume(steps, k, SYSTEM, LLM)
|
445 |
+
thread_config = get_last_config(
|
446 |
+
steps
|
447 |
+
) # TODO: Returns None for now - which causes config to be reloaded with default values
|
448 |
+
cl.user_session.set("memory", conversation_list)
|
449 |
+
await self.start(config=thread_config)
|
450 |
+
|
451 |
+
async def on_follow_up(self, action: cl.Action):
|
452 |
+
user = cl.user_session.get("user")
|
453 |
+
message = await cl.Message(
|
454 |
+
content=action.description,
|
455 |
+
type="user_message",
|
456 |
+
author=user.identifier,
|
457 |
+
).send()
|
458 |
+
async with cl.Step(
|
459 |
+
name="on_follow_up", type="run", parent_id=message.id
|
460 |
+
) as step:
|
461 |
+
await self.main(message)
|
462 |
+
step.output = message.content
|
463 |
+
|
464 |
+
|
465 |
+
chatbot = Chatbot(config=config)
|
466 |
+
|
467 |
+
|
468 |
+
async def start_app():
|
469 |
+
# cl_data._data_layer = await setup_data_layer()
|
470 |
+
# chatbot.literal_client = cl_data._data_layer.client if cl_data._data_layer else None
|
471 |
+
cl.set_starters(chatbot.set_starters)
|
472 |
+
cl.author_rename(chatbot.rename)
|
473 |
+
cl.on_chat_start(chatbot.start)
|
474 |
+
cl.on_chat_resume(chatbot.on_chat_resume)
|
475 |
+
cl.on_message(chatbot.main)
|
476 |
+
cl.on_settings_update(chatbot.update_llm)
|
477 |
+
cl.action_callback("follow up question")(chatbot.on_follow_up)
|
478 |
+
|
479 |
+
|
480 |
+
loop = asyncio.get_event_loop()
|
481 |
+
if loop.is_running():
|
482 |
+
asyncio.ensure_future(start_app())
|
483 |
+
else:
|
484 |
+
asyncio.run(start_app())
|
code/main.py
CHANGED
@@ -1,15 +1,12 @@
|
|
1 |
import chainlit.data as cl_data
|
2 |
import asyncio
|
3 |
from modules.config.constants import (
|
4 |
-
LLAMA_PATH,
|
5 |
LITERAL_API_KEY_LOGGING,
|
6 |
LITERAL_API_URL,
|
7 |
)
|
8 |
from modules.chat_processor.literal_ai import CustomLiteralDataLayer
|
9 |
-
|
10 |
import json
|
11 |
import yaml
|
12 |
-
import os
|
13 |
from typing import Any, Dict, no_type_check
|
14 |
import chainlit as cl
|
15 |
from modules.chat.llm_tutor import LLMTutor
|
@@ -19,17 +16,27 @@ from modules.chat.helpers import (
|
|
19 |
get_history_setup_llm,
|
20 |
get_last_config,
|
21 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
import copy
|
23 |
from typing import Optional
|
24 |
from chainlit.types import ThreadDict
|
25 |
import time
|
|
|
|
|
|
|
26 |
|
27 |
USER_TIMEOUT = 60_000
|
28 |
-
SYSTEM = "System
|
29 |
-
LLM = "
|
30 |
-
AGENT = "Agent
|
31 |
-
YOU = "
|
32 |
-
ERROR = "Error
|
33 |
|
34 |
with open("modules/config/config.yml", "r") as f:
|
35 |
config = yaml.safe_load(f)
|
@@ -49,6 +56,24 @@ async def setup_data_layer():
|
|
49 |
return data_layer
|
50 |
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
class Chatbot:
|
53 |
def __init__(self, config):
|
54 |
"""
|
@@ -73,7 +98,14 @@ class Chatbot:
|
|
73 |
start_time = time.time()
|
74 |
|
75 |
llm_settings = cl.user_session.get("llm_settings", {})
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
llm_settings.get("chat_model"),
|
78 |
llm_settings.get("retriever_method"),
|
79 |
llm_settings.get("memory_window"),
|
@@ -106,15 +138,8 @@ class Chatbot:
|
|
106 |
) # update only llm attributes that are changed
|
107 |
self.chain = self.llm_tutor.qa_bot(
|
108 |
memory=conversation_list,
|
109 |
-
callbacks=(
|
110 |
-
[cl.LangchainCallbackHandler()]
|
111 |
-
if cl_data._data_layer and self.config["chat_logging"]["callbacks"]
|
112 |
-
else None
|
113 |
-
),
|
114 |
)
|
115 |
|
116 |
-
tags = [chat_profile, self.config["vectorstore"]["db_option"]]
|
117 |
-
|
118 |
cl.user_session.set("chain", self.chain)
|
119 |
cl.user_session.set("llm_tutor", self.llm_tutor)
|
120 |
|
@@ -180,7 +205,7 @@ class Chatbot:
|
|
180 |
cl.input_widget.Select(
|
181 |
id="chunking_mode",
|
182 |
label="Chunking mode",
|
183 |
-
values=[
|
184 |
initial_index=1,
|
185 |
),
|
186 |
cl.input_widget.Switch(
|
@@ -216,17 +241,18 @@ class Chatbot:
|
|
216 |
"view_sources": llm_settings.get("view_sources"),
|
217 |
"follow_up_questions": llm_settings.get("follow_up_questions"),
|
218 |
}
|
|
|
219 |
await cl.Message(
|
220 |
author=SYSTEM,
|
221 |
content="LLM settings have been updated. You can continue with your Query!",
|
222 |
-
elements=[
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
],
|
230 |
).send()
|
231 |
|
232 |
async def set_starters(self):
|
@@ -241,7 +267,8 @@ class Chatbot:
|
|
241 |
) # see if the thread has any steps
|
242 |
if thread.steps or len(thread.steps) > 0:
|
243 |
return None
|
244 |
-
except:
|
|
|
245 |
return [
|
246 |
cl.Starter(
|
247 |
label="recording on CNNs?",
|
@@ -275,7 +302,7 @@ class Chatbot:
|
|
275 |
Returns:
|
276 |
str: The renamed author.
|
277 |
"""
|
278 |
-
rename_dict = {"Chatbot":
|
279 |
return rename_dict.get(orig_author, orig_author)
|
280 |
|
281 |
async def start(self, config=None):
|
@@ -292,25 +319,26 @@ class Chatbot:
|
|
292 |
|
293 |
await self.make_llm_settings_widgets(self.config) # Reload the settings widgets
|
294 |
|
295 |
-
await self.make_llm_settings_widgets(self.config)
|
296 |
user = cl.user_session.get("user")
|
297 |
-
self.user = {
|
298 |
-
"user_id": user.identifier,
|
299 |
-
"session_id": cl.context.session.thread_id,
|
300 |
-
}
|
301 |
|
302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
|
304 |
-
cl.user_session.
|
305 |
self.llm_tutor = LLMTutor(self.config, user=self.user)
|
306 |
|
307 |
self.chain = self.llm_tutor.qa_bot(
|
308 |
memory=memory,
|
309 |
-
callbacks=(
|
310 |
-
[cl.LangchainCallbackHandler()]
|
311 |
-
if cl_data._data_layer and self.config["chat_logging"]["callbacks"]
|
312 |
-
else None
|
313 |
-
),
|
314 |
)
|
315 |
self.question_generator = self.llm_tutor.question_generator
|
316 |
cl.user_session.set("llm_tutor", self.llm_tutor)
|
@@ -351,29 +379,98 @@ class Chatbot:
|
|
351 |
start_time = time.time()
|
352 |
|
353 |
chain = cl.user_session.get("chain")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
|
355 |
llm_settings = cl.user_session.get("llm_settings", {})
|
356 |
view_sources = llm_settings.get("view_sources", False)
|
357 |
stream = llm_settings.get("stream_response", False)
|
358 |
-
|
359 |
user_query_dict = {"input": message.content}
|
360 |
# Define the base configuration
|
|
|
361 |
chain_config = {
|
362 |
"configurable": {
|
363 |
"user_id": self.user["user_id"],
|
364 |
"conversation_id": self.user["session_id"],
|
365 |
"memory_window": self.config["llm_params"]["memory_window"],
|
366 |
-
}
|
|
|
|
|
|
|
|
|
|
|
367 |
}
|
368 |
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
|
|
|
|
377 |
|
378 |
answer = res.get("answer", res.get("result"))
|
379 |
|
@@ -388,15 +485,26 @@ class Chatbot:
|
|
388 |
|
389 |
if self.config["llm_params"]["generate_follow_up"]:
|
390 |
start_time = time.time()
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
|
398 |
-
|
399 |
|
|
|
400 |
actions.append(
|
401 |
cl.Action(
|
402 |
name="follow up question",
|
@@ -408,6 +516,15 @@ class Chatbot:
|
|
408 |
|
409 |
print("Time taken to generate questions: ", time.time() - start_time)
|
410 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
411 |
await cl.Message(
|
412 |
content=answer_with_sources,
|
413 |
elements=source_elements,
|
@@ -429,22 +546,46 @@ class Chatbot:
|
|
429 |
cl.user_session.set("memory", conversation_list)
|
430 |
await self.start(config=thread_config)
|
431 |
|
432 |
-
@cl.
|
433 |
-
def
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
440 |
|
441 |
async def on_follow_up(self, action: cl.Action):
|
|
|
442 |
message = await cl.Message(
|
443 |
content=action.description,
|
444 |
type="user_message",
|
445 |
-
author=
|
446 |
).send()
|
447 |
-
|
|
|
|
|
|
|
|
|
448 |
|
449 |
|
450 |
chatbot = Chatbot(config=config)
|
@@ -462,4 +603,8 @@ async def start_app():
|
|
462 |
cl.action_callback("follow up question")(chatbot.on_follow_up)
|
463 |
|
464 |
|
465 |
-
asyncio.
|
|
|
|
|
|
|
|
|
|
1 |
import chainlit.data as cl_data
|
2 |
import asyncio
|
3 |
from modules.config.constants import (
|
|
|
4 |
LITERAL_API_KEY_LOGGING,
|
5 |
LITERAL_API_URL,
|
6 |
)
|
7 |
from modules.chat_processor.literal_ai import CustomLiteralDataLayer
|
|
|
8 |
import json
|
9 |
import yaml
|
|
|
10 |
from typing import Any, Dict, no_type_check
|
11 |
import chainlit as cl
|
12 |
from modules.chat.llm_tutor import LLMTutor
|
|
|
16 |
get_history_setup_llm,
|
17 |
get_last_config,
|
18 |
)
|
19 |
+
from modules.chat_processor.helpers import (
|
20 |
+
update_user_info,
|
21 |
+
get_time,
|
22 |
+
check_user_cooldown,
|
23 |
+
reset_tokens_for_user,
|
24 |
+
get_user_details,
|
25 |
+
)
|
26 |
import copy
|
27 |
from typing import Optional
|
28 |
from chainlit.types import ThreadDict
|
29 |
import time
|
30 |
+
import base64
|
31 |
+
from langchain_community.callbacks import get_openai_callback
|
32 |
+
from datetime import datetime, timezone
|
33 |
|
34 |
USER_TIMEOUT = 60_000
|
35 |
+
SYSTEM = "System"
|
36 |
+
LLM = "AI Tutor"
|
37 |
+
AGENT = "Agent"
|
38 |
+
YOU = "User"
|
39 |
+
ERROR = "Error"
|
40 |
|
41 |
with open("modules/config/config.yml", "r") as f:
|
42 |
config = yaml.safe_load(f)
|
|
|
56 |
return data_layer
|
57 |
|
58 |
|
59 |
+
async def update_user_from_chainlit(user, token_count=0):
|
60 |
+
if "admin" not in user.metadata["role"]:
|
61 |
+
user.metadata["tokens_left"] = user.metadata["tokens_left"] - token_count
|
62 |
+
user.metadata["all_time_tokens_allocated"] = (
|
63 |
+
user.metadata["all_time_tokens_allocated"] - token_count
|
64 |
+
)
|
65 |
+
user.metadata["tokens_left_at_last_message"] = user.metadata[
|
66 |
+
"tokens_left"
|
67 |
+
] # tokens_left will keep regenerating outside of chainlit
|
68 |
+
user.metadata["last_message_time"] = get_time()
|
69 |
+
await update_user_info(user)
|
70 |
+
|
71 |
+
tokens_left = user.metadata["tokens_left"]
|
72 |
+
if tokens_left < 0:
|
73 |
+
tokens_left = 0
|
74 |
+
return tokens_left
|
75 |
+
|
76 |
+
|
77 |
class Chatbot:
|
78 |
def __init__(self, config):
|
79 |
"""
|
|
|
98 |
start_time = time.time()
|
99 |
|
100 |
llm_settings = cl.user_session.get("llm_settings", {})
|
101 |
+
(
|
102 |
+
chat_profile,
|
103 |
+
retriever_method,
|
104 |
+
memory_window,
|
105 |
+
llm_style,
|
106 |
+
generate_follow_up,
|
107 |
+
chunking_mode,
|
108 |
+
) = (
|
109 |
llm_settings.get("chat_model"),
|
110 |
llm_settings.get("retriever_method"),
|
111 |
llm_settings.get("memory_window"),
|
|
|
138 |
) # update only llm attributes that are changed
|
139 |
self.chain = self.llm_tutor.qa_bot(
|
140 |
memory=conversation_list,
|
|
|
|
|
|
|
|
|
|
|
141 |
)
|
142 |
|
|
|
|
|
143 |
cl.user_session.set("chain", self.chain)
|
144 |
cl.user_session.set("llm_tutor", self.llm_tutor)
|
145 |
|
|
|
205 |
cl.input_widget.Select(
|
206 |
id="chunking_mode",
|
207 |
label="Chunking mode",
|
208 |
+
values=["fixed", "semantic"],
|
209 |
initial_index=1,
|
210 |
),
|
211 |
cl.input_widget.Switch(
|
|
|
241 |
"view_sources": llm_settings.get("view_sources"),
|
242 |
"follow_up_questions": llm_settings.get("follow_up_questions"),
|
243 |
}
|
244 |
+
print("Settings Dict: ", settings_dict)
|
245 |
await cl.Message(
|
246 |
author=SYSTEM,
|
247 |
content="LLM settings have been updated. You can continue with your Query!",
|
248 |
+
# elements=[
|
249 |
+
# cl.Text(
|
250 |
+
# name="settings",
|
251 |
+
# display="side",
|
252 |
+
# content=json.dumps(settings_dict, indent=4),
|
253 |
+
# language="json",
|
254 |
+
# ),
|
255 |
+
# ],
|
256 |
).send()
|
257 |
|
258 |
async def set_starters(self):
|
|
|
267 |
) # see if the thread has any steps
|
268 |
if thread.steps or len(thread.steps) > 0:
|
269 |
return None
|
270 |
+
except Exception as e:
|
271 |
+
print(e)
|
272 |
return [
|
273 |
cl.Starter(
|
274 |
label="recording on CNNs?",
|
|
|
302 |
Returns:
|
303 |
str: The renamed author.
|
304 |
"""
|
305 |
+
rename_dict = {"Chatbot": LLM}
|
306 |
return rename_dict.get(orig_author, orig_author)
|
307 |
|
308 |
async def start(self, config=None):
|
|
|
319 |
|
320 |
await self.make_llm_settings_widgets(self.config) # Reload the settings widgets
|
321 |
|
|
|
322 |
user = cl.user_session.get("user")
|
|
|
|
|
|
|
|
|
323 |
|
324 |
+
# TODO: remove self.user with cl.user_session.get("user")
|
325 |
+
try:
|
326 |
+
self.user = {
|
327 |
+
"user_id": user.identifier,
|
328 |
+
"session_id": cl.context.session.thread_id,
|
329 |
+
}
|
330 |
+
except Exception as e:
|
331 |
+
print(e)
|
332 |
+
self.user = {
|
333 |
+
"user_id": "guest",
|
334 |
+
"session_id": cl.context.session.thread_id,
|
335 |
+
}
|
336 |
|
337 |
+
memory = cl.user_session.get("memory", [])
|
338 |
self.llm_tutor = LLMTutor(self.config, user=self.user)
|
339 |
|
340 |
self.chain = self.llm_tutor.qa_bot(
|
341 |
memory=memory,
|
|
|
|
|
|
|
|
|
|
|
342 |
)
|
343 |
self.question_generator = self.llm_tutor.question_generator
|
344 |
cl.user_session.set("llm_tutor", self.llm_tutor)
|
|
|
379 |
start_time = time.time()
|
380 |
|
381 |
chain = cl.user_session.get("chain")
|
382 |
+
token_count = 0 # initialize token count
|
383 |
+
if not chain:
|
384 |
+
await self.start() # start the chatbot if the chain is not present
|
385 |
+
chain = cl.user_session.get("chain")
|
386 |
+
|
387 |
+
# update user info with last message time
|
388 |
+
user = cl.user_session.get("user")
|
389 |
+
await reset_tokens_for_user(user)
|
390 |
+
updated_user = await get_user_details(user.identifier)
|
391 |
+
user.metadata = updated_user.metadata
|
392 |
+
cl.user_session.set("user", user)
|
393 |
+
|
394 |
+
print("\n\n User Tokens Left: ", user.metadata["tokens_left"])
|
395 |
+
|
396 |
+
# see if user has token credits left
|
397 |
+
# if not, return message saying they have run out of tokens
|
398 |
+
if user.metadata["tokens_left"] <= 0 and "admin" not in user.metadata["role"]:
|
399 |
+
current_datetime = get_time()
|
400 |
+
cooldown, cooldown_end_time = await check_user_cooldown(
|
401 |
+
user, current_datetime
|
402 |
+
)
|
403 |
+
if cooldown:
|
404 |
+
# get time left in cooldown
|
405 |
+
# convert both to datetime objects
|
406 |
+
cooldown_end_time = datetime.fromisoformat(cooldown_end_time).replace(
|
407 |
+
tzinfo=timezone.utc
|
408 |
+
)
|
409 |
+
current_datetime = datetime.fromisoformat(current_datetime).replace(
|
410 |
+
tzinfo=timezone.utc
|
411 |
+
)
|
412 |
+
cooldown_time_left = cooldown_end_time - current_datetime
|
413 |
+
# Get the total seconds
|
414 |
+
total_seconds = int(cooldown_time_left.total_seconds())
|
415 |
+
# Calculate hours, minutes, and seconds
|
416 |
+
hours, remainder = divmod(total_seconds, 3600)
|
417 |
+
minutes, seconds = divmod(remainder, 60)
|
418 |
+
# Format the time as 00 hrs 00 mins 00 secs
|
419 |
+
formatted_time = f"{hours:02} hrs {minutes:02} mins {seconds:02} secs"
|
420 |
+
await cl.Message(
|
421 |
+
content=(
|
422 |
+
"Ah, seems like you have run out of tokens...Click "
|
423 |
+
'<a href="/cooldown" style="color: #0000CD; text-decoration: none;" target="_self">here</a> for more info. Please come back after {}'.format(
|
424 |
+
formatted_time
|
425 |
+
)
|
426 |
+
),
|
427 |
+
author=SYSTEM,
|
428 |
+
).send()
|
429 |
+
user.metadata["in_cooldown"] = True
|
430 |
+
await update_user_info(user)
|
431 |
+
return
|
432 |
+
else:
|
433 |
+
await cl.Message(
|
434 |
+
content=(
|
435 |
+
"Ah, seems like you don't have any tokens left...Please wait while we regenerate your tokens. Click "
|
436 |
+
'<a href="/cooldown" style="color: #0000CD; text-decoration: none;" target="_self">here</a> to view your token credits.'
|
437 |
+
),
|
438 |
+
author=SYSTEM,
|
439 |
+
).send()
|
440 |
+
return
|
441 |
+
|
442 |
+
user.metadata["in_cooldown"] = False
|
443 |
|
444 |
llm_settings = cl.user_session.get("llm_settings", {})
|
445 |
view_sources = llm_settings.get("view_sources", False)
|
446 |
stream = llm_settings.get("stream_response", False)
|
447 |
+
stream = False # Fix streaming
|
448 |
user_query_dict = {"input": message.content}
|
449 |
# Define the base configuration
|
450 |
+
cb = cl.AsyncLangchainCallbackHandler()
|
451 |
chain_config = {
|
452 |
"configurable": {
|
453 |
"user_id": self.user["user_id"],
|
454 |
"conversation_id": self.user["session_id"],
|
455 |
"memory_window": self.config["llm_params"]["memory_window"],
|
456 |
+
},
|
457 |
+
"callbacks": (
|
458 |
+
[cb]
|
459 |
+
if cl_data._data_layer and self.config["chat_logging"]["callbacks"]
|
460 |
+
else None
|
461 |
+
),
|
462 |
}
|
463 |
|
464 |
+
with get_openai_callback() as token_count_cb:
|
465 |
+
if stream:
|
466 |
+
res = chain.stream(user_query=user_query_dict, config=chain_config)
|
467 |
+
res = await self.stream_response(res)
|
468 |
+
else:
|
469 |
+
res = await chain.invoke(
|
470 |
+
user_query=user_query_dict,
|
471 |
+
config=chain_config,
|
472 |
+
)
|
473 |
+
token_count += token_count_cb.total_tokens
|
474 |
|
475 |
answer = res.get("answer", res.get("result"))
|
476 |
|
|
|
485 |
|
486 |
if self.config["llm_params"]["generate_follow_up"]:
|
487 |
start_time = time.time()
|
488 |
+
cb_follow_up = cl.AsyncLangchainCallbackHandler()
|
489 |
+
config = {
|
490 |
+
"callbacks": (
|
491 |
+
[cb_follow_up]
|
492 |
+
if cl_data._data_layer and self.config["chat_logging"]["callbacks"]
|
493 |
+
else None
|
494 |
+
)
|
495 |
+
}
|
496 |
+
with get_openai_callback() as token_count_cb:
|
497 |
+
list_of_questions = await self.question_generator.generate_questions(
|
498 |
+
query=user_query_dict["input"],
|
499 |
+
response=answer,
|
500 |
+
chat_history=res.get("chat_history"),
|
501 |
+
context=res.get("context"),
|
502 |
+
config=config,
|
503 |
+
)
|
504 |
|
505 |
+
token_count += token_count_cb.total_tokens
|
506 |
|
507 |
+
for question in list_of_questions:
|
508 |
actions.append(
|
509 |
cl.Action(
|
510 |
name="follow up question",
|
|
|
516 |
|
517 |
print("Time taken to generate questions: ", time.time() - start_time)
|
518 |
|
519 |
+
# # update user info with token count
|
520 |
+
tokens_left = await update_user_from_chainlit(user, token_count)
|
521 |
+
|
522 |
+
answer_with_sources += (
|
523 |
+
'\n\n<footer><span style="font-size: 0.8em; text-align: right; display: block;">Tokens Left: '
|
524 |
+
+ str(tokens_left)
|
525 |
+
+ "</span></footer>\n"
|
526 |
+
)
|
527 |
+
|
528 |
await cl.Message(
|
529 |
content=answer_with_sources,
|
530 |
elements=source_elements,
|
|
|
546 |
cl.user_session.set("memory", conversation_list)
|
547 |
await self.start(config=thread_config)
|
548 |
|
549 |
+
@cl.header_auth_callback
|
550 |
+
def header_auth_callback(headers: dict) -> Optional[cl.User]:
|
551 |
+
print("\n\n\nI am here\n\n\n")
|
552 |
+
# try: # TODO: Add try-except block after testing
|
553 |
+
# TODO: Implement to get the user information from the headers (not the cookie)
|
554 |
+
cookie = headers.get("cookie") # gets back a str
|
555 |
+
# Create a dictionary from the pairs
|
556 |
+
cookie_dict = {}
|
557 |
+
for pair in cookie.split("; "):
|
558 |
+
key, value = pair.split("=", 1)
|
559 |
+
# Strip surrounding quotes if present
|
560 |
+
cookie_dict[key] = value.strip('"')
|
561 |
+
|
562 |
+
decoded_user_info = base64.b64decode(
|
563 |
+
cookie_dict.get("X-User-Info", "")
|
564 |
+
).decode()
|
565 |
+
decoded_user_info = json.loads(decoded_user_info)
|
566 |
+
|
567 |
+
print(
|
568 |
+
f"\n\n USER ROLE: {decoded_user_info['literalai_info']['metadata']['role']} \n\n"
|
569 |
+
)
|
570 |
+
|
571 |
+
return cl.User(
|
572 |
+
id=decoded_user_info["literalai_info"]["id"],
|
573 |
+
identifier=decoded_user_info["literalai_info"]["identifier"],
|
574 |
+
metadata=decoded_user_info["literalai_info"]["metadata"],
|
575 |
+
)
|
576 |
|
577 |
async def on_follow_up(self, action: cl.Action):
|
578 |
+
user = cl.user_session.get("user")
|
579 |
message = await cl.Message(
|
580 |
content=action.description,
|
581 |
type="user_message",
|
582 |
+
author=user.identifier,
|
583 |
).send()
|
584 |
+
async with cl.Step(
|
585 |
+
name="on_follow_up", type="run", parent_id=message.id
|
586 |
+
) as step:
|
587 |
+
await self.main(message)
|
588 |
+
step.output = message.content
|
589 |
|
590 |
|
591 |
chatbot = Chatbot(config=config)
|
|
|
603 |
cl.action_callback("follow up question")(chatbot.on_follow_up)
|
604 |
|
605 |
|
606 |
+
loop = asyncio.get_event_loop()
|
607 |
+
if loop.is_running():
|
608 |
+
asyncio.ensure_future(start_app())
|
609 |
+
else:
|
610 |
+
asyncio.run(start_app())
|
code/modules/chat/chat_model_loader.py
CHANGED
@@ -1,15 +1,8 @@
|
|
1 |
from langchain_openai import ChatOpenAI
|
2 |
-
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
3 |
-
from transformers import AutoTokenizer, TextStreamer
|
4 |
from langchain_community.llms import LlamaCpp
|
5 |
-
import torch
|
6 |
-
import transformers
|
7 |
import os
|
8 |
from pathlib import Path
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
-
from langchain.callbacks.manager import CallbackManager
|
11 |
-
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
12 |
-
from modules.config.constants import LLAMA_PATH
|
13 |
|
14 |
|
15 |
class ChatModelLoader:
|
@@ -35,10 +28,10 @@ class ChatModelLoader:
|
|
35 |
elif self.config["llm_params"]["llm_loader"] == "local_llm":
|
36 |
n_batch = 512 # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
|
37 |
model_path = self._verify_model_cache(
|
38 |
-
self.config["llm_params"]["local_llm_params"]["
|
39 |
)
|
40 |
llm = LlamaCpp(
|
41 |
-
model_path=
|
42 |
n_batch=n_batch,
|
43 |
n_ctx=2048,
|
44 |
f16_kv=True,
|
|
|
1 |
from langchain_openai import ChatOpenAI
|
|
|
|
|
2 |
from langchain_community.llms import LlamaCpp
|
|
|
|
|
3 |
import os
|
4 |
from pathlib import Path
|
5 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
class ChatModelLoader:
|
|
|
28 |
elif self.config["llm_params"]["llm_loader"] == "local_llm":
|
29 |
n_batch = 512 # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
|
30 |
model_path = self._verify_model_cache(
|
31 |
+
self.config["llm_params"]["local_llm_params"]["model_path"]
|
32 |
)
|
33 |
llm = LlamaCpp(
|
34 |
+
model_path=model_path,
|
35 |
n_batch=n_batch,
|
36 |
n_ctx=2048,
|
37 |
f16_kv=True,
|
code/modules/chat/helpers.py
CHANGED
@@ -42,7 +42,6 @@ def get_sources(res, answer, stream=True, view_sources=False):
|
|
42 |
full_answer += answer
|
43 |
|
44 |
if view_sources:
|
45 |
-
|
46 |
# Then, display the sources
|
47 |
# check if the answer has sources
|
48 |
if len(source_dict) == 0:
|
@@ -51,7 +50,6 @@ def get_sources(res, answer, stream=True, view_sources=False):
|
|
51 |
else:
|
52 |
full_answer += "\n\n**Sources:**\n"
|
53 |
for idx, (url_name, source_data) in enumerate(source_dict.items()):
|
54 |
-
|
55 |
full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
|
56 |
|
57 |
name = f"Source {idx + 1} Text\n"
|
@@ -110,6 +108,7 @@ def get_prompt(config, prompt_type):
|
|
110 |
return prompts["openai"]["rephrase_prompt"]
|
111 |
|
112 |
|
|
|
113 |
def get_history_chat_resume(steps, k, SYSTEM, LLM):
|
114 |
conversation_list = []
|
115 |
count = 0
|
@@ -119,14 +118,17 @@ def get_history_chat_resume(steps, k, SYSTEM, LLM):
|
|
119 |
conversation_list.append(
|
120 |
{"type": "user_message", "content": step["output"]}
|
121 |
)
|
|
|
122 |
elif step["type"] == "assistant_message":
|
123 |
if step["name"] == LLM:
|
124 |
conversation_list.append(
|
125 |
{"type": "ai_message", "content": step["output"]}
|
126 |
)
|
|
|
127 |
else:
|
128 |
-
|
129 |
-
|
|
|
130 |
if count >= 2 * k: # 2 * k to account for both user and assistant messages
|
131 |
break
|
132 |
conversation_list = conversation_list[::-1]
|
|
|
42 |
full_answer += answer
|
43 |
|
44 |
if view_sources:
|
|
|
45 |
# Then, display the sources
|
46 |
# check if the answer has sources
|
47 |
if len(source_dict) == 0:
|
|
|
50 |
else:
|
51 |
full_answer += "\n\n**Sources:**\n"
|
52 |
for idx, (url_name, source_data) in enumerate(source_dict.items()):
|
|
|
53 |
full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
|
54 |
|
55 |
name = f"Source {idx + 1} Text\n"
|
|
|
108 |
return prompts["openai"]["rephrase_prompt"]
|
109 |
|
110 |
|
111 |
+
# TODO: Do this better
|
112 |
def get_history_chat_resume(steps, k, SYSTEM, LLM):
|
113 |
conversation_list = []
|
114 |
count = 0
|
|
|
118 |
conversation_list.append(
|
119 |
{"type": "user_message", "content": step["output"]}
|
120 |
)
|
121 |
+
count += 1
|
122 |
elif step["type"] == "assistant_message":
|
123 |
if step["name"] == LLM:
|
124 |
conversation_list.append(
|
125 |
{"type": "ai_message", "content": step["output"]}
|
126 |
)
|
127 |
+
count += 1
|
128 |
else:
|
129 |
+
pass
|
130 |
+
# raise ValueError("Invalid message type")
|
131 |
+
# count += 1
|
132 |
if count >= 2 * k: # 2 * k to account for both user and assistant messages
|
133 |
break
|
134 |
conversation_list = conversation_list[::-1]
|
code/modules/chat/langchain/__init__.py
ADDED
File without changes
|
code/modules/chat/langchain/langchain_rag.py
CHANGED
@@ -1,20 +1,24 @@
|
|
1 |
from langchain_core.prompts import ChatPromptTemplate
|
2 |
|
3 |
-
from modules.chat.langchain.utils import
|
4 |
-
from
|
5 |
from modules.chat.base import BaseRAG
|
6 |
from langchain_core.prompts import PromptTemplate
|
7 |
-
from langchain.memory import
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
)
|
11 |
|
12 |
-
import chainlit as cl
|
13 |
-
from langchain_community.chat_models import ChatOpenAI
|
14 |
-
|
15 |
|
16 |
class Langchain_RAG_V1(BaseRAG):
|
17 |
-
|
18 |
def __init__(
|
19 |
self,
|
20 |
llm,
|
@@ -95,8 +99,8 @@ class QuestionGenerator:
|
|
95 |
def __init__(self):
|
96 |
pass
|
97 |
|
98 |
-
def generate_questions(self, query, response, chat_history, context):
|
99 |
-
questions = return_questions(query, response, chat_history, context)
|
100 |
return questions
|
101 |
|
102 |
|
@@ -199,7 +203,7 @@ class Langchain_RAG_V2(BaseRAG):
|
|
199 |
is_shared=True,
|
200 |
),
|
201 |
],
|
202 |
-
)
|
203 |
|
204 |
if callbacks is not None:
|
205 |
self.rag_chain = self.rag_chain.with_config(callbacks=callbacks)
|
|
|
1 |
from langchain_core.prompts import ChatPromptTemplate
|
2 |
|
3 |
+
# from modules.chat.langchain.utils import
|
4 |
+
from langchain_community.chat_message_histories import ChatMessageHistory
|
5 |
from modules.chat.base import BaseRAG
|
6 |
from langchain_core.prompts import PromptTemplate
|
7 |
+
from langchain.memory import ConversationBufferWindowMemory
|
8 |
+
from langchain_core.runnables.utils import ConfigurableFieldSpec
|
9 |
+
from .utils import (
|
10 |
+
CustomConversationalRetrievalChain,
|
11 |
+
create_history_aware_retriever,
|
12 |
+
create_stuff_documents_chain,
|
13 |
+
create_retrieval_chain,
|
14 |
+
return_questions,
|
15 |
+
CustomRunnableWithHistory,
|
16 |
+
BaseChatMessageHistory,
|
17 |
+
InMemoryHistory,
|
18 |
)
|
19 |
|
|
|
|
|
|
|
20 |
|
21 |
class Langchain_RAG_V1(BaseRAG):
|
|
|
22 |
def __init__(
|
23 |
self,
|
24 |
llm,
|
|
|
99 |
def __init__(self):
|
100 |
pass
|
101 |
|
102 |
+
def generate_questions(self, query, response, chat_history, context, config):
|
103 |
+
questions = return_questions(query, response, chat_history, context, config)
|
104 |
return questions
|
105 |
|
106 |
|
|
|
203 |
is_shared=True,
|
204 |
),
|
205 |
],
|
206 |
+
).with_config(run_name="Langchain_RAG_V2")
|
207 |
|
208 |
if callbacks is not None:
|
209 |
self.rag_chain = self.rag_chain.with_config(callbacks=callbacks)
|
code/modules/chat/langchain/utils.py
CHANGED
@@ -1,56 +1,31 @@
|
|
1 |
from typing import Any, Dict, List, Union, Tuple, Optional
|
2 |
-
from langchain_core.messages import (
|
3 |
-
BaseMessage,
|
4 |
-
AIMessage,
|
5 |
-
FunctionMessage,
|
6 |
-
HumanMessage,
|
7 |
-
)
|
8 |
-
|
9 |
from langchain_core.prompts.base import BasePromptTemplate, format_document
|
10 |
-
from langchain_core.prompts.chat import MessagesPlaceholder
|
11 |
from langchain_core.output_parsers import StrOutputParser
|
12 |
from langchain_core.output_parsers.base import BaseOutputParser
|
13 |
from langchain_core.retrievers import BaseRetriever, RetrieverOutput
|
14 |
from langchain_core.language_models import LanguageModelLike
|
15 |
from langchain_core.runnables import Runnable, RunnableBranch, RunnablePassthrough
|
16 |
from langchain_core.runnables.history import RunnableWithMessageHistory
|
17 |
-
from langchain_core.runnables.utils import ConfigurableFieldSpec
|
18 |
from langchain_core.chat_history import BaseChatMessageHistory
|
19 |
from langchain_core.pydantic_v1 import BaseModel, Field
|
20 |
from langchain.chains.combine_documents.base import (
|
21 |
DEFAULT_DOCUMENT_PROMPT,
|
22 |
DEFAULT_DOCUMENT_SEPARATOR,
|
23 |
DOCUMENTS_KEY,
|
24 |
-
BaseCombineDocumentsChain,
|
25 |
_validate_prompt,
|
26 |
)
|
27 |
-
from langchain.chains.llm import LLMChain
|
28 |
-
from langchain_core.callbacks import Callbacks
|
29 |
-
from langchain_core.documents import Document
|
30 |
-
|
31 |
-
|
32 |
-
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
|
33 |
-
|
34 |
from langchain_core.runnables.config import RunnableConfig
|
35 |
-
from langchain_core.messages import BaseMessage
|
36 |
-
|
37 |
-
|
38 |
-
from langchain_core.output_parsers import StrOutputParser
|
39 |
from langchain_core.prompts import ChatPromptTemplate
|
40 |
from langchain_community.chat_models import ChatOpenAI
|
41 |
-
|
42 |
-
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
|
43 |
-
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
|
44 |
-
|
45 |
-
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
46 |
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
|
47 |
import inspect
|
48 |
-
from langchain.chains.conversational_retrieval.base import _get_chat_history
|
49 |
from langchain_core.messages import BaseMessage
|
50 |
|
|
|
51 |
|
52 |
-
class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
53 |
|
|
|
54 |
def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
|
55 |
_ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
|
56 |
buffer = ""
|
@@ -163,7 +138,6 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
|
163 |
|
164 |
|
165 |
class CustomRunnableWithHistory(RunnableWithMessageHistory):
|
166 |
-
|
167 |
def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
|
168 |
_ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
|
169 |
buffer = ""
|
@@ -304,8 +278,8 @@ def create_retrieval_chain(
|
|
304 |
return retrieval_chain
|
305 |
|
306 |
|
307 |
-
|
308 |
-
|
309 |
system = (
|
310 |
"You are someone that suggests a question based on the student's input and chat history. "
|
311 |
"Generate a question that is relevant to the student's input and chat history. "
|
@@ -322,18 +296,22 @@ def return_questions(query, response, chat_history_str, context):
|
|
322 |
prompt = ChatPromptTemplate.from_messages(
|
323 |
[
|
324 |
("system", system),
|
325 |
-
("human", "{chat_history_str}, {context}, {query}, {response}"),
|
326 |
]
|
327 |
)
|
328 |
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
329 |
question_generator = prompt | llm | StrOutputParser()
|
330 |
-
|
|
|
|
|
|
|
331 |
{
|
332 |
"chat_history_str": chat_history_str,
|
333 |
"context": context,
|
334 |
"query": query,
|
335 |
"response": response,
|
336 |
-
}
|
|
|
337 |
)
|
338 |
|
339 |
list_of_questions = new_questions.split("...")
|
|
|
1 |
from typing import Any, Dict, List, Union, Tuple, Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from langchain_core.prompts.base import BasePromptTemplate, format_document
|
|
|
3 |
from langchain_core.output_parsers import StrOutputParser
|
4 |
from langchain_core.output_parsers.base import BaseOutputParser
|
5 |
from langchain_core.retrievers import BaseRetriever, RetrieverOutput
|
6 |
from langchain_core.language_models import LanguageModelLike
|
7 |
from langchain_core.runnables import Runnable, RunnableBranch, RunnablePassthrough
|
8 |
from langchain_core.runnables.history import RunnableWithMessageHistory
|
|
|
9 |
from langchain_core.chat_history import BaseChatMessageHistory
|
10 |
from langchain_core.pydantic_v1 import BaseModel, Field
|
11 |
from langchain.chains.combine_documents.base import (
|
12 |
DEFAULT_DOCUMENT_PROMPT,
|
13 |
DEFAULT_DOCUMENT_SEPARATOR,
|
14 |
DOCUMENTS_KEY,
|
|
|
15 |
_validate_prompt,
|
16 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
from langchain_core.runnables.config import RunnableConfig
|
|
|
|
|
|
|
|
|
18 |
from langchain_core.prompts import ChatPromptTemplate
|
19 |
from langchain_community.chat_models import ChatOpenAI
|
20 |
+
from langchain.chains import ConversationalRetrievalChain
|
|
|
|
|
|
|
|
|
21 |
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
|
22 |
import inspect
|
|
|
23 |
from langchain_core.messages import BaseMessage
|
24 |
|
25 |
+
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
|
26 |
|
|
|
27 |
|
28 |
+
class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
29 |
def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
|
30 |
_ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
|
31 |
buffer = ""
|
|
|
138 |
|
139 |
|
140 |
class CustomRunnableWithHistory(RunnableWithMessageHistory):
|
|
|
141 |
def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
|
142 |
_ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
|
143 |
buffer = ""
|
|
|
278 |
return retrieval_chain
|
279 |
|
280 |
|
281 |
+
# TODO: Remove Hard-coded values
|
282 |
+
async def return_questions(query, response, chat_history_str, context, config):
|
283 |
system = (
|
284 |
"You are someone that suggests a question based on the student's input and chat history. "
|
285 |
"Generate a question that is relevant to the student's input and chat history. "
|
|
|
296 |
prompt = ChatPromptTemplate.from_messages(
|
297 |
[
|
298 |
("system", system),
|
299 |
+
# ("human", "{chat_history_str}, {context}, {query}, {response}"),
|
300 |
]
|
301 |
)
|
302 |
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
303 |
question_generator = prompt | llm | StrOutputParser()
|
304 |
+
question_generator = question_generator.with_config(
|
305 |
+
run_name="follow_up_question_generator"
|
306 |
+
)
|
307 |
+
new_questions = await question_generator.ainvoke(
|
308 |
{
|
309 |
"chat_history_str": chat_history_str,
|
310 |
"context": context,
|
311 |
"query": query,
|
312 |
"response": response,
|
313 |
+
},
|
314 |
+
config=config,
|
315 |
)
|
316 |
|
317 |
list_of_questions = new_questions.split("...")
|
code/modules/chat/llm_tutor.py
CHANGED
@@ -3,7 +3,6 @@ from modules.chat.chat_model_loader import ChatModelLoader
|
|
3 |
from modules.vectorstore.store_manager import VectorStoreManager
|
4 |
from modules.retriever.retriever import Retriever
|
5 |
from modules.chat.langchain.langchain_rag import (
|
6 |
-
Langchain_RAG_V1,
|
7 |
Langchain_RAG_V2,
|
8 |
QuestionGenerator,
|
9 |
)
|
@@ -28,9 +27,11 @@ class LLMTutor:
|
|
28 |
self.rephrase_prompt = get_prompt(
|
29 |
config, "rephrase"
|
30 |
) # Initialize rephrase_prompt
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
34 |
|
35 |
def update_llm(self, old_config, new_config):
|
36 |
"""
|
@@ -48,9 +49,11 @@ class LLMTutor:
|
|
48 |
self.vector_db = VectorStoreManager(
|
49 |
self.config, logger=self.logger
|
50 |
).load_database() # Reinitialize VectorStoreManager if vectorstore changes
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
54 |
|
55 |
if "llm_params.llm_style" in changes:
|
56 |
self.qa_prompt = get_prompt(
|
|
|
3 |
from modules.vectorstore.store_manager import VectorStoreManager
|
4 |
from modules.retriever.retriever import Retriever
|
5 |
from modules.chat.langchain.langchain_rag import (
|
|
|
6 |
Langchain_RAG_V2,
|
7 |
QuestionGenerator,
|
8 |
)
|
|
|
27 |
self.rephrase_prompt = get_prompt(
|
28 |
config, "rephrase"
|
29 |
) # Initialize rephrase_prompt
|
30 |
+
|
31 |
+
# TODO: Removed this functionality for now, don't know if we need it
|
32 |
+
# if self.config["vectorstore"]["embedd_files"]:
|
33 |
+
# self.vector_db.create_database()
|
34 |
+
# self.vector_db.save_database()
|
35 |
|
36 |
def update_llm(self, old_config, new_config):
|
37 |
"""
|
|
|
49 |
self.vector_db = VectorStoreManager(
|
50 |
self.config, logger=self.logger
|
51 |
).load_database() # Reinitialize VectorStoreManager if vectorstore changes
|
52 |
+
|
53 |
+
# TODO: Removed this functionality for now, don't know if we need it
|
54 |
+
# if self.config["vectorstore"]["embedd_files"]:
|
55 |
+
# self.vector_db.create_database()
|
56 |
+
# self.vector_db.save_database()
|
57 |
|
58 |
if "llm_params.llm_style" in changes:
|
59 |
self.qa_prompt = get_prompt(
|
code/modules/chat_processor/helpers.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from literalai import AsyncLiteralClient
|
3 |
+
from datetime import datetime, timedelta, timezone
|
4 |
+
from modules.config.constants import COOLDOWN_TIME, TOKENS_LEFT, REGEN_TIME
|
5 |
+
from typing_extensions import TypedDict
|
6 |
+
import tiktoken
|
7 |
+
from typing import Any, Generic, List, Literal, Optional, TypeVar, Union
|
8 |
+
|
9 |
+
Field = TypeVar("Field")
|
10 |
+
Operators = TypeVar("Operators")
|
11 |
+
Value = TypeVar("Value")
|
12 |
+
|
13 |
+
BOOLEAN_OPERATORS = Literal["is", "nis"]
|
14 |
+
STRING_OPERATORS = Literal["eq", "neq", "ilike", "nilike"]
|
15 |
+
NUMBER_OPERATORS = Literal["eq", "neq", "gt", "gte", "lt", "lte"]
|
16 |
+
STRING_LIST_OPERATORS = Literal["in", "nin"]
|
17 |
+
DATETIME_OPERATORS = Literal["gte", "lte", "gt", "lt"]
|
18 |
+
|
19 |
+
OPERATORS = Union[
|
20 |
+
BOOLEAN_OPERATORS,
|
21 |
+
STRING_OPERATORS,
|
22 |
+
NUMBER_OPERATORS,
|
23 |
+
STRING_LIST_OPERATORS,
|
24 |
+
DATETIME_OPERATORS,
|
25 |
+
]
|
26 |
+
|
27 |
+
|
28 |
+
class Filter(Generic[Field], TypedDict, total=False):
|
29 |
+
field: Field
|
30 |
+
operator: OPERATORS
|
31 |
+
value: Any
|
32 |
+
path: Optional[str]
|
33 |
+
|
34 |
+
|
35 |
+
class OrderBy(Generic[Field], TypedDict):
|
36 |
+
column: Field
|
37 |
+
direction: Literal["ASC", "DESC"]
|
38 |
+
|
39 |
+
|
40 |
+
threads_filterable_fields = Literal[
|
41 |
+
"id",
|
42 |
+
"createdAt",
|
43 |
+
"name",
|
44 |
+
"stepType",
|
45 |
+
"stepName",
|
46 |
+
"stepOutput",
|
47 |
+
"metadata",
|
48 |
+
"tokenCount",
|
49 |
+
"tags",
|
50 |
+
"participantId",
|
51 |
+
"participantIdentifiers",
|
52 |
+
"scoreValue",
|
53 |
+
"duration",
|
54 |
+
]
|
55 |
+
threads_orderable_fields = Literal["createdAt", "tokenCount"]
|
56 |
+
threads_filters = List[Filter[threads_filterable_fields]]
|
57 |
+
threads_order_by = OrderBy[threads_orderable_fields]
|
58 |
+
|
59 |
+
steps_filterable_fields = Literal[
|
60 |
+
"id",
|
61 |
+
"name",
|
62 |
+
"input",
|
63 |
+
"output",
|
64 |
+
"participantIdentifier",
|
65 |
+
"startTime",
|
66 |
+
"endTime",
|
67 |
+
"metadata",
|
68 |
+
"parentId",
|
69 |
+
"threadId",
|
70 |
+
"error",
|
71 |
+
"tags",
|
72 |
+
]
|
73 |
+
steps_orderable_fields = Literal["createdAt"]
|
74 |
+
steps_filters = List[Filter[steps_filterable_fields]]
|
75 |
+
steps_order_by = OrderBy[steps_orderable_fields]
|
76 |
+
|
77 |
+
users_filterable_fields = Literal[
|
78 |
+
"id",
|
79 |
+
"createdAt",
|
80 |
+
"identifier",
|
81 |
+
"lastEngaged",
|
82 |
+
"threadCount",
|
83 |
+
"tokenCount",
|
84 |
+
"metadata",
|
85 |
+
]
|
86 |
+
users_filters = List[Filter[users_filterable_fields]]
|
87 |
+
|
88 |
+
scores_filterable_fields = Literal[
|
89 |
+
"id",
|
90 |
+
"createdAt",
|
91 |
+
"participant",
|
92 |
+
"name",
|
93 |
+
"tags",
|
94 |
+
"value",
|
95 |
+
"type",
|
96 |
+
"comment",
|
97 |
+
]
|
98 |
+
scores_orderable_fields = Literal["createdAt"]
|
99 |
+
scores_filters = List[Filter[scores_filterable_fields]]
|
100 |
+
scores_order_by = OrderBy[scores_orderable_fields]
|
101 |
+
|
102 |
+
generation_filterable_fields = Literal[
|
103 |
+
"id",
|
104 |
+
"createdAt",
|
105 |
+
"model",
|
106 |
+
"duration",
|
107 |
+
"promptLineage",
|
108 |
+
"promptVersion",
|
109 |
+
"tags",
|
110 |
+
"score",
|
111 |
+
"participant",
|
112 |
+
"tokenCount",
|
113 |
+
"error",
|
114 |
+
]
|
115 |
+
generation_orderable_fields = Literal[
|
116 |
+
"createdAt",
|
117 |
+
"tokenCount",
|
118 |
+
"model",
|
119 |
+
"provider",
|
120 |
+
"participant",
|
121 |
+
"duration",
|
122 |
+
]
|
123 |
+
generations_filters = List[Filter[generation_filterable_fields]]
|
124 |
+
generations_order_by = OrderBy[generation_orderable_fields]
|
125 |
+
|
126 |
+
literal_client = AsyncLiteralClient(api_key=os.getenv("LITERAL_API_KEY_LOGGING"))
|
127 |
+
|
128 |
+
|
129 |
+
# For consistency, use dictionary for user_info
|
130 |
+
def convert_to_dict(user_info):
|
131 |
+
# if already a dictionary, return as is
|
132 |
+
if isinstance(user_info, dict):
|
133 |
+
return user_info
|
134 |
+
if hasattr(user_info, "__dict__"):
|
135 |
+
user_info = user_info.__dict__
|
136 |
+
return user_info
|
137 |
+
|
138 |
+
|
139 |
+
def get_time():
|
140 |
+
return datetime.now(timezone.utc).isoformat()
|
141 |
+
|
142 |
+
|
143 |
+
async def get_user_details(user_email_id):
|
144 |
+
user_info = await literal_client.api.get_or_create_user(identifier=user_email_id)
|
145 |
+
return user_info
|
146 |
+
|
147 |
+
|
148 |
+
async def update_user_info(user_info):
|
149 |
+
# if object type, convert to dictionary
|
150 |
+
user_info = convert_to_dict(user_info)
|
151 |
+
await literal_client.api.update_user(
|
152 |
+
id=user_info["id"],
|
153 |
+
identifier=user_info["identifier"],
|
154 |
+
metadata=user_info["metadata"],
|
155 |
+
)
|
156 |
+
|
157 |
+
|
158 |
+
async def check_user_cooldown(user_info, current_time):
|
159 |
+
# # Check if no tokens left
|
160 |
+
tokens_left = user_info.metadata.get("tokens_left", 0)
|
161 |
+
if tokens_left > 0 and not user_info.metadata.get("in_cooldown", False):
|
162 |
+
return False, None
|
163 |
+
|
164 |
+
user_info = convert_to_dict(user_info)
|
165 |
+
last_message_time_str = user_info["metadata"].get("last_message_time")
|
166 |
+
|
167 |
+
# Convert from ISO format string to datetime object and ensure UTC timezone
|
168 |
+
last_message_time = datetime.fromisoformat(last_message_time_str).replace(
|
169 |
+
tzinfo=timezone.utc
|
170 |
+
)
|
171 |
+
current_time = datetime.fromisoformat(current_time).replace(tzinfo=timezone.utc)
|
172 |
+
|
173 |
+
# Calculate the elapsed time
|
174 |
+
elapsed_time = current_time - last_message_time
|
175 |
+
elapsed_time_in_seconds = elapsed_time.total_seconds()
|
176 |
+
|
177 |
+
# Calculate when the cooldown period ends
|
178 |
+
cooldown_end_time = last_message_time + timedelta(seconds=COOLDOWN_TIME)
|
179 |
+
cooldown_end_time_iso = cooldown_end_time.isoformat()
|
180 |
+
|
181 |
+
# Debug: Print the cooldown end time
|
182 |
+
print(f"Cooldown end time (ISO): {cooldown_end_time_iso}")
|
183 |
+
|
184 |
+
# Check if the user is still in cooldown
|
185 |
+
if elapsed_time_in_seconds < COOLDOWN_TIME:
|
186 |
+
return True, cooldown_end_time_iso # Return in ISO 8601 format
|
187 |
+
|
188 |
+
user_info["metadata"]["in_cooldown"] = False
|
189 |
+
# If not in cooldown, regenerate tokens
|
190 |
+
await reset_tokens_for_user(user_info)
|
191 |
+
|
192 |
+
return False, None
|
193 |
+
|
194 |
+
|
195 |
+
async def reset_tokens_for_user(user_info):
|
196 |
+
user_info = convert_to_dict(user_info)
|
197 |
+
last_message_time_str = user_info["metadata"].get("last_message_time")
|
198 |
+
|
199 |
+
last_message_time = datetime.fromisoformat(last_message_time_str).replace(
|
200 |
+
tzinfo=timezone.utc
|
201 |
+
)
|
202 |
+
current_time = datetime.fromisoformat(get_time()).replace(tzinfo=timezone.utc)
|
203 |
+
|
204 |
+
# Calculate the elapsed time since the last message
|
205 |
+
elapsed_time_in_seconds = (current_time - last_message_time).total_seconds()
|
206 |
+
|
207 |
+
# Current token count (can be negative)
|
208 |
+
current_tokens = user_info["metadata"].get("tokens_left_at_last_message", 0)
|
209 |
+
current_tokens = min(current_tokens, TOKENS_LEFT)
|
210 |
+
|
211 |
+
# Maximum tokens that can be regenerated
|
212 |
+
max_tokens = user_info["metadata"].get("max_tokens", TOKENS_LEFT)
|
213 |
+
|
214 |
+
# Calculate how many tokens should have been regenerated proportionally
|
215 |
+
if current_tokens < max_tokens:
|
216 |
+
# Calculate the regeneration rate per second based on REGEN_TIME for full regeneration
|
217 |
+
regeneration_rate_per_second = max_tokens / REGEN_TIME
|
218 |
+
|
219 |
+
# Calculate how many tokens should have been regenerated based on the elapsed time
|
220 |
+
tokens_to_regenerate = int(
|
221 |
+
elapsed_time_in_seconds * regeneration_rate_per_second
|
222 |
+
)
|
223 |
+
|
224 |
+
# Ensure the new token count does not exceed max_tokens
|
225 |
+
new_token_count = min(current_tokens + tokens_to_regenerate, max_tokens)
|
226 |
+
|
227 |
+
print(
|
228 |
+
f"\n\n Adding {tokens_to_regenerate} tokens to the user, Time elapsed: {elapsed_time_in_seconds} seconds, Tokens after regeneration: {new_token_count}, Tokens before: {current_tokens} \n\n"
|
229 |
+
)
|
230 |
+
|
231 |
+
# Update the user's token count
|
232 |
+
user_info["metadata"]["tokens_left"] = new_token_count
|
233 |
+
|
234 |
+
await update_user_info(user_info)
|
235 |
+
|
236 |
+
|
237 |
+
async def get_thread_step_info(thread_id):
|
238 |
+
step = await literal_client.api.get_step(thread_id)
|
239 |
+
return step
|
240 |
+
|
241 |
+
|
242 |
+
def get_num_tokens(text, model):
|
243 |
+
encoding = tiktoken.encoding_for_model(model)
|
244 |
+
tokens = encoding.encode(text)
|
245 |
+
return len(tokens)
|
code/modules/chat_processor/literal_ai.py
CHANGED
@@ -1,44 +1,7 @@
|
|
1 |
-
from chainlit.data import ChainlitDataLayer
|
2 |
|
3 |
|
4 |
# update custom methods here (Ref: https://github.com/Chainlit/chainlit/blob/4b533cd53173bcc24abe4341a7108f0070d60099/backend/chainlit/data/__init__.py)
|
5 |
class CustomLiteralDataLayer(ChainlitDataLayer):
|
6 |
def __init__(self, **kwargs):
|
7 |
super().__init__(**kwargs)
|
8 |
-
|
9 |
-
@queue_until_user_message()
|
10 |
-
async def create_step(self, step_dict: "StepDict"):
|
11 |
-
metadata = dict(
|
12 |
-
step_dict.get("metadata", {}),
|
13 |
-
**{
|
14 |
-
"waitForAnswer": step_dict.get("waitForAnswer"),
|
15 |
-
"language": step_dict.get("language"),
|
16 |
-
"showInput": step_dict.get("showInput"),
|
17 |
-
},
|
18 |
-
)
|
19 |
-
|
20 |
-
step: LiteralStepDict = {
|
21 |
-
"createdAt": step_dict.get("createdAt"),
|
22 |
-
"startTime": step_dict.get("start"),
|
23 |
-
"endTime": step_dict.get("end"),
|
24 |
-
"generation": step_dict.get("generation"),
|
25 |
-
"id": step_dict.get("id"),
|
26 |
-
"parentId": step_dict.get("parentId"),
|
27 |
-
"name": step_dict.get("name"),
|
28 |
-
"threadId": step_dict.get("threadId"),
|
29 |
-
"type": step_dict.get("type"),
|
30 |
-
"tags": step_dict.get("tags"),
|
31 |
-
"metadata": metadata,
|
32 |
-
}
|
33 |
-
if step_dict.get("input"):
|
34 |
-
step["input"] = {"content": step_dict.get("input")}
|
35 |
-
if step_dict.get("output"):
|
36 |
-
step["output"] = {"content": step_dict.get("output")}
|
37 |
-
if step_dict.get("isError"):
|
38 |
-
step["error"] = step_dict.get("output")
|
39 |
-
|
40 |
-
# print("\n\n\n")
|
41 |
-
# print("Step: ", step)
|
42 |
-
# print("\n\n\n")
|
43 |
-
|
44 |
-
await self.client.api.send_steps([step])
|
|
|
1 |
+
from chainlit.data import ChainlitDataLayer
|
2 |
|
3 |
|
4 |
# update custom methods here (Ref: https://github.com/Chainlit/chainlit/blob/4b533cd53173bcc24abe4341a7108f0070d60099/backend/chainlit/data/__init__.py)
|
5 |
class CustomLiteralDataLayer(ChainlitDataLayer):
|
6 |
def __init__(self, **kwargs):
|
7 |
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/modules/config/config.yml
CHANGED
@@ -4,7 +4,7 @@ device: 'cpu' # str [cuda, cpu]
|
|
4 |
|
5 |
vectorstore:
|
6 |
load_from_HF: True # bool
|
7 |
-
|
8 |
data_path: '../storage/data' # str
|
9 |
url_file_path: '../storage/data/urls.txt' # str
|
10 |
expand_urls: True # bool
|
@@ -37,14 +37,14 @@ llm_params:
|
|
37 |
temperature: 0.7 # float
|
38 |
repo_id: 'TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF' # HuggingFace repo id
|
39 |
filename: 'tinyllama-1.1b-chat-v1.0.Q5_0.gguf' # Specific name of gguf file in the repo
|
40 |
-
|
41 |
stream: False # bool
|
42 |
pdf_reader: 'gpt' # str [llama, pymupdf, gpt]
|
43 |
|
44 |
chat_logging:
|
45 |
log_chat: True # bool
|
46 |
platform: 'literalai'
|
47 |
-
callbacks:
|
48 |
|
49 |
splitter_options:
|
50 |
use_splitter: True # bool
|
|
|
4 |
|
5 |
vectorstore:
|
6 |
load_from_HF: True # bool
|
7 |
+
reparse_files: True # bool
|
8 |
data_path: '../storage/data' # str
|
9 |
url_file_path: '../storage/data/urls.txt' # str
|
10 |
expand_urls: True # bool
|
|
|
37 |
temperature: 0.7 # float
|
38 |
repo_id: 'TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF' # HuggingFace repo id
|
39 |
filename: 'tinyllama-1.1b-chat-v1.0.Q5_0.gguf' # Specific name of gguf file in the repo
|
40 |
+
model_path: 'storage/models/tinyllama-1.1b-chat-v1.0.Q5_0.gguf' # Path to the model file
|
41 |
stream: False # bool
|
42 |
pdf_reader: 'gpt' # str [llama, pymupdf, gpt]
|
43 |
|
44 |
chat_logging:
|
45 |
log_chat: True # bool
|
46 |
platform: 'literalai'
|
47 |
+
callbacks: True # bool
|
48 |
|
49 |
splitter_options:
|
50 |
use_splitter: True # bool
|
code/modules/config/constants.py
CHANGED
@@ -3,6 +3,15 @@ import os
|
|
3 |
|
4 |
load_dotenv()
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
# API Keys - Loaded from the .env file
|
7 |
|
8 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
@@ -10,14 +19,16 @@ LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY")
|
|
10 |
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
|
11 |
LITERAL_API_KEY_LOGGING = os.getenv("LITERAL_API_KEY_LOGGING")
|
12 |
LITERAL_API_URL = os.getenv("LITERAL_API_URL")
|
|
|
13 |
|
14 |
OAUTH_GOOGLE_CLIENT_ID = os.getenv("OAUTH_GOOGLE_CLIENT_ID")
|
15 |
OAUTH_GOOGLE_CLIENT_SECRET = os.getenv("OAUTH_GOOGLE_CLIENT_SECRET")
|
16 |
|
17 |
-
opening_message =
|
|
|
|
|
|
|
18 |
|
19 |
# Model Paths
|
20 |
|
21 |
LLAMA_PATH = "../storage/models/tinyllama"
|
22 |
-
|
23 |
-
RETRIEVER_HF_PATHS = {"RAGatouille": "XThomasBU/Colbert_Index"}
|
|
|
3 |
|
4 |
load_dotenv()
|
5 |
|
6 |
+
TIMEOUT = 60
|
7 |
+
COOLDOWN_TIME = 60
|
8 |
+
REGEN_TIME = 180
|
9 |
+
TOKENS_LEFT = 2000
|
10 |
+
ALL_TIME_TOKENS_ALLOCATED = 1000000
|
11 |
+
|
12 |
+
GITHUB_REPO = "https://github.com/DL4DS/dl4ds_tutor"
|
13 |
+
DOCS_WEBSITE = "https://dl4ds.github.io/dl4ds_tutor/"
|
14 |
+
|
15 |
# API Keys - Loaded from the .env file
|
16 |
|
17 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
|
19 |
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
|
20 |
LITERAL_API_KEY_LOGGING = os.getenv("LITERAL_API_KEY_LOGGING")
|
21 |
LITERAL_API_URL = os.getenv("LITERAL_API_URL")
|
22 |
+
CHAINLIT_URL = os.getenv("CHAINLIT_URL")
|
23 |
|
24 |
OAUTH_GOOGLE_CLIENT_ID = os.getenv("OAUTH_GOOGLE_CLIENT_ID")
|
25 |
OAUTH_GOOGLE_CLIENT_SECRET = os.getenv("OAUTH_GOOGLE_CLIENT_SECRET")
|
26 |
|
27 |
+
opening_message = "Hey, What Can I Help You With?\n\nYou can me ask me questions about the course logistics, course content, about the final project, or anything else!"
|
28 |
+
chat_end_message = (
|
29 |
+
"I hope I was able to help you. If you have any more questions, feel free to ask!"
|
30 |
+
)
|
31 |
|
32 |
# Model Paths
|
33 |
|
34 |
LLAMA_PATH = "../storage/models/tinyllama"
|
|
|
|
code/modules/config/project_config.yml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
retriever:
|
2 |
+
retriever_hf_paths:
|
3 |
+
RAGatouille: "XThomasBU/Colbert_Index"
|
4 |
+
|
5 |
+
metadata:
|
6 |
+
metadata_links: ["https://dl4ds.github.io/sp2024/lectures/", "https://dl4ds.github.io/sp2024/schedule/"]
|
7 |
+
slide_base_link: "https://dl4ds.github.io"
|
code/modules/dataloader/data_loader.py
CHANGED
@@ -3,40 +3,26 @@ import re
|
|
3 |
import requests
|
4 |
import pysrt
|
5 |
from langchain_community.document_loaders import (
|
6 |
-
PyMuPDFLoader,
|
7 |
Docx2txtLoader,
|
8 |
YoutubeLoader,
|
9 |
-
WebBaseLoader,
|
10 |
TextLoader,
|
11 |
)
|
12 |
-
from langchain_community.document_loaders import UnstructuredMarkdownLoader
|
13 |
-
from llama_parse import LlamaParse
|
14 |
from langchain.schema import Document
|
15 |
import logging
|
16 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
17 |
from langchain_experimental.text_splitter import SemanticChunker
|
18 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
19 |
-
from ragatouille import RAGPretrainedModel
|
20 |
-
from langchain.chains import LLMChain
|
21 |
-
from langchain_community.llms import OpenAI
|
22 |
-
from langchain import PromptTemplate
|
23 |
import json
|
24 |
from concurrent.futures import ThreadPoolExecutor
|
25 |
from urllib.parse import urljoin
|
26 |
import html2text
|
27 |
import bs4
|
28 |
-
import tempfile
|
29 |
import PyPDF2
|
30 |
from modules.dataloader.pdf_readers.base import PDFReader
|
31 |
from modules.dataloader.pdf_readers.llama import LlamaParser
|
32 |
from modules.dataloader.pdf_readers.gpt import GPTParser
|
33 |
-
|
34 |
-
|
35 |
-
from modules.dataloader.helpers import get_metadata, download_pdf_from_url
|
36 |
-
from modules.config.constants import OPENAI_API_KEY, LLAMA_CLOUD_API_KEY
|
37 |
-
except:
|
38 |
-
from dataloader.helpers import get_metadata, download_pdf_from_url
|
39 |
-
from config.constants import OPENAI_API_KEY, LLAMA_CLOUD_API_KEY
|
40 |
|
41 |
logger = logging.getLogger(__name__)
|
42 |
BASE_DIR = os.getcwd()
|
@@ -47,7 +33,7 @@ class HTMLReader:
|
|
47 |
pass
|
48 |
|
49 |
def read_url(self, url):
|
50 |
-
response = requests.get(url)
|
51 |
if response.status_code == 200:
|
52 |
return response.text
|
53 |
else:
|
@@ -65,11 +51,13 @@ class HTMLReader:
|
|
65 |
href = href.replace("http", "https")
|
66 |
|
67 |
absolute_url = urljoin(base_url, href)
|
68 |
-
link[
|
69 |
|
70 |
-
resp = requests.head(absolute_url)
|
71 |
if resp.status_code != 200:
|
72 |
-
logger.warning(
|
|
|
|
|
73 |
|
74 |
return str(soup)
|
75 |
|
@@ -85,6 +73,7 @@ class HTMLReader:
|
|
85 |
else:
|
86 |
return None
|
87 |
|
|
|
88 |
class FileReader:
|
89 |
def __init__(self, logger, kind):
|
90 |
self.logger = logger
|
@@ -96,7 +85,9 @@ class FileReader:
|
|
96 |
else:
|
97 |
self.pdf_reader = PDFReader()
|
98 |
self.web_reader = HTMLReader()
|
99 |
-
self.logger.info(
|
|
|
|
|
100 |
|
101 |
def extract_text_from_pdf(self, pdf_path):
|
102 |
text = ""
|
@@ -137,7 +128,7 @@ class FileReader:
|
|
137 |
return [Document(page_content=self.web_reader.read_html(url))]
|
138 |
|
139 |
def read_tex_from_url(self, tex_url):
|
140 |
-
response = requests.get(tex_url)
|
141 |
if response.status_code == 200:
|
142 |
return [Document(page_content=response.text)]
|
143 |
else:
|
@@ -154,17 +145,20 @@ class ChunkProcessor:
|
|
154 |
self.document_metadata = {}
|
155 |
self.document_chunks_full = []
|
156 |
|
157 |
-
|
|
|
158 |
self.load_document_data()
|
159 |
|
160 |
if config["splitter_options"]["use_splitter"]:
|
161 |
if config["splitter_options"]["chunking_mode"] == "fixed":
|
162 |
if config["splitter_options"]["split_by_token"]:
|
163 |
-
self.splitter =
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
|
|
168 |
)
|
169 |
else:
|
170 |
self.splitter = RecursiveCharacterTextSplitter(
|
@@ -175,8 +169,7 @@ class ChunkProcessor:
|
|
175 |
)
|
176 |
else:
|
177 |
self.splitter = SemanticChunker(
|
178 |
-
OpenAIEmbeddings(),
|
179 |
-
breakpoint_threshold_type="percentile"
|
180 |
)
|
181 |
|
182 |
else:
|
@@ -203,7 +196,10 @@ class ChunkProcessor:
|
|
203 |
):
|
204 |
# TODO: Clear up this pipeline of re-adding metadata
|
205 |
documents = [Document(page_content=documents, source=source, page=page)]
|
206 |
-
if
|
|
|
|
|
|
|
207 |
document_chunks = documents
|
208 |
else:
|
209 |
document_chunks = self.splitter.split_documents(documents)
|
@@ -226,9 +222,22 @@ class ChunkProcessor:
|
|
226 |
|
227 |
def chunk_docs(self, file_reader, uploaded_files, weblinks):
|
228 |
addl_metadata = get_metadata(
|
229 |
-
|
230 |
-
"https://dl4ds.github.io/sp2024/schedule/",
|
231 |
) # For any additional metadata
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
with ThreadPoolExecutor() as executor:
|
233 |
executor.map(
|
234 |
self.process_file,
|
@@ -298,6 +307,7 @@ class ChunkProcessor:
|
|
298 |
self.document_metadata[file_path] = file_metadata
|
299 |
|
300 |
def process_file(self, file_path, file_index, file_reader, addl_metadata):
|
|
|
301 |
file_name = os.path.basename(file_path)
|
302 |
|
303 |
file_type = file_name.split(".")[-1]
|
@@ -314,10 +324,12 @@ class ChunkProcessor:
|
|
314 |
return
|
315 |
|
316 |
try:
|
317 |
-
|
318 |
if file_path in self.document_data:
|
319 |
self.logger.warning(f"File {file_name} already processed")
|
320 |
-
documents = [
|
|
|
|
|
|
|
321 |
else:
|
322 |
documents = read_methods[file_type](file_path)
|
323 |
|
@@ -370,22 +382,31 @@ class ChunkProcessor:
|
|
370 |
json.dump(self.document_metadata, json_file, indent=4)
|
371 |
|
372 |
def load_document_data(self):
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
|
385 |
|
386 |
class DataLoader:
|
387 |
def __init__(self, config, logger=None):
|
388 |
-
self.file_reader = FileReader(
|
|
|
|
|
389 |
self.chunk_processor = ChunkProcessor(config, logger=logger)
|
390 |
|
391 |
def get_chunks(self, uploaded_files, weblinks):
|
@@ -396,6 +417,15 @@ class DataLoader:
|
|
396 |
|
397 |
if __name__ == "__main__":
|
398 |
import yaml
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
399 |
|
400 |
logger = logging.getLogger(__name__)
|
401 |
logger.setLevel(logging.INFO)
|
@@ -403,19 +433,30 @@ if __name__ == "__main__":
|
|
403 |
with open("../code/modules/config/config.yml", "r") as f:
|
404 |
config = yaml.safe_load(f)
|
405 |
|
406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
uploaded_files = [
|
408 |
-
os.path.join(STORAGE_DIR, file)
|
|
|
|
|
409 |
]
|
410 |
|
411 |
data_loader = DataLoader(config, logger=logger)
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
|
|
|
|
|
|
|
|
417 |
)
|
418 |
|
419 |
print(document_names[:5])
|
420 |
print(len(document_chunks))
|
421 |
-
|
|
|
3 |
import requests
|
4 |
import pysrt
|
5 |
from langchain_community.document_loaders import (
|
|
|
6 |
Docx2txtLoader,
|
7 |
YoutubeLoader,
|
|
|
8 |
TextLoader,
|
9 |
)
|
|
|
|
|
10 |
from langchain.schema import Document
|
11 |
import logging
|
12 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
13 |
from langchain_experimental.text_splitter import SemanticChunker
|
14 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
|
|
|
|
|
|
|
|
15 |
import json
|
16 |
from concurrent.futures import ThreadPoolExecutor
|
17 |
from urllib.parse import urljoin
|
18 |
import html2text
|
19 |
import bs4
|
|
|
20 |
import PyPDF2
|
21 |
from modules.dataloader.pdf_readers.base import PDFReader
|
22 |
from modules.dataloader.pdf_readers.llama import LlamaParser
|
23 |
from modules.dataloader.pdf_readers.gpt import GPTParser
|
24 |
+
from modules.dataloader.helpers import get_metadata
|
25 |
+
from modules.config.constants import TIMEOUT
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
logger = logging.getLogger(__name__)
|
28 |
BASE_DIR = os.getcwd()
|
|
|
33 |
pass
|
34 |
|
35 |
def read_url(self, url):
|
36 |
+
response = requests.get(url, timeout=TIMEOUT)
|
37 |
if response.status_code == 200:
|
38 |
return response.text
|
39 |
else:
|
|
|
51 |
href = href.replace("http", "https")
|
52 |
|
53 |
absolute_url = urljoin(base_url, href)
|
54 |
+
link["href"] = absolute_url
|
55 |
|
56 |
+
resp = requests.head(absolute_url, timeout=TIMEOUT)
|
57 |
if resp.status_code != 200:
|
58 |
+
logger.warning(
|
59 |
+
f"Link {absolute_url} is broken. Status code: {resp.status_code}"
|
60 |
+
)
|
61 |
|
62 |
return str(soup)
|
63 |
|
|
|
73 |
else:
|
74 |
return None
|
75 |
|
76 |
+
|
77 |
class FileReader:
|
78 |
def __init__(self, logger, kind):
|
79 |
self.logger = logger
|
|
|
85 |
else:
|
86 |
self.pdf_reader = PDFReader()
|
87 |
self.web_reader = HTMLReader()
|
88 |
+
self.logger.info(
|
89 |
+
f"Initialized FileReader with {kind} PDF reader and HTML reader"
|
90 |
+
)
|
91 |
|
92 |
def extract_text_from_pdf(self, pdf_path):
|
93 |
text = ""
|
|
|
128 |
return [Document(page_content=self.web_reader.read_html(url))]
|
129 |
|
130 |
def read_tex_from_url(self, tex_url):
|
131 |
+
response = requests.get(tex_url, timeout=TIMEOUT)
|
132 |
if response.status_code == 200:
|
133 |
return [Document(page_content=response.text)]
|
134 |
else:
|
|
|
145 |
self.document_metadata = {}
|
146 |
self.document_chunks_full = []
|
147 |
|
148 |
+
# TODO: Fix when reparse_files is False
|
149 |
+
if not config["vectorstore"]["reparse_files"]:
|
150 |
self.load_document_data()
|
151 |
|
152 |
if config["splitter_options"]["use_splitter"]:
|
153 |
if config["splitter_options"]["chunking_mode"] == "fixed":
|
154 |
if config["splitter_options"]["split_by_token"]:
|
155 |
+
self.splitter = (
|
156 |
+
RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
157 |
+
chunk_size=config["splitter_options"]["chunk_size"],
|
158 |
+
chunk_overlap=config["splitter_options"]["chunk_overlap"],
|
159 |
+
separators=config["splitter_options"]["chunk_separators"],
|
160 |
+
disallowed_special=(),
|
161 |
+
)
|
162 |
)
|
163 |
else:
|
164 |
self.splitter = RecursiveCharacterTextSplitter(
|
|
|
169 |
)
|
170 |
else:
|
171 |
self.splitter = SemanticChunker(
|
172 |
+
OpenAIEmbeddings(), breakpoint_threshold_type="percentile"
|
|
|
173 |
)
|
174 |
|
175 |
else:
|
|
|
196 |
):
|
197 |
# TODO: Clear up this pipeline of re-adding metadata
|
198 |
documents = [Document(page_content=documents, source=source, page=page)]
|
199 |
+
if (
|
200 |
+
file_type == "pdf"
|
201 |
+
and self.config["splitter_options"]["chunking_mode"] == "fixed"
|
202 |
+
):
|
203 |
document_chunks = documents
|
204 |
else:
|
205 |
document_chunks = self.splitter.split_documents(documents)
|
|
|
222 |
|
223 |
def chunk_docs(self, file_reader, uploaded_files, weblinks):
|
224 |
addl_metadata = get_metadata(
|
225 |
+
*self.config["metadata"]["metadata_links"], self.config
|
|
|
226 |
) # For any additional metadata
|
227 |
+
|
228 |
+
# remove already processed files if reparse_files is False
|
229 |
+
if not self.config["vectorstore"]["reparse_files"]:
|
230 |
+
total_documents = len(uploaded_files) + len(weblinks)
|
231 |
+
uploaded_files = [
|
232 |
+
file_path
|
233 |
+
for file_path in uploaded_files
|
234 |
+
if file_path not in self.document_data
|
235 |
+
]
|
236 |
+
weblinks = [link for link in weblinks if link not in self.document_data]
|
237 |
+
print(
|
238 |
+
f"Total documents to process: {total_documents}, Documents already processed: {total_documents - len(uploaded_files) - len(weblinks)}"
|
239 |
+
)
|
240 |
+
|
241 |
with ThreadPoolExecutor() as executor:
|
242 |
executor.map(
|
243 |
self.process_file,
|
|
|
307 |
self.document_metadata[file_path] = file_metadata
|
308 |
|
309 |
def process_file(self, file_path, file_index, file_reader, addl_metadata):
|
310 |
+
print(f"Processing file {file_index + 1} : {file_path}")
|
311 |
file_name = os.path.basename(file_path)
|
312 |
|
313 |
file_type = file_name.split(".")[-1]
|
|
|
324 |
return
|
325 |
|
326 |
try:
|
|
|
327 |
if file_path in self.document_data:
|
328 |
self.logger.warning(f"File {file_name} already processed")
|
329 |
+
documents = [
|
330 |
+
Document(page_content=content)
|
331 |
+
for content in self.document_data[file_path].values()
|
332 |
+
]
|
333 |
else:
|
334 |
documents = read_methods[file_type](file_path)
|
335 |
|
|
|
382 |
json.dump(self.document_metadata, json_file, indent=4)
|
383 |
|
384 |
def load_document_data(self):
|
385 |
+
try:
|
386 |
+
with open(
|
387 |
+
f"{self.config['log_chunk_dir']}/docs/doc_content.json", "r"
|
388 |
+
) as json_file:
|
389 |
+
self.document_data = json.load(json_file)
|
390 |
+
with open(
|
391 |
+
f"{self.config['log_chunk_dir']}/metadata/doc_metadata.json", "r"
|
392 |
+
) as json_file:
|
393 |
+
self.document_metadata = json.load(json_file)
|
394 |
+
self.logger.info(
|
395 |
+
f"Loaded document content from {self.config['log_chunk_dir']}/docs/doc_content.json. Total documents: {len(self.document_data)}"
|
396 |
+
)
|
397 |
+
except FileNotFoundError:
|
398 |
+
self.logger.warning(
|
399 |
+
f"Document content not found in {self.config['log_chunk_dir']}/docs/doc_content.json"
|
400 |
+
)
|
401 |
+
self.document_data = {}
|
402 |
+
self.document_metadata = {}
|
403 |
|
404 |
|
405 |
class DataLoader:
|
406 |
def __init__(self, config, logger=None):
|
407 |
+
self.file_reader = FileReader(
|
408 |
+
logger=logger, kind=config["llm_params"]["pdf_reader"]
|
409 |
+
)
|
410 |
self.chunk_processor = ChunkProcessor(config, logger=logger)
|
411 |
|
412 |
def get_chunks(self, uploaded_files, weblinks):
|
|
|
417 |
|
418 |
if __name__ == "__main__":
|
419 |
import yaml
|
420 |
+
import argparse
|
421 |
+
|
422 |
+
parser = argparse.ArgumentParser(description="Process some links.")
|
423 |
+
parser.add_argument(
|
424 |
+
"--links", nargs="+", required=True, help="List of links to process."
|
425 |
+
)
|
426 |
+
|
427 |
+
args = parser.parse_args()
|
428 |
+
links_to_process = args.links
|
429 |
|
430 |
logger = logging.getLogger(__name__)
|
431 |
logger.setLevel(logging.INFO)
|
|
|
433 |
with open("../code/modules/config/config.yml", "r") as f:
|
434 |
config = yaml.safe_load(f)
|
435 |
|
436 |
+
with open("../code/modules/config/project_config.yml", "r") as f:
|
437 |
+
project_config = yaml.safe_load(f)
|
438 |
+
|
439 |
+
# Combine project config with the main config
|
440 |
+
config.update(project_config)
|
441 |
+
|
442 |
+
STORAGE_DIR = os.path.join(BASE_DIR, config["vectorstore"]["data_path"])
|
443 |
uploaded_files = [
|
444 |
+
os.path.join(STORAGE_DIR, file)
|
445 |
+
for file in os.listdir(STORAGE_DIR)
|
446 |
+
if file != "urls.txt"
|
447 |
]
|
448 |
|
449 |
data_loader = DataLoader(config, logger=logger)
|
450 |
+
# Just for testing
|
451 |
+
(
|
452 |
+
document_chunks,
|
453 |
+
document_names,
|
454 |
+
documents,
|
455 |
+
document_metadata,
|
456 |
+
) = data_loader.get_chunks(
|
457 |
+
links_to_process,
|
458 |
+
[],
|
459 |
)
|
460 |
|
461 |
print(document_names[:5])
|
462 |
print(len(document_chunks))
|
|
code/modules/dataloader/helpers.py
CHANGED
@@ -2,6 +2,8 @@ import requests
|
|
2 |
from bs4 import BeautifulSoup
|
3 |
from urllib.parse import urlparse
|
4 |
import tempfile
|
|
|
|
|
5 |
|
6 |
def get_urls_from_file(file_path: str):
|
7 |
"""
|
@@ -19,18 +21,19 @@ def get_base_url(url):
|
|
19 |
return base_url
|
20 |
|
21 |
|
22 |
-
|
|
|
23 |
"""
|
24 |
Function to get the lecture metadata from the lectures and schedule URLs.
|
25 |
"""
|
26 |
lecture_metadata = {}
|
27 |
|
28 |
# Get the main lectures page content
|
29 |
-
r_lectures = requests.get(lectures_url)
|
30 |
soup_lectures = BeautifulSoup(r_lectures.text, "html.parser")
|
31 |
|
32 |
# Get the main schedule page content
|
33 |
-
r_schedule = requests.get(schedule_url)
|
34 |
soup_schedule = BeautifulSoup(r_schedule.text, "html.parser")
|
35 |
|
36 |
# Find all lecture blocks
|
@@ -48,7 +51,9 @@ def get_metadata(lectures_url, schedule_url):
|
|
48 |
slides_link_tag = description_div.find("a", title="Download slides")
|
49 |
slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
|
50 |
slides_link = (
|
51 |
-
f"
|
|
|
|
|
52 |
)
|
53 |
if slides_link:
|
54 |
date_mapping[slides_link] = date
|
@@ -68,7 +73,9 @@ def get_metadata(lectures_url, schedule_url):
|
|
68 |
slides_link_tag = block.find("a", title="Download slides")
|
69 |
slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
|
70 |
slides_link = (
|
71 |
-
f"
|
|
|
|
|
72 |
)
|
73 |
|
74 |
# Extract the link to the lecture recording
|
@@ -118,7 +125,7 @@ def download_pdf_from_url(pdf_url):
|
|
118 |
Returns:
|
119 |
str: The local file path of the downloaded PDF file.
|
120 |
"""
|
121 |
-
response = requests.get(pdf_url)
|
122 |
if response.status_code == 200:
|
123 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
|
124 |
temp_file.write(response.content)
|
|
|
2 |
from bs4 import BeautifulSoup
|
3 |
from urllib.parse import urlparse
|
4 |
import tempfile
|
5 |
+
from modules.config.constants import TIMEOUT
|
6 |
+
|
7 |
|
8 |
def get_urls_from_file(file_path: str):
|
9 |
"""
|
|
|
21 |
return base_url
|
22 |
|
23 |
|
24 |
+
### THIS FUNCTION IS NOT GENERALIZABLE.. IT IS SPECIFIC TO THE COURSE WEBSITE ###
|
25 |
+
def get_metadata(lectures_url, schedule_url, config):
|
26 |
"""
|
27 |
Function to get the lecture metadata from the lectures and schedule URLs.
|
28 |
"""
|
29 |
lecture_metadata = {}
|
30 |
|
31 |
# Get the main lectures page content
|
32 |
+
r_lectures = requests.get(lectures_url, timeout=TIMEOUT)
|
33 |
soup_lectures = BeautifulSoup(r_lectures.text, "html.parser")
|
34 |
|
35 |
# Get the main schedule page content
|
36 |
+
r_schedule = requests.get(schedule_url, timeout=TIMEOUT)
|
37 |
soup_schedule = BeautifulSoup(r_schedule.text, "html.parser")
|
38 |
|
39 |
# Find all lecture blocks
|
|
|
51 |
slides_link_tag = description_div.find("a", title="Download slides")
|
52 |
slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
|
53 |
slides_link = (
|
54 |
+
f"{config['metadata']['slide_base_link']}{slides_link}"
|
55 |
+
if slides_link
|
56 |
+
else None
|
57 |
)
|
58 |
if slides_link:
|
59 |
date_mapping[slides_link] = date
|
|
|
73 |
slides_link_tag = block.find("a", title="Download slides")
|
74 |
slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
|
75 |
slides_link = (
|
76 |
+
f"{config['metadata']['slide_base_link']}{slides_link}"
|
77 |
+
if slides_link
|
78 |
+
else None
|
79 |
)
|
80 |
|
81 |
# Extract the link to the lecture recording
|
|
|
125 |
Returns:
|
126 |
str: The local file path of the downloaded PDF file.
|
127 |
"""
|
128 |
+
response = requests.get(pdf_url, timeout=TIMEOUT)
|
129 |
if response.status_code == 200:
|
130 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
|
131 |
temp_file.write(response.content)
|
code/modules/dataloader/pdf_readers/gpt.py
CHANGED
@@ -6,6 +6,7 @@ from io import BytesIO
|
|
6 |
from openai import OpenAI
|
7 |
from pdf2image import convert_from_path
|
8 |
from langchain.schema import Document
|
|
|
9 |
|
10 |
|
11 |
class GPTParser:
|
@@ -19,9 +20,9 @@ class GPTParser:
|
|
19 |
self.api_key = os.getenv("OPENAI_API_KEY")
|
20 |
self.prompt = """
|
21 |
The provided documents are images of PDFs of lecture slides of deep learning material.
|
22 |
-
They contain LaTeX equations, images, and text.
|
23 |
The goal is to extract the text, images and equations from the slides and convert everything to markdown format. Some of the equations may be complicated.
|
24 |
-
The markdown should be clean and easy to read, and any math equation should be converted to LaTeX, between $$.
|
25 |
For images, give a description and if you can, a source. Separate each page with '---'.
|
26 |
Just respond with the markdown. Do not include page numbers or any other metadata. Do not try to provide titles. Strictly the content.
|
27 |
"""
|
@@ -31,36 +32,45 @@ class GPTParser:
|
|
31 |
|
32 |
encoded_images = [self.encode_image(image) for image in images]
|
33 |
|
34 |
-
chunks = [encoded_images[i:i + 5] for i in range(0, len(encoded_images), 5)]
|
35 |
|
36 |
headers = {
|
37 |
"Content-Type": "application/json",
|
38 |
-
"Authorization": f"Bearer {self.api_key}"
|
39 |
}
|
40 |
|
41 |
output = ""
|
42 |
for chunk_num, chunk in enumerate(chunks):
|
43 |
-
content = [
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
content.insert(0, {"type": "text", "text": self.prompt})
|
47 |
|
48 |
payload = {
|
49 |
"model": "gpt-4o-mini",
|
50 |
-
"messages": [
|
51 |
-
{
|
52 |
-
"role": "user",
|
53 |
-
"content": content
|
54 |
-
}
|
55 |
-
],
|
56 |
}
|
57 |
|
58 |
response = requests.post(
|
59 |
-
"https://api.openai.com/v1/chat/completions",
|
|
|
|
|
|
|
|
|
60 |
|
61 |
resp = response.json()
|
62 |
|
63 |
-
chunk_output =
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
output += chunk_output + "\n---\n"
|
66 |
|
@@ -68,14 +78,12 @@ class GPTParser:
|
|
68 |
output = [doc for doc in output if doc.strip() != ""]
|
69 |
|
70 |
documents = [
|
71 |
-
Document(
|
72 |
-
|
73 |
-
metadata={"source": pdf_path, "page": i}
|
74 |
-
) for i, page in enumerate(output)
|
75 |
]
|
76 |
return documents
|
77 |
|
78 |
def encode_image(self, image):
|
79 |
buffered = BytesIO()
|
80 |
image.save(buffered, format="JPEG")
|
81 |
-
return base64.b64encode(buffered.getvalue()).decode(
|
|
|
6 |
from openai import OpenAI
|
7 |
from pdf2image import convert_from_path
|
8 |
from langchain.schema import Document
|
9 |
+
from modules.config.constants import TIMEOUT
|
10 |
|
11 |
|
12 |
class GPTParser:
|
|
|
20 |
self.api_key = os.getenv("OPENAI_API_KEY")
|
21 |
self.prompt = """
|
22 |
The provided documents are images of PDFs of lecture slides of deep learning material.
|
23 |
+
They contain LaTeX equations, images, and text.
|
24 |
The goal is to extract the text, images and equations from the slides and convert everything to markdown format. Some of the equations may be complicated.
|
25 |
+
The markdown should be clean and easy to read, and any math equation should be converted to LaTeX, between $$.
|
26 |
For images, give a description and if you can, a source. Separate each page with '---'.
|
27 |
Just respond with the markdown. Do not include page numbers or any other metadata. Do not try to provide titles. Strictly the content.
|
28 |
"""
|
|
|
32 |
|
33 |
encoded_images = [self.encode_image(image) for image in images]
|
34 |
|
35 |
+
chunks = [encoded_images[i : i + 5] for i in range(0, len(encoded_images), 5)]
|
36 |
|
37 |
headers = {
|
38 |
"Content-Type": "application/json",
|
39 |
+
"Authorization": f"Bearer {self.api_key}",
|
40 |
}
|
41 |
|
42 |
output = ""
|
43 |
for chunk_num, chunk in enumerate(chunks):
|
44 |
+
content = [
|
45 |
+
{
|
46 |
+
"type": "image_url",
|
47 |
+
"image_url": {"url": f"data:image/jpeg;base64,{image}"},
|
48 |
+
}
|
49 |
+
for image in chunk
|
50 |
+
]
|
51 |
|
52 |
content.insert(0, {"type": "text", "text": self.prompt})
|
53 |
|
54 |
payload = {
|
55 |
"model": "gpt-4o-mini",
|
56 |
+
"messages": [{"role": "user", "content": content}],
|
|
|
|
|
|
|
|
|
|
|
57 |
}
|
58 |
|
59 |
response = requests.post(
|
60 |
+
"https://api.openai.com/v1/chat/completions",
|
61 |
+
headers=headers,
|
62 |
+
json=payload,
|
63 |
+
timeout=TIMEOUT,
|
64 |
+
)
|
65 |
|
66 |
resp = response.json()
|
67 |
|
68 |
+
chunk_output = (
|
69 |
+
resp["choices"][0]["message"]["content"]
|
70 |
+
.replace("```", "")
|
71 |
+
.replace("markdown", "")
|
72 |
+
.replace("````", "")
|
73 |
+
)
|
74 |
|
75 |
output += chunk_output + "\n---\n"
|
76 |
|
|
|
78 |
output = [doc for doc in output if doc.strip() != ""]
|
79 |
|
80 |
documents = [
|
81 |
+
Document(page_content=page, metadata={"source": pdf_path, "page": i})
|
82 |
+
for i, page in enumerate(output)
|
|
|
|
|
83 |
]
|
84 |
return documents
|
85 |
|
86 |
def encode_image(self, image):
|
87 |
buffered = BytesIO()
|
88 |
image.save(buffered, format="JPEG")
|
89 |
+
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
code/modules/dataloader/pdf_readers/llama.py
CHANGED
@@ -2,19 +2,18 @@ import os
|
|
2 |
import requests
|
3 |
from llama_parse import LlamaParse
|
4 |
from langchain.schema import Document
|
5 |
-
from modules.config.constants import OPENAI_API_KEY, LLAMA_CLOUD_API_KEY
|
6 |
from modules.dataloader.helpers import download_pdf_from_url
|
7 |
|
8 |
|
9 |
-
|
10 |
class LlamaParser:
|
11 |
def __init__(self):
|
12 |
self.GPT_API_KEY = OPENAI_API_KEY
|
13 |
self.LLAMA_CLOUD_API_KEY = LLAMA_CLOUD_API_KEY
|
14 |
self.parse_url = "https://api.cloud.llamaindex.ai/api/parsing/upload"
|
15 |
self.headers = {
|
16 |
-
|
17 |
-
|
18 |
}
|
19 |
self.parser = LlamaParse(
|
20 |
api_key=LLAMA_CLOUD_API_KEY,
|
@@ -23,7 +22,7 @@ class LlamaParser:
|
|
23 |
language="en",
|
24 |
gpt4o_mode=False,
|
25 |
# gpt4o_api_key=OPENAI_API_KEY,
|
26 |
-
parsing_instruction="The provided documents are PDFs of lecture slides of deep learning material. They contain LaTeX equations, images, and text. The goal is to extract the text, images and equations from the slides. The markdown should be clean and easy to read, and any math equation should be converted to LaTeX format, between $ signs. For images, if you can, give a description and a source."
|
27 |
)
|
28 |
|
29 |
def parse(self, pdf_path):
|
@@ -38,10 +37,8 @@ class LlamaParser:
|
|
38 |
pages = [page.strip() for page in pages]
|
39 |
|
40 |
documents = [
|
41 |
-
Document(
|
42 |
-
|
43 |
-
metadata={"source": pdf_path, "page": i}
|
44 |
-
) for i, page in enumerate(pages)
|
45 |
]
|
46 |
|
47 |
return documents
|
@@ -53,20 +50,30 @@ class LlamaParser:
|
|
53 |
}
|
54 |
|
55 |
files = [
|
56 |
-
(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
]
|
58 |
|
59 |
response = requests.request(
|
60 |
-
"POST", self.parse_url, headers=self.headers, data=payload, files=files
|
|
|
61 |
|
62 |
-
return response.json()[
|
63 |
|
64 |
async def get_result(self, job_id):
|
65 |
-
url =
|
|
|
|
|
66 |
|
67 |
response = requests.request("GET", url, headers=self.headers, data={})
|
68 |
|
69 |
-
return response.json()[
|
70 |
|
71 |
async def _parse(self, pdf_path):
|
72 |
job_id, status = self.make_request(pdf_path)
|
@@ -78,15 +85,9 @@ class LlamaParser:
|
|
78 |
|
79 |
result = await self.get_result(job_id)
|
80 |
|
81 |
-
documents = [
|
82 |
-
Document(
|
83 |
-
page_content=result,
|
84 |
-
metadata={"source": pdf_path}
|
85 |
-
)
|
86 |
-
]
|
87 |
|
88 |
return documents
|
89 |
|
90 |
-
async def _parse(self, pdf_path):
|
91 |
-
|
92 |
-
|
|
|
2 |
import requests
|
3 |
from llama_parse import LlamaParse
|
4 |
from langchain.schema import Document
|
5 |
+
from modules.config.constants import OPENAI_API_KEY, LLAMA_CLOUD_API_KEY, TIMEOUT
|
6 |
from modules.dataloader.helpers import download_pdf_from_url
|
7 |
|
8 |
|
|
|
9 |
class LlamaParser:
|
10 |
def __init__(self):
|
11 |
self.GPT_API_KEY = OPENAI_API_KEY
|
12 |
self.LLAMA_CLOUD_API_KEY = LLAMA_CLOUD_API_KEY
|
13 |
self.parse_url = "https://api.cloud.llamaindex.ai/api/parsing/upload"
|
14 |
self.headers = {
|
15 |
+
"Accept": "application/json",
|
16 |
+
"Authorization": f"Bearer {LLAMA_CLOUD_API_KEY}",
|
17 |
}
|
18 |
self.parser = LlamaParse(
|
19 |
api_key=LLAMA_CLOUD_API_KEY,
|
|
|
22 |
language="en",
|
23 |
gpt4o_mode=False,
|
24 |
# gpt4o_api_key=OPENAI_API_KEY,
|
25 |
+
parsing_instruction="The provided documents are PDFs of lecture slides of deep learning material. They contain LaTeX equations, images, and text. The goal is to extract the text, images and equations from the slides. The markdown should be clean and easy to read, and any math equation should be converted to LaTeX format, between $ signs. For images, if you can, give a description and a source.",
|
26 |
)
|
27 |
|
28 |
def parse(self, pdf_path):
|
|
|
37 |
pages = [page.strip() for page in pages]
|
38 |
|
39 |
documents = [
|
40 |
+
Document(page_content=page, metadata={"source": pdf_path, "page": i})
|
41 |
+
for i, page in enumerate(pages)
|
|
|
|
|
42 |
]
|
43 |
|
44 |
return documents
|
|
|
50 |
}
|
51 |
|
52 |
files = [
|
53 |
+
(
|
54 |
+
"file",
|
55 |
+
(
|
56 |
+
"file",
|
57 |
+
requests.get(pdf_url, timeout=TIMEOUT).content,
|
58 |
+
"application/octet-stream",
|
59 |
+
),
|
60 |
+
)
|
61 |
]
|
62 |
|
63 |
response = requests.request(
|
64 |
+
"POST", self.parse_url, headers=self.headers, data=payload, files=files
|
65 |
+
)
|
66 |
|
67 |
+
return response.json()["id"], response.json()["status"]
|
68 |
|
69 |
async def get_result(self, job_id):
|
70 |
+
url = (
|
71 |
+
f"https://api.cloud.llamaindex.ai/api/parsing/job/{job_id}/result/markdown"
|
72 |
+
)
|
73 |
|
74 |
response = requests.request("GET", url, headers=self.headers, data={})
|
75 |
|
76 |
+
return response.json()["markdown"]
|
77 |
|
78 |
async def _parse(self, pdf_path):
|
79 |
job_id, status = self.make_request(pdf_path)
|
|
|
85 |
|
86 |
result = await self.get_result(job_id)
|
87 |
|
88 |
+
documents = [Document(page_content=result, metadata={"source": pdf_path})]
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
return documents
|
91 |
|
92 |
+
# async def _parse(self, pdf_path):
|
93 |
+
# return await self._parse(pdf_path)
|
|
code/modules/dataloader/webpage_crawler.py
CHANGED
@@ -3,7 +3,9 @@ from aiohttp import ClientSession
|
|
3 |
import asyncio
|
4 |
import requests
|
5 |
from bs4 import BeautifulSoup
|
6 |
-
from urllib.parse import
|
|
|
|
|
7 |
|
8 |
class WebpageCrawler:
|
9 |
def __init__(self):
|
@@ -18,7 +20,7 @@ class WebpageCrawler:
|
|
18 |
|
19 |
def url_exists(self, url: str) -> bool:
|
20 |
try:
|
21 |
-
response = requests.head(url)
|
22 |
return response.status_code == 200
|
23 |
except requests.ConnectionError:
|
24 |
return False
|
@@ -88,7 +90,7 @@ class WebpageCrawler:
|
|
88 |
|
89 |
def is_webpage(self, url: str) -> bool:
|
90 |
try:
|
91 |
-
response = requests.head(url, allow_redirects=True)
|
92 |
content_type = response.headers.get("Content-Type", "").lower()
|
93 |
return "text/html" in content_type
|
94 |
except requests.RequestException:
|
|
|
3 |
import asyncio
|
4 |
import requests
|
5 |
from bs4 import BeautifulSoup
|
6 |
+
from urllib.parse import urljoin, urldefrag
|
7 |
+
from modules.config.constants import TIMEOUT
|
8 |
+
|
9 |
|
10 |
class WebpageCrawler:
|
11 |
def __init__(self):
|
|
|
20 |
|
21 |
def url_exists(self, url: str) -> bool:
|
22 |
try:
|
23 |
+
response = requests.head(url, timeout=TIMEOUT)
|
24 |
return response.status_code == 200
|
25 |
except requests.ConnectionError:
|
26 |
return False
|
|
|
90 |
|
91 |
def is_webpage(self, url: str) -> bool:
|
92 |
try:
|
93 |
+
response = requests.head(url, allow_redirects=True, timeout=TIMEOUT)
|
94 |
content_type = response.headers.get("Content-Type", "").lower()
|
95 |
return "text/html" in content_type
|
96 |
except requests.RequestException:
|
code/modules/retriever/helpers.py
CHANGED
@@ -6,7 +6,6 @@ from typing import List
|
|
6 |
|
7 |
|
8 |
class VectorStoreRetrieverScore(VectorStoreRetriever):
|
9 |
-
|
10 |
# See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
|
11 |
def _get_relevant_documents(
|
12 |
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
|
|
6 |
|
7 |
|
8 |
class VectorStoreRetrieverScore(VectorStoreRetriever):
|
|
|
9 |
# See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
|
10 |
def _get_relevant_documents(
|
11 |
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
code/modules/vectorstore/colbert.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
from ragatouille import RAGPretrainedModel
|
2 |
from modules.vectorstore.base import VectorStoreBase
|
3 |
from langchain_core.retrievers import BaseRetriever
|
4 |
-
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
5 |
from langchain_core.documents import Document
|
6 |
-
from typing import Any, List
|
7 |
import os
|
8 |
import json
|
9 |
|
@@ -85,6 +85,7 @@ class ColbertVectorStore(VectorStoreBase):
|
|
85 |
document_ids=document_names,
|
86 |
document_metadatas=document_metadata,
|
87 |
)
|
|
|
88 |
self.colbert.set_document_count(len(document_names))
|
89 |
|
90 |
def load_database(self):
|
|
|
1 |
from ragatouille import RAGPretrainedModel
|
2 |
from modules.vectorstore.base import VectorStoreBase
|
3 |
from langchain_core.retrievers import BaseRetriever
|
4 |
+
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
5 |
from langchain_core.documents import Document
|
6 |
+
from typing import Any, List
|
7 |
import os
|
8 |
import json
|
9 |
|
|
|
85 |
document_ids=document_names,
|
86 |
document_metadatas=document_metadata,
|
87 |
)
|
88 |
+
print(f"Index created at {index_path}")
|
89 |
self.colbert.set_document_count(len(document_names))
|
90 |
|
91 |
def load_database(self):
|
code/modules/vectorstore/embedding_model_loader.py
CHANGED
@@ -1,9 +1,6 @@
|
|
1 |
from langchain_community.embeddings import OpenAIEmbeddings
|
2 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
3 |
-
from
|
4 |
-
|
5 |
-
from modules.config.constants import *
|
6 |
-
import os
|
7 |
|
8 |
|
9 |
class EmbeddingModelLoader:
|
@@ -28,8 +25,5 @@ class EmbeddingModelLoader:
|
|
28 |
"trust_remote_code": True,
|
29 |
},
|
30 |
)
|
31 |
-
# embedding_model = LlamaCppEmbeddings(
|
32 |
-
# model_path=os.path.abspath("storage/llama-7b.ggmlv3.q4_0.bin")
|
33 |
-
# )
|
34 |
|
35 |
return embedding_model
|
|
|
1 |
from langchain_community.embeddings import OpenAIEmbeddings
|
2 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
3 |
+
from modules.config.constants import OPENAI_API_KEY, HUGGINGFACE_TOKEN
|
|
|
|
|
|
|
4 |
|
5 |
|
6 |
class EmbeddingModelLoader:
|
|
|
25 |
"trust_remote_code": True,
|
26 |
},
|
27 |
)
|
|
|
|
|
|
|
28 |
|
29 |
return embedding_model
|
code/modules/vectorstore/faiss.py
CHANGED
@@ -14,10 +14,15 @@ class FaissVectorStore(VectorStoreBase):
|
|
14 |
def __init__(self, config):
|
15 |
self.config = config
|
16 |
self._init_vector_db()
|
17 |
-
self.local_path = os.path.join(
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
def _init_vector_db(self):
|
23 |
self.faiss = FAISS(
|
@@ -28,9 +33,7 @@ class FaissVectorStore(VectorStoreBase):
|
|
28 |
self.vectorstore = self.faiss.from_documents(
|
29 |
documents=document_chunks, embedding=embedding_model
|
30 |
)
|
31 |
-
self.vectorstore.save_local(
|
32 |
-
self.local_path
|
33 |
-
)
|
34 |
|
35 |
def load_database(self, embedding_model):
|
36 |
self.vectorstore = self.faiss.load_local(
|
|
|
14 |
def __init__(self, config):
|
15 |
self.config = config
|
16 |
self._init_vector_db()
|
17 |
+
self.local_path = os.path.join(
|
18 |
+
self.config["vectorstore"]["db_path"],
|
19 |
+
"db_"
|
20 |
+
+ self.config["vectorstore"]["db_option"]
|
21 |
+
+ "_"
|
22 |
+
+ self.config["vectorstore"]["model"]
|
23 |
+
+ "_"
|
24 |
+
+ config["splitter_options"]["chunking_mode"],
|
25 |
+
)
|
26 |
|
27 |
def _init_vector_db(self):
|
28 |
self.faiss = FAISS(
|
|
|
33 |
self.vectorstore = self.faiss.from_documents(
|
34 |
documents=document_chunks, embedding=embedding_model
|
35 |
)
|
36 |
+
self.vectorstore.save_local(self.local_path)
|
|
|
|
|
37 |
|
38 |
def load_database(self, embedding_model):
|
39 |
self.vectorstore = self.faiss.load_local(
|
code/modules/vectorstore/raptor.py
CHANGED
@@ -317,13 +317,10 @@ class RAPTORVectoreStore(VectorStoreBase):
|
|
317 |
print(f"--Generated {len(all_clusters)} clusters--")
|
318 |
|
319 |
# Summarization
|
320 |
-
template = """Here is content from the course DS598: Deep Learning for Data Science.
|
321 |
-
|
322 |
The content may be form webapge about the course, or lecture content, or any other relevant information.
|
323 |
If the content is in bullet points (from pdf lectre slides), you can summarize the bullet points.
|
324 |
-
|
325 |
Give a detailed summary of the content below.
|
326 |
-
|
327 |
Documentation:
|
328 |
{context}
|
329 |
"""
|
|
|
317 |
print(f"--Generated {len(all_clusters)} clusters--")
|
318 |
|
319 |
# Summarization
|
320 |
+
template = """Here is content from the course DS598: Deep Learning for Data Science.
|
|
|
321 |
The content may be form webapge about the course, or lecture content, or any other relevant information.
|
322 |
If the content is in bullet points (from pdf lectre slides), you can summarize the bullet points.
|
|
|
323 |
Give a detailed summary of the content below.
|
|
|
324 |
Documentation:
|
325 |
{context}
|
326 |
"""
|
code/modules/vectorstore/store_manager.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1 |
from modules.vectorstore.vectorstore import VectorStore
|
2 |
-
from modules.
|
3 |
from modules.dataloader.webpage_crawler import WebpageCrawler
|
4 |
from modules.dataloader.data_loader import DataLoader
|
5 |
-
from modules.dataloader.helpers import *
|
6 |
-
from modules.config.constants import RETRIEVER_HF_PATHS
|
7 |
from modules.vectorstore.embedding_model_loader import EmbeddingModelLoader
|
8 |
import logging
|
9 |
import os
|
@@ -49,7 +47,6 @@ class VectorStoreManager:
|
|
49 |
return logger
|
50 |
|
51 |
def load_files(self):
|
52 |
-
|
53 |
files = os.listdir(self.config["vectorstore"]["data_path"])
|
54 |
files = [
|
55 |
os.path.join(self.config["vectorstore"]["data_path"], file)
|
@@ -71,7 +68,6 @@ class VectorStoreManager:
|
|
71 |
return files, urls
|
72 |
|
73 |
def create_embedding_model(self):
|
74 |
-
|
75 |
self.logger.info("Creating embedding function")
|
76 |
embedding_model_loader = EmbeddingModelLoader(self.config)
|
77 |
embedding_model = embedding_model_loader.load_embedding_model()
|
@@ -102,7 +98,6 @@ class VectorStoreManager:
|
|
102 |
)
|
103 |
|
104 |
def create_database(self):
|
105 |
-
|
106 |
start_time = time.time() # Start time for creating database
|
107 |
data_loader = DataLoader(self.config, self.logger)
|
108 |
self.logger.info("Loading data")
|
@@ -112,12 +107,15 @@ class VectorStoreManager:
|
|
112 |
self.logger.info(f"Number of webpages: {len(webpages)}")
|
113 |
if f"{self.config['vectorstore']['url_file_path']}" in files:
|
114 |
files.remove(f"{self.config['vectorstores']['url_file_path']}") # cleanup
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
118 |
num_documents = len(document_chunks)
|
119 |
self.logger.info(f"Number of documents in the DB: {num_documents}")
|
120 |
-
metadata_keys = list(document_metadata[0].keys())
|
121 |
self.logger.info(f"Metadata keys: {metadata_keys}")
|
122 |
self.logger.info("Completed loading data")
|
123 |
self.initialize_database(
|
@@ -130,7 +128,6 @@ class VectorStoreManager:
|
|
130 |
)
|
131 |
|
132 |
def load_database(self):
|
133 |
-
|
134 |
start_time = time.time() # Start time for loading database
|
135 |
if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma", "RAPTOR"]:
|
136 |
self.embedding_model = self.create_embedding_model()
|
@@ -170,13 +167,23 @@ if __name__ == "__main__":
|
|
170 |
|
171 |
with open("modules/config/config.yml", "r") as f:
|
172 |
config = yaml.safe_load(f)
|
|
|
|
|
|
|
|
|
|
|
173 |
print(config)
|
174 |
print(f"Trying to create database with config: {config}")
|
175 |
vector_db = VectorStoreManager(config)
|
176 |
if config["vectorstore"]["load_from_HF"]:
|
177 |
-
if
|
|
|
|
|
|
|
178 |
vector_db.load_from_HF(
|
179 |
-
HF_PATH=
|
|
|
|
|
180 |
)
|
181 |
else:
|
182 |
# print(f"HF_PATH not available for {config['vectorstore']['db_option']}")
|
@@ -189,7 +196,7 @@ if __name__ == "__main__":
|
|
189 |
vector_db.create_database()
|
190 |
print("Created database")
|
191 |
|
192 |
-
print(
|
193 |
vector_db = VectorStoreManager(config)
|
194 |
vector_db.load_database()
|
195 |
print("Loaded database")
|
|
|
1 |
from modules.vectorstore.vectorstore import VectorStore
|
2 |
+
from modules.dataloader.helpers import get_urls_from_file
|
3 |
from modules.dataloader.webpage_crawler import WebpageCrawler
|
4 |
from modules.dataloader.data_loader import DataLoader
|
|
|
|
|
5 |
from modules.vectorstore.embedding_model_loader import EmbeddingModelLoader
|
6 |
import logging
|
7 |
import os
|
|
|
47 |
return logger
|
48 |
|
49 |
def load_files(self):
|
|
|
50 |
files = os.listdir(self.config["vectorstore"]["data_path"])
|
51 |
files = [
|
52 |
os.path.join(self.config["vectorstore"]["data_path"], file)
|
|
|
68 |
return files, urls
|
69 |
|
70 |
def create_embedding_model(self):
|
|
|
71 |
self.logger.info("Creating embedding function")
|
72 |
embedding_model_loader = EmbeddingModelLoader(self.config)
|
73 |
embedding_model = embedding_model_loader.load_embedding_model()
|
|
|
98 |
)
|
99 |
|
100 |
def create_database(self):
|
|
|
101 |
start_time = time.time() # Start time for creating database
|
102 |
data_loader = DataLoader(self.config, self.logger)
|
103 |
self.logger.info("Loading data")
|
|
|
107 |
self.logger.info(f"Number of webpages: {len(webpages)}")
|
108 |
if f"{self.config['vectorstore']['url_file_path']}" in files:
|
109 |
files.remove(f"{self.config['vectorstores']['url_file_path']}") # cleanup
|
110 |
+
(
|
111 |
+
document_chunks,
|
112 |
+
document_names,
|
113 |
+
documents,
|
114 |
+
document_metadata,
|
115 |
+
) = data_loader.get_chunks(files, webpages)
|
116 |
num_documents = len(document_chunks)
|
117 |
self.logger.info(f"Number of documents in the DB: {num_documents}")
|
118 |
+
metadata_keys = list(document_metadata[0].keys()) if document_metadata else []
|
119 |
self.logger.info(f"Metadata keys: {metadata_keys}")
|
120 |
self.logger.info("Completed loading data")
|
121 |
self.initialize_database(
|
|
|
128 |
)
|
129 |
|
130 |
def load_database(self):
|
|
|
131 |
start_time = time.time() # Start time for loading database
|
132 |
if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma", "RAPTOR"]:
|
133 |
self.embedding_model = self.create_embedding_model()
|
|
|
167 |
|
168 |
with open("modules/config/config.yml", "r") as f:
|
169 |
config = yaml.safe_load(f)
|
170 |
+
with open("modules/config/project_config.yml", "r") as f:
|
171 |
+
project_config = yaml.safe_load(f)
|
172 |
+
|
173 |
+
# combine the two configs
|
174 |
+
config.update(project_config)
|
175 |
print(config)
|
176 |
print(f"Trying to create database with config: {config}")
|
177 |
vector_db = VectorStoreManager(config)
|
178 |
if config["vectorstore"]["load_from_HF"]:
|
179 |
+
if (
|
180 |
+
config["vectorstore"]["db_option"]
|
181 |
+
in config["retriever"]["retriever_hf_paths"]
|
182 |
+
):
|
183 |
vector_db.load_from_HF(
|
184 |
+
HF_PATH=config["retriever"]["retriever_hf_paths"][
|
185 |
+
config["vectorstore"]["db_option"]
|
186 |
+
]
|
187 |
)
|
188 |
else:
|
189 |
# print(f"HF_PATH not available for {config['vectorstore']['db_option']}")
|
|
|
196 |
vector_db.create_database()
|
197 |
print("Created database")
|
198 |
|
199 |
+
print("Trying to load the database")
|
200 |
vector_db = VectorStoreManager(config)
|
201 |
vector_db.load_database()
|
202 |
print("Loaded database")
|
code/public/avatars/{ai-tutor.png → ai_tutor.png}
RENAMED
File without changes
|
code/public/space.jpg
ADDED
![]() |
Git LFS Details
|
code/public/test.css
CHANGED
@@ -13,10 +13,6 @@ a[href*='https://github.com/Chainlit/chainlit'] {
|
|
13 |
border-radius: 50%; /* Maintain circular shape */
|
14 |
}
|
15 |
|
16 |
-
/* Hide the default image */
|
17 |
-
.MuiAvatar-root.MuiAvatar-circular.css-m2icte .MuiAvatar-img.css-1hy9t21 {
|
18 |
-
display: none;
|
19 |
-
}
|
20 |
|
21 |
.MuiAvatar-root.MuiAvatar-circular.css-v72an7 {
|
22 |
background-image: url('/public/avatars/ai-tutor.png'); /* Replace with your custom image URL */
|
@@ -26,18 +22,3 @@ a[href*='https://github.com/Chainlit/chainlit'] {
|
|
26 |
height: 40px; /* Ensure the dimensions match the original */
|
27 |
border-radius: 50%; /* Maintain circular shape */
|
28 |
}
|
29 |
-
|
30 |
-
/* Hide the default image */
|
31 |
-
.MuiAvatar-root.MuiAvatar-circular.css-v72an7 .MuiAvatar-img.css-1hy9t21 {
|
32 |
-
display: none;
|
33 |
-
}
|
34 |
-
|
35 |
-
/* Hide the new chat button
|
36 |
-
#new-chat-button {
|
37 |
-
display: none;
|
38 |
-
} */
|
39 |
-
|
40 |
-
/* Hide the open sidebar button
|
41 |
-
#open-sidebar-button {
|
42 |
-
display: none;
|
43 |
-
} */
|
|
|
13 |
border-radius: 50%; /* Maintain circular shape */
|
14 |
}
|
15 |
|
|
|
|
|
|
|
|
|
16 |
|
17 |
.MuiAvatar-root.MuiAvatar-circular.css-v72an7 {
|
18 |
background-image: url('/public/avatars/ai-tutor.png'); /* Replace with your custom image URL */
|
|
|
22 |
height: 40px; /* Ensure the dimensions match the original */
|
23 |
border-radius: 50%; /* Maintain circular shape */
|
24 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/templates/cooldown.html
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Cooldown Period | Terrier Tutor</title>
|
7 |
+
<style>
|
8 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
|
9 |
+
|
10 |
+
body, html {
|
11 |
+
margin: 0;
|
12 |
+
padding: 0;
|
13 |
+
font-family: 'Inter', sans-serif;
|
14 |
+
background-color: #f7f7f7;
|
15 |
+
background-image: url('https://www.transparenttextures.com/patterns/cubes.png');
|
16 |
+
background-repeat: repeat;
|
17 |
+
display: flex;
|
18 |
+
align-items: center;
|
19 |
+
justify-content: center;
|
20 |
+
height: 100vh;
|
21 |
+
color: #333;
|
22 |
+
}
|
23 |
+
|
24 |
+
.container {
|
25 |
+
background: rgba(255, 255, 255, 0.9);
|
26 |
+
border: 1px solid #ddd;
|
27 |
+
border-radius: 8px;
|
28 |
+
width: 100%;
|
29 |
+
max-width: 400px;
|
30 |
+
padding: 50px;
|
31 |
+
box-sizing: border-box;
|
32 |
+
text-align: center;
|
33 |
+
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
|
34 |
+
backdrop-filter: blur(10px);
|
35 |
+
-webkit-backdrop-filter: blur(10px);
|
36 |
+
}
|
37 |
+
|
38 |
+
.avatar {
|
39 |
+
width: 90px;
|
40 |
+
height: 90px;
|
41 |
+
border-radius: 50%;
|
42 |
+
margin-bottom: 25px;
|
43 |
+
border: 2px solid #ddd;
|
44 |
+
}
|
45 |
+
|
46 |
+
.container h1 {
|
47 |
+
margin-bottom: 15px;
|
48 |
+
font-size: 24px;
|
49 |
+
font-weight: 600;
|
50 |
+
color: #1a1a1a;
|
51 |
+
}
|
52 |
+
|
53 |
+
.container p {
|
54 |
+
font-size: 16px;
|
55 |
+
color: #4a4a4a;
|
56 |
+
margin-bottom: 30px;
|
57 |
+
line-height: 1.5;
|
58 |
+
}
|
59 |
+
|
60 |
+
.cooldown-message {
|
61 |
+
font-size: 16px;
|
62 |
+
color: #333;
|
63 |
+
margin-bottom: 30px;
|
64 |
+
}
|
65 |
+
|
66 |
+
.tokens-left {
|
67 |
+
font-size: 14px;
|
68 |
+
color: #333;
|
69 |
+
margin-bottom: 30px;
|
70 |
+
font-weight: 600;
|
71 |
+
}
|
72 |
+
|
73 |
+
.button {
|
74 |
+
padding: 12px 0;
|
75 |
+
margin: 12px 0;
|
76 |
+
font-size: 14px;
|
77 |
+
border-radius: 6px;
|
78 |
+
cursor: pointer;
|
79 |
+
width: 100%;
|
80 |
+
border: 1px solid #4285F4;
|
81 |
+
background-color: #fff;
|
82 |
+
color: #4285F4;
|
83 |
+
transition: background-color 0.3s ease, border-color 0.3s ease;
|
84 |
+
display: none;
|
85 |
+
}
|
86 |
+
|
87 |
+
.button.start-tutor {
|
88 |
+
display: none;
|
89 |
+
}
|
90 |
+
|
91 |
+
.button:hover {
|
92 |
+
background-color: #e0e0e0;
|
93 |
+
border-color: #357ae8;
|
94 |
+
}
|
95 |
+
|
96 |
+
.sign-out-button {
|
97 |
+
border: 1px solid #FF4C4C;
|
98 |
+
background-color: #fff;
|
99 |
+
color: #FF4C4C;
|
100 |
+
display: block;
|
101 |
+
}
|
102 |
+
|
103 |
+
.sign-out-button:hover {
|
104 |
+
background-color: #ffe6e6;
|
105 |
+
border-color: #e04343;
|
106 |
+
color: #e04343;
|
107 |
+
}
|
108 |
+
|
109 |
+
#countdown {
|
110 |
+
font-size: 14px;
|
111 |
+
color: #555;
|
112 |
+
margin-bottom: 20px;
|
113 |
+
}
|
114 |
+
|
115 |
+
.footer {
|
116 |
+
font-size: 12px;
|
117 |
+
color: #777;
|
118 |
+
margin-top: 20px;
|
119 |
+
}
|
120 |
+
</style>
|
121 |
+
</head>
|
122 |
+
<body>
|
123 |
+
<div class="container">
|
124 |
+
<img src="/public/avatars/ai_tutor.png" alt="AI Tutor Avatar" class="avatar">
|
125 |
+
<h1>Hello, {{ username }}</h1>
|
126 |
+
<p>It seems like you need to wait a bit before starting a new session.</p>
|
127 |
+
<p class="cooldown-message">Time remaining until the cooldown period ends:</p>
|
128 |
+
<p id="countdown"></p>
|
129 |
+
<p class="tokens-left">Tokens Left: <span id="tokensLeft">{{ tokens_left }}</span></p>
|
130 |
+
<button id="startTutorBtn" class="button start-tutor" onclick="startTutor()">Start AI Tutor</button>
|
131 |
+
<form action="/logout" method="get">
|
132 |
+
<button type="submit" class="button sign-out-button">Sign Out</button>
|
133 |
+
</form>
|
134 |
+
<div class="footer">Reload the page to update token stats</div>
|
135 |
+
</div>
|
136 |
+
<script>
|
137 |
+
function startCountdown(endTime) {
|
138 |
+
const countdownElement = document.getElementById('countdown');
|
139 |
+
const startTutorBtn = document.getElementById('startTutorBtn');
|
140 |
+
const endTimeDate = new Date(endTime);
|
141 |
+
|
142 |
+
function updateCountdown() {
|
143 |
+
const now = new Date();
|
144 |
+
const timeLeft = endTimeDate.getTime() - now.getTime();
|
145 |
+
|
146 |
+
if (timeLeft <= 0) {
|
147 |
+
countdownElement.textContent = "Cooldown period has ended.";
|
148 |
+
startTutorBtn.style.display = "block";
|
149 |
+
} else {
|
150 |
+
const hours = Math.floor(timeLeft / 1000 / 60 / 60);
|
151 |
+
const minutes = Math.floor((timeLeft / 1000 / 60) % 60);
|
152 |
+
const seconds = Math.floor((timeLeft / 1000) % 60);
|
153 |
+
countdownElement.textContent = `${hours}h ${minutes}m ${seconds}s`;
|
154 |
+
}
|
155 |
+
}
|
156 |
+
|
157 |
+
updateCountdown();
|
158 |
+
setInterval(updateCountdown, 1000);
|
159 |
+
}
|
160 |
+
|
161 |
+
function startTutor() {
|
162 |
+
window.location.href = "/start-tutor";
|
163 |
+
}
|
164 |
+
|
165 |
+
function updateTokensLeft() {
|
166 |
+
fetch('/get-tokens-left')
|
167 |
+
.then(response => response.json())
|
168 |
+
.then(data => {
|
169 |
+
document.getElementById('tokensLeft').textContent = data.tokens_left;
|
170 |
+
})
|
171 |
+
.catch(error => console.error('Error fetching tokens:', error));
|
172 |
+
}
|
173 |
+
|
174 |
+
// Start the countdown
|
175 |
+
startCountdown("{{ cooldown_end_time }}");
|
176 |
+
|
177 |
+
// Update tokens left when the page loads
|
178 |
+
updateTokensLeft();
|
179 |
+
</script>
|
180 |
+
</body>
|
181 |
+
</html>
|
code/templates/dashboard.html
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Dashboard | Terrier Tutor</title>
|
7 |
+
<style>
|
8 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
|
9 |
+
|
10 |
+
body, html {
|
11 |
+
margin: 0;
|
12 |
+
padding: 0;
|
13 |
+
font-family: 'Inter', sans-serif;
|
14 |
+
background-color: #f7f7f7; /* Light gray background */
|
15 |
+
background-image: url('https://www.transparenttextures.com/patterns/cubes.png'); /* Subtle geometric pattern */
|
16 |
+
background-repeat: repeat;
|
17 |
+
display: flex;
|
18 |
+
align-items: center;
|
19 |
+
justify-content: center;
|
20 |
+
height: 100vh;
|
21 |
+
color: #333;
|
22 |
+
}
|
23 |
+
|
24 |
+
.container {
|
25 |
+
background: rgba(255, 255, 255, 0.9);
|
26 |
+
border: 1px solid #ddd;
|
27 |
+
border-radius: 8px;
|
28 |
+
width: 100%;
|
29 |
+
max-width: 400px;
|
30 |
+
padding: 40px;
|
31 |
+
box-sizing: border-box;
|
32 |
+
text-align: center;
|
33 |
+
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
|
34 |
+
backdrop-filter: blur(10px);
|
35 |
+
-webkit-backdrop-filter: blur(10px);
|
36 |
+
}
|
37 |
+
|
38 |
+
.avatar {
|
39 |
+
width: 90px;
|
40 |
+
height: 90px;
|
41 |
+
border-radius: 50%;
|
42 |
+
margin-bottom: 20px;
|
43 |
+
border: 2px solid #ddd;
|
44 |
+
}
|
45 |
+
|
46 |
+
.container h1 {
|
47 |
+
margin-bottom: 20px;
|
48 |
+
font-size: 26px;
|
49 |
+
font-weight: 600;
|
50 |
+
color: #1a1a1a;
|
51 |
+
}
|
52 |
+
|
53 |
+
.container p {
|
54 |
+
font-size: 15px;
|
55 |
+
color: #4a4a4a;
|
56 |
+
margin-bottom: 25px;
|
57 |
+
line-height: 1.5;
|
58 |
+
}
|
59 |
+
|
60 |
+
.tokens-left {
|
61 |
+
font-size: 17px;
|
62 |
+
color: #333;
|
63 |
+
margin-bottom: 10px;
|
64 |
+
font-weight: 600;
|
65 |
+
}
|
66 |
+
|
67 |
+
.all-time-tokens {
|
68 |
+
font-size: 14px; /* Reduced font size */
|
69 |
+
color: #555;
|
70 |
+
margin-bottom: 30px;
|
71 |
+
font-weight: 500;
|
72 |
+
white-space: nowrap; /* Prevents breaking to a new line */
|
73 |
+
}
|
74 |
+
|
75 |
+
.button {
|
76 |
+
padding: 12px 0;
|
77 |
+
margin: 12px 0;
|
78 |
+
font-size: 15px;
|
79 |
+
border-radius: 6px;
|
80 |
+
cursor: pointer;
|
81 |
+
width: 100%;
|
82 |
+
border: 1px solid #4285F4; /* Button border color */
|
83 |
+
background-color: #fff; /* Button background color */
|
84 |
+
color: #4285F4; /* Button text color */
|
85 |
+
transition: background-color 0.3s ease, border-color 0.3s ease;
|
86 |
+
}
|
87 |
+
|
88 |
+
.button:hover {
|
89 |
+
background-color: #e0e0e0;
|
90 |
+
border-color: #357ae8; /* Darker blue for hover */
|
91 |
+
}
|
92 |
+
|
93 |
+
.start-button {
|
94 |
+
border: 1px solid #4285F4;
|
95 |
+
color: #4285F4;
|
96 |
+
background-color: #fff;
|
97 |
+
}
|
98 |
+
|
99 |
+
.start-button:hover {
|
100 |
+
background-color: #e0f0ff; /* Light blue on hover */
|
101 |
+
border-color: #357ae8; /* Darker blue for hover */
|
102 |
+
color: #357ae8; /* Blue text on hover */
|
103 |
+
}
|
104 |
+
|
105 |
+
.sign-out-button {
|
106 |
+
border: 1px solid #FF4C4C;
|
107 |
+
background-color: #fff;
|
108 |
+
color: #FF4C4C;
|
109 |
+
}
|
110 |
+
|
111 |
+
.sign-out-button:hover {
|
112 |
+
background-color: #ffe6e6; /* Light red on hover */
|
113 |
+
border-color: #e04343; /* Darker red for hover */
|
114 |
+
color: #e04343; /* Red text on hover */
|
115 |
+
}
|
116 |
+
|
117 |
+
.footer {
|
118 |
+
font-size: 12px;
|
119 |
+
color: #777;
|
120 |
+
margin-top: 25px;
|
121 |
+
}
|
122 |
+
</style>
|
123 |
+
</head>
|
124 |
+
<body>
|
125 |
+
<div class="container">
|
126 |
+
<img src="/public/avatars/ai_tutor.png" alt="AI Tutor Avatar" class="avatar">
|
127 |
+
<h1>Welcome, {{ username }}</h1>
|
128 |
+
<p>Ready to start your AI tutoring session?</p>
|
129 |
+
<p class="tokens-left">Tokens Left: {{ tokens_left }}</p>
|
130 |
+
<p class="all-time-tokens">All-Time Tokens Allocated: {{ all_time_tokens_allocated }} / {{ total_tokens_allocated }}</p>
|
131 |
+
<form action="/start-tutor" method="post">
|
132 |
+
<button type="submit" class="button start-button">Start AI Tutor</button>
|
133 |
+
</form>
|
134 |
+
<form action="/logout" method="get">
|
135 |
+
<button type="submit" class="button sign-out-button">Sign Out</button>
|
136 |
+
</form>
|
137 |
+
<div class="footer">Reload the page to update token stats</div>
|
138 |
+
</div>
|
139 |
+
<script>
|
140 |
+
let token = "{{ jwt_token }}";
|
141 |
+
console.log("Token: ", token);
|
142 |
+
localStorage.setItem('token', token);
|
143 |
+
</script>
|
144 |
+
</body>
|
145 |
+
</html>
|
code/templates/error.html
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Error | Terrier Tutor</title>
|
7 |
+
<style>
|
8 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
|
9 |
+
|
10 |
+
body, html {
|
11 |
+
margin: 0;
|
12 |
+
padding: 0;
|
13 |
+
font-family: 'Inter', sans-serif;
|
14 |
+
background-color: #f7f7f7; /* Light gray background */
|
15 |
+
background-image: url('https://www.transparenttextures.com/patterns/cubes.png'); /* Subtle geometric pattern */
|
16 |
+
background-repeat: repeat;
|
17 |
+
display: flex;
|
18 |
+
align-items: center;
|
19 |
+
justify-content: center;
|
20 |
+
height: 100vh;
|
21 |
+
color: #333;
|
22 |
+
}
|
23 |
+
|
24 |
+
.container {
|
25 |
+
background: rgba(255, 255, 255, 0.9);
|
26 |
+
border: 1px solid #ddd;
|
27 |
+
border-radius: 8px;
|
28 |
+
width: 100%;
|
29 |
+
max-width: 400px;
|
30 |
+
padding: 50px;
|
31 |
+
box-sizing: border-box;
|
32 |
+
text-align: center;
|
33 |
+
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
|
34 |
+
backdrop-filter: blur(10px);
|
35 |
+
-webkit-backdrop-filter: blur(10px);
|
36 |
+
}
|
37 |
+
|
38 |
+
.container h1 {
|
39 |
+
margin-bottom: 20px;
|
40 |
+
font-size: 26px;
|
41 |
+
font-weight: 600;
|
42 |
+
color: #1a1a1a;
|
43 |
+
}
|
44 |
+
|
45 |
+
.container p {
|
46 |
+
font-size: 18px;
|
47 |
+
color: #4a4a4a;
|
48 |
+
margin-bottom: 35px;
|
49 |
+
line-height: 1.5;
|
50 |
+
}
|
51 |
+
|
52 |
+
.button {
|
53 |
+
padding: 14px 0;
|
54 |
+
margin: 12px 0;
|
55 |
+
font-size: 16px;
|
56 |
+
border-radius: 6px;
|
57 |
+
cursor: pointer;
|
58 |
+
width: 100%;
|
59 |
+
border: 1px solid #ccc;
|
60 |
+
background-color: #007BFF;
|
61 |
+
color: #fff;
|
62 |
+
transition: background-color 0.3s ease, border-color 0.3s ease;
|
63 |
+
}
|
64 |
+
|
65 |
+
.button:hover {
|
66 |
+
background-color: #0056b3;
|
67 |
+
border-color: #0056b3;
|
68 |
+
}
|
69 |
+
|
70 |
+
.error-box {
|
71 |
+
background-color: #2d2d2d;
|
72 |
+
color: #fff;
|
73 |
+
padding: 10px;
|
74 |
+
margin-top: 20px;
|
75 |
+
font-family: 'Courier New', Courier, monospace;
|
76 |
+
text-align: left;
|
77 |
+
overflow-x: auto;
|
78 |
+
white-space: pre-wrap;
|
79 |
+
border-radius: 5px;
|
80 |
+
}
|
81 |
+
</style>
|
82 |
+
</head>
|
83 |
+
<body>
|
84 |
+
<div class="container">
|
85 |
+
<h1>Oops! Something went wrong...</h1>
|
86 |
+
<p>An unexpected error occurred. The details are below:</p>
|
87 |
+
<div class="error-box">
|
88 |
+
<code>{{ error }}</code>
|
89 |
+
</div>
|
90 |
+
<form action="/" method="get">
|
91 |
+
<button type="submit" class="button">Return to Home</button>
|
92 |
+
</form>
|
93 |
+
</div>
|
94 |
+
</body>
|
95 |
+
</html>
|
code/templates/error_404.html
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>404 - Not Found</title>
|
7 |
+
<style>
|
8 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
|
9 |
+
|
10 |
+
body, html {
|
11 |
+
margin: 0;
|
12 |
+
padding: 0;
|
13 |
+
font-family: 'Inter', sans-serif;
|
14 |
+
background-color: #f7f7f7; /* Light gray background */
|
15 |
+
background-image: url('https://www.transparenttextures.com/patterns/cubes.png'); /* Subtle geometric pattern */
|
16 |
+
background-repeat: repeat;
|
17 |
+
display: flex;
|
18 |
+
align-items: center;
|
19 |
+
justify-content: center;
|
20 |
+
height: 100vh;
|
21 |
+
color: #333;
|
22 |
+
}
|
23 |
+
|
24 |
+
.container {
|
25 |
+
background: rgba(255, 255, 255, 0.9);
|
26 |
+
border: 1px solid #ddd;
|
27 |
+
border-radius: 8px;
|
28 |
+
width: 100%;
|
29 |
+
max-width: 400px;
|
30 |
+
padding: 50px;
|
31 |
+
box-sizing: border-box;
|
32 |
+
text-align: center;
|
33 |
+
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
|
34 |
+
backdrop-filter: blur(10px);
|
35 |
+
-webkit-backdrop-filter: blur(10px);
|
36 |
+
}
|
37 |
+
|
38 |
+
.container h1 {
|
39 |
+
margin-bottom: 20px;
|
40 |
+
font-size: 26px;
|
41 |
+
font-weight: 600;
|
42 |
+
color: #1a1a1a;
|
43 |
+
}
|
44 |
+
|
45 |
+
.container p {
|
46 |
+
font-size: 18px;
|
47 |
+
color: #4a4a4a;
|
48 |
+
margin-bottom: 35px;
|
49 |
+
line-height: 1.5;
|
50 |
+
}
|
51 |
+
|
52 |
+
.button {
|
53 |
+
padding: 14px 0;
|
54 |
+
margin: 12px 0;
|
55 |
+
font-size: 16px;
|
56 |
+
border-radius: 6px;
|
57 |
+
cursor: pointer;
|
58 |
+
width: 100%;
|
59 |
+
border: 1px solid #ccc;
|
60 |
+
background-color: #007BFF;
|
61 |
+
color: #fff;
|
62 |
+
transition: background-color 0.3s ease, border-color 0.3s ease;
|
63 |
+
}
|
64 |
+
|
65 |
+
.button:hover {
|
66 |
+
background-color: #0056b3;
|
67 |
+
border-color: #0056b3;
|
68 |
+
}
|
69 |
+
</style>
|
70 |
+
</head>
|
71 |
+
<body>
|
72 |
+
<div class="container">
|
73 |
+
<h1>You have ventured into the abyss...</h1>
|
74 |
+
<p>To get back to reality, click the button below.</p>
|
75 |
+
<form action="/" method="get">
|
76 |
+
<button type="submit" class="button">Return to Home</button>
|
77 |
+
</form>
|
78 |
+
</div>
|
79 |
+
</body>
|
80 |
+
</html>
|
code/templates/login.html
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Login | Terrier Tutor</title>
|
7 |
+
<style>
|
8 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
|
9 |
+
|
10 |
+
body, html {
|
11 |
+
margin: 0;
|
12 |
+
padding: 0;
|
13 |
+
font-family: 'Inter', sans-serif;
|
14 |
+
background-color: #f7f7f7; /* Light gray background */
|
15 |
+
background-image: url('https://www.transparenttextures.com/patterns/cubes.png'); /* Subtle geometric pattern */
|
16 |
+
background-repeat: repeat;
|
17 |
+
display: flex;
|
18 |
+
align-items: center;
|
19 |
+
justify-content: center;
|
20 |
+
height: 100vh;
|
21 |
+
color: #333;
|
22 |
+
}
|
23 |
+
|
24 |
+
.container {
|
25 |
+
background: rgba(255, 255, 255, 0.9);
|
26 |
+
border: 1px solid #ddd;
|
27 |
+
border-radius: 8px;
|
28 |
+
width: 100%;
|
29 |
+
max-width: 400px;
|
30 |
+
padding: 50px;
|
31 |
+
box-sizing: border-box;
|
32 |
+
text-align: center;
|
33 |
+
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
|
34 |
+
backdrop-filter: blur(10px);
|
35 |
+
-webkit-backdrop-filter: blur(10px);
|
36 |
+
}
|
37 |
+
|
38 |
+
.avatar {
|
39 |
+
width: 90px;
|
40 |
+
height: 90px;
|
41 |
+
border-radius: 50%;
|
42 |
+
margin-bottom: 25px;
|
43 |
+
border: 2px solid #ddd;
|
44 |
+
}
|
45 |
+
|
46 |
+
.container h1 {
|
47 |
+
margin-bottom: 15px;
|
48 |
+
font-size: 24px;
|
49 |
+
font-weight: 600;
|
50 |
+
color: #1a1a1a;
|
51 |
+
}
|
52 |
+
|
53 |
+
.container p {
|
54 |
+
font-size: 16px;
|
55 |
+
color: #4a4a4a;
|
56 |
+
margin-bottom: 30px;
|
57 |
+
line-height: 1.5;
|
58 |
+
}
|
59 |
+
|
60 |
+
.button {
|
61 |
+
padding: 12px 0;
|
62 |
+
margin: 12px 0;
|
63 |
+
font-size: 14px;
|
64 |
+
border-radius: 6px;
|
65 |
+
cursor: pointer;
|
66 |
+
width: 100%;
|
67 |
+
border: 1px solid #4285F4; /* Google button border color */
|
68 |
+
background-color: #fff; /* Guest button color */
|
69 |
+
color: #4285F4; /* Google button text color */
|
70 |
+
transition: background-color 0.3s ease, border-color 0.3s ease;
|
71 |
+
}
|
72 |
+
|
73 |
+
.button:hover {
|
74 |
+
background-color: #e0f0ff; /* Light blue on hover */
|
75 |
+
border-color: #357ae8; /* Darker blue for hover */
|
76 |
+
color: #357ae8; /* Blue text on hover */
|
77 |
+
}
|
78 |
+
|
79 |
+
.footer {
|
80 |
+
margin-top: 40px;
|
81 |
+
font-size: 15px;
|
82 |
+
color: #666;
|
83 |
+
text-align: center; /* Center the text in the footer */
|
84 |
+
}
|
85 |
+
|
86 |
+
.footer a {
|
87 |
+
color: #333;
|
88 |
+
text-decoration: none;
|
89 |
+
font-weight: 500;
|
90 |
+
display: inline-flex;
|
91 |
+
align-items: center;
|
92 |
+
justify-content: center; /* Center the content of the links */
|
93 |
+
transition: color 0.3s ease;
|
94 |
+
margin-bottom: 8px;
|
95 |
+
width: 100%; /* Make the link block level */
|
96 |
+
}
|
97 |
+
|
98 |
+
.footer a:hover {
|
99 |
+
color: #000;
|
100 |
+
}
|
101 |
+
|
102 |
+
.footer svg {
|
103 |
+
margin-right: 8px;
|
104 |
+
fill: currentColor;
|
105 |
+
}
|
106 |
+
</style>
|
107 |
+
</head>
|
108 |
+
<body>
|
109 |
+
<div class="container">
|
110 |
+
<img src="/public/avatars/ai_tutor.png" alt="AI Tutor Avatar" class="avatar">
|
111 |
+
<h1>Terrier Tutor</h1>
|
112 |
+
<p>Welcome to the DS598 AI Tutor. Please sign in to continue.</p>
|
113 |
+
<form action="/login/google" method="get">
|
114 |
+
<button type="submit" class="button">Sign in with Google</button>
|
115 |
+
</form>
|
116 |
+
<div class="footer">
|
117 |
+
<a href="{{ GITHUB_REPO }}" target="_blank">
|
118 |
+
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24">
|
119 |
+
<path d="M12 .5C5.596.5.5 5.596.5 12c0 5.098 3.292 9.414 7.852 10.94.574.105.775-.249.775-.553 0-.272-.01-1.008-.015-1.98-3.194.694-3.87-1.544-3.87-1.544-.521-1.324-1.273-1.676-1.273-1.676-1.04-.714.079-.7.079-.7 1.148.08 1.75 1.181 1.75 1.181 1.022 1.752 2.683 1.246 3.34.954.104-.74.4-1.246.73-1.533-2.551-.292-5.234-1.276-5.234-5.675 0-1.253.447-2.277 1.181-3.079-.12-.293-.51-1.47.113-3.063 0 0 .96-.307 3.15 1.174.913-.255 1.892-.383 2.867-.388.975.005 1.954.133 2.868.388 2.188-1.481 3.147-1.174 3.147-1.174.624 1.593.233 2.77.114 3.063.735.802 1.18 1.826 1.18 3.079 0 4.407-2.688 5.38-5.248 5.668.413.354.782 1.049.782 2.113 0 1.526-.014 2.757-.014 3.132 0 .307.198.662.783.553C20.21 21.411 23.5 17.096 23.5 12c0-6.404-5.096-11.5-11.5-11.5z"/>
|
120 |
+
</svg>
|
121 |
+
View on GitHub
|
122 |
+
</a>
|
123 |
+
<a href="{{ DOCS_WEBSITE }}" target="_blank">
|
124 |
+
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24">
|
125 |
+
<path d="M19 2H8c-1.103 0-2 .897-2 2v16c0 1.103.897 2 2 2h12c1.103 0 2-.897 2-2V7l-5-5zm0 2l.001 4H14V4h5zm-1 14H9V4h4v6h6v8zM7 4H6v16c0 1.654 1.346 3 3 3h9v-2H9c-.551 0-1-.449-1-1V4z"/>
|
126 |
+
</svg>
|
127 |
+
View Docs
|
128 |
+
</a>
|
129 |
+
</div>
|
130 |
+
</div>
|
131 |
+
</body>
|
132 |
+
</html>
|
code/templates/logout.html
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html>
|
3 |
+
<head>
|
4 |
+
<title>Logout</title>
|
5 |
+
<script>
|
6 |
+
window.onload = function() {
|
7 |
+
fetch('/chainlit_tutor/logout', {
|
8 |
+
method: 'POST',
|
9 |
+
credentials: 'include' // Ensure cookies are sent
|
10 |
+
}).then(() => {
|
11 |
+
window.location.href = '/';
|
12 |
+
}).catch(error => {
|
13 |
+
console.error('Logout failed:', error);
|
14 |
+
});
|
15 |
+
};
|
16 |
+
</script>
|
17 |
+
</head>
|
18 |
+
<body>
|
19 |
+
<p>Logging out... If you are not redirected, <a href="/">click here</a>.</p>
|
20 |
+
</body>
|
21 |
+
</html>
|
docs/README.md
DELETED
@@ -1,51 +0,0 @@
|
|
1 |
-
# Documentation
|
2 |
-
|
3 |
-
## File Structure:
|
4 |
-
- `docs/` - Documentation files
|
5 |
-
- `code/` - Code files
|
6 |
-
- `storage/` - Storage files
|
7 |
-
- `vectorstores/` - Vector Databases
|
8 |
-
- `.env` - Environment Variables
|
9 |
-
- `Dockerfile` - Dockerfile for Hugging Face
|
10 |
-
- `.chainlit` - Chainlit Configuration
|
11 |
-
- `chainlit.md` - Chainlit README
|
12 |
-
- `README.md` - Repository README
|
13 |
-
- `.gitignore` - Gitignore file
|
14 |
-
- `requirements.txt` - Python Requirements
|
15 |
-
- `.gitattributes` - Gitattributes file
|
16 |
-
|
17 |
-
## Code Structure
|
18 |
-
|
19 |
-
- `code/main.py` - Main Chainlit App
|
20 |
-
- `code/config.yaml` - Configuration File to set Embedding related, Vector Database related, and Chat Model related parameters.
|
21 |
-
- `code/modules/vector_db.py` - Vector Database Creation
|
22 |
-
- `code/modules/chat_model_loader.py` - Chat Model Loader (Creates the Chat Model)
|
23 |
-
- `code/modules/constants.py` - Constants (Loads the Environment Variables, Prompts, Model Paths, etc.)
|
24 |
-
- `code/modules/data_loader.py` - Loads and Chunks the Data
|
25 |
-
- `code/modules/embedding_model.py` - Creates the Embedding Model to Embed the Data
|
26 |
-
- `code/modules/llm_tutor.py` - Creates the RAG LLM Tutor
|
27 |
-
- The Function `qa_bot()` loads the vector database and the chat model, and sets the prompt to pass to the chat model.
|
28 |
-
- `code/modules/helpers.py` - Helper Functions
|
29 |
-
|
30 |
-
## Storage and Vectorstores
|
31 |
-
|
32 |
-
- `storage/data/` - Data Storage (Put your pdf files under this directory, and urls in the urls.txt file)
|
33 |
-
- `storage/models/` - Model Storage (Put your local LLMs under this directory)
|
34 |
-
|
35 |
-
- `vectorstores/` - Vector Databases (Stores the Vector Databases generated from `code/modules/vector_db.py`)
|
36 |
-
|
37 |
-
|
38 |
-
## Useful Configurations
|
39 |
-
set these in `code/config.yaml`:
|
40 |
-
* ``["embedding_options"]["embedd_files"]`` - If set to True, embeds the files from the storage directory everytime you run the chainlit command. If set to False, uses the stored vector database.
|
41 |
-
* ``["embedding_options"]["expand_urls"]`` - If set to True, gets and reads the data from all the links under the url provided. If set to False, only reads the data in the url provided.
|
42 |
-
* ``["embedding_options"]["search_top_k"]`` - Number of sources that the retriever returns
|
43 |
-
* ``["llm_params]["use_history"]`` - Whether to use history in the prompt or not
|
44 |
-
* ``["llm_params]["memory_window"]`` - Number of interactions to keep a track of in the history
|
45 |
-
|
46 |
-
|
47 |
-
## LlamaCpp
|
48 |
-
* https://python.langchain.com/docs/integrations/llms/llamacpp
|
49 |
-
|
50 |
-
## Hugging Face Models
|
51 |
-
* Download the ``.gguf`` files for your Local LLM from Hugging Face (Example: https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/contribute.md
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
💡 **Please ensure formatting, linting, and security checks pass before submitting a pull request**
|
2 |
+
|
3 |
+
## Code Formatting
|
4 |
+
|
5 |
+
The codebase is formatted using [black](https://github.com/psf/black)
|
6 |
+
|
7 |
+
To format the codebase, run the following command:
|
8 |
+
|
9 |
+
```bash
|
10 |
+
black .
|
11 |
+
```
|
12 |
+
|
13 |
+
Please ensure that the code is formatted before submitting a pull request.
|
14 |
+
|
15 |
+
## Linting
|
16 |
+
|
17 |
+
The codebase is linted using [flake8](https://flake8.pycqa.org/en/latest/)
|
18 |
+
|
19 |
+
To view the linting errors, run the following command:
|
20 |
+
|
21 |
+
```bash
|
22 |
+
flake8 .
|
23 |
+
```
|
24 |
+
|
25 |
+
## Security and Vulnerabilities
|
26 |
+
|
27 |
+
The codebase is scanned for security vulnerabilities using [bandit](https://github.com/PyCQA/bandit)
|
28 |
+
|
29 |
+
To scan the codebase for security vulnerabilities, run the following command:
|
30 |
+
|
31 |
+
```bash
|
32 |
+
bandit -r .
|
33 |
+
```
|
docs/setup.md
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Initial Setup
|
2 |
+
|
3 |
+
⚠️ **Create the .env file inside the `code/` directory.**
|
4 |
+
|
5 |
+
## Python Environment
|
6 |
+
|
7 |
+
Python Version: 3.11
|
8 |
+
|
9 |
+
Create a virtual environment and install the required packages:
|
10 |
+
|
11 |
+
```bash
|
12 |
+
conda create -n ai_tutor python=3.11
|
13 |
+
conda activate ai_tutor
|
14 |
+
pip install -r requirements.txt
|
15 |
+
```
|
16 |
+
|
17 |
+
## Code Formatting
|
18 |
+
|
19 |
+
The codebase is formatted using [black](https://github.com/psf/black), and if making changes to the codebase, ensure that the code is formatted before submitting a pull request. More instructions can be found in `docs/contribute.md`.
|
20 |
+
|
21 |
+
## Google OAuth 2.0 Client ID and Secret
|
22 |
+
|
23 |
+
To set up the Google OAuth 2.0 Client ID and Secret, follow these steps:
|
24 |
+
|
25 |
+
1. Go to the [Google Cloud Console](https://console.cloud.google.com/apis/credentials).
|
26 |
+
2. Create a new project or select an existing one.
|
27 |
+
3. Navigate to the "Credentials" page.
|
28 |
+
4. Click on "Create Credentials" and select "OAuth 2.0 Client ID".
|
29 |
+
5. Configure the OAuth consent screen if you haven't already.
|
30 |
+
6. Choose "Web application" as the application type.
|
31 |
+
7. Configure the redirect URIs as needed.
|
32 |
+
8. Copy the generated `Client ID` and `Client Secret`.
|
33 |
+
|
34 |
+
Set the following in the .env file (if running locally) or in secrets (if running on Hugging Face Spaces):
|
35 |
+
|
36 |
+
```bash
|
37 |
+
OAUTH_GOOGLE_CLIENT_ID=<your_client_id>
|
38 |
+
OAUTH_GOOGLE_CLIENT_SECRET=<your_client_secret>
|
39 |
+
```
|
40 |
+
|
41 |
+
## Literal AI API Key
|
42 |
+
|
43 |
+
To obtain the Literal AI API key:
|
44 |
+
|
45 |
+
1. Sign up or log in to [Literal AI](https://cloud.getliteral.ai/).
|
46 |
+
2. Navigate to the API Keys section under your account settings.
|
47 |
+
3. Create a new API key if necessary and copy it.
|
48 |
+
|
49 |
+
Set the following in the .env file (if running locally) or in secrets (if running on Hugging Face Spaces):
|
50 |
+
|
51 |
+
```bash
|
52 |
+
LITERAL_API_KEY_LOGGING=<your_api_key>
|
53 |
+
LITERAL_API_URL=https://cloud.getliteral.ai
|
54 |
+
```
|
55 |
+
|
56 |
+
## LlamaCloud API Key
|
57 |
+
|
58 |
+
To obtain the LlamaCloud API Key:
|
59 |
+
|
60 |
+
1. Go to [LlamaCloud](https://cloud.llamaindex.ai/).
|
61 |
+
2. Sign up or log in to your account.
|
62 |
+
3. Navigate to the API section and generate a new API key if necessary.
|
63 |
+
|
64 |
+
Set the following in the .env file (if running locally) or in secrets (if running on Hugging Face Spaces):
|
65 |
+
|
66 |
+
```bash
|
67 |
+
LLAMA_CLOUD_API_KEY=<your_api_key>
|
68 |
+
```
|
69 |
+
|
70 |
+
## Hugging Face Access Token
|
71 |
+
|
72 |
+
To obtain your Hugging Face access token:
|
73 |
+
|
74 |
+
1. Go to [Hugging Face settings](https://huggingface.co/settings/tokens).
|
75 |
+
2. Log in or create an account.
|
76 |
+
3. Generate a new token or use an existing one.
|
77 |
+
|
78 |
+
Set the following in the .env file (if running locally) or in secrets (if running on Hugging Face Spaces):
|
79 |
+
|
80 |
+
```bash
|
81 |
+
HUGGINGFACE_TOKEN=<your-huggingface-token>
|
82 |
+
```
|
83 |
+
|
84 |
+
## Chainlit Authentication Secret
|
85 |
+
|
86 |
+
You must provide a JWT secret in the environment to use authentication. Run `chainlit create-secret` to generate one.
|
87 |
+
|
88 |
+
```bash
|
89 |
+
chainlit create-secret
|
90 |
+
```
|
91 |
+
|
92 |
+
Set the following in the .env file (if running locally) or in secrets (if running on Hugging Face Spaces):
|
93 |
+
|
94 |
+
```bash
|
95 |
+
CHAINLIT_AUTH_SECRET=<your_jwt_secret>
|
96 |
+
CHAINLIT_URL=<your_chainlit_url> # Example: CHAINLIT_URL=http://localhost:8000
|
97 |
+
```
|
98 |
+
|
99 |
+
## OpenAI API Key
|
100 |
+
|
101 |
+
Set the following in the .env file (if running locally) or in secrets (if running on Hugging Face Spaces):
|
102 |
+
|
103 |
+
```bash
|
104 |
+
OPENAI_API_KEY=<your_openai_api_key>
|
105 |
+
```
|
106 |
+
|
107 |
+
## In a Nutshell
|
108 |
+
|
109 |
+
Your .env file (secrets in HuggingFace) should look like this:
|
110 |
+
|
111 |
+
```bash
|
112 |
+
CHAINLIT_AUTH_SECRET=<your_jwt_secret>
|
113 |
+
OPENAI_API_KEY=<your_openai_api_key>
|
114 |
+
HUGGINGFACE_TOKEN=<your-huggingface-token>
|
115 |
+
LITERAL_API_KEY_LOGGING=<your_api_key>
|
116 |
+
LITERAL_API_URL=<https://cloud.getliteral.ai>
|
117 |
+
OAUTH_GOOGLE_CLIENT_ID=<your_client_id>
|
118 |
+
OAUTH_GOOGLE_CLIENT_SECRET=<your_client_secret>
|
119 |
+
LLAMA_CLOUD_API_KEY=<your_api_key>
|
120 |
+
CHAINLIT_URL=<your_chainlit_url>
|
121 |
+
```
|
122 |
+
|
123 |
+
|
124 |
+
# Configuration
|
125 |
+
|
126 |
+
The configuration file `code/modules/config.yaml` contains the parameters that control the behaviour of your app.
|
127 |
+
The configuration file `code/modules/project_config.yaml` contains project-specific parameters.
|
pyproject.toml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[tool.black]
|
2 |
+
line-length = 88
|
requirements.txt
CHANGED
@@ -22,4 +22,15 @@ umap-learn
|
|
22 |
llama-cpp-python
|
23 |
pymupdf
|
24 |
websockets
|
25 |
-
langchain-openai
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
llama-cpp-python
|
23 |
pymupdf
|
24 |
websockets
|
25 |
+
langchain-openai
|
26 |
+
langchain-experimental
|
27 |
+
html2text
|
28 |
+
PyPDF2
|
29 |
+
pdf2image
|
30 |
+
black
|
31 |
+
flake8
|
32 |
+
bandit
|
33 |
+
fastapi
|
34 |
+
google-auth
|
35 |
+
google-auth-oauthlib
|
36 |
+
Jinja2
|