Spaces:
Runtime error
Runtime error
Add all the variables right there
Browse files- app.py +18 -5
- backend/semantic_search.py +19 -4
app.py
CHANGED
@@ -34,7 +34,8 @@ def add_text(history, text):
|
|
34 |
return history, gr.Textbox(value="", interactive=False)
|
35 |
|
36 |
|
37 |
-
def bot(history, api_kind
|
|
|
38 |
query = history[-1][0]
|
39 |
|
40 |
if not query:
|
@@ -44,7 +45,7 @@ def bot(history, api_kind):
|
|
44 |
# Retrieve documents relevant to query
|
45 |
document_start = perf_counter()
|
46 |
|
47 |
-
documents = retrieve(query,
|
48 |
|
49 |
document_time = perf_counter() - document_start
|
50 |
logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
|
@@ -86,19 +87,31 @@ with gr.Blocks() as demo:
|
|
86 |
)
|
87 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
88 |
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
prompt_html = gr.HTML()
|
92 |
# Turn off interactivity while generating if you click
|
93 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
94 |
-
bot, [chatbot, api_kind
|
|
|
|
|
95 |
|
96 |
# Turn it back on
|
97 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
98 |
|
99 |
# Turn off interactivity while generating if you hit enter
|
100 |
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
101 |
-
bot, [chatbot, api_kind
|
|
|
|
|
102 |
|
103 |
# Turn it back on
|
104 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
|
|
34 |
return history, gr.Textbox(value="", interactive=False)
|
35 |
|
36 |
|
37 |
+
def bot(history, api_kind,
|
38 |
+
num_docs, model_kind, sub_vector_size, chunk_size, splitter_type):
|
39 |
query = history[-1][0]
|
40 |
|
41 |
if not query:
|
|
|
45 |
# Retrieve documents relevant to query
|
46 |
document_start = perf_counter()
|
47 |
|
48 |
+
documents = retrieve(query, int(num_docs), model_kind, sub_vector_size, chunk_size, splitter_type)
|
49 |
|
50 |
document_time = perf_counter() - document_start
|
51 |
logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
|
|
|
87 |
)
|
88 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
89 |
|
90 |
+
|
91 |
+
with gr.Row():
|
92 |
+
num_docs = gr.Slider(1, 20, label="number of docs", step=1, value=4)
|
93 |
+
model_kind = gr.Radio(choices=["bge", "minilm"], value="bge", label="embedding model")
|
94 |
+
sub_vector_size = gr.Radio(choices=["8", "16", "32"], value="32", label="sub-vector size")
|
95 |
+
with gr.Row():
|
96 |
+
api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace", label="Chat model engine")
|
97 |
+
chunk_size = gr.Radio(choices=["500", "2000"], value="2000", label="chunk size")
|
98 |
+
splitter_type = gr.Radio(choices=["ct", "rct","nltk"], value="nltk", label="splitter")
|
99 |
|
100 |
prompt_html = gr.HTML()
|
101 |
# Turn off interactivity while generating if you click
|
102 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
103 |
+
bot, [chatbot, api_kind,
|
104 |
+
num_docs, model_kind, sub_vector_size, chunk_size, splitter_type
|
105 |
+
], [chatbot, prompt_html])
|
106 |
|
107 |
# Turn it back on
|
108 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
109 |
|
110 |
# Turn off interactivity while generating if you hit enter
|
111 |
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
112 |
+
bot, [chatbot, api_kind,
|
113 |
+
num_docs, model_kind, sub_vector_size, chunk_size, splitter_type
|
114 |
+
], [chatbot, prompt_html])
|
115 |
|
116 |
# Turn it back on
|
117 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
backend/semantic_search.py
CHANGED
@@ -6,18 +6,33 @@ from sentence_transformers import SentenceTransformer
|
|
6 |
|
7 |
db = lancedb.connect(".lancedb")
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
TABLE = db.open_table(os.getenv("TABLE_NAME"))
|
10 |
VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
|
11 |
TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
|
12 |
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
|
13 |
|
14 |
-
|
|
|
|
|
15 |
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
def retrieve(query, k):
|
18 |
-
query_vec = retriever.encode(query)
|
19 |
try:
|
20 |
-
documents =
|
|
|
|
|
21 |
documents = [doc[TEXT_COLUMN] for doc in documents]
|
22 |
|
23 |
return documents
|
|
|
6 |
|
7 |
db = lancedb.connect(".lancedb")
|
8 |
|
9 |
+
tables = {}
|
10 |
+
|
11 |
+
def table(tname):
|
12 |
+
if not tname in tables:
|
13 |
+
tables[tname] = db.open_table(tname)
|
14 |
+
return tables[tname]
|
15 |
+
|
16 |
+
|
17 |
TABLE = db.open_table(os.getenv("TABLE_NAME"))
|
18 |
VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
|
19 |
TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
|
20 |
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
|
21 |
|
22 |
+
retriever_bge = SentenceTransformer("BAAI/bge-large-en-v1.5")
|
23 |
+
retriever_minilm = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
24 |
+
|
25 |
|
26 |
+
def retrieve(query, k, model_kind, sub_vector_size, chunk_size, splitter_type):
|
27 |
+
if model_kind == "bge":
|
28 |
+
query_vec = retriever_bge.encode(query)
|
29 |
+
else:
|
30 |
+
query_vec = retriever_minilm.encode(query)
|
31 |
|
|
|
|
|
32 |
try:
|
33 |
+
documents = table(
|
34 |
+
f"{splitter_type}_{model_kind}_{sub_vector_size}_{chunk_size}",
|
35 |
+
).search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
|
36 |
documents = [doc[TEXT_COLUMN] for doc in documents]
|
37 |
|
38 |
return documents
|