shimizukawa commited on
Commit
d7b3b8a
1 Parent(s): 7f20d45

support multiple index

Browse files
Files changed (1) hide show
  1. app.py +25 -11
app.py CHANGED
@@ -77,6 +77,18 @@ def _get_vicuna_llm(temperature=0.2) -> HuggingFacePipeline | None:
77
  VICUNA_LLM = _get_vicuna_llm()
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def make_filter_obj(options: list[dict[str]]):
81
  # print(options)
82
  must = []
@@ -152,20 +164,22 @@ def _get_related_url(metadata) -> Iterable[str]:
152
 
153
  def _get_query_str_filter(
154
  query: str,
155
- index: str,
156
  ) -> tuple[str, Filter]:
157
- options = [{"key": "metadata.index", "value": index}]
158
- filter = make_filter_obj(options=options)
 
 
159
  return query, filter
160
 
161
 
162
  def run_qa(
163
  llm,
164
  query: str,
165
- index: str,
166
  ) -> tuple[str, str]:
167
  now = time()
168
- query_str, filter = _get_query_str_filter(query, index)
169
  qa = get_retrieval_qa(filter, llm)
170
  try:
171
  result = qa(query_str)
@@ -180,9 +194,9 @@ def run_qa(
180
 
181
  def run_search(
182
  query: str,
183
- index: str,
184
  ) -> Iterable[tuple[BaseModel, float, str]]:
185
- query_str, filter = _get_query_str_filter(query, index)
186
  qdocs = get_similay(query_str, filter)
187
  for qdoc, score in qdocs:
188
  text = qdoc.page_content
@@ -203,7 +217,7 @@ def run_search(
203
  with st.form("my_form"):
204
  st.title("Document Search")
205
  query = st.text_area(label="query")
206
- index = st.selectbox(label="index", options=INDEX_NAMES)
207
 
208
  submit_col1, submit_col2 = st.columns(2)
209
  searched = submit_col1.form_submit_button("Search")
@@ -212,7 +226,7 @@ with st.form("my_form"):
212
  st.header("Search Results")
213
  st.divider()
214
  with st.spinner("Searching..."):
215
- results = run_search(query, index)
216
  for doc, score, text in results:
217
  title = doc.title
218
  url = doc.url
@@ -235,7 +249,7 @@ with st.form("my_form"):
235
  results = run_qa(
236
  LLM,
237
  query,
238
- index,
239
  )
240
  answer, html = results
241
  with st.container():
@@ -252,7 +266,7 @@ with st.form("my_form"):
252
  results = run_qa(
253
  VICUNA_LLM,
254
  query,
255
- index,
256
  )
257
  answer, html = results
258
  with st.container():
 
77
  VICUNA_LLM = _get_vicuna_llm()
78
 
79
 
80
+ def make_index_filter_obj(index_list: list[str]):
81
+ should = []
82
+ for index in index_list:
83
+ should.append(
84
+ FieldCondition(
85
+ key="metadata.index", match=MatchValue(value=index)
86
+ )
87
+ )
88
+ filter = Filter(should=should)
89
+ return filter
90
+
91
+
92
  def make_filter_obj(options: list[dict[str]]):
93
  # print(options)
94
  must = []
 
164
 
165
  def _get_query_str_filter(
166
  query: str,
167
+ index_list: list[str],
168
  ) -> tuple[str, Filter]:
169
+ # options = [{"key": "metadata.index", "value": index_list[0]}]
170
+ # filter = make_filter_obj(options=options)
171
+
172
+ filter = make_index_filter_obj(index_list)
173
  return query, filter
174
 
175
 
176
  def run_qa(
177
  llm,
178
  query: str,
179
+ index_list: list[str],
180
  ) -> tuple[str, str]:
181
  now = time()
182
+ query_str, filter = _get_query_str_filter(query, index_list)
183
  qa = get_retrieval_qa(filter, llm)
184
  try:
185
  result = qa(query_str)
 
194
 
195
  def run_search(
196
  query: str,
197
+ index_list: list[str],
198
  ) -> Iterable[tuple[BaseModel, float, str]]:
199
+ query_str, filter = _get_query_str_filter(query, index_list)
200
  qdocs = get_similay(query_str, filter)
201
  for qdoc, score in qdocs:
202
  text = qdoc.page_content
 
217
  with st.form("my_form"):
218
  st.title("Document Search")
219
  query = st.text_area(label="query")
220
+ index_list = st.multiselect(label="index", options=INDEX_NAMES)
221
 
222
  submit_col1, submit_col2 = st.columns(2)
223
  searched = submit_col1.form_submit_button("Search")
 
226
  st.header("Search Results")
227
  st.divider()
228
  with st.spinner("Searching..."):
229
+ results = run_search(query, index_list)
230
  for doc, score, text in results:
231
  title = doc.title
232
  url = doc.url
 
249
  results = run_qa(
250
  LLM,
251
  query,
252
+ index_list,
253
  )
254
  answer, html = results
255
  with st.container():
 
266
  results = run_qa(
267
  VICUNA_LLM,
268
  query,
269
+ index_list,
270
  )
271
  answer, html = results
272
  with st.container():