complynx commited on
Commit
5fa8f2f
Β·
1 Parent(s): e5ef522

Add all the variables right there

Browse files
Files changed (2) hide show
  1. app.py +18 -5
  2. backend/semantic_search.py +19 -4
app.py CHANGED
@@ -34,7 +34,8 @@ def add_text(history, text):
34
  return history, gr.Textbox(value="", interactive=False)
35
 
36
 
37
- def bot(history, api_kind):
 
38
  query = history[-1][0]
39
 
40
  if not query:
@@ -44,7 +45,7 @@ def bot(history, api_kind):
44
  # Retrieve documents relevant to query
45
  document_start = perf_counter()
46
 
47
- documents = retrieve(query, TOP_K)
48
 
49
  document_time = perf_counter() - document_start
50
  logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
@@ -86,19 +87,31 @@ with gr.Blocks() as demo:
86
  )
87
  txt_btn = gr.Button(value="Submit text", scale=1)
88
 
89
- api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace")
 
 
 
 
 
 
 
 
90
 
91
  prompt_html = gr.HTML()
92
  # Turn off interactivity while generating if you click
93
  txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
94
- bot, [chatbot, api_kind], [chatbot, prompt_html])
 
 
95
 
96
  # Turn it back on
97
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
98
 
99
  # Turn off interactivity while generating if you hit enter
100
  txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
101
- bot, [chatbot, api_kind], [chatbot, prompt_html])
 
 
102
 
103
  # Turn it back on
104
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
 
34
  return history, gr.Textbox(value="", interactive=False)
35
 
36
 
37
+ def bot(history, api_kind,
38
+ num_docs, model_kind, sub_vector_size, chunk_size, splitter_type):
39
  query = history[-1][0]
40
 
41
  if not query:
 
45
  # Retrieve documents relevant to query
46
  document_start = perf_counter()
47
 
48
+ documents = retrieve(query, int(num_docs), model_kind, sub_vector_size, chunk_size, splitter_type)
49
 
50
  document_time = perf_counter() - document_start
51
  logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
 
87
  )
88
  txt_btn = gr.Button(value="Submit text", scale=1)
89
 
90
+
91
+ with gr.Row():
92
+ num_docs = gr.Slider(1, 20, label="number of docs", step=1, value=4)
93
+ model_kind = gr.Radio(choices=["bge", "minilm"], value="bge", label="embedding model")
94
+ sub_vector_size = gr.Radio(choices=["8", "16", "32"], value="32", label="sub-vector size")
95
+ with gr.Row():
96
+ api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace", label="Chat model engine")
97
+ chunk_size = gr.Radio(choices=["500", "2000"], value="2000", label="chunk size")
98
+ splitter_type = gr.Radio(choices=["ct", "rct","nltk"], value="nltk", label="splitter")
99
 
100
  prompt_html = gr.HTML()
101
  # Turn off interactivity while generating if you click
102
  txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
103
+ bot, [chatbot, api_kind,
104
+ num_docs, model_kind, sub_vector_size, chunk_size, splitter_type
105
+ ], [chatbot, prompt_html])
106
 
107
  # Turn it back on
108
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
109
 
110
  # Turn off interactivity while generating if you hit enter
111
  txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
112
+ bot, [chatbot, api_kind,
113
+ num_docs, model_kind, sub_vector_size, chunk_size, splitter_type
114
+ ], [chatbot, prompt_html])
115
 
116
  # Turn it back on
117
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
backend/semantic_search.py CHANGED
@@ -6,18 +6,33 @@ from sentence_transformers import SentenceTransformer
6
 
7
  db = lancedb.connect(".lancedb")
8
 
 
 
 
 
 
 
 
 
9
  TABLE = db.open_table(os.getenv("TABLE_NAME"))
10
  VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
11
  TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
12
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
13
 
14
- retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
 
 
15
 
 
 
 
 
 
16
 
17
- def retrieve(query, k):
18
- query_vec = retriever.encode(query)
19
  try:
20
- documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
 
 
21
  documents = [doc[TEXT_COLUMN] for doc in documents]
22
 
23
  return documents
 
6
 
7
  db = lancedb.connect(".lancedb")
8
 
9
+ tables = {}
10
+
11
+ def table(tname):
12
+ if not tname in tables:
13
+ tables[tname] = db.open_table(tname)
14
+ return tables[tname]
15
+
16
+
17
  TABLE = db.open_table(os.getenv("TABLE_NAME"))
18
  VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
19
  TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
20
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
21
 
22
+ retriever_bge = SentenceTransformer("BAAI/bge-large-en-v1.5")
23
+ retriever_minilm = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
24
+
25
 
26
+ def retrieve(query, k, model_kind, sub_vector_size, chunk_size, splitter_type):
27
+ if model_kind == "bge":
28
+ query_vec = retriever_bge.encode(query)
29
+ else:
30
+ query_vec = retriever_minilm.encode(query)
31
 
 
 
32
  try:
33
+ documents = table(
34
+ f"{splitter_type}_{model_kind}_{sub_vector_size}_{chunk_size}",
35
+ ).search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
36
  documents = [doc[TEXT_COLUMN] for doc in documents]
37
 
38
  return documents