Spaces:
Running
Running
import sys | |
import re | |
import os | |
import json | |
from src.vectordb.ingest import create_wikivoyage_docs_db_and_add_data, create_wikivoyage_listings_db_and_add_data | |
sys.path.append("../") | |
from src.vectordb.search import search_wikivoyage_listings, search_wikivoyage_docs | |
from src.sustainability import s_fairness | |
import logging | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(encoding='utf-8', level=logging.DEBUG) | |
from src.helpers.data_loaders import load_scores | |
def get_travel_months(query): | |
""" | |
Function to parse the user's query and search if month of travel has been provided by the user. | |
Args: | |
- query: str | |
""" | |
months = [ | |
"January", "February", "March", "April", "May", "June", | |
"July", "August", "September", "October", "November", "December" | |
] | |
seasons = { | |
"spring": ["March", "April", "May"], | |
"summer": ["June", "July", "August"], | |
"fall": ["September", "October", "November"], | |
"autumn": ["September", "October", "November"], | |
"winter": ["December", "January", "February"] | |
} | |
months_in_query = [] | |
for month in months: | |
if re.search(r'\b' + month + r'\b', query, re.IGNORECASE): | |
months_in_query.append(month) | |
# Check for seasons in the query | |
for season, season_months in seasons.items(): | |
if re.search(r'\b' + season + r'\b', query, re.IGNORECASE): | |
months_in_query += season_months | |
# Return None if neither months nor seasons are found | |
return months_in_query | |
def get_wikivoyage_context(query, limit=10, reranking=0): | |
""" | |
Function to retrieve the relevant documents and listings from the wikivoyage database. Works in two steps: | |
(i) the relevant cities are returned by the wikivoyage_docs table and (ii) then passed on to the wikivoyage listings database to retrieve further information. | |
The user can pass a limit of how many results the search should return as well as whether to perform reranking (uses a CrossEncoderReranker) | |
Args: | |
- query: str | |
- limit: int | |
- reranking: bool | |
""" | |
# limit = params['limit'] | |
# reranking = params['reranking'] | |
docs = search_wikivoyage_docs(query, limit, reranking) | |
logger.info("Finished getting chunked wikivoyage docs.") | |
results = {} | |
for doc in docs: | |
results[doc['city']] = {key: value for key, value in doc.items() if key != 'city'} | |
results[doc['city']]['listings'] = [] | |
cities = [result['city'] for result in docs] | |
listings = search_wikivoyage_listings(query, cities, limit, reranking) | |
logger.info("Finished getting wikivoyage listings.") | |
# logger.info(type(docs), type(listings)) | |
for listing in listings: | |
# logger.info(listing['city']) | |
results[listing['city']]['listings'].append({ | |
'type': listing['type'], | |
'name': listing['title'], | |
'description': listing['description'] | |
}) | |
logger.info("Returning retrieval results.") | |
return results | |
def get_sustainability_scores(starting_point: str, query: str, destinations: list): | |
""" | |
Function to get the s-fairness scores for each destination for the given month (or the ideal month of travel if the user hasn't provided a month). | |
If multiple months are provided (or season), then the month with the minimum s-fairness score is chosen for the city. | |
Args: | |
- query: str | |
- destinations: list | |
""" | |
result = [] # list of dicts of the format {city: <city>, month: <month>, } | |
city_scores = {} | |
months = get_travel_months(query) | |
logger.info("Finished parsing query for months.") | |
popularity_data = load_scores("popularity") | |
seasonality_data = load_scores("seasonality") | |
emissions_data = load_scores("emissions") | |
data = [popularity_data, seasonality_data, emissions_data] | |
for city in destinations: | |
if city not in city_scores: | |
city_scores[city] = [] | |
if not months: # no month(s) or seasons provided by the user | |
city_scores[city].append(s_fairness.compute_sfairness_score(data, starting_point, city)) | |
else: | |
for month in months: | |
city_scores[city].append(s_fairness.compute_sfairness_score(data, city, month)) | |
logger.info("Finished getting s-fairness scores.") | |
for city, scores in city_scores.items(): | |
no_result = 0 | |
for score in scores: | |
if not score['month']: | |
no_result = 1 | |
result.append({ | |
'city': city, | |
'month': 'No data available', | |
's-fairness': 'No data available', | |
'mode': 'No data available' | |
}) | |
break | |
if not no_result: | |
min_score = min(scores, key=lambda x: x['s-fairness']) | |
result.append({ | |
'city': city, | |
'month': min_score['month'], | |
's-fairness': min_score['s-fairness'], | |
'mode': min_score['mode'], | |
}) | |
logger.info("Returning s-fairness results.") | |
return result | |
def get_cities(context: dict): | |
""" | |
Only to be used for testing! Function that returns a list of cities with their s-fairness scores, provided the retrieved context | |
Args: | |
- context: dict | |
""" | |
recommended_cities = [] | |
info = context[list(context.keys())[0]] | |
for city, info in context.items(): | |
city_info = { | |
'city': city, | |
'country': info['country'] | |
} | |
if "sustainability" in info: | |
city_info['month'] = info['sustainability']['month'] | |
city_info['s-fairness'] = info['sustainability']['s-fairness'] | |
recommended_cities.append(city_info) | |
if "sustainability" in info: | |
def get_s_fairness_value(item): | |
s_fairness = item['s-fairness'] | |
if s_fairness == 'No data available': | |
return float('inf') # Assign a high value for "No data available" | |
return s_fairness | |
# Sort the list using the custom key | |
sorted_cities = sorted(recommended_cities, key=get_s_fairness_value) | |
return sorted_cities | |
else: | |
return recommended_cities | |
def get_context(starting_point: str, query: str, **params): | |
""" | |
Function that returns all the context: from the database, as well as the respective s-fairness scores for the | |
destinations. The default does not consider S-Fairness scores, i.e. to append sustainability scores, a non-zero | |
parameter "sustainability" needs to be explicitly passed to params. | |
Args: | |
- query: str | |
- params: dict; contains value of the limit and reranking (and sustainability) | |
""" | |
limit = 3 | |
reranking = 1 | |
if 'limit' in params: | |
limit = params['limit'] | |
if 'reranking' in params: | |
reranking = params['reranking'] | |
wikivoyage_context = get_wikivoyage_context(query, limit, reranking) | |
recommended_cities = wikivoyage_context.keys() | |
if 'sustainability' in params and params['sustainability']: | |
s_fairness_scores = get_sustainability_scores(starting_point, query, recommended_cities) | |
for score in s_fairness_scores: | |
wikivoyage_context[score['city']]['sustainability'] = { | |
'month': score['month'], | |
's-fairness': score['s-fairness'], | |
'transport': score['mode'] | |
} | |
return wikivoyage_context | |
def test(): | |
queries = [] | |
query = "Suggest some places to visit during winter. I like hiking, nature and the mountains and I enjoy skiing " \ | |
"in winter. " | |
starting_point = "Munich" | |
context = None | |
try: | |
context = get_context(starting_point, query, sustainability=1) | |
# cities = get_cities(context) | |
# print(cities) | |
except FileNotFoundError as e: | |
try: | |
create_wikivoyage_docs_db_and_add_data() | |
create_wikivoyage_listings_db_and_add_data() | |
try: | |
context = get_context(query, sustainability=1) | |
# cities = get_cities(context) | |
# print(cities) | |
except Exception as e: | |
exc_type, exc_obj, exc_tb = sys.exc_info() | |
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] | |
logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}") | |
except Exception as e: | |
logger.error(f"Error while creating DB: {e}") | |
except Exception as e: | |
exc_type, exc_obj, exc_tb = sys.exc_info() | |
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] | |
logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}") | |
file_path = os.path.join(os.getcwd(), "test_results", "test_result.json") | |
with open(file_path, 'w') as file: | |
json.dump(context, file) | |
return context | |
if __name__ == "__main__": | |
context = test() | |
print(context) | |