NCTCMumbai commited on
Commit
e5b3236
1 Parent(s): 0e6440f

Upload 24 files

Browse files
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Win executables
10
+ *.exe
11
+ rag-env/
12
+ mixtral-playground/
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ .pybuilder/
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # poetry
103
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107
+ #poetry.lock
108
+
109
+ # pdm
110
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111
+ #pdm.lock
112
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113
+ # in version control.
114
+ # https://pdm.fming.dev/#use-with-ide
115
+ .pdm.toml
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+
160
+ # PyCharm
161
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
164
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165
+ #.idea/
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ build-essential \
7
+ curl \
8
+ software-properties-common \
9
+ git \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ COPY components /app/
13
+ COPY middlewares /app/
14
+ COPY app.py /app/
15
+ COPY requirements.txt /app/
16
+ COPY config.yaml /app/
17
+
18
+ RUN pip3 install -r requirements.txt
19
+
20
+ EXPOSE 8501
21
+
22
+ ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
Manage ADDED
File without changes
README.md CHANGED
@@ -1,12 +1,41 @@
1
  ---
2
- title: NCTC OSINT AGENT
3
- emoji: 🌖
4
- colorFrom: indigo
5
- colorTo: green
6
  sdk: streamlit
7
- sdk_version: 1.33.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
 
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Mixtral Search Engine
3
+ emoji: 🔍
4
+ colorFrom: pink
5
+ colorTo: gray
6
  sdk: streamlit
7
+ sdk_version: 1.29.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # Mixtral Search Engine
14
+
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
16
+
17
+ ## Docker Setup
18
+
19
+ If you prefer using Docker, follow these steps:
20
+
21
+ 1. Clone the repository.
22
+ 2. Create a `.env` file to store API credentials.
23
+
24
+ ```
25
+ HF_TOKEN = ...
26
+ GOOGLE_SEARCH_ENGINE_ID = ...
27
+ GOOGLE_SEARCH_API_KEY = ...
28
+ BING_SEARCH_API_KEY = ...
29
+ ```
30
+
31
+ 3. Build docker image using
32
+
33
+ ```
34
+ docker build -t mixtral-search .
35
+ ```
36
+
37
+ 4. Run the image using
38
+
39
+ ```
40
+ docker run --env-file .env -p 8501:8501 mixtral-search
41
+ ```
app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import streamlit as st
3
+ from components.sidebar import sidebar
4
+ from components.chat_box import chat_box
5
+ from components.chat_loop import chat_loop
6
+ from components.init_state import init_state
7
+ from components.prompt_engineering_dashboard import prompt_engineering_dashboard
8
+
9
+
10
+
11
+ with open("config.yaml", "r") as file:
12
+ config = yaml.safe_load(file)
13
+
14
+ st.set_page_config(
15
+ page_title="NCTC OSINT AGENT",
16
+ page_icon="📚",
17
+ )
18
+
19
+
20
+ init_state(st.session_state, config)
21
+
22
+ st.write("# NCTC OSINT AGENT ")
23
+
24
+ # Prompt Engineering Dashboard is working but not for production, works great for testing.
25
+ prompt_engineering_dashboard(st.session_state, config)
26
+
27
+
28
+ sidebar(st.session_state, config)
29
+
30
+ chat_box(st.session_state, config)
31
+
32
+ chat_loop(st.session_state, config)
components/__init__.py ADDED
File without changes
components/chat_box.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ def chat_box(session_state, config):
5
+ for message in session_state.messages:
6
+ with st.chat_message(message["role"]):
7
+ st.markdown(message["content"])
components/chat_loop.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from components.generate_chat_stream import generate_chat_stream
3
+ from components.stream_handler import stream_handler
4
+ from components.show_source import show_source
5
+
6
+
7
+ def chat_loop(session_state, config):
8
+ if prompt := st.chat_input("Search the web..."):
9
+ st.chat_message("user").markdown(prompt)
10
+ session_state.messages.append({"role": "user", "content": prompt})
11
+
12
+ chat_stream, links = generate_chat_stream(session_state, prompt, config)
13
+
14
+ with st.chat_message("assistant"):
15
+ placeholder = st.empty()
16
+ full_response = stream_handler(
17
+ session_state, chat_stream, prompt, placeholder
18
+ )
19
+ if session_state.rag_enabled:
20
+ show_source(links)
21
+
22
+ session_state.history.append([prompt, full_response])
23
+ session_state.messages.append({"role": "assistant", "content": full_response})
components/generate_chat_stream.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from middlewares.utils import gen_augmented_prompt_via_websearch
3
+ from middlewares.chat_client import chat
4
+ import json
5
+ from pprint import pformat
6
+ from notion_client import Client
7
+
8
+ def safe_get(data, dot_chained_keys):
9
+ '''
10
+ {'a': {'b': [{'c': 1}]}}
11
+ safe_get(data, 'a.b.0.c') -> 1
12
+ '''
13
+ keys = dot_chained_keys.split('.')
14
+ for key in keys:
15
+ try:
16
+ if isinstance(data, list):
17
+ data = data[int(key)]
18
+ else:
19
+ data = data[key]
20
+ except (KeyError, TypeError, IndexError):
21
+ return None
22
+ return data
23
+
24
+ def get_notion_data() :
25
+ integration_token = "secret_lTOe0q9dqqKQLRRb2KJwi7QFSl0vqoztroRFHW6MeQE"
26
+ notion_database_id = "6c0d877b823a4e3699016fa7083f3006"
27
+
28
+ client = Client(auth=integration_token)
29
+
30
+ first_db_rows = client.databases.query(notion_database_id)
31
+ rows = []
32
+
33
+
34
+ for row in first_db_rows['results']:
35
+ price = safe_get(row, 'properties.($) Per Unit.number')
36
+ store_link = safe_get(row, 'properties.Store Link.url')
37
+ supplier_email = safe_get(row, 'properties.Supplier Email.email')
38
+ exp_del = safe_get(row, 'properties.Expected Delivery.date')
39
+
40
+ collections = safe_get(row, 'properties.Collection.multi_select')
41
+ collection_names = []
42
+ for collection in collections :
43
+ collection_names.append(collection['name'])
44
+
45
+ status = safe_get(row, 'properties.Status.select.name')
46
+ sup_phone = safe_get(row, 'properties.Supplier Phone.phone_number')
47
+ stock_alert = safe_get(row, 'properties.Stock Alert.status.name')
48
+ prod_name = safe_get(row, 'properties.Product .title.0.text.content')
49
+ sku = safe_get(row, 'properties.SKU.number')
50
+ shipped_date = safe_get(row, 'properties.Shipped On.date')
51
+ on_order = safe_get(row, 'properties.On Order.number')
52
+ on_hand = safe_get(row, 'properties.On Hand.number')
53
+ size_names = []
54
+ sizes = safe_get(row, 'properties.Size.multi_select')
55
+ for size in sizes :
56
+ size_names.append(size['name'])
57
+
58
+ rows.append({
59
+ 'Price Per unit': price,
60
+ 'Store Link' : store_link,
61
+ 'Supplier Email' : supplier_email,
62
+ 'Expected Delivery' : exp_del,
63
+ 'Collection' : collection_names,
64
+ 'Status' : status,
65
+ 'Supplier Phone' : sup_phone,
66
+ 'Stock Alert' : stock_alert,
67
+ 'Product Name' : prod_name,
68
+ 'SKU' : sku,
69
+ 'Sizes' : size_names,
70
+ 'Shipped Date' : shipped_date,
71
+ 'On Order' : on_order,
72
+ "On Hand" : on_hand,
73
+ })
74
+
75
+ notion_data_string = pformat(rows)
76
+ return notion_data_string
77
+
78
+ def generate_chat_stream(session_state, query, config):
79
+ # 1. augments prompt according to the template
80
+ # 2. returns chat_stream and source links
81
+ # 3. chat_stream and source links are used by stream_handler and show_source
82
+ chat_bot_dict = config["CHAT_BOTS"]
83
+ links = []
84
+ if session_state.rag_enabled:
85
+ with st.spinner("Fetching relevent documents from Web...."):
86
+ query, links = gen_augmented_prompt_via_websearch(
87
+ prompt=query,
88
+ pre_context=session_state.pre_context,
89
+ post_context=session_state.post_context,
90
+ pre_prompt=session_state.pre_prompt,
91
+ post_prompt=session_state.post_prompt,
92
+ search_vendor=session_state.search_vendor,
93
+ top_k=session_state.top_k,
94
+ n_crawl=session_state.n_crawl,
95
+ pass_prev=session_state.pass_prev,
96
+ prev_output=session_state.history[-1][1],
97
+ )
98
+
99
+ notion_data = get_notion_data()
100
+
101
+ with st.spinner("Generating response..."):
102
+ chat_stream = chat(session_state, notion_data + " " + query , config)
103
+
104
+ return chat_stream, links
components/init_state.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def init_state(session_state, config):
2
+ initial_prompt_engineering_dict = config["PROMPT_ENGINEERING_DICT"]
3
+ if "messages" not in session_state:
4
+ session_state.messages = []
5
+
6
+ if "tokens_used" not in session_state:
7
+ session_state.tokens_used = 0
8
+
9
+ if "tps" not in session_state:
10
+ session_state.tps = 0
11
+
12
+ if "temp" not in session_state:
13
+ session_state.temp = 0.8
14
+
15
+ if "history" not in session_state:
16
+ session_state.history = [
17
+ [
18
+ initial_prompt_engineering_dict["SYSTEM_INSTRUCTION"],
19
+ initial_prompt_engineering_dict["SYSTEM_RESPONSE"],
20
+ ]
21
+ ]
22
+
23
+ if "n_crawl" not in session_state:
24
+ session_state.n_crawl = 5
25
+
26
+ if "repetion_penalty" not in session_state:
27
+ session_state.repetion_penalty = 1
28
+
29
+ if "rag_enabled" not in session_state:
30
+ session_state.rag_enabled = True
31
+
32
+ if "chat_bot" not in session_state:
33
+ session_state.chat_bot = "Mixtral 8x7B v0.1"
34
+
35
+ if "search_vendor" not in session_state:
36
+ session_state.search_vendor = "Bing"
37
+
38
+ if "system_instruction" not in session_state:
39
+ session_state.system_instruction = initial_prompt_engineering_dict[
40
+ "SYSTEM_INSTRUCTION"
41
+ ]
42
+
43
+ if "system_response" not in session_state:
44
+ session_state.system_instruction = initial_prompt_engineering_dict[
45
+ "SYSTEM_RESPONSE"
46
+ ]
47
+
48
+ if "pre_context" not in session_state:
49
+ session_state.pre_context = initial_prompt_engineering_dict["PRE_CONTEXT"]
50
+
51
+ if "post_context" not in session_state:
52
+ session_state.post_context = initial_prompt_engineering_dict["POST_CONTEXT"]
53
+
54
+ if "pre_prompt" not in session_state:
55
+ session_state.pre_prompt = initial_prompt_engineering_dict["PRE_PROMPT"]
56
+
57
+ if "post_prompt" not in session_state:
58
+ session_state.post_prompt = initial_prompt_engineering_dict["POST_PROMPT"]
59
+
60
+ if "pass_prev" not in session_state:
61
+ session_state.pass_prev = False
62
+
63
+ if "chunk_size" not in session_state:
64
+ session_state.chunk_size = 512
components/prompt_engineering_dashboard.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ def prompt_engineering_dashboard(session_state, config):
5
+ inital_prompt_engineering_dict = config["PROMPT_ENGINEERING_DICT"]
6
+
7
+ def engineer_prompt():
8
+ session_state.history[0] = [
9
+ session_state.system_instruction,
10
+ session_state.system_response,
11
+ ]
12
+
13
+ with st.expander("Prompt Engineering Dashboard"):
14
+ st.info(
15
+ "**The input to the model follows this below template**",
16
+ )
17
+ st.code(
18
+ """
19
+ [SYSTEM INSTRUCTION]
20
+ [SYSTEM RESPONSE]
21
+
22
+ [... LIST OF PREV INPUTS]
23
+
24
+ [PRE CONTEXT]
25
+ [CONTEXT RETRIEVED FROM THE WEB]
26
+ [POST CONTEXT]
27
+
28
+ [PRE PROMPT]
29
+ [PROMPT]
30
+ [POST PROMPT]
31
+ [PREV GENERATED INPUT] # Only if Pass previous prompt set True
32
+
33
+ """
34
+ )
35
+ session_state.system_instruction = st.text_area(
36
+ label="SYSTEM INSTRUCTION",
37
+ value=inital_prompt_engineering_dict["SYSTEM_INSTRUCTION"],
38
+ )
39
+ session_state.system_response = st.text_area(
40
+ "SYSTEM RESPONSE", value=inital_prompt_engineering_dict["SYSTEM_RESPONSE"]
41
+ )
42
+
43
+ col1, col2 = st.columns(2)
44
+ with col1:
45
+ session_state.pre_context = st.text_input(
46
+ "PRE CONTEXT",
47
+ value=inital_prompt_engineering_dict["PRE_CONTEXT"],
48
+ disabled=not session_state.rag_enabled,
49
+ )
50
+ session_state.post_context = st.text_input(
51
+ "POST CONTEXT",
52
+ value=inital_prompt_engineering_dict["POST_CONTEXT"],
53
+ disabled=not session_state.rag_enabled,
54
+ )
55
+
56
+ with col2:
57
+ session_state.pre_prompt = st.text_input(
58
+ "PRE PROMPT", value=inital_prompt_engineering_dict["PRE_PROMPT"]
59
+ )
60
+ session_state.post_prompt = st.text_input(
61
+ "POST PROMPT", value=inital_prompt_engineering_dict["POST_PROMPT"]
62
+ )
63
+
64
+ col3, col4 = st.columns(2)
65
+ with col3:
66
+ session_state.pass_prev = st.toggle("Pass previous Output")
67
+ with col4:
68
+ st.button("Engineer Prompts", on_click=engineer_prompt)
components/show_source.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ def show_source(links):
5
+ # Expander component to show source
6
+ with st.expander("Show source"):
7
+ for i, link in enumerate(links):
8
+ st.info(f"{link}")
components/sidebar.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from components.sidebar_components.model_analytics import model_analytics
3
+ from components.sidebar_components.retrieval_settings import retrieval_settings
4
+ from components.sidebar_components.model_settings import model_settings
5
+
6
+
7
+ def sidebar(session_state, config):
8
+ with st.sidebar:
9
+ retrieval_settings(session_state, config)
10
+ model_analytics(session_state, config)
11
+ model_settings(session_state, config)
components/sidebar_components/__init__.py ADDED
File without changes
components/sidebar_components/model_analytics.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ def model_analytics(session_state, config):
3
+ COST_PER_1000_TOKENS_USD = config["COST_PER_1000_TOKENS_USD"]
4
+
5
+ st.markdown("# Model Analytics")
6
+
7
+ st.write("Total tokens used :", session_state["tokens_used"])
8
+ st.write("Speed :", session_state["tps"], " tokens/sec")
9
+ st.write(
10
+ "Total cost incurred :",
11
+ round(
12
+ COST_PER_1000_TOKENS_USD * session_state["tokens_used"] / 1000,
13
+ 3,
14
+ ),
15
+ "USD",
16
+ )
17
+
18
+ st.markdown("---")
components/sidebar_components/model_settings.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ def model_settings(session_state,config):
4
+ CHAT_BOTS = config["CHAT_BOTS"]
5
+
6
+ st.markdown("# Model Settings")
7
+
8
+ session_state.chat_bot = st.sidebar.radio(
9
+ "Select one:", [key for key, _ in CHAT_BOTS.items()]
10
+ )
11
+ session_state.temp = st.slider(
12
+ label="Temperature", min_value=0.0, max_value=1.0, step=0.1, value=0.9
13
+ )
14
+
15
+ session_state.max_tokens = st.slider(
16
+ label="New tokens to generate",
17
+ min_value=64,
18
+ max_value=2048,
19
+ step=32,
20
+ value=512,
21
+ )
22
+
23
+ session_state.repetion_penalty = st.slider(
24
+ label="Repetion Penalty", min_value=0.0, max_value=1.0, step=0.1, value=1.0
25
+ )
components/sidebar_components/retrieval_settings.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ def retrieval_settings(session_state, config):
4
+ st.markdown("# Web Retrieval")
5
+ session_state.rag_enabled = st.toggle("Activate Web Retrieval", value=True)
6
+ session_state.search_vendor = st.radio(
7
+ "Select Search Vendor",
8
+ ["Bing", "Google"],
9
+ disabled=not session_state.rag_enabled,
10
+ )
11
+ session_state.n_crawl = st.slider(
12
+ label="Links to Crawl",
13
+ key=1,
14
+ min_value=1,
15
+ max_value=10,
16
+ value=4,
17
+ disabled=not session_state.rag_enabled,
18
+ )
19
+ session_state.top_k = st.slider(
20
+ label="Chunks to Retrieve via Reranker",
21
+ key=2,
22
+ min_value=1,
23
+ max_value=20,
24
+ value=5,
25
+ disabled=not session_state.rag_enabled,
26
+ )
27
+
28
+ session_state.chunk_size = st.slider(
29
+ label="Chunk Size",
30
+ value=512,
31
+ min_value=128,
32
+ max_value=1024,
33
+ step=8,
34
+ disabled=not session_state.rag_enabled,
35
+ )
36
+
37
+ st.markdown("---")
components/stream_handler.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import streamlit as st
3
+
4
+ COST_PER_1000_TOKENS_USD = 0.139 / 80
5
+
6
+
7
+ def stream_handler(session_state, chat_stream, prompt, placeholder):
8
+ # 1. Uses the chat_stream and streams message on placeholder
9
+ # 2. returns full_response for token calculation
10
+ start_time = time.time()
11
+ full_response = ""
12
+
13
+ for chunk in chat_stream:
14
+ if chunk.token.text in ["</s>", "<|im_end|>"]:
15
+ break;
16
+ full_response += chunk.token.text
17
+ placeholder.markdown(full_response + "▌")
18
+ placeholder.markdown(full_response)
19
+
20
+ end_time = time.time()
21
+ elapsed_time = end_time - start_time
22
+ total_tokens_processed = len(full_response.split())
23
+ tokens_per_second = total_tokens_processed // elapsed_time
24
+ len_response = (len(prompt.split()) + len(full_response.split())) * 1.25
25
+ col1, col2, col3 = st.columns(3)
26
+
27
+ with col1:
28
+ st.write(f"**{tokens_per_second} tokens/second**")
29
+
30
+ with col2:
31
+ st.write(f"**{int(len_response)} tokens generated**")
32
+
33
+ with col3:
34
+ st.write(
35
+ f"**$ {round(len_response * COST_PER_1000_TOKENS_USD / 1000, 5)} cost incurred**"
36
+ )
37
+
38
+ session_state["tps"] = tokens_per_second
39
+ session_state["tokens_used"] = len_response + session_state["tokens_used"]
40
+
41
+ return full_response
config.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROMPT_ENGINEERING_DICT:
2
+ SYSTEM_INSTRUCTION: |
3
+ YOU ARE A SEARCH ENGINE AND AN INVENTORY MANAGER HAVING FULL ACCESS TO WEB PAGES AND NOTION DATABASE IN JSON,
4
+ YOU GIVE EXTREMELY DETAILED AND ACCURATE INFORMATION ACCORDING TO USER PROMPTS.
5
+
6
+ SYSTEM_RESPONSE: |
7
+ Certainly! I'm here to help. What information are you looking for?
8
+ Please provide me with a specific topic or question, and I'll do my
9
+ best to provide you with detailed and accurate information.
10
+
11
+ PRE_CONTEXT: NOW YOU ARE SEARCHING THE WEB, AND HERE ARE THE CHUNKS RETRIEVED FROM THE WEB, YOU ALSO HAVE ACCESS TO INVENTORY DATASET IN JSON FORMAT.
12
+ POST_CONTEXT: ""
13
+ PRE_PROMPT: NOW ACCORDING TO THE CONTEXT RETRIEVED FROM THE GENERATE THE CONTENT FOR THE FOLLOWING SUBJECT
14
+ POST_PROMPT: PRIORITIZE DATA, FACTS AND STATISTICS OVER PERSONAL EXPERIENCES AND OPINIONS, FOCUS MORE ON STATISTICS AND DATA.
15
+
16
+ CHAT_BOTS:
17
+ Nous Hermes 2: NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO
18
+ Mixtral 8x7B v0.1: mistralai/Mixtral-8x7B-Instruct-v0.1
19
+ Mistral 7B v0.1: mistralai/Mistral-7B-Instruct-v0.1
20
+ Mistral 7B v0.2: mistralai/Mistral-7B-Instruct-v0.2
21
+
22
+ CROSS_ENCODERS:
23
+
24
+ COST_PER_1000_TOKENS_USD: 0.001737375
middlewares/chat_client.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
+ import os
3
+ from dotenv import load_dotenv
4
+
5
+ load_dotenv()
6
+
7
+ API_TOKEN = os.getenv("HF_TOKEN")
8
+
9
+
10
+
11
+
12
+ def format_prompt(session_state ,query, history, chat_client):
13
+ if chat_client=="NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO" :
14
+ model_input = f"""<|im_start|>system
15
+ {session_state.system_instruction}
16
+ """
17
+ for user_prompt, bot_response in history:
18
+ model_input += f"""<|im_start|>user
19
+ {user_prompt}<|im_end|>
20
+ """
21
+ model_input += f"""<|im_start|>assistant
22
+ {bot_response}<|im_end|>
23
+ """
24
+ model_input += f"""<|im_start|>user
25
+ {query}<|im_end|>
26
+ <|im_start|>assistant"""
27
+
28
+ return model_input
29
+
30
+
31
+ else :
32
+ model_input = "<s>"
33
+ for user_prompt, bot_response in history:
34
+ model_input += f"[INST] {user_prompt} [/INST]"
35
+ model_input += f" {bot_response}</s> "
36
+ model_input += f"[INST] {query} [/INST]"
37
+ return model_input
38
+
39
+
40
+ def chat(session_state, query, config):
41
+
42
+
43
+
44
+ chat_bot_dict = config["CHAT_BOTS"]
45
+ chat_client = chat_bot_dict[session_state.chat_bot]
46
+ temperature = session_state.temp
47
+ max_new_tokens = session_state.max_tokens
48
+ repetion_penalty = session_state.repetion_penalty
49
+ history = session_state.history
50
+
51
+
52
+ client = InferenceClient(chat_client, token=API_TOKEN)
53
+ temperature = float(temperature)
54
+ if temperature < 1e-2:
55
+ temperature = 1e-2
56
+ top_p = float(0.95)
57
+
58
+ generate_kwargs = dict(
59
+ temperature=temperature,
60
+ max_new_tokens=max_new_tokens,
61
+ top_p=top_p,
62
+ repetition_penalty=repetion_penalty,
63
+ do_sample=True,
64
+ seed=42,
65
+ )
66
+
67
+ formatted_prompt = format_prompt(session_state, query, history, chat_client)
68
+
69
+ stream = client.text_generation(
70
+ formatted_prompt,
71
+ **generate_kwargs,
72
+ stream=True,
73
+ details=True,
74
+ return_full_text=False,
75
+ truncate = 32000
76
+ )
77
+
78
+ return stream
middlewares/search_client.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from bs4 import BeautifulSoup
3
+ import re
4
+ import concurrent.futures
5
+
6
+
7
+ class SearchClient:
8
+ def __init__(self, vendor, engine_id=None, api_key=None):
9
+ self.vendor = vendor
10
+ if vendor == "google":
11
+ self.endpoint = f"https://www.googleapis.com/customsearch/v1?key={api_key}&cx={engine_id}"
12
+ elif vendor == "bing":
13
+ self.endpoint = "https://api.bing.microsoft.com/v7.0/search"
14
+ self.headers = {
15
+ "Ocp-Apim-Subscription-Key": api_key,
16
+ }
17
+
18
+ @staticmethod
19
+ def _extract_text_from_link(link):
20
+ page = requests.get(link)
21
+ if page.status_code == 200:
22
+ soup = BeautifulSoup(page.content, "html.parser")
23
+ text = soup.get_text()
24
+ cleaned_text = re.sub(r"\s+", " ", text)
25
+ return cleaned_text
26
+ return None
27
+
28
+ def _fetch_text_from_links(self, links):
29
+ results = []
30
+ with concurrent.futures.ThreadPoolExecutor() as executor:
31
+ future_to_link = {
32
+ executor.submit(self._extract_text_from_link, link): link
33
+ for link in links
34
+ }
35
+ for future in concurrent.futures.as_completed(future_to_link):
36
+ link = future_to_link[future]
37
+ try:
38
+ cleaned_text = future.result()
39
+ if cleaned_text:
40
+ results.append({"text": cleaned_text, "link": link})
41
+ except Exception as e:
42
+ print(f"Error fetching data from {link}: {e}")
43
+ return results
44
+
45
+ def _google_search(self, query, n_crawl):
46
+ response = requests.get(self.endpoint, params={"q": query})
47
+ search_results = response.json()
48
+
49
+ results = []
50
+ count = 0
51
+ for item in search_results.get("items", []):
52
+ if count >= n_crawl:
53
+ break
54
+
55
+ link = item["link"]
56
+ results.append(link)
57
+ count += 1
58
+
59
+ text_results = self._fetch_text_from_links(results)
60
+ return text_results
61
+
62
+ def _bing_search(self, query, n_crawl):
63
+ params = {
64
+ "q": query,
65
+ "count": n_crawl, # You might need to adjust this based on Bing API requirements
66
+ "mkt": "en-US",
67
+ }
68
+ response = requests.get(self.endpoint, headers=self.headers, params=params)
69
+ search_results = response.json()
70
+
71
+ results = []
72
+ for item in search_results.get("webPages", {}).get("value", []):
73
+ link = item["url"]
74
+ results.append(link)
75
+
76
+ text_results = self._fetch_text_from_links(results)
77
+ return text_results
78
+
79
+ def search(self, query, n_crawl):
80
+ if self.vendor == "google":
81
+ return self._google_search(query, n_crawl)
82
+ elif self.vendor == "bing":
83
+ return self._bing_search(query, n_crawl)
84
+ else:
85
+ return "Invalid vendor"
middlewares/utils.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import CrossEncoder
2
+
3
+ import json
4
+ import math
5
+ import numpy as np
6
+ from middlewares.search_client import SearchClient
7
+ import os
8
+ from dotenv import load_dotenv
9
+
10
+
11
+ load_dotenv()
12
+
13
+
14
+ GOOGLE_SEARCH_ENGINE_ID = os.getenv("GOOGLE_SEARCH_ENGINE_ID")
15
+ GOOGLE_SEARCH_API_KEY = os.getenv("GOOGLE_SEARCH_API_KEY")
16
+ BING_SEARCH_API_KEY = os.getenv("BING_SEARCH_API_KEY")
17
+
18
+ reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
19
+
20
+ googleSearchClient = SearchClient(
21
+ "google", api_key=GOOGLE_SEARCH_API_KEY, engine_id=GOOGLE_SEARCH_ENGINE_ID
22
+ )
23
+ bingSearchClient = SearchClient("bing", api_key=BING_SEARCH_API_KEY, engine_id=None)
24
+
25
+
26
+
27
+
28
+ def rerank(query, top_k, search_results, chunk_size=512):
29
+ chunks = []
30
+ for result in search_results:
31
+ text = result["text"]
32
+ words = text.split()
33
+ num_chunks = math.ceil(len(words) / chunk_size)
34
+ for i in range(num_chunks):
35
+ start = i * chunk_size
36
+ end = (i + 1) * chunk_size
37
+ chunk = " ".join(words[start:end])
38
+ chunks.append((result["link"], chunk))
39
+
40
+ # Create sentence combinations with the query
41
+ sentence_combinations = [[query, chunk[1]] for chunk in chunks]
42
+
43
+ # Compute similarity scores for these combinations
44
+ similarity_scores = reranker.predict(sentence_combinations)
45
+
46
+ # Sort scores indexes in decreasing order
47
+ sim_scores_argsort = reversed(np.argsort(similarity_scores))
48
+
49
+ # Rearrange search_results based on the reranked scores
50
+ reranked_results = []
51
+ for idx in sim_scores_argsort:
52
+ link = chunks[idx][0]
53
+ chunk = chunks[idx][1]
54
+ reranked_results.append({"link": link, "text": chunk})
55
+
56
+ # Return the top K ranks
57
+ return reranked_results[:top_k]
58
+
59
+
60
+ def gen_augmented_prompt_via_websearch(
61
+ prompt,
62
+ search_vendor,
63
+ n_crawl,
64
+ top_k,
65
+ pre_context="",
66
+ post_context="",
67
+ pre_prompt="",
68
+ post_prompt="",
69
+ pass_prev=False,
70
+ prev_output="",
71
+ chunk_size=512,
72
+ ):
73
+
74
+
75
+ search_results = []
76
+ reranked_results = []
77
+ if search_vendor == "Google":
78
+ search_results = googleSearchClient.search(prompt, n_crawl)
79
+ elif search_vendor == "Bing":
80
+ search_results = bingSearchClient.search(prompt, n_crawl)
81
+
82
+ if len(search_results) > 0:
83
+ reranked_results = rerank(prompt, top_k, search_results, chunk_size)
84
+
85
+ links = []
86
+ context = ""
87
+ for res in reranked_results:
88
+ context += res["text"] + "\n\n"
89
+ link = res["link"]
90
+ links.append(link)
91
+
92
+ # remove duplicate links
93
+ links = list(set(links))
94
+
95
+ prev_output = prev_output if pass_prev else ""
96
+
97
+ augmented_prompt = f"""
98
+
99
+ {pre_context}
100
+
101
+ {context}
102
+
103
+
104
+ {post_context}
105
+
106
+ {pre_prompt}
107
+
108
+ {prompt}
109
+
110
+ {post_prompt}
111
+
112
+ {prev_output}
113
+
114
+ """
115
+
116
+ print(augmented_prompt)
117
+ return augmented_prompt, links
requirements.txt ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ notion-client==2.2.1
2
+ altair==5.1.2
3
+ asttokens==2.2.1
4
+ attrs==23.1.0
5
+ backcall==0.2.0
6
+ beautifulsoup4==4.12.2
7
+ blinker==1.6.3
8
+ cachetools==5.3.1
9
+ certifi==2023.7.22
10
+ charset-normalizer==3.3.0
11
+ click==8.1.7
12
+ colorama==0.4.6
13
+ comm==0.1.3
14
+ debugpy==1.6.7
15
+ decorator==5.1.1
16
+ dnspython==2.4.2
17
+ executing==1.2.0
18
+ filelock==3.12.4
19
+ fsspec==2023.9.2
20
+ gitdb==4.0.10
21
+ GitPython==3.1.37
22
+ huggingface-hub==0.18.0
23
+ idna==3.4
24
+ importlib-metadata==6.8.0
25
+ ipykernel==6.23.3
26
+ ipython==8.14.0
27
+ jedi==0.18.2
28
+ Jinja2==3.1.2
29
+ joblib==1.3.2
30
+ jsonschema==4.19.1
31
+ jsonschema-specifications==2023.7.1
32
+ jupyter_client==8.3.0
33
+ jupyter_core==5.3.1
34
+ loguru==0.7.2
35
+ markdown-it-py==3.0.0
36
+ MarkupSafe==2.1.3
37
+ matplotlib-inline==0.1.6
38
+ mdurl==0.1.2
39
+ mpmath==1.3.0
40
+ nest-asyncio==1.5.6
41
+ networkx==3.2.1
42
+ nltk==3.8.1
43
+ numpy==1.26.0
44
+ packaging==23.1
45
+ pandas==2.1.1
46
+ parso==0.8.3
47
+ pickleshare==0.7.5
48
+ Pillow==10.0.1
49
+ platformdirs==3.8.0
50
+ prompt-toolkit==3.0.38
51
+ protobuf==4.24.4
52
+ psutil==5.9.5
53
+ pure-eval==0.2.2
54
+ pyarrow==13.0.0
55
+ pydeck==0.8.1b0
56
+ Pygments==2.15.1
57
+ python-dateutil==2.8.2
58
+ python-dotenv==1.0.0
59
+ pytz==2023.3.post1
60
+ PyYAML==6.0.1
61
+ pyzmq==25.1.0
62
+ referencing==0.30.2
63
+ regex==2023.10.3
64
+ requests==2.31.0
65
+ rich==13.6.0
66
+ rpds-py==0.10.4
67
+ safetensors==0.4.1
68
+ scikit-learn==1.3.2
69
+ scipy==1.11.4
70
+ sentence-transformers==2.2.2
71
+ sentencepiece==0.1.99
72
+ six==1.16.0
73
+ smmap==5.0.1
74
+ soupsieve==2.5
75
+ stack-data==0.6.2
76
+ streamlit==1.27.2
77
+ sympy==1.12
78
+ tenacity==8.2.3
79
+ threadpoolctl==3.2.0
80
+ tokenizers==0.15.0
81
+ toml==0.10.2
82
+ toolz==0.12.0
83
+ torch==2.1.2
84
+ torchvision==0.16.2
85
+ tornado==6.3.2
86
+ tqdm==4.66.1
87
+ traitlets==5.9.0
88
+ transformers==4.35.2
89
+ typing_extensions==4.8.0
90
+ tzdata==2023.3
91
+ tzlocal==5.1
92
+ urllib3==2.0.6
93
+ validators==0.22.0
94
+ watchdog==3.0.0
95
+ wcwidth==0.2.6
96
+ win32-setctime==1.1.0
97
+ zipp==3.17.0