Ashmi Banerjee commited on
Commit
89cd5d5
1 Parent(s): ac20456

refactored the vectordb

Browse files
app.py CHANGED
@@ -56,7 +56,7 @@ def create_ui():
56
  " ")
57
 
58
  with gr.Group():
59
- countries = gr.Dropdown(choices=list(df.country), multiselect=False, label="Countries")
60
  starting_point = gr.Dropdown(choices=[], multiselect=False,
61
  label="Select your starting point for the trip!")
62
 
 
56
  " ")
57
 
58
  with gr.Group():
59
+ countries = gr.Dropdown(choices=list(df.country.unique()), multiselect=False, label="Country")
60
  starting_point = gr.Dropdown(choices=[], multiselect=False,
61
  label="Select your starting point for the trip!")
62
 
src/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
- from src.vectordb.vectordb import *
2
  from src.vectordb.helpers import *
3
- from src.vectordb.lancedb_init import *
4
 
5
  from src.sustainability.s_fairness import *
6
  from src.information_retrieval.info_retrieval import *
 
1
+ from src.vectordb.search import *
2
  from src.vectordb.helpers import *
3
+ from src.vectordb.schema import *
4
 
5
  from src.sustainability.s_fairness import *
6
  from src.information_retrieval.info_retrieval import *
src/information_retrieval/info_retrieval.py CHANGED
@@ -2,8 +2,11 @@ import sys
2
  import re
3
  import os
4
  import json
 
 
 
5
  sys.path.append("../")
6
- from src.vectordb import vectordb
7
  from src.sustainability import s_fairness
8
  import logging
9
 
@@ -12,6 +15,7 @@ logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
12
 
13
  from src.helpers.data_loaders import load_scores
14
 
 
15
  def get_travel_months(query):
16
  """
17
 
@@ -66,7 +70,7 @@ def get_wikivoyage_context(query, limit=10, reranking=0):
66
  # limit = params['limit']
67
  # reranking = params['reranking']
68
 
69
- docs = vectordb.search_wikivoyage_docs(query, limit, reranking)
70
  logger.info("Finished getting chunked wikivoyage docs.")
71
 
72
  results = {}
@@ -76,7 +80,7 @@ def get_wikivoyage_context(query, limit=10, reranking=0):
76
 
77
  cities = [result['city'] for result in docs]
78
 
79
- listings = vectordb.search_wikivoyage_listings(query, cities, limit, reranking)
80
  logger.info("Finished getting wikivoyage listings.")
81
  # logger.info(type(docs), type(listings))
82
 
@@ -92,7 +96,7 @@ def get_wikivoyage_context(query, limit=10, reranking=0):
92
  return results
93
 
94
 
95
- def get_sustainability_scores(starting_point: str , query: str, destinations: list):
96
  """
97
 
98
  Function to get the s-fairness scores for each destination for the given month (or the ideal month of travel if the user hasn't provided a month).
@@ -164,7 +168,7 @@ def get_cities(context: dict):
164
  """
165
 
166
  recommended_cities = []
167
-
168
  for city, info in context.items():
169
  city_info = {
170
  'city': city,
@@ -242,8 +246,8 @@ def test():
242
  # print(cities)
243
  except FileNotFoundError as e:
244
  try:
245
- vectordb.create_wikivoyage_docs_db_and_add_data()
246
- vectordb.create_wikivoyage_listings_db_and_add_data()
247
 
248
  try:
249
  context = get_context(query, sustainability=1)
 
2
  import re
3
  import os
4
  import json
5
+
6
+ from src.vectordb.ingest import create_wikivoyage_docs_db_and_add_data, create_wikivoyage_listings_db_and_add_data
7
+
8
  sys.path.append("../")
9
+ from src.vectordb.search import search_wikivoyage_listings, search_wikivoyage_docs
10
  from src.sustainability import s_fairness
11
  import logging
12
 
 
15
 
16
  from src.helpers.data_loaders import load_scores
17
 
18
+
19
  def get_travel_months(query):
20
  """
21
 
 
70
  # limit = params['limit']
71
  # reranking = params['reranking']
72
 
73
+ docs = search_wikivoyage_docs(query, limit, reranking)
74
  logger.info("Finished getting chunked wikivoyage docs.")
75
 
76
  results = {}
 
80
 
81
  cities = [result['city'] for result in docs]
82
 
83
+ listings = search_wikivoyage_listings(query, cities, limit, reranking)
84
  logger.info("Finished getting wikivoyage listings.")
85
  # logger.info(type(docs), type(listings))
86
 
 
96
  return results
97
 
98
 
99
+ def get_sustainability_scores(starting_point: str, query: str, destinations: list):
100
  """
101
 
102
  Function to get the s-fairness scores for each destination for the given month (or the ideal month of travel if the user hasn't provided a month).
 
168
  """
169
 
170
  recommended_cities = []
171
+ info = context[list(context.keys())[0]]
172
  for city, info in context.items():
173
  city_info = {
174
  'city': city,
 
246
  # print(cities)
247
  except FileNotFoundError as e:
248
  try:
249
+ create_wikivoyage_docs_db_and_add_data()
250
+ create_wikivoyage_listings_db_and_add_data()
251
 
252
  try:
253
  context = get_context(query, sustainability=1)
src/vectordb/create_db.py CHANGED
@@ -1,9 +1,11 @@
1
- from src.vectordb.vectordb import *
2
  import logging
3
 
4
  logger = logging.getLogger(__name__)
5
  logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
6
 
 
 
7
 
8
  def run():
9
  logging.info("Creating database for Wikivoyage Documents")
 
1
+ from src.vectordb.search import *
2
  import logging
3
 
4
  logger = logging.getLogger(__name__)
5
  logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
6
 
7
+ from src.vectordb.ingest import create_wikivoyage_docs_db_and_add_data, create_wikivoyage_listings_db_and_add_data
8
+
9
 
10
  def run():
11
  logging.info("Creating database for Wikivoyage Documents")
src/vectordb/helpers.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import pandas as pd
2
  import os
3
  import re
@@ -7,7 +9,7 @@ import sys
7
  SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
8
  sys.path.append(os.path.dirname(SCRIPT_DIR))
9
 
10
- from data_directories import *
11
 
12
 
13
  def create_chunks(city, country, text):
@@ -148,3 +150,15 @@ def embed_query(query):
148
  # vector_dimension = model.get_sentence_embedding_dimension()
149
  embedding = model.encode(query).tolist()
150
  return embedding
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
  import pandas as pd
4
  import os
5
  import re
 
9
  SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
10
  sys.path.append(os.path.dirname(SCRIPT_DIR))
11
 
12
+ from src.data_directories import *
13
 
14
 
15
  def create_chunks(city, country, text):
 
150
  # vector_dimension = model.get_sentence_embedding_dimension()
151
  embedding = model.encode(query).tolist()
152
  return embedding
153
+
154
+
155
+ def set_uri(run_local: Optional[bool] = False):
156
+ if run_local:
157
+ uri = database_dir
158
+ current_dir = os.path.split(os.getcwd())[1]
159
+
160
+ if "src" or "tests" in current_dir: # hacky way to get the correct path
161
+ uri = uri.replace("../../", "../")
162
+ else:
163
+ uri = os.environ["BUCKET_NAME"]
164
+ return uri
src/vectordb/ingest.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Callable
2
+ import logging
3
+ logger = logging.getLogger(__name__)
4
+ logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
5
+ from src.vectordb.helpers import read_docs, read_listings, preprocess_df
6
+ from src.vectordb.schema import WikivoyageDocuments, WikivoyageListings
7
+ from src.vectordb.helpers import set_uri
8
+ import lancedb
9
+
10
+
11
+ def _create_table_and_ingest_data(table_name: str, schema: object, data_fetcher: Callable,
12
+ preprocessor: Optional[Callable] = None):
13
+ """
14
+ Generalized function to create a table and ingest data into the database.
15
+
16
+ Args:
17
+ - table_name: str, name of the table to create.
18
+ - schema: object, schema of the table.
19
+ - data_fetcher: Callable, function to fetch the data.
20
+ - preprocessor: Optional[Callable], function to preprocess the data (default is None).
21
+ """
22
+ uri = set_uri()
23
+
24
+ db = lancedb.connect(uri)
25
+ logger.info(f"Connected to DB. Reading data for table {table_name} now...")
26
+
27
+ df = data_fetcher()
28
+
29
+ if preprocessor:
30
+ df = preprocessor(df)
31
+
32
+ logger.info(f"Finished reading data for {table_name}, attempting to create table and ingest the data...")
33
+
34
+ db.drop_table(table_name, ignore_missing=True)
35
+ table = db.create_table(table_name, schema=schema)
36
+
37
+ table.add(df.to_dict('records'))
38
+ logger.info(f"Completed ingestion for {table_name}.")
39
+
40
+
41
+ def create_wikivoyage_docs_db_and_add_data():
42
+ """
43
+ Creates the Wikivoyage documents table and ingests data.
44
+ """
45
+ _create_table_and_ingest_data(
46
+ table_name="wikivoyage_documents",
47
+ schema=WikivoyageDocuments,
48
+ data_fetcher=read_docs,
49
+ preprocessor=preprocess_df
50
+ )
51
+
52
+
53
+ def create_wikivoyage_listings_db_and_add_data():
54
+ """
55
+ Creates the Wikivoyage listings table and ingests data.
56
+ """
57
+ _create_table_and_ingest_data(
58
+ table_name="wikivoyage_listings",
59
+ schema=WikivoyageListings,
60
+ data_fetcher=read_listings
61
+ )
src/vectordb/{lancedb_init.py → schema.py} RENAMED
File without changes
src/vectordb/search.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from src import *
2
+
3
+ import logging
4
+ import os
5
+ import lancedb
6
+ from lancedb.rerankers import ColbertReranker
7
+
8
+ import sys
9
+
10
+ logger = logging.getLogger(__name__)
11
+ from typing import Optional
12
+ from src.vectordb.helpers import set_uri
13
+
14
+
15
+ # db = lancedb.connect("/tmp/db")
16
+
17
+
18
+ def search(query: str, table_name: str, filter_condition: Optional[str] = None,
19
+ category: str = "docs", limit: int = 10, reranking: int = 0,
20
+ run_local: Optional[bool] = False) -> list | None:
21
+ """
22
+ Generalized function to search a database table, with optional filters and reranking.
23
+
24
+ Args:
25
+ - query: str, search query.
26
+ - table_name: str, name of the table to search.
27
+ - filter_condition: Optional[str], optional SQL-like condition for filtering results.
28
+ - category: str, type of category (default is 'docs').
29
+ - limit: int, number of results (default is 10).
30
+ - reranking: int (0 or 1), if activated, ColbertReranker is used.
31
+ - run_local: Optional[bool], whether to run in a local environment.
32
+
33
+ Returns:
34
+ A list of the most relevant documents or listings based on the category.
35
+ """
36
+ uri = set_uri(run_local)
37
+
38
+ try:
39
+ db = lancedb.connect(uri)
40
+ except Exception as e:
41
+ logger.error(f"Error while connecting to DB: {e}")
42
+ return None
43
+
44
+ logger.info(f"Connected to {table_name} DB.")
45
+ table = db.open_table(table_name)
46
+
47
+ search_query = table.search(query).metric('cosine')
48
+
49
+ if filter_condition:
50
+ search_query = search_query.where(filter_condition)
51
+
52
+ if reranking:
53
+ try:
54
+ column = 'description' if category == 'listings' else 'text'
55
+ reranker = ColbertReranker(column=column)
56
+ results = search_query.rerank(reranker=reranker).limit(limit).to_list()
57
+ except Exception as e:
58
+ exc_type, exc_obj, exc_tb = sys.exc_info()
59
+ fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
60
+ logger.error(f"Error while reranking results: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
61
+ return None
62
+ else:
63
+ try:
64
+ results = search_query.limit(limit).to_list()
65
+ except Exception as e:
66
+ exc_type, exc_obj, exc_tb = sys.exc_info()
67
+ fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
68
+ logger.error(f"Error while searching: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
69
+ return None
70
+
71
+ logger.info("Found the most relevant documents.")
72
+
73
+ if category == "docs":
74
+ return [{"city": r['city'], "country": r['country'], "section": r['section'], "text": r['text']} for r in
75
+ results]
76
+ else:
77
+ return [{"city": r['city'], "country": r['country'], "type": r['type'], "title": r['title'],
78
+ "description": r['description']} for r in results]
79
+
80
+
81
+ def search_wikivoyage_docs(query: str, limit: int = 10, reranking: int = 0,
82
+ run_local: Optional[bool] = False) -> list | None:
83
+ """
84
+ Function to search documents in the Wikivoyage database.
85
+ """
86
+ return search(query=query, table_name="wikivoyage_documents", category="docs",
87
+ limit=limit, reranking=reranking, run_local=run_local)
88
+
89
+
90
+ def search_wikivoyage_listings(query: str, cities: list, limit: int = 10, reranking: int = 0,
91
+ run_local: Optional[bool] = False) -> list | None:
92
+ """
93
+ Function to search listings in the Wikivoyage database, post-filtered by cities.
94
+ """
95
+ cities_filter = f"city IN {tuple(cities)}"
96
+ return search(query=query, table_name="wikivoyage_listings", filter_condition=cities_filter,
97
+ category="listings", limit=limit, reranking=reranking, run_local=run_local)
src/vectordb/vectordb.py DELETED
@@ -1,190 +0,0 @@
1
- # from src import *
2
- from src.vectordb.helpers import *
3
- from src.vectordb.lancedb_init import *
4
- import logging
5
- import os
6
- import lancedb
7
- from lancedb.rerankers import ColbertReranker
8
-
9
- import sys
10
- logger = logging.getLogger(__name__)
11
- from typing import Optional
12
-
13
- # db = lancedb.connect("/tmp/db")
14
-
15
- def create_wikivoyage_docs_db_and_add_data():
16
- """
17
-
18
- Creates wikivoyage documents table and ingests data
19
-
20
- """
21
- uri = database_dir
22
- current_dir = os.path.split(os.getcwd())[1]
23
-
24
- if "src" or "tests" in current_dir: # hacky way to get the correct path
25
- uri = uri.replace("../../", "../")
26
-
27
- db = lancedb.connect(uri)
28
- logger.info("Connected to DB. Reading data now...")
29
- df = read_docs()
30
- filtered_df = preprocess_df(df)
31
- logger.info("Finished reading data, attempting to create table and ingest the data...")
32
-
33
- db.drop_table("wikivoyage_documents", ignore_missing=True)
34
- table = db.create_table("wikivoyage_documents", schema=WikivoyageDocuments)
35
-
36
- table.add(filtered_df.to_dict('records'))
37
- logger.info("Completed ingestion.")
38
-
39
-
40
- def create_wikivoyage_listings_db_and_add_data():
41
- """
42
-
43
- Creates wikivoyage listings table and ingests data
44
-
45
- """
46
- uri = database_dir
47
- current_dir = os.path.split(os.getcwd())[1]
48
-
49
- if "src" or "tests" in current_dir: # hacky way to get the correct path
50
- uri = uri.replace("../../", "../")
51
-
52
- db = lancedb.connect(uri)
53
- logger.info("Connected to DB. Reading data now...")
54
- df = read_listings()
55
- logger.info("Finished reading data, attempting to create table and ingest the data...")
56
- # filtered_df = preprocess_df(df)
57
-
58
- db.drop_table("wikivoyage_listings", ignore_missing=True)
59
- table = db.create_table("wikivoyage_listings", schema=WikivoyageListings)
60
-
61
- table.add(df.astype('str').to_dict('records'))
62
- logger.info("Completed ingestion.")
63
-
64
-
65
- def search_wikivoyage_docs(query: str, limit: int = 10, reranking: int = 0, run_local: Optional[bool] = False):
66
- """
67
-
68
- Function to search the wikivoyage database an return most relevant chunked docs.
69
-
70
- Args:
71
- - query: str
72
- - limit: number of results (default is 10)
73
- - reranking: bool (0 or 1), if activated, CrossEncoderReranker is used.
74
-
75
- """
76
- if run_local:
77
- uri = database_dir
78
- current_dir = os.path.split(os.getcwd())[1]
79
-
80
- if "src" or "tests" in current_dir: # hacky way to get the correct path
81
- uri = uri.replace("../../", "../")
82
- else:
83
- uri = os.environ["BUCKET_NAME"]
84
- # print(uri)
85
- try:
86
- db = lancedb.connect(uri)
87
- except Exception as e:
88
- logger.error(f"Error while connecting to DB: {e}")
89
-
90
- logger.info("Connected to Wikivoyage DB.")
91
- print("Tablenames: ", db.table_names())
92
-
93
- # query_embedding = embed_query(query)
94
- table = db.open_table("wikivoyage_documents")
95
-
96
- if reranking:
97
- try:
98
- reranker = ColbertReranker(column='text')
99
- results = table.search(query) \
100
- .metric('cosine') \
101
- .rerank(reranker=reranker) \
102
- .limit(limit) \
103
- .to_list()
104
- except Exception as e:
105
- exc_type, exc_obj, exc_tb = sys.exc_info()
106
- fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
107
- logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
108
-
109
- else:
110
- try:
111
- results = table.search(query) \
112
- .limit(limit) \
113
- .metric('cosine') \
114
- .to_list()
115
- except Exception as e:
116
- exc_type, exc_obj, exc_tb = sys.exc_info()
117
- fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
118
- logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
119
-
120
- logger.info("Found the most relevant documents.")
121
- city_lists = [{"city": r['city'], "country": r['country'], "section": r['section'], "text": r['text']} for r in
122
- results]
123
-
124
- # context = [f"city: {r['city']}, country: {r['country']}, name: {r['title']}, description: {r['description']}"
125
- # for r in results]
126
-
127
- return city_lists
128
-
129
-
130
- def search_wikivoyage_listings(query:str, cities: list, limit: int=10, reranking: int = 0, run_local: Optional[bool] = False):
131
- """
132
-
133
- Function to search the wikivoyage database an return most relevant listings, post-filtered by the recommended
134
- cities provided by wikivoyage_documents table.
135
-
136
- Args:
137
- - query: str
138
- - cities: list
139
- - limit: number of results (default is 10)
140
- - reranking: bool (0 or 1), if activated, CrossEncoderReranker is used.
141
-
142
- """
143
- if run_local:
144
- uri = database_dir
145
- current_dir = os.path.split(os.getcwd())[1]
146
-
147
- if "src" or "tests" in current_dir: # hacky way to get the correct path
148
- uri = uri.replace("../../", "../")
149
- else:
150
- uri = os.environ["BUCKET_NAME"]
151
-
152
- db = lancedb.connect(uri)
153
- logger.info("Connected to Wikivoyage Listings DB.")
154
-
155
- table = db.open_table("wikivoyage_listings")
156
-
157
- cities_filter = f"city IN {tuple(cities)}"
158
-
159
- if reranking:
160
- try:
161
- reranker = ColbertReranker(column='description')
162
- results = table.search(query) \
163
- .where(cities_filter) \
164
- .metric('cosine') \
165
- .rerank(reranker=reranker) \
166
- .limit(limit) \
167
- .to_list()
168
-
169
- except Exception as e:
170
- exc_type, exc_obj, exc_tb = sys.exc_info()
171
- fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
172
- logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
173
-
174
- else:
175
- try:
176
- results = table.search(query) \
177
- .where(cities_filter) \
178
- .metric('cosine') \
179
- .limit(limit) \
180
- .to_list()
181
- except Exception as e:
182
- exc_type, exc_obj, exc_tb = sys.exc_info()
183
- fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
184
- logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
185
-
186
- logger.info("Found the most relevant documents.")
187
- city_listings = [{"city": r['city'], "country": r['country'], "type": r['type'], "title": r['title'],
188
- "description": r['description']} for r in results]
189
-
190
- return city_listings