shimizukawa commited on
Commit
240fe8a
1 Parent(s): e04649d

Update streamlit UI

Browse files
Files changed (3) hide show
  1. README.md +7 -0
  2. app.py +10 -10
  3. config.py +14 -0
README.md CHANGED
@@ -12,6 +12,13 @@ license: mit
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
 
 
 
 
 
 
 
 
15
  # import GitHub issues
16
 
17
  ## export from github
 
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
 
15
+ # Required Environment variables
16
+
17
+ - `INDEX_KEYS`: comma separated index names
18
+ - `QDRANT_URL`: Qdrant API endpoint
19
+ - `QDRANT_API_KEY`: Qdrant API Key
20
+ - `OPENAI_API_KEY`: OpenAI API Key
21
+
22
  # import GitHub issues
23
 
24
  ## export from github
app.py CHANGED
@@ -1,5 +1,5 @@
1
- from time import time
2
  from datetime import datetime
 
3
  from typing import Iterable
4
 
5
  import streamlit as st
@@ -14,7 +14,7 @@ from langchain.chains import RetrievalQA
14
  from openai.error import InvalidRequestError
15
  from langchain.chat_models import ChatOpenAI
16
 
17
- from config import DB_CONFIG
18
  from models import BaseModel
19
 
20
 
@@ -202,8 +202,8 @@ def run_search(
202
 
203
  with st.form("my_form"):
204
  st.title("Document Search")
205
- query = st.text_input(label="query")
206
- index = st.text_input(label="index")
207
 
208
  submit_col1, submit_col2 = st.columns(2)
209
  searched = submit_col1.form_submit_button("Search")
@@ -226,12 +226,12 @@ with st.form("my_form"):
226
  st.write(text)
227
  st.write("score:", score, "Date:", ctime.date(), "User:", user)
228
  st.divider()
229
- qa_searched = submit_col2.form_submit_button("QA Search by OpenAI")
230
  if qa_searched:
231
  st.divider()
232
- st.header("QA Search Results by OpenAI GPT-3")
233
  st.divider()
234
- with st.spinner("QA Searching..."):
235
  results = run_qa(
236
  LLM,
237
  query,
@@ -243,12 +243,12 @@ with st.form("my_form"):
243
  st.markdown(html, unsafe_allow_html=True)
244
  st.divider()
245
  if torch.cuda.is_available():
246
- qa_searched_vicuna = submit_col2.form_submit_button("QA Search by Vicuna")
247
  if qa_searched_vicuna:
248
  st.divider()
249
- st.header("QA Search Results by Vicuna-13b-v1.5")
250
  st.divider()
251
- with st.spinner("QA Searching..."):
252
  results = run_qa(
253
  VICUNA_LLM,
254
  query,
 
 
1
  from datetime import datetime
2
+ from time import time
3
  from typing import Iterable
4
 
5
  import streamlit as st
 
14
  from openai.error import InvalidRequestError
15
  from langchain.chat_models import ChatOpenAI
16
 
17
+ from config import DB_CONFIG, INDEX_KEYS
18
  from models import BaseModel
19
 
20
 
 
202
 
203
  with st.form("my_form"):
204
  st.title("Document Search")
205
+ query = st.text_area(label="query")
206
+ index = st.selectbox(label="index", options=INDEX_KEYS)
207
 
208
  submit_col1, submit_col2 = st.columns(2)
209
  searched = submit_col1.form_submit_button("Search")
 
226
  st.write(text)
227
  st.write("score:", score, "Date:", ctime.date(), "User:", user)
228
  st.divider()
229
+ qa_searched = submit_col2.form_submit_button("Q&A by OpenAI")
230
  if qa_searched:
231
  st.divider()
232
+ st.header("Answer by OpenAI GPT-3")
233
  st.divider()
234
+ with st.spinner("Thinking..."):
235
  results = run_qa(
236
  LLM,
237
  query,
 
243
  st.markdown(html, unsafe_allow_html=True)
244
  st.divider()
245
  if torch.cuda.is_available():
246
+ qa_searched_vicuna = submit_col2.form_submit_button("Answer by Vicuna")
247
  if qa_searched_vicuna:
248
  st.divider()
249
+ st.header("Answer by Vicuna-13b-v1.5")
250
  st.divider()
251
+ with st.spinner("Thinking..."):
252
  results = run_qa(
253
  VICUNA_LLM,
254
  query,
config.py CHANGED
@@ -18,4 +18,18 @@ def get_local_db_congin():
18
  return url, None, collection_name
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  DB_CONFIG = get_db_config() if SAAS else get_local_db_congin()
 
 
18
  return url, None, collection_name
19
 
20
 
21
+ def get_index_keys():
22
+ keys = [
23
+ k for k in [
24
+ k.strip().lower()
25
+ for k in os.environ["INDEX_KEYS"].split(",")
26
+ ]
27
+ if k
28
+ ]
29
+ if not keys:
30
+ keys = ["INDEX_KEYS is empty"]
31
+ return keys
32
+
33
+
34
  DB_CONFIG = get_db_config() if SAAS else get_local_db_congin()
35
+ INDEX_KEYS = get_index_keys()