Spaces:
Sleeping
Sleeping
NCTCMumbai
commited on
Commit
•
e5b3236
1
Parent(s):
0e6440f
Upload 24 files
Browse files- .gitignore +165 -0
- Dockerfile +22 -0
- Manage +0 -0
- README.md +34 -5
- app.py +32 -0
- components/__init__.py +0 -0
- components/chat_box.py +7 -0
- components/chat_loop.py +23 -0
- components/generate_chat_stream.py +104 -0
- components/init_state.py +64 -0
- components/prompt_engineering_dashboard.py +68 -0
- components/show_source.py +8 -0
- components/sidebar.py +11 -0
- components/sidebar_components/__init__.py +0 -0
- components/sidebar_components/model_analytics.py +18 -0
- components/sidebar_components/model_settings.py +25 -0
- components/sidebar_components/retrieval_settings.py +37 -0
- components/stream_handler.py +41 -0
- config.yaml +24 -0
- middlewares/chat_client.py +78 -0
- middlewares/search_client.py +85 -0
- middlewares/utils.py +117 -0
- requirements.txt +97 -0
.gitignore
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Win executables
|
10 |
+
*.exe
|
11 |
+
rag-env/
|
12 |
+
mixtral-playground/
|
13 |
+
|
14 |
+
# Distribution / packaging
|
15 |
+
.Python
|
16 |
+
build/
|
17 |
+
develop-eggs/
|
18 |
+
dist/
|
19 |
+
downloads/
|
20 |
+
eggs/
|
21 |
+
.eggs/
|
22 |
+
lib/
|
23 |
+
lib64/
|
24 |
+
parts/
|
25 |
+
sdist/
|
26 |
+
var/
|
27 |
+
wheels/
|
28 |
+
share/python-wheels/
|
29 |
+
*.egg-info/
|
30 |
+
.installed.cfg
|
31 |
+
*.egg
|
32 |
+
MANIFEST
|
33 |
+
|
34 |
+
# PyInstaller
|
35 |
+
# Usually these files are written by a python script from a template
|
36 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
37 |
+
*.manifest
|
38 |
+
*.spec
|
39 |
+
|
40 |
+
# Installer logs
|
41 |
+
pip-log.txt
|
42 |
+
pip-delete-this-directory.txt
|
43 |
+
|
44 |
+
# Unit test / coverage reports
|
45 |
+
htmlcov/
|
46 |
+
.tox/
|
47 |
+
.nox/
|
48 |
+
.coverage
|
49 |
+
.coverage.*
|
50 |
+
.cache
|
51 |
+
nosetests.xml
|
52 |
+
coverage.xml
|
53 |
+
*.cover
|
54 |
+
*.py,cover
|
55 |
+
.hypothesis/
|
56 |
+
.pytest_cache/
|
57 |
+
cover/
|
58 |
+
|
59 |
+
# Translations
|
60 |
+
*.mo
|
61 |
+
*.pot
|
62 |
+
|
63 |
+
# Django stuff:
|
64 |
+
*.log
|
65 |
+
local_settings.py
|
66 |
+
db.sqlite3
|
67 |
+
db.sqlite3-journal
|
68 |
+
|
69 |
+
# Flask stuff:
|
70 |
+
instance/
|
71 |
+
.webassets-cache
|
72 |
+
|
73 |
+
# Scrapy stuff:
|
74 |
+
.scrapy
|
75 |
+
|
76 |
+
# Sphinx documentation
|
77 |
+
docs/_build/
|
78 |
+
|
79 |
+
# PyBuilder
|
80 |
+
.pybuilder/
|
81 |
+
target/
|
82 |
+
|
83 |
+
# Jupyter Notebook
|
84 |
+
.ipynb_checkpoints
|
85 |
+
|
86 |
+
# IPython
|
87 |
+
profile_default/
|
88 |
+
ipython_config.py
|
89 |
+
|
90 |
+
# pyenv
|
91 |
+
# For a library or package, you might want to ignore these files since the code is
|
92 |
+
# intended to run in multiple environments; otherwise, check them in:
|
93 |
+
# .python-version
|
94 |
+
|
95 |
+
# pipenv
|
96 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
97 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
98 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
99 |
+
# install all needed dependencies.
|
100 |
+
#Pipfile.lock
|
101 |
+
|
102 |
+
# poetry
|
103 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
104 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
105 |
+
# commonly ignored for libraries.
|
106 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
107 |
+
#poetry.lock
|
108 |
+
|
109 |
+
# pdm
|
110 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
111 |
+
#pdm.lock
|
112 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
113 |
+
# in version control.
|
114 |
+
# https://pdm.fming.dev/#use-with-ide
|
115 |
+
.pdm.toml
|
116 |
+
|
117 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
118 |
+
__pypackages__/
|
119 |
+
|
120 |
+
# Celery stuff
|
121 |
+
celerybeat-schedule
|
122 |
+
celerybeat.pid
|
123 |
+
|
124 |
+
# SageMath parsed files
|
125 |
+
*.sage.py
|
126 |
+
|
127 |
+
# Environments
|
128 |
+
.env
|
129 |
+
.venv
|
130 |
+
env/
|
131 |
+
venv/
|
132 |
+
ENV/
|
133 |
+
env.bak/
|
134 |
+
venv.bak/
|
135 |
+
|
136 |
+
# Spyder project settings
|
137 |
+
.spyderproject
|
138 |
+
.spyproject
|
139 |
+
|
140 |
+
# Rope project settings
|
141 |
+
.ropeproject
|
142 |
+
|
143 |
+
# mkdocs documentation
|
144 |
+
/site
|
145 |
+
|
146 |
+
# mypy
|
147 |
+
.mypy_cache/
|
148 |
+
.dmypy.json
|
149 |
+
dmypy.json
|
150 |
+
|
151 |
+
# Pyre type checker
|
152 |
+
.pyre/
|
153 |
+
|
154 |
+
# pytype static type analyzer
|
155 |
+
.pytype/
|
156 |
+
|
157 |
+
# Cython debug symbols
|
158 |
+
cython_debug/
|
159 |
+
|
160 |
+
# PyCharm
|
161 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
162 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
163 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
164 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
165 |
+
#.idea/
|
Dockerfile
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9-slim
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
RUN apt-get update && apt-get install -y \
|
6 |
+
build-essential \
|
7 |
+
curl \
|
8 |
+
software-properties-common \
|
9 |
+
git \
|
10 |
+
&& rm -rf /var/lib/apt/lists/*
|
11 |
+
|
12 |
+
COPY components /app/
|
13 |
+
COPY middlewares /app/
|
14 |
+
COPY app.py /app/
|
15 |
+
COPY requirements.txt /app/
|
16 |
+
COPY config.yaml /app/
|
17 |
+
|
18 |
+
RUN pip3 install -r requirements.txt
|
19 |
+
|
20 |
+
EXPOSE 8501
|
21 |
+
|
22 |
+
ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
Manage
ADDED
File without changes
|
README.md
CHANGED
@@ -1,12 +1,41 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
-
sdk_version: 1.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
|
|
|
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Mixtral Search Engine
|
3 |
+
emoji: 🔍
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: gray
|
6 |
sdk: streamlit
|
7 |
+
sdk_version: 1.29.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: mit
|
11 |
---
|
12 |
|
13 |
+
# Mixtral Search Engine
|
14 |
+
|
15 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
16 |
+
|
17 |
+
## Docker Setup
|
18 |
+
|
19 |
+
If you prefer using Docker, follow these steps:
|
20 |
+
|
21 |
+
1. Clone the repository.
|
22 |
+
2. Create a `.env` file to store API credentials.
|
23 |
+
|
24 |
+
```
|
25 |
+
HF_TOKEN = ...
|
26 |
+
GOOGLE_SEARCH_ENGINE_ID = ...
|
27 |
+
GOOGLE_SEARCH_API_KEY = ...
|
28 |
+
BING_SEARCH_API_KEY = ...
|
29 |
+
```
|
30 |
+
|
31 |
+
3. Build docker image using
|
32 |
+
|
33 |
+
```
|
34 |
+
docker build -t mixtral-search .
|
35 |
+
```
|
36 |
+
|
37 |
+
4. Run the image using
|
38 |
+
|
39 |
+
```
|
40 |
+
docker run --env-file .env -p 8501:8501 mixtral-search
|
41 |
+
```
|
app.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
import streamlit as st
|
3 |
+
from components.sidebar import sidebar
|
4 |
+
from components.chat_box import chat_box
|
5 |
+
from components.chat_loop import chat_loop
|
6 |
+
from components.init_state import init_state
|
7 |
+
from components.prompt_engineering_dashboard import prompt_engineering_dashboard
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
with open("config.yaml", "r") as file:
|
12 |
+
config = yaml.safe_load(file)
|
13 |
+
|
14 |
+
st.set_page_config(
|
15 |
+
page_title="NCTC OSINT AGENT",
|
16 |
+
page_icon="📚",
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
init_state(st.session_state, config)
|
21 |
+
|
22 |
+
st.write("# NCTC OSINT AGENT ")
|
23 |
+
|
24 |
+
# Prompt Engineering Dashboard is working but not for production, works great for testing.
|
25 |
+
prompt_engineering_dashboard(st.session_state, config)
|
26 |
+
|
27 |
+
|
28 |
+
sidebar(st.session_state, config)
|
29 |
+
|
30 |
+
chat_box(st.session_state, config)
|
31 |
+
|
32 |
+
chat_loop(st.session_state, config)
|
components/__init__.py
ADDED
File without changes
|
components/chat_box.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
|
4 |
+
def chat_box(session_state, config):
|
5 |
+
for message in session_state.messages:
|
6 |
+
with st.chat_message(message["role"]):
|
7 |
+
st.markdown(message["content"])
|
components/chat_loop.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from components.generate_chat_stream import generate_chat_stream
|
3 |
+
from components.stream_handler import stream_handler
|
4 |
+
from components.show_source import show_source
|
5 |
+
|
6 |
+
|
7 |
+
def chat_loop(session_state, config):
|
8 |
+
if prompt := st.chat_input("Search the web..."):
|
9 |
+
st.chat_message("user").markdown(prompt)
|
10 |
+
session_state.messages.append({"role": "user", "content": prompt})
|
11 |
+
|
12 |
+
chat_stream, links = generate_chat_stream(session_state, prompt, config)
|
13 |
+
|
14 |
+
with st.chat_message("assistant"):
|
15 |
+
placeholder = st.empty()
|
16 |
+
full_response = stream_handler(
|
17 |
+
session_state, chat_stream, prompt, placeholder
|
18 |
+
)
|
19 |
+
if session_state.rag_enabled:
|
20 |
+
show_source(links)
|
21 |
+
|
22 |
+
session_state.history.append([prompt, full_response])
|
23 |
+
session_state.messages.append({"role": "assistant", "content": full_response})
|
components/generate_chat_stream.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from middlewares.utils import gen_augmented_prompt_via_websearch
|
3 |
+
from middlewares.chat_client import chat
|
4 |
+
import json
|
5 |
+
from pprint import pformat
|
6 |
+
from notion_client import Client
|
7 |
+
|
8 |
+
def safe_get(data, dot_chained_keys):
|
9 |
+
'''
|
10 |
+
{'a': {'b': [{'c': 1}]}}
|
11 |
+
safe_get(data, 'a.b.0.c') -> 1
|
12 |
+
'''
|
13 |
+
keys = dot_chained_keys.split('.')
|
14 |
+
for key in keys:
|
15 |
+
try:
|
16 |
+
if isinstance(data, list):
|
17 |
+
data = data[int(key)]
|
18 |
+
else:
|
19 |
+
data = data[key]
|
20 |
+
except (KeyError, TypeError, IndexError):
|
21 |
+
return None
|
22 |
+
return data
|
23 |
+
|
24 |
+
def get_notion_data() :
|
25 |
+
integration_token = "secret_lTOe0q9dqqKQLRRb2KJwi7QFSl0vqoztroRFHW6MeQE"
|
26 |
+
notion_database_id = "6c0d877b823a4e3699016fa7083f3006"
|
27 |
+
|
28 |
+
client = Client(auth=integration_token)
|
29 |
+
|
30 |
+
first_db_rows = client.databases.query(notion_database_id)
|
31 |
+
rows = []
|
32 |
+
|
33 |
+
|
34 |
+
for row in first_db_rows['results']:
|
35 |
+
price = safe_get(row, 'properties.($) Per Unit.number')
|
36 |
+
store_link = safe_get(row, 'properties.Store Link.url')
|
37 |
+
supplier_email = safe_get(row, 'properties.Supplier Email.email')
|
38 |
+
exp_del = safe_get(row, 'properties.Expected Delivery.date')
|
39 |
+
|
40 |
+
collections = safe_get(row, 'properties.Collection.multi_select')
|
41 |
+
collection_names = []
|
42 |
+
for collection in collections :
|
43 |
+
collection_names.append(collection['name'])
|
44 |
+
|
45 |
+
status = safe_get(row, 'properties.Status.select.name')
|
46 |
+
sup_phone = safe_get(row, 'properties.Supplier Phone.phone_number')
|
47 |
+
stock_alert = safe_get(row, 'properties.Stock Alert.status.name')
|
48 |
+
prod_name = safe_get(row, 'properties.Product .title.0.text.content')
|
49 |
+
sku = safe_get(row, 'properties.SKU.number')
|
50 |
+
shipped_date = safe_get(row, 'properties.Shipped On.date')
|
51 |
+
on_order = safe_get(row, 'properties.On Order.number')
|
52 |
+
on_hand = safe_get(row, 'properties.On Hand.number')
|
53 |
+
size_names = []
|
54 |
+
sizes = safe_get(row, 'properties.Size.multi_select')
|
55 |
+
for size in sizes :
|
56 |
+
size_names.append(size['name'])
|
57 |
+
|
58 |
+
rows.append({
|
59 |
+
'Price Per unit': price,
|
60 |
+
'Store Link' : store_link,
|
61 |
+
'Supplier Email' : supplier_email,
|
62 |
+
'Expected Delivery' : exp_del,
|
63 |
+
'Collection' : collection_names,
|
64 |
+
'Status' : status,
|
65 |
+
'Supplier Phone' : sup_phone,
|
66 |
+
'Stock Alert' : stock_alert,
|
67 |
+
'Product Name' : prod_name,
|
68 |
+
'SKU' : sku,
|
69 |
+
'Sizes' : size_names,
|
70 |
+
'Shipped Date' : shipped_date,
|
71 |
+
'On Order' : on_order,
|
72 |
+
"On Hand" : on_hand,
|
73 |
+
})
|
74 |
+
|
75 |
+
notion_data_string = pformat(rows)
|
76 |
+
return notion_data_string
|
77 |
+
|
78 |
+
def generate_chat_stream(session_state, query, config):
|
79 |
+
# 1. augments prompt according to the template
|
80 |
+
# 2. returns chat_stream and source links
|
81 |
+
# 3. chat_stream and source links are used by stream_handler and show_source
|
82 |
+
chat_bot_dict = config["CHAT_BOTS"]
|
83 |
+
links = []
|
84 |
+
if session_state.rag_enabled:
|
85 |
+
with st.spinner("Fetching relevent documents from Web...."):
|
86 |
+
query, links = gen_augmented_prompt_via_websearch(
|
87 |
+
prompt=query,
|
88 |
+
pre_context=session_state.pre_context,
|
89 |
+
post_context=session_state.post_context,
|
90 |
+
pre_prompt=session_state.pre_prompt,
|
91 |
+
post_prompt=session_state.post_prompt,
|
92 |
+
search_vendor=session_state.search_vendor,
|
93 |
+
top_k=session_state.top_k,
|
94 |
+
n_crawl=session_state.n_crawl,
|
95 |
+
pass_prev=session_state.pass_prev,
|
96 |
+
prev_output=session_state.history[-1][1],
|
97 |
+
)
|
98 |
+
|
99 |
+
notion_data = get_notion_data()
|
100 |
+
|
101 |
+
with st.spinner("Generating response..."):
|
102 |
+
chat_stream = chat(session_state, notion_data + " " + query , config)
|
103 |
+
|
104 |
+
return chat_stream, links
|
components/init_state.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def init_state(session_state, config):
|
2 |
+
initial_prompt_engineering_dict = config["PROMPT_ENGINEERING_DICT"]
|
3 |
+
if "messages" not in session_state:
|
4 |
+
session_state.messages = []
|
5 |
+
|
6 |
+
if "tokens_used" not in session_state:
|
7 |
+
session_state.tokens_used = 0
|
8 |
+
|
9 |
+
if "tps" not in session_state:
|
10 |
+
session_state.tps = 0
|
11 |
+
|
12 |
+
if "temp" not in session_state:
|
13 |
+
session_state.temp = 0.8
|
14 |
+
|
15 |
+
if "history" not in session_state:
|
16 |
+
session_state.history = [
|
17 |
+
[
|
18 |
+
initial_prompt_engineering_dict["SYSTEM_INSTRUCTION"],
|
19 |
+
initial_prompt_engineering_dict["SYSTEM_RESPONSE"],
|
20 |
+
]
|
21 |
+
]
|
22 |
+
|
23 |
+
if "n_crawl" not in session_state:
|
24 |
+
session_state.n_crawl = 5
|
25 |
+
|
26 |
+
if "repetion_penalty" not in session_state:
|
27 |
+
session_state.repetion_penalty = 1
|
28 |
+
|
29 |
+
if "rag_enabled" not in session_state:
|
30 |
+
session_state.rag_enabled = True
|
31 |
+
|
32 |
+
if "chat_bot" not in session_state:
|
33 |
+
session_state.chat_bot = "Mixtral 8x7B v0.1"
|
34 |
+
|
35 |
+
if "search_vendor" not in session_state:
|
36 |
+
session_state.search_vendor = "Bing"
|
37 |
+
|
38 |
+
if "system_instruction" not in session_state:
|
39 |
+
session_state.system_instruction = initial_prompt_engineering_dict[
|
40 |
+
"SYSTEM_INSTRUCTION"
|
41 |
+
]
|
42 |
+
|
43 |
+
if "system_response" not in session_state:
|
44 |
+
session_state.system_instruction = initial_prompt_engineering_dict[
|
45 |
+
"SYSTEM_RESPONSE"
|
46 |
+
]
|
47 |
+
|
48 |
+
if "pre_context" not in session_state:
|
49 |
+
session_state.pre_context = initial_prompt_engineering_dict["PRE_CONTEXT"]
|
50 |
+
|
51 |
+
if "post_context" not in session_state:
|
52 |
+
session_state.post_context = initial_prompt_engineering_dict["POST_CONTEXT"]
|
53 |
+
|
54 |
+
if "pre_prompt" not in session_state:
|
55 |
+
session_state.pre_prompt = initial_prompt_engineering_dict["PRE_PROMPT"]
|
56 |
+
|
57 |
+
if "post_prompt" not in session_state:
|
58 |
+
session_state.post_prompt = initial_prompt_engineering_dict["POST_PROMPT"]
|
59 |
+
|
60 |
+
if "pass_prev" not in session_state:
|
61 |
+
session_state.pass_prev = False
|
62 |
+
|
63 |
+
if "chunk_size" not in session_state:
|
64 |
+
session_state.chunk_size = 512
|
components/prompt_engineering_dashboard.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
|
4 |
+
def prompt_engineering_dashboard(session_state, config):
|
5 |
+
inital_prompt_engineering_dict = config["PROMPT_ENGINEERING_DICT"]
|
6 |
+
|
7 |
+
def engineer_prompt():
|
8 |
+
session_state.history[0] = [
|
9 |
+
session_state.system_instruction,
|
10 |
+
session_state.system_response,
|
11 |
+
]
|
12 |
+
|
13 |
+
with st.expander("Prompt Engineering Dashboard"):
|
14 |
+
st.info(
|
15 |
+
"**The input to the model follows this below template**",
|
16 |
+
)
|
17 |
+
st.code(
|
18 |
+
"""
|
19 |
+
[SYSTEM INSTRUCTION]
|
20 |
+
[SYSTEM RESPONSE]
|
21 |
+
|
22 |
+
[... LIST OF PREV INPUTS]
|
23 |
+
|
24 |
+
[PRE CONTEXT]
|
25 |
+
[CONTEXT RETRIEVED FROM THE WEB]
|
26 |
+
[POST CONTEXT]
|
27 |
+
|
28 |
+
[PRE PROMPT]
|
29 |
+
[PROMPT]
|
30 |
+
[POST PROMPT]
|
31 |
+
[PREV GENERATED INPUT] # Only if Pass previous prompt set True
|
32 |
+
|
33 |
+
"""
|
34 |
+
)
|
35 |
+
session_state.system_instruction = st.text_area(
|
36 |
+
label="SYSTEM INSTRUCTION",
|
37 |
+
value=inital_prompt_engineering_dict["SYSTEM_INSTRUCTION"],
|
38 |
+
)
|
39 |
+
session_state.system_response = st.text_area(
|
40 |
+
"SYSTEM RESPONSE", value=inital_prompt_engineering_dict["SYSTEM_RESPONSE"]
|
41 |
+
)
|
42 |
+
|
43 |
+
col1, col2 = st.columns(2)
|
44 |
+
with col1:
|
45 |
+
session_state.pre_context = st.text_input(
|
46 |
+
"PRE CONTEXT",
|
47 |
+
value=inital_prompt_engineering_dict["PRE_CONTEXT"],
|
48 |
+
disabled=not session_state.rag_enabled,
|
49 |
+
)
|
50 |
+
session_state.post_context = st.text_input(
|
51 |
+
"POST CONTEXT",
|
52 |
+
value=inital_prompt_engineering_dict["POST_CONTEXT"],
|
53 |
+
disabled=not session_state.rag_enabled,
|
54 |
+
)
|
55 |
+
|
56 |
+
with col2:
|
57 |
+
session_state.pre_prompt = st.text_input(
|
58 |
+
"PRE PROMPT", value=inital_prompt_engineering_dict["PRE_PROMPT"]
|
59 |
+
)
|
60 |
+
session_state.post_prompt = st.text_input(
|
61 |
+
"POST PROMPT", value=inital_prompt_engineering_dict["POST_PROMPT"]
|
62 |
+
)
|
63 |
+
|
64 |
+
col3, col4 = st.columns(2)
|
65 |
+
with col3:
|
66 |
+
session_state.pass_prev = st.toggle("Pass previous Output")
|
67 |
+
with col4:
|
68 |
+
st.button("Engineer Prompts", on_click=engineer_prompt)
|
components/show_source.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
|
4 |
+
def show_source(links):
|
5 |
+
# Expander component to show source
|
6 |
+
with st.expander("Show source"):
|
7 |
+
for i, link in enumerate(links):
|
8 |
+
st.info(f"{link}")
|
components/sidebar.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from components.sidebar_components.model_analytics import model_analytics
|
3 |
+
from components.sidebar_components.retrieval_settings import retrieval_settings
|
4 |
+
from components.sidebar_components.model_settings import model_settings
|
5 |
+
|
6 |
+
|
7 |
+
def sidebar(session_state, config):
|
8 |
+
with st.sidebar:
|
9 |
+
retrieval_settings(session_state, config)
|
10 |
+
model_analytics(session_state, config)
|
11 |
+
model_settings(session_state, config)
|
components/sidebar_components/__init__.py
ADDED
File without changes
|
components/sidebar_components/model_analytics.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
def model_analytics(session_state, config):
|
3 |
+
COST_PER_1000_TOKENS_USD = config["COST_PER_1000_TOKENS_USD"]
|
4 |
+
|
5 |
+
st.markdown("# Model Analytics")
|
6 |
+
|
7 |
+
st.write("Total tokens used :", session_state["tokens_used"])
|
8 |
+
st.write("Speed :", session_state["tps"], " tokens/sec")
|
9 |
+
st.write(
|
10 |
+
"Total cost incurred :",
|
11 |
+
round(
|
12 |
+
COST_PER_1000_TOKENS_USD * session_state["tokens_used"] / 1000,
|
13 |
+
3,
|
14 |
+
),
|
15 |
+
"USD",
|
16 |
+
)
|
17 |
+
|
18 |
+
st.markdown("---")
|
components/sidebar_components/model_settings.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
def model_settings(session_state,config):
|
4 |
+
CHAT_BOTS = config["CHAT_BOTS"]
|
5 |
+
|
6 |
+
st.markdown("# Model Settings")
|
7 |
+
|
8 |
+
session_state.chat_bot = st.sidebar.radio(
|
9 |
+
"Select one:", [key for key, _ in CHAT_BOTS.items()]
|
10 |
+
)
|
11 |
+
session_state.temp = st.slider(
|
12 |
+
label="Temperature", min_value=0.0, max_value=1.0, step=0.1, value=0.9
|
13 |
+
)
|
14 |
+
|
15 |
+
session_state.max_tokens = st.slider(
|
16 |
+
label="New tokens to generate",
|
17 |
+
min_value=64,
|
18 |
+
max_value=2048,
|
19 |
+
step=32,
|
20 |
+
value=512,
|
21 |
+
)
|
22 |
+
|
23 |
+
session_state.repetion_penalty = st.slider(
|
24 |
+
label="Repetion Penalty", min_value=0.0, max_value=1.0, step=0.1, value=1.0
|
25 |
+
)
|
components/sidebar_components/retrieval_settings.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
def retrieval_settings(session_state, config):
|
4 |
+
st.markdown("# Web Retrieval")
|
5 |
+
session_state.rag_enabled = st.toggle("Activate Web Retrieval", value=True)
|
6 |
+
session_state.search_vendor = st.radio(
|
7 |
+
"Select Search Vendor",
|
8 |
+
["Bing", "Google"],
|
9 |
+
disabled=not session_state.rag_enabled,
|
10 |
+
)
|
11 |
+
session_state.n_crawl = st.slider(
|
12 |
+
label="Links to Crawl",
|
13 |
+
key=1,
|
14 |
+
min_value=1,
|
15 |
+
max_value=10,
|
16 |
+
value=4,
|
17 |
+
disabled=not session_state.rag_enabled,
|
18 |
+
)
|
19 |
+
session_state.top_k = st.slider(
|
20 |
+
label="Chunks to Retrieve via Reranker",
|
21 |
+
key=2,
|
22 |
+
min_value=1,
|
23 |
+
max_value=20,
|
24 |
+
value=5,
|
25 |
+
disabled=not session_state.rag_enabled,
|
26 |
+
)
|
27 |
+
|
28 |
+
session_state.chunk_size = st.slider(
|
29 |
+
label="Chunk Size",
|
30 |
+
value=512,
|
31 |
+
min_value=128,
|
32 |
+
max_value=1024,
|
33 |
+
step=8,
|
34 |
+
disabled=not session_state.rag_enabled,
|
35 |
+
)
|
36 |
+
|
37 |
+
st.markdown("---")
|
components/stream_handler.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
COST_PER_1000_TOKENS_USD = 0.139 / 80
|
5 |
+
|
6 |
+
|
7 |
+
def stream_handler(session_state, chat_stream, prompt, placeholder):
|
8 |
+
# 1. Uses the chat_stream and streams message on placeholder
|
9 |
+
# 2. returns full_response for token calculation
|
10 |
+
start_time = time.time()
|
11 |
+
full_response = ""
|
12 |
+
|
13 |
+
for chunk in chat_stream:
|
14 |
+
if chunk.token.text in ["</s>", "<|im_end|>"]:
|
15 |
+
break;
|
16 |
+
full_response += chunk.token.text
|
17 |
+
placeholder.markdown(full_response + "▌")
|
18 |
+
placeholder.markdown(full_response)
|
19 |
+
|
20 |
+
end_time = time.time()
|
21 |
+
elapsed_time = end_time - start_time
|
22 |
+
total_tokens_processed = len(full_response.split())
|
23 |
+
tokens_per_second = total_tokens_processed // elapsed_time
|
24 |
+
len_response = (len(prompt.split()) + len(full_response.split())) * 1.25
|
25 |
+
col1, col2, col3 = st.columns(3)
|
26 |
+
|
27 |
+
with col1:
|
28 |
+
st.write(f"**{tokens_per_second} tokens/second**")
|
29 |
+
|
30 |
+
with col2:
|
31 |
+
st.write(f"**{int(len_response)} tokens generated**")
|
32 |
+
|
33 |
+
with col3:
|
34 |
+
st.write(
|
35 |
+
f"**$ {round(len_response * COST_PER_1000_TOKENS_USD / 1000, 5)} cost incurred**"
|
36 |
+
)
|
37 |
+
|
38 |
+
session_state["tps"] = tokens_per_second
|
39 |
+
session_state["tokens_used"] = len_response + session_state["tokens_used"]
|
40 |
+
|
41 |
+
return full_response
|
config.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PROMPT_ENGINEERING_DICT:
|
2 |
+
SYSTEM_INSTRUCTION: |
|
3 |
+
YOU ARE A SEARCH ENGINE AND AN INVENTORY MANAGER HAVING FULL ACCESS TO WEB PAGES AND NOTION DATABASE IN JSON,
|
4 |
+
YOU GIVE EXTREMELY DETAILED AND ACCURATE INFORMATION ACCORDING TO USER PROMPTS.
|
5 |
+
|
6 |
+
SYSTEM_RESPONSE: |
|
7 |
+
Certainly! I'm here to help. What information are you looking for?
|
8 |
+
Please provide me with a specific topic or question, and I'll do my
|
9 |
+
best to provide you with detailed and accurate information.
|
10 |
+
|
11 |
+
PRE_CONTEXT: NOW YOU ARE SEARCHING THE WEB, AND HERE ARE THE CHUNKS RETRIEVED FROM THE WEB, YOU ALSO HAVE ACCESS TO INVENTORY DATASET IN JSON FORMAT.
|
12 |
+
POST_CONTEXT: ""
|
13 |
+
PRE_PROMPT: NOW ACCORDING TO THE CONTEXT RETRIEVED FROM THE GENERATE THE CONTENT FOR THE FOLLOWING SUBJECT
|
14 |
+
POST_PROMPT: PRIORITIZE DATA, FACTS AND STATISTICS OVER PERSONAL EXPERIENCES AND OPINIONS, FOCUS MORE ON STATISTICS AND DATA.
|
15 |
+
|
16 |
+
CHAT_BOTS:
|
17 |
+
Nous Hermes 2: NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO
|
18 |
+
Mixtral 8x7B v0.1: mistralai/Mixtral-8x7B-Instruct-v0.1
|
19 |
+
Mistral 7B v0.1: mistralai/Mistral-7B-Instruct-v0.1
|
20 |
+
Mistral 7B v0.2: mistralai/Mistral-7B-Instruct-v0.2
|
21 |
+
|
22 |
+
CROSS_ENCODERS:
|
23 |
+
|
24 |
+
COST_PER_1000_TOKENS_USD: 0.001737375
|
middlewares/chat_client.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import InferenceClient
|
2 |
+
import os
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
|
5 |
+
load_dotenv()
|
6 |
+
|
7 |
+
API_TOKEN = os.getenv("HF_TOKEN")
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
def format_prompt(session_state ,query, history, chat_client):
|
13 |
+
if chat_client=="NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO" :
|
14 |
+
model_input = f"""<|im_start|>system
|
15 |
+
{session_state.system_instruction}
|
16 |
+
"""
|
17 |
+
for user_prompt, bot_response in history:
|
18 |
+
model_input += f"""<|im_start|>user
|
19 |
+
{user_prompt}<|im_end|>
|
20 |
+
"""
|
21 |
+
model_input += f"""<|im_start|>assistant
|
22 |
+
{bot_response}<|im_end|>
|
23 |
+
"""
|
24 |
+
model_input += f"""<|im_start|>user
|
25 |
+
{query}<|im_end|>
|
26 |
+
<|im_start|>assistant"""
|
27 |
+
|
28 |
+
return model_input
|
29 |
+
|
30 |
+
|
31 |
+
else :
|
32 |
+
model_input = "<s>"
|
33 |
+
for user_prompt, bot_response in history:
|
34 |
+
model_input += f"[INST] {user_prompt} [/INST]"
|
35 |
+
model_input += f" {bot_response}</s> "
|
36 |
+
model_input += f"[INST] {query} [/INST]"
|
37 |
+
return model_input
|
38 |
+
|
39 |
+
|
40 |
+
def chat(session_state, query, config):
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
chat_bot_dict = config["CHAT_BOTS"]
|
45 |
+
chat_client = chat_bot_dict[session_state.chat_bot]
|
46 |
+
temperature = session_state.temp
|
47 |
+
max_new_tokens = session_state.max_tokens
|
48 |
+
repetion_penalty = session_state.repetion_penalty
|
49 |
+
history = session_state.history
|
50 |
+
|
51 |
+
|
52 |
+
client = InferenceClient(chat_client, token=API_TOKEN)
|
53 |
+
temperature = float(temperature)
|
54 |
+
if temperature < 1e-2:
|
55 |
+
temperature = 1e-2
|
56 |
+
top_p = float(0.95)
|
57 |
+
|
58 |
+
generate_kwargs = dict(
|
59 |
+
temperature=temperature,
|
60 |
+
max_new_tokens=max_new_tokens,
|
61 |
+
top_p=top_p,
|
62 |
+
repetition_penalty=repetion_penalty,
|
63 |
+
do_sample=True,
|
64 |
+
seed=42,
|
65 |
+
)
|
66 |
+
|
67 |
+
formatted_prompt = format_prompt(session_state, query, history, chat_client)
|
68 |
+
|
69 |
+
stream = client.text_generation(
|
70 |
+
formatted_prompt,
|
71 |
+
**generate_kwargs,
|
72 |
+
stream=True,
|
73 |
+
details=True,
|
74 |
+
return_full_text=False,
|
75 |
+
truncate = 32000
|
76 |
+
)
|
77 |
+
|
78 |
+
return stream
|
middlewares/search_client.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
import re
|
4 |
+
import concurrent.futures
|
5 |
+
|
6 |
+
|
7 |
+
class SearchClient:
|
8 |
+
def __init__(self, vendor, engine_id=None, api_key=None):
|
9 |
+
self.vendor = vendor
|
10 |
+
if vendor == "google":
|
11 |
+
self.endpoint = f"https://www.googleapis.com/customsearch/v1?key={api_key}&cx={engine_id}"
|
12 |
+
elif vendor == "bing":
|
13 |
+
self.endpoint = "https://api.bing.microsoft.com/v7.0/search"
|
14 |
+
self.headers = {
|
15 |
+
"Ocp-Apim-Subscription-Key": api_key,
|
16 |
+
}
|
17 |
+
|
18 |
+
@staticmethod
|
19 |
+
def _extract_text_from_link(link):
|
20 |
+
page = requests.get(link)
|
21 |
+
if page.status_code == 200:
|
22 |
+
soup = BeautifulSoup(page.content, "html.parser")
|
23 |
+
text = soup.get_text()
|
24 |
+
cleaned_text = re.sub(r"\s+", " ", text)
|
25 |
+
return cleaned_text
|
26 |
+
return None
|
27 |
+
|
28 |
+
def _fetch_text_from_links(self, links):
|
29 |
+
results = []
|
30 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
31 |
+
future_to_link = {
|
32 |
+
executor.submit(self._extract_text_from_link, link): link
|
33 |
+
for link in links
|
34 |
+
}
|
35 |
+
for future in concurrent.futures.as_completed(future_to_link):
|
36 |
+
link = future_to_link[future]
|
37 |
+
try:
|
38 |
+
cleaned_text = future.result()
|
39 |
+
if cleaned_text:
|
40 |
+
results.append({"text": cleaned_text, "link": link})
|
41 |
+
except Exception as e:
|
42 |
+
print(f"Error fetching data from {link}: {e}")
|
43 |
+
return results
|
44 |
+
|
45 |
+
def _google_search(self, query, n_crawl):
|
46 |
+
response = requests.get(self.endpoint, params={"q": query})
|
47 |
+
search_results = response.json()
|
48 |
+
|
49 |
+
results = []
|
50 |
+
count = 0
|
51 |
+
for item in search_results.get("items", []):
|
52 |
+
if count >= n_crawl:
|
53 |
+
break
|
54 |
+
|
55 |
+
link = item["link"]
|
56 |
+
results.append(link)
|
57 |
+
count += 1
|
58 |
+
|
59 |
+
text_results = self._fetch_text_from_links(results)
|
60 |
+
return text_results
|
61 |
+
|
62 |
+
def _bing_search(self, query, n_crawl):
|
63 |
+
params = {
|
64 |
+
"q": query,
|
65 |
+
"count": n_crawl, # You might need to adjust this based on Bing API requirements
|
66 |
+
"mkt": "en-US",
|
67 |
+
}
|
68 |
+
response = requests.get(self.endpoint, headers=self.headers, params=params)
|
69 |
+
search_results = response.json()
|
70 |
+
|
71 |
+
results = []
|
72 |
+
for item in search_results.get("webPages", {}).get("value", []):
|
73 |
+
link = item["url"]
|
74 |
+
results.append(link)
|
75 |
+
|
76 |
+
text_results = self._fetch_text_from_links(results)
|
77 |
+
return text_results
|
78 |
+
|
79 |
+
def search(self, query, n_crawl):
|
80 |
+
if self.vendor == "google":
|
81 |
+
return self._google_search(query, n_crawl)
|
82 |
+
elif self.vendor == "bing":
|
83 |
+
return self._bing_search(query, n_crawl)
|
84 |
+
else:
|
85 |
+
return "Invalid vendor"
|
middlewares/utils.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import CrossEncoder
|
2 |
+
|
3 |
+
import json
|
4 |
+
import math
|
5 |
+
import numpy as np
|
6 |
+
from middlewares.search_client import SearchClient
|
7 |
+
import os
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
|
10 |
+
|
11 |
+
load_dotenv()
|
12 |
+
|
13 |
+
|
14 |
+
GOOGLE_SEARCH_ENGINE_ID = os.getenv("GOOGLE_SEARCH_ENGINE_ID")
|
15 |
+
GOOGLE_SEARCH_API_KEY = os.getenv("GOOGLE_SEARCH_API_KEY")
|
16 |
+
BING_SEARCH_API_KEY = os.getenv("BING_SEARCH_API_KEY")
|
17 |
+
|
18 |
+
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
|
19 |
+
|
20 |
+
googleSearchClient = SearchClient(
|
21 |
+
"google", api_key=GOOGLE_SEARCH_API_KEY, engine_id=GOOGLE_SEARCH_ENGINE_ID
|
22 |
+
)
|
23 |
+
bingSearchClient = SearchClient("bing", api_key=BING_SEARCH_API_KEY, engine_id=None)
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
def rerank(query, top_k, search_results, chunk_size=512):
|
29 |
+
chunks = []
|
30 |
+
for result in search_results:
|
31 |
+
text = result["text"]
|
32 |
+
words = text.split()
|
33 |
+
num_chunks = math.ceil(len(words) / chunk_size)
|
34 |
+
for i in range(num_chunks):
|
35 |
+
start = i * chunk_size
|
36 |
+
end = (i + 1) * chunk_size
|
37 |
+
chunk = " ".join(words[start:end])
|
38 |
+
chunks.append((result["link"], chunk))
|
39 |
+
|
40 |
+
# Create sentence combinations with the query
|
41 |
+
sentence_combinations = [[query, chunk[1]] for chunk in chunks]
|
42 |
+
|
43 |
+
# Compute similarity scores for these combinations
|
44 |
+
similarity_scores = reranker.predict(sentence_combinations)
|
45 |
+
|
46 |
+
# Sort scores indexes in decreasing order
|
47 |
+
sim_scores_argsort = reversed(np.argsort(similarity_scores))
|
48 |
+
|
49 |
+
# Rearrange search_results based on the reranked scores
|
50 |
+
reranked_results = []
|
51 |
+
for idx in sim_scores_argsort:
|
52 |
+
link = chunks[idx][0]
|
53 |
+
chunk = chunks[idx][1]
|
54 |
+
reranked_results.append({"link": link, "text": chunk})
|
55 |
+
|
56 |
+
# Return the top K ranks
|
57 |
+
return reranked_results[:top_k]
|
58 |
+
|
59 |
+
|
60 |
+
def gen_augmented_prompt_via_websearch(
|
61 |
+
prompt,
|
62 |
+
search_vendor,
|
63 |
+
n_crawl,
|
64 |
+
top_k,
|
65 |
+
pre_context="",
|
66 |
+
post_context="",
|
67 |
+
pre_prompt="",
|
68 |
+
post_prompt="",
|
69 |
+
pass_prev=False,
|
70 |
+
prev_output="",
|
71 |
+
chunk_size=512,
|
72 |
+
):
|
73 |
+
|
74 |
+
|
75 |
+
search_results = []
|
76 |
+
reranked_results = []
|
77 |
+
if search_vendor == "Google":
|
78 |
+
search_results = googleSearchClient.search(prompt, n_crawl)
|
79 |
+
elif search_vendor == "Bing":
|
80 |
+
search_results = bingSearchClient.search(prompt, n_crawl)
|
81 |
+
|
82 |
+
if len(search_results) > 0:
|
83 |
+
reranked_results = rerank(prompt, top_k, search_results, chunk_size)
|
84 |
+
|
85 |
+
links = []
|
86 |
+
context = ""
|
87 |
+
for res in reranked_results:
|
88 |
+
context += res["text"] + "\n\n"
|
89 |
+
link = res["link"]
|
90 |
+
links.append(link)
|
91 |
+
|
92 |
+
# remove duplicate links
|
93 |
+
links = list(set(links))
|
94 |
+
|
95 |
+
prev_output = prev_output if pass_prev else ""
|
96 |
+
|
97 |
+
augmented_prompt = f"""
|
98 |
+
|
99 |
+
{pre_context}
|
100 |
+
|
101 |
+
{context}
|
102 |
+
|
103 |
+
|
104 |
+
{post_context}
|
105 |
+
|
106 |
+
{pre_prompt}
|
107 |
+
|
108 |
+
{prompt}
|
109 |
+
|
110 |
+
{post_prompt}
|
111 |
+
|
112 |
+
{prev_output}
|
113 |
+
|
114 |
+
"""
|
115 |
+
|
116 |
+
print(augmented_prompt)
|
117 |
+
return augmented_prompt, links
|
requirements.txt
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
notion-client==2.2.1
|
2 |
+
altair==5.1.2
|
3 |
+
asttokens==2.2.1
|
4 |
+
attrs==23.1.0
|
5 |
+
backcall==0.2.0
|
6 |
+
beautifulsoup4==4.12.2
|
7 |
+
blinker==1.6.3
|
8 |
+
cachetools==5.3.1
|
9 |
+
certifi==2023.7.22
|
10 |
+
charset-normalizer==3.3.0
|
11 |
+
click==8.1.7
|
12 |
+
colorama==0.4.6
|
13 |
+
comm==0.1.3
|
14 |
+
debugpy==1.6.7
|
15 |
+
decorator==5.1.1
|
16 |
+
dnspython==2.4.2
|
17 |
+
executing==1.2.0
|
18 |
+
filelock==3.12.4
|
19 |
+
fsspec==2023.9.2
|
20 |
+
gitdb==4.0.10
|
21 |
+
GitPython==3.1.37
|
22 |
+
huggingface-hub==0.18.0
|
23 |
+
idna==3.4
|
24 |
+
importlib-metadata==6.8.0
|
25 |
+
ipykernel==6.23.3
|
26 |
+
ipython==8.14.0
|
27 |
+
jedi==0.18.2
|
28 |
+
Jinja2==3.1.2
|
29 |
+
joblib==1.3.2
|
30 |
+
jsonschema==4.19.1
|
31 |
+
jsonschema-specifications==2023.7.1
|
32 |
+
jupyter_client==8.3.0
|
33 |
+
jupyter_core==5.3.1
|
34 |
+
loguru==0.7.2
|
35 |
+
markdown-it-py==3.0.0
|
36 |
+
MarkupSafe==2.1.3
|
37 |
+
matplotlib-inline==0.1.6
|
38 |
+
mdurl==0.1.2
|
39 |
+
mpmath==1.3.0
|
40 |
+
nest-asyncio==1.5.6
|
41 |
+
networkx==3.2.1
|
42 |
+
nltk==3.8.1
|
43 |
+
numpy==1.26.0
|
44 |
+
packaging==23.1
|
45 |
+
pandas==2.1.1
|
46 |
+
parso==0.8.3
|
47 |
+
pickleshare==0.7.5
|
48 |
+
Pillow==10.0.1
|
49 |
+
platformdirs==3.8.0
|
50 |
+
prompt-toolkit==3.0.38
|
51 |
+
protobuf==4.24.4
|
52 |
+
psutil==5.9.5
|
53 |
+
pure-eval==0.2.2
|
54 |
+
pyarrow==13.0.0
|
55 |
+
pydeck==0.8.1b0
|
56 |
+
Pygments==2.15.1
|
57 |
+
python-dateutil==2.8.2
|
58 |
+
python-dotenv==1.0.0
|
59 |
+
pytz==2023.3.post1
|
60 |
+
PyYAML==6.0.1
|
61 |
+
pyzmq==25.1.0
|
62 |
+
referencing==0.30.2
|
63 |
+
regex==2023.10.3
|
64 |
+
requests==2.31.0
|
65 |
+
rich==13.6.0
|
66 |
+
rpds-py==0.10.4
|
67 |
+
safetensors==0.4.1
|
68 |
+
scikit-learn==1.3.2
|
69 |
+
scipy==1.11.4
|
70 |
+
sentence-transformers==2.2.2
|
71 |
+
sentencepiece==0.1.99
|
72 |
+
six==1.16.0
|
73 |
+
smmap==5.0.1
|
74 |
+
soupsieve==2.5
|
75 |
+
stack-data==0.6.2
|
76 |
+
streamlit==1.27.2
|
77 |
+
sympy==1.12
|
78 |
+
tenacity==8.2.3
|
79 |
+
threadpoolctl==3.2.0
|
80 |
+
tokenizers==0.15.0
|
81 |
+
toml==0.10.2
|
82 |
+
toolz==0.12.0
|
83 |
+
torch==2.1.2
|
84 |
+
torchvision==0.16.2
|
85 |
+
tornado==6.3.2
|
86 |
+
tqdm==4.66.1
|
87 |
+
traitlets==5.9.0
|
88 |
+
transformers==4.35.2
|
89 |
+
typing_extensions==4.8.0
|
90 |
+
tzdata==2023.3
|
91 |
+
tzlocal==5.1
|
92 |
+
urllib3==2.0.6
|
93 |
+
validators==0.22.0
|
94 |
+
watchdog==3.0.0
|
95 |
+
wcwidth==0.2.6
|
96 |
+
win32-setctime==1.1.0
|
97 |
+
zipp==3.17.0
|