Ashmi Banerjee
refactored the vectordb
89cd5d5
raw
history blame
9.05 kB
import sys
import re
import os
import json
from src.vectordb.ingest import create_wikivoyage_docs_db_and_add_data, create_wikivoyage_listings_db_and_add_data
sys.path.append("../")
from src.vectordb.search import search_wikivoyage_listings, search_wikivoyage_docs
from src.sustainability import s_fairness
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
from src.helpers.data_loaders import load_scores
def get_travel_months(query):
"""
Function to parse the user's query and search if month of travel has been provided by the user.
Args:
- query: str
"""
months = [
"January", "February", "March", "April", "May", "June",
"July", "August", "September", "October", "November", "December"
]
seasons = {
"spring": ["March", "April", "May"],
"summer": ["June", "July", "August"],
"fall": ["September", "October", "November"],
"autumn": ["September", "October", "November"],
"winter": ["December", "January", "February"]
}
months_in_query = []
for month in months:
if re.search(r'\b' + month + r'\b', query, re.IGNORECASE):
months_in_query.append(month)
# Check for seasons in the query
for season, season_months in seasons.items():
if re.search(r'\b' + season + r'\b', query, re.IGNORECASE):
months_in_query += season_months
# Return None if neither months nor seasons are found
return months_in_query
def get_wikivoyage_context(query, limit=10, reranking=0):
"""
Function to retrieve the relevant documents and listings from the wikivoyage database. Works in two steps:
(i) the relevant cities are returned by the wikivoyage_docs table and (ii) then passed on to the wikivoyage listings database to retrieve further information.
The user can pass a limit of how many results the search should return as well as whether to perform reranking (uses a CrossEncoderReranker)
Args:
- query: str
- limit: int
- reranking: bool
"""
# limit = params['limit']
# reranking = params['reranking']
docs = search_wikivoyage_docs(query, limit, reranking)
logger.info("Finished getting chunked wikivoyage docs.")
results = {}
for doc in docs:
results[doc['city']] = {key: value for key, value in doc.items() if key != 'city'}
results[doc['city']]['listings'] = []
cities = [result['city'] for result in docs]
listings = search_wikivoyage_listings(query, cities, limit, reranking)
logger.info("Finished getting wikivoyage listings.")
# logger.info(type(docs), type(listings))
for listing in listings:
# logger.info(listing['city'])
results[listing['city']]['listings'].append({
'type': listing['type'],
'name': listing['title'],
'description': listing['description']
})
logger.info("Returning retrieval results.")
return results
def get_sustainability_scores(starting_point: str, query: str, destinations: list):
"""
Function to get the s-fairness scores for each destination for the given month (or the ideal month of travel if the user hasn't provided a month).
If multiple months are provided (or season), then the month with the minimum s-fairness score is chosen for the city.
Args:
- query: str
- destinations: list
"""
result = [] # list of dicts of the format {city: <city>, month: <month>, }
city_scores = {}
months = get_travel_months(query)
logger.info("Finished parsing query for months.")
popularity_data = load_scores("popularity")
seasonality_data = load_scores("seasonality")
emissions_data = load_scores("emissions")
data = [popularity_data, seasonality_data, emissions_data]
for city in destinations:
if city not in city_scores:
city_scores[city] = []
if not months: # no month(s) or seasons provided by the user
city_scores[city].append(s_fairness.compute_sfairness_score(data, starting_point, city))
else:
for month in months:
city_scores[city].append(s_fairness.compute_sfairness_score(data, city, month))
logger.info("Finished getting s-fairness scores.")
for city, scores in city_scores.items():
no_result = 0
for score in scores:
if not score['month']:
no_result = 1
result.append({
'city': city,
'month': 'No data available',
's-fairness': 'No data available',
'mode': 'No data available'
})
break
if not no_result:
min_score = min(scores, key=lambda x: x['s-fairness'])
result.append({
'city': city,
'month': min_score['month'],
's-fairness': min_score['s-fairness'],
'mode': min_score['mode'],
})
logger.info("Returning s-fairness results.")
return result
def get_cities(context: dict):
"""
Only to be used for testing! Function that returns a list of cities with their s-fairness scores, provided the retrieved context
Args:
- context: dict
"""
recommended_cities = []
info = context[list(context.keys())[0]]
for city, info in context.items():
city_info = {
'city': city,
'country': info['country']
}
if "sustainability" in info:
city_info['month'] = info['sustainability']['month']
city_info['s-fairness'] = info['sustainability']['s-fairness']
recommended_cities.append(city_info)
if "sustainability" in info:
def get_s_fairness_value(item):
s_fairness = item['s-fairness']
if s_fairness == 'No data available':
return float('inf') # Assign a high value for "No data available"
return s_fairness
# Sort the list using the custom key
sorted_cities = sorted(recommended_cities, key=get_s_fairness_value)
return sorted_cities
else:
return recommended_cities
def get_context(starting_point: str, query: str, **params):
"""
Function that returns all the context: from the database, as well as the respective s-fairness scores for the
destinations. The default does not consider S-Fairness scores, i.e. to append sustainability scores, a non-zero
parameter "sustainability" needs to be explicitly passed to params.
Args:
- query: str
- params: dict; contains value of the limit and reranking (and sustainability)
"""
limit = 3
reranking = 1
if 'limit' in params:
limit = params['limit']
if 'reranking' in params:
reranking = params['reranking']
wikivoyage_context = get_wikivoyage_context(query, limit, reranking)
recommended_cities = wikivoyage_context.keys()
if 'sustainability' in params and params['sustainability']:
s_fairness_scores = get_sustainability_scores(starting_point, query, recommended_cities)
for score in s_fairness_scores:
wikivoyage_context[score['city']]['sustainability'] = {
'month': score['month'],
's-fairness': score['s-fairness'],
'transport': score['mode']
}
return wikivoyage_context
def test():
queries = []
query = "Suggest some places to visit during winter. I like hiking, nature and the mountains and I enjoy skiing " \
"in winter. "
starting_point = "Munich"
context = None
try:
context = get_context(starting_point, query, sustainability=1)
# cities = get_cities(context)
# print(cities)
except FileNotFoundError as e:
try:
create_wikivoyage_docs_db_and_add_data()
create_wikivoyage_listings_db_and_add_data()
try:
context = get_context(query, sustainability=1)
# cities = get_cities(context)
# print(cities)
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
except Exception as e:
logger.error(f"Error while creating DB: {e}")
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
file_path = os.path.join(os.getcwd(), "test_results", "test_result.json")
with open(file_path, 'w') as file:
json.dump(context, file)
return context
if __name__ == "__main__":
context = test()
print(context)