Spaces:
Running
Running
Ashmi Banerjee
commited on
Commit
•
4b722ec
1
Parent(s):
2fdacb0
first draft, with the gcp bucket
Browse files- .gitignore +5 -1
- requirements.txt +5 -3
- src/__init__.py +11 -0
- src/app.py +78 -0
- src/augmentation/__init__.py +0 -0
- src/augmentation/prompt_generation.py +185 -0
- src/data_directories.py +15 -0
- src/information_retrieval/__init__.py +0 -0
- src/information_retrieval/info_retrieval.py +267 -0
- src/pipeline.py +128 -0
- src/sustainability/__init__.py +0 -0
- src/sustainability/s_fairness.py +114 -0
- src/text_generation/__init__.py +0 -0
- src/text_generation/model_init.py +123 -0
- src/text_generation/models.py +112 -0
- src/text_generation/text_generation.py +85 -0
- src/text_generation/vertexai_setup.py +37 -0
- src/vectordb/__init__.py +0 -0
- src/vectordb/create_db.py +24 -0
- src/vectordb/helpers.py +150 -0
- src/vectordb/lancedb_init.py +37 -0
- src/vectordb/vectordb.py +186 -0
.gitignore
CHANGED
@@ -182,4 +182,8 @@ gradio_cached_examples/
|
|
182 |
models/__pycache__/
|
183 |
setup/__pycache__/
|
184 |
setup/gcp_default_creds.json
|
185 |
-
.config/*.json
|
|
|
|
|
|
|
|
|
|
182 |
models/__pycache__/
|
183 |
setup/__pycache__/
|
184 |
setup/gcp_default_creds.json
|
185 |
+
.config/*.json
|
186 |
+
gradio_cached_examples/
|
187 |
+
src/gradio_cached_examples/
|
188 |
+
database/
|
189 |
+
european-city-data/
|
requirements.txt
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
sentence-transformers
|
2 |
gradio
|
3 |
gradio_client
|
4 |
-
pymongo==4.6.2
|
5 |
python-dotenv
|
6 |
google-cloud-aiplatform
|
7 |
google-cloud
|
8 |
-
vertexai==1.
|
9 |
huggingface_hub
|
10 |
-
certifi==
|
|
|
|
|
|
|
|
1 |
sentence-transformers
|
2 |
gradio
|
3 |
gradio_client
|
|
|
4 |
python-dotenv
|
5 |
google-cloud-aiplatform
|
6 |
google-cloud
|
7 |
+
vertexai==1.64.0
|
8 |
huggingface_hub
|
9 |
+
certifi==2024.7.4
|
10 |
+
python-dotenv==1.0.1
|
11 |
+
bitsandbytes
|
12 |
+
accelerate
|
src/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.vectordb.vectordb import *
|
2 |
+
from src.vectordb.helpers import *
|
3 |
+
from src.vectordb.lancedb_init import *
|
4 |
+
|
5 |
+
from src.sustainability.s_fairness import *
|
6 |
+
from src.information_retrieval.info_retrieval import *
|
7 |
+
from src.augmentation.prompt_generation import *
|
8 |
+
from src.text_generation.model_init import *
|
9 |
+
from src.text_generation.text_generation import *
|
10 |
+
|
11 |
+
from src.data_directories import *
|
src/app.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import gradio as gr
|
3 |
+
import os, sys
|
4 |
+
sys.path.append("./src")
|
5 |
+
print(os.getcwd())
|
6 |
+
from src.pipeline import pipeline
|
7 |
+
|
8 |
+
|
9 |
+
def clear():
|
10 |
+
return None, None, None
|
11 |
+
|
12 |
+
|
13 |
+
def generate_text(query_text, model_name: Optional[str], is_sustainable: Optional[bool], tokens: Optional[int] = 1024,
|
14 |
+
temp: Optional[float] = 0.49):
|
15 |
+
if is_sustainable:
|
16 |
+
sustainability = 1
|
17 |
+
else:
|
18 |
+
sustainability = 0
|
19 |
+
pipeline_response = pipeline(
|
20 |
+
query=query_text,
|
21 |
+
model_name=model_name,
|
22 |
+
sustainability= sustainability
|
23 |
+
)
|
24 |
+
return pipeline_response
|
25 |
+
|
26 |
+
|
27 |
+
examples = [["I'm planning a vacation to France. Can you suggest a one-week itinerary including must-visit places and "
|
28 |
+
"local cuisines to try?", "GPT-4"],
|
29 |
+
["I want to explore off-the-beaten-path destinations in Europe, any suggestions?", "Gemini-1.0-pro"],
|
30 |
+
["Suggest some cities that can be visited from London and are very rich in history and culture.",
|
31 |
+
"Gemini-1.0-pro"],
|
32 |
+
]
|
33 |
+
|
34 |
+
with gr.Blocks() as demo:
|
35 |
+
gr.HTML("""<center><h1 style='font-size:xx-large;'>🇪🇺 Euro City Recommender using Gemini & Gemma 🇪🇺</h1><br><h3>Gemini
|
36 |
+
& Gemma Sprints 2024 submissions by Ashmi Banerjee. </h3></center> <br><p>We're testing the compatibility of
|
37 |
+
Retrieval Augmented Generation (RAG) implementations with Google's <b>Gemma-2b-it</b> & <b>Gemini 1.0 Pro</b>
|
38 |
+
models through HuggingFace and VertexAI, respectively, to generate travel recommendations. This early version (read
|
39 |
+
quick and dirty implementation) aims to see if functionalities work smoothly. It relies on Wikipedia abstracts
|
40 |
+
from 160 European cities to provide answers to your questions. Please be kind with it, as it's a work in progress!
|
41 |
+
</p> <br>Google Cloud credits are provided for this project. </p>
|
42 |
+
""")
|
43 |
+
|
44 |
+
with gr.Group():
|
45 |
+
query = gr.Textbox(label="Query", placeholder="Ask for your city recommendation here!")
|
46 |
+
sustainable = gr.Checkbox(label="Sustainable", info="If you want sustainable recommendations for "
|
47 |
+
"hidden gems?")
|
48 |
+
model = gr.Dropdown(
|
49 |
+
["GPT-4", "Gemini-1.0-pro"], label="Model", info="Select your model. Will add more "
|
50 |
+
"models "
|
51 |
+
"later!",
|
52 |
+
)
|
53 |
+
output = gr.Textbox(label="Generated Results", lines=4)
|
54 |
+
|
55 |
+
with gr.Accordion("Settings", open=False):
|
56 |
+
max_new_tokens = gr.Slider(label="Max new tokens", value=1024, minimum=0, maximum=8192, step=64,
|
57 |
+
interactive=True,
|
58 |
+
visible=True, info="The maximum number of output tokens")
|
59 |
+
temperature = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.49,
|
60 |
+
interactive=True,
|
61 |
+
visible=True, info="The value used to module the logits distribution")
|
62 |
+
with gr.Group():
|
63 |
+
with gr.Row():
|
64 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
65 |
+
clear_btn = gr.Button("Clear", variant="secondary")
|
66 |
+
cancel_btn = gr.Button("Cancel", variant="stop")
|
67 |
+
submit_btn.click(generate_text, inputs=[query, model, sustainable], outputs=[output])
|
68 |
+
clear_btn.click(clear, inputs=[], outputs=[query, model, output])
|
69 |
+
cancel_btn.click(clear, inputs=[], outputs=[query, model, output])
|
70 |
+
|
71 |
+
gr.Markdown("## Examples")
|
72 |
+
gr.Examples(
|
73 |
+
examples, inputs=[query, model], label="Examples", fn=generate_text, outputs=[output],
|
74 |
+
cache_examples=True,
|
75 |
+
)
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
demo.launch(show_api=False)
|
src/augmentation/__init__.py
ADDED
File without changes
|
src/augmentation/prompt_generation.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from information_retrieval import info_retrieval as ir
|
2 |
+
import logging
|
3 |
+
|
4 |
+
logger = logging.getLogger(__name__)
|
5 |
+
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
6 |
+
|
7 |
+
|
8 |
+
def generate_prompt(query, context, template=None):
|
9 |
+
"""
|
10 |
+
Function that generates the prompt given the user query and retrieved context. A specific prompt template will be
|
11 |
+
used if provided, otherwise the default base_prompt template is used.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
- query: str
|
15 |
+
- context: list[dict]
|
16 |
+
- template: str
|
17 |
+
|
18 |
+
"""
|
19 |
+
|
20 |
+
if template:
|
21 |
+
SYS_PROMPT = template
|
22 |
+
else:
|
23 |
+
SYS_PROMPT = """You are an AI recommendation system. Your task is to recommend cities in Europe for travel
|
24 |
+
based on the user's question. You should use the provided contexts to suggest a list of the 3 best cities
|
25 |
+
that are best suited to the user's question, as well as the month of travel. If the user has already provided
|
26 |
+
the month of travel in the question, use the same month; otherwise, provide the ideal month of travel. Each
|
27 |
+
recommendation should also contain an explanation of why it is being recommended, based on the context. Your
|
28 |
+
answer must begin with "I recommend " followed by the city name and why you recommended it. Your answers are
|
29 |
+
correct, high-quality, and written by a domain expert. If the provided context does not contain the answer,
|
30 |
+
simply state, "The provided context does not have the answer." """
|
31 |
+
|
32 |
+
USER_PROMPT = """ Question: {} Which city do you recommend and why?
|
33 |
+
|
34 |
+
Context: Here are the options: {}
|
35 |
+
|
36 |
+
Answer:
|
37 |
+
|
38 |
+
"""
|
39 |
+
|
40 |
+
formatted_prompt = f"{USER_PROMPT.format(query, context)}"
|
41 |
+
messages = [
|
42 |
+
{"role": "system", "content": SYS_PROMPT},
|
43 |
+
{"role": "user", "content": formatted_prompt}
|
44 |
+
]
|
45 |
+
|
46 |
+
return messages
|
47 |
+
|
48 |
+
|
49 |
+
def format_context(context):
|
50 |
+
"""
|
51 |
+
Function that formats the retrieved context in a way that is easy for the LLM to understand.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
- context: list[dict]; retrieved context
|
55 |
+
|
56 |
+
"""
|
57 |
+
|
58 |
+
formatted_context = ''
|
59 |
+
|
60 |
+
for i, (city, info) in enumerate(context.items()):
|
61 |
+
|
62 |
+
text = f"Option {i + 1}: {city} is a city in {info['country']}."
|
63 |
+
info_text = f"Here is some information about the city. {info['text']}"
|
64 |
+
|
65 |
+
attractions_text = "Here are some attractions: "
|
66 |
+
att_flag = 0
|
67 |
+
restaurants_text = "Here are some places to eat/drink: "
|
68 |
+
rest_flag = 0
|
69 |
+
|
70 |
+
hotels_text = "Here are some hotels: "
|
71 |
+
hotel_flag = 0
|
72 |
+
|
73 |
+
if len(info['listings']):
|
74 |
+
for listing in info['listings']:
|
75 |
+
if listing['type'] in ['see', 'do', 'go', 'view']:
|
76 |
+
att_flag = 1
|
77 |
+
attractions_text += f"{listing['name']} ({listing['description']}), "
|
78 |
+
elif listing['type'] in ['eat', 'drink']:
|
79 |
+
rest_flag = 1
|
80 |
+
restaurants_text += f"{listing['name']} ({listing['description']}), "
|
81 |
+
else:
|
82 |
+
hotel_flag = 1
|
83 |
+
hotels_text += f"{listing['name']} ({listing['description']}), "
|
84 |
+
|
85 |
+
# If we add sustainability in the end then it could get truncated because of context window
|
86 |
+
if "sustainability" in info:
|
87 |
+
if info['sustainability']['month'] == 'No data available':
|
88 |
+
sfairness_text = "This city has no sustainability (or s-fairness) score available."
|
89 |
+
|
90 |
+
else:
|
91 |
+
sfairness_text = f"The sustainability (or s-fairness) score for {city} in {info['sustainability']['month']} is {info['sustainability']['s-fairness']}. \n "
|
92 |
+
|
93 |
+
text += sfairness_text
|
94 |
+
|
95 |
+
text += info_text
|
96 |
+
|
97 |
+
if att_flag:
|
98 |
+
text += f"\n{attractions_text}"
|
99 |
+
|
100 |
+
if rest_flag:
|
101 |
+
text += f"\n{restaurants_text}"
|
102 |
+
|
103 |
+
if hotel_flag:
|
104 |
+
text += f"\n{hotels_text}"
|
105 |
+
|
106 |
+
formatted_context += text + "\n\n "
|
107 |
+
|
108 |
+
return formatted_context
|
109 |
+
|
110 |
+
|
111 |
+
def augment_prompt(query, context, **params):
|
112 |
+
"""
|
113 |
+
Function that accepts the user query as input, obtains relevant documents and augments the prompt with the
|
114 |
+
retrieved context, which can be passed to the LLM.
|
115 |
+
|
116 |
+
Args: - query: str - context: retrieved context, must be formatted otherwise the LLM cannot understand the nested
|
117 |
+
dictionaries! - sustainability: bool; if true, then the prompt is appended to instruct the LLM to use s-fairness
|
118 |
+
scores while generating the answer - params: key-value parameters to be passed to the get_context function; sets
|
119 |
+
the limit of results and whether to rerank the results
|
120 |
+
|
121 |
+
"""
|
122 |
+
|
123 |
+
# what about the cities without s-fairness scores? i.e. they don't have seasonality data
|
124 |
+
|
125 |
+
prompt_with_sustainability = """You are an AI recommendation system. Your task is to recommend cities in Europe
|
126 |
+
for travel based on the user's question. You should use the provided contexts to suggest the city that is best
|
127 |
+
suited to the user's question. You recommend a list of the top 3 most sustainable cities to the user, as well as
|
128 |
+
the best month of travel. Each recommendation should also contain an explanation of why it is being recommended,
|
129 |
+
on sustainability grounds based on the context. The context contains a sustainability score for each city,
|
130 |
+
also known as the s-fairness score, along with the ideal month of travel. A lower s-fairness score indicates that
|
131 |
+
the city is a better destination for the month provided. A city without a sustainability score should not be
|
132 |
+
considered. You should only consider the s-fairness score while choosing the best city. However, your answer
|
133 |
+
should not contain the numeric score itself or any mention of the sustainability score. Your answer must begin
|
134 |
+
with "I recommend " followed by the city names and why you recommended it. Your answers are correct, high-quality,
|
135 |
+
and written by a domain expert. If the provided context does not contain the answer, simply state, "The provided
|
136 |
+
context does not have the answer. """
|
137 |
+
|
138 |
+
# format context
|
139 |
+
formatted_context = format_context(context)
|
140 |
+
|
141 |
+
if "sustainability" in params["params"] and params["params"]["sustainability"]:
|
142 |
+
prompt = generate_prompt(query, formatted_context, prompt_with_sustainability)
|
143 |
+
else:
|
144 |
+
prompt = generate_prompt(query, formatted_context)
|
145 |
+
|
146 |
+
return prompt
|
147 |
+
|
148 |
+
|
149 |
+
def test():
|
150 |
+
context_params = {
|
151 |
+
'limit': 3,
|
152 |
+
'reranking': 0,
|
153 |
+
'sustainability': 0
|
154 |
+
}
|
155 |
+
|
156 |
+
query = "Suggest some places to visit during winter. I like hiking, nature and the mountains and I enjoy skiing " \
|
157 |
+
"in winter. "
|
158 |
+
|
159 |
+
# without sustainability
|
160 |
+
context = ir.get_context(query, **context_params)
|
161 |
+
# formatted_context = format_context(context)
|
162 |
+
|
163 |
+
without_sfairness = augment_prompt(
|
164 |
+
query=query,
|
165 |
+
context=context,
|
166 |
+
params=context_params
|
167 |
+
)
|
168 |
+
|
169 |
+
# with sustainability
|
170 |
+
context_params.update({'sustainability': 1})
|
171 |
+
s_context = ir.get_context(query, **context_params)
|
172 |
+
# s_formatted_context = format_context(s_context)
|
173 |
+
|
174 |
+
with_sfairness = augment_prompt(
|
175 |
+
query=query,
|
176 |
+
context=s_context,
|
177 |
+
params=context_params
|
178 |
+
)
|
179 |
+
|
180 |
+
return with_sfairness
|
181 |
+
|
182 |
+
|
183 |
+
if __name__ == "__main__":
|
184 |
+
prompt = test()
|
185 |
+
print(prompt)
|
src/data_directories.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
current = os.path.dirname(os.path.realpath(''))
|
4 |
+
parent = os.path.dirname(current)
|
5 |
+
|
6 |
+
data_parent_dir = "../../european-city-data/"
|
7 |
+
data_dir = data_parent_dir + "data-sources/"
|
8 |
+
wikivoyage_docs_dir = data_dir + "wikivoyage/"
|
9 |
+
wikivoyage_listings_dir = wikivoyage_docs_dir + "listings/"
|
10 |
+
database_dir = "../../database/wikivoyage/"
|
11 |
+
seasonality_dir = data_parent_dir + "computed/seasonality/"
|
12 |
+
popularity_dir = data_parent_dir + "computed/popularity/"
|
13 |
+
cities_csv = data_parent_dir + "city_abstracts_embeddings.csv"
|
14 |
+
prompts_dir = data_parent_dir + "rag-sustainability/prompts/"
|
15 |
+
results_dir = data_parent_dir + "rag-sustainability/results/"
|
src/information_retrieval/__init__.py
ADDED
File without changes
|
src/information_retrieval/info_retrieval.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import re
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
sys.path.append("../")
|
6 |
+
from src.vectordb import vectordb
|
7 |
+
from src.sustainability import s_fairness
|
8 |
+
import logging
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
12 |
+
|
13 |
+
|
14 |
+
def get_travel_months(query):
|
15 |
+
"""
|
16 |
+
|
17 |
+
Function to parse the user's query and search if month of travel has been provided by the user.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
- query: str
|
21 |
+
|
22 |
+
"""
|
23 |
+
months = [
|
24 |
+
"January", "February", "March", "April", "May", "June",
|
25 |
+
"July", "August", "September", "October", "November", "December"
|
26 |
+
]
|
27 |
+
|
28 |
+
seasons = {
|
29 |
+
"spring": ["March", "April", "May"],
|
30 |
+
"summer": ["June", "July", "August"],
|
31 |
+
"fall": ["September", "October", "November"],
|
32 |
+
"autumn": ["September", "October", "November"],
|
33 |
+
"winter": ["December", "January", "February"]
|
34 |
+
}
|
35 |
+
|
36 |
+
months_in_query = []
|
37 |
+
|
38 |
+
for month in months:
|
39 |
+
if re.search(r'\b' + month + r'\b', query, re.IGNORECASE):
|
40 |
+
months_in_query.append(month)
|
41 |
+
|
42 |
+
# Check for seasons in the query
|
43 |
+
for season, season_months in seasons.items():
|
44 |
+
if re.search(r'\b' + season + r'\b', query, re.IGNORECASE):
|
45 |
+
months_in_query += season_months
|
46 |
+
|
47 |
+
# Return None if neither months nor seasons are found
|
48 |
+
return months_in_query
|
49 |
+
|
50 |
+
|
51 |
+
def get_wikivoyage_context(query, limit=10, reranking=0):
|
52 |
+
"""
|
53 |
+
|
54 |
+
Function to retrieve the relevant documents and listings from the wikivoyage database. Works in two steps:
|
55 |
+
(i) the relevant cities are returned by the wikivoyage_docs table and (ii) then passed on to the wikivoyage listings database to retrieve further information.
|
56 |
+
The user can pass a limit of how many results the search should return as well as whether to perform reranking (uses a CrossEncoderReranker)
|
57 |
+
|
58 |
+
Args:
|
59 |
+
- query: str
|
60 |
+
- limit: int
|
61 |
+
- reranking: bool
|
62 |
+
|
63 |
+
"""
|
64 |
+
|
65 |
+
# limit = params['limit']
|
66 |
+
# reranking = params['reranking']
|
67 |
+
|
68 |
+
docs = vectordb.search_wikivoyage_docs(query, limit, reranking)
|
69 |
+
logger.info("Finished getting chunked wikivoyage docs.")
|
70 |
+
|
71 |
+
results = {}
|
72 |
+
for doc in docs:
|
73 |
+
results[doc['city']] = {key: value for key, value in doc.items() if key != 'city'}
|
74 |
+
results[doc['city']]['listings'] = []
|
75 |
+
|
76 |
+
cities = [result['city'] for result in docs]
|
77 |
+
|
78 |
+
listings = vectordb.search_wikivoyage_listings(query, cities, limit, reranking)
|
79 |
+
logger.info("Finished getting wikivoyage listings.")
|
80 |
+
# logger.info(type(docs), type(listings))
|
81 |
+
|
82 |
+
for listing in listings:
|
83 |
+
# logger.info(listing['city'])
|
84 |
+
results[listing['city']]['listings'].append({
|
85 |
+
'type': listing['type'],
|
86 |
+
'name': listing['title'],
|
87 |
+
'description': listing['description']
|
88 |
+
})
|
89 |
+
|
90 |
+
logger.info("Returning retrieval results.")
|
91 |
+
return results
|
92 |
+
|
93 |
+
|
94 |
+
def get_sustainability_scores(query, destinations):
|
95 |
+
"""
|
96 |
+
|
97 |
+
Function to get the s-fairness scores for each destination for the given month (or the ideal month of travel if the user hasn't provided a month).
|
98 |
+
If multiple months are provided (or season), then the month with the minimum s-fairness score is chosen for the city.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
- query: str
|
102 |
+
- destinations: list
|
103 |
+
|
104 |
+
"""
|
105 |
+
|
106 |
+
result = [] # list of dicts of the format {city: <city>, month: <month>, }
|
107 |
+
city_scores = {}
|
108 |
+
|
109 |
+
months = get_travel_months(query)
|
110 |
+
logger.info("Finished parsing query for months.")
|
111 |
+
|
112 |
+
for city in destinations:
|
113 |
+
if city not in city_scores:
|
114 |
+
city_scores[city] = []
|
115 |
+
|
116 |
+
if not months: # no month(s) or seasons provided by the user
|
117 |
+
city_scores[city].append(s_fairness.compute_sfairness_score(city))
|
118 |
+
else:
|
119 |
+
for month in months:
|
120 |
+
city_scores[city].append(s_fairness.compute_sfairness_score(city, month))
|
121 |
+
|
122 |
+
logger.info("Finished getting s-fairness scores.")
|
123 |
+
|
124 |
+
for city, scores in city_scores.items():
|
125 |
+
|
126 |
+
no_result = 0
|
127 |
+
for score in scores:
|
128 |
+
if not score['month']:
|
129 |
+
no_result = 1
|
130 |
+
result.append({
|
131 |
+
'city': city,
|
132 |
+
'month': 'No data available',
|
133 |
+
's-fairness': 'No data available'
|
134 |
+
})
|
135 |
+
break
|
136 |
+
|
137 |
+
if not no_result:
|
138 |
+
min_score = min(scores, key=lambda x: x['s-fairness'])
|
139 |
+
result.append({
|
140 |
+
'city': city,
|
141 |
+
'month': min_score['month'],
|
142 |
+
's-fairness': min_score['s-fairness']
|
143 |
+
})
|
144 |
+
|
145 |
+
logger.info("Returning s-fairness results.")
|
146 |
+
return result
|
147 |
+
|
148 |
+
|
149 |
+
def get_cities(context):
|
150 |
+
"""
|
151 |
+
Only to be used for testing! Function that returns a list of cities with their s-fairness scores, provided the retrieved context
|
152 |
+
|
153 |
+
Args:
|
154 |
+
- context: dict
|
155 |
+
|
156 |
+
"""
|
157 |
+
|
158 |
+
recommended_cities = []
|
159 |
+
|
160 |
+
for city, info in context.items():
|
161 |
+
city_info = {
|
162 |
+
'city': city,
|
163 |
+
'country': info['country']
|
164 |
+
}
|
165 |
+
|
166 |
+
if "sustainability" in info:
|
167 |
+
city_info['month'] = info['sustainability']['month']
|
168 |
+
city_info['s-fairness'] = info['sustainability']['s-fairness']
|
169 |
+
|
170 |
+
recommended_cities.append(city_info)
|
171 |
+
|
172 |
+
if "sustainability" in info:
|
173 |
+
def get_s_fairness_value(item):
|
174 |
+
s_fairness = item['s-fairness']
|
175 |
+
if s_fairness == 'No data available':
|
176 |
+
return float('inf') # Assign a high value for "No data available"
|
177 |
+
return s_fairness
|
178 |
+
|
179 |
+
# Sort the list using the custom key
|
180 |
+
sorted_cities = sorted(recommended_cities, key=get_s_fairness_value)
|
181 |
+
return sorted_cities
|
182 |
+
|
183 |
+
else:
|
184 |
+
return recommended_cities
|
185 |
+
|
186 |
+
|
187 |
+
def get_context(query, **params):
|
188 |
+
"""
|
189 |
+
|
190 |
+
Function that returns all the context: from the database, as well as the respective s-fairness scores for the
|
191 |
+
destinations. The default does not consider S-Fairness scores, i.e. to append sustainability scores, a non-zero
|
192 |
+
parameter "sustainability" needs to be explicitly passed to params.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
- query: str
|
196 |
+
- params: dict; contains value of the limit and reranking (and sustainability)
|
197 |
+
|
198 |
+
"""
|
199 |
+
|
200 |
+
limit = 3
|
201 |
+
reranking = 1
|
202 |
+
|
203 |
+
if 'limit' in params:
|
204 |
+
limit = params['limit']
|
205 |
+
|
206 |
+
if 'reranking' in params:
|
207 |
+
reranking = params['reranking']
|
208 |
+
|
209 |
+
wikivoyage_context = get_wikivoyage_context(query, limit, reranking)
|
210 |
+
recommended_cities = wikivoyage_context.keys()
|
211 |
+
|
212 |
+
if 'sustainability' in params and params['sustainability']:
|
213 |
+
s_fairness_scores = get_sustainability_scores(query, recommended_cities)
|
214 |
+
|
215 |
+
for score in s_fairness_scores:
|
216 |
+
wikivoyage_context[score['city']]['sustainability'] = {
|
217 |
+
'month': score['month'],
|
218 |
+
's-fairness': score['s-fairness']
|
219 |
+
}
|
220 |
+
|
221 |
+
return wikivoyage_context
|
222 |
+
|
223 |
+
|
224 |
+
def test():
|
225 |
+
queries = []
|
226 |
+
query = "Suggest some places to visit during winter. I like hiking, nature and the mountains and I enjoy skiing " \
|
227 |
+
"in winter. "
|
228 |
+
|
229 |
+
context = None
|
230 |
+
|
231 |
+
try:
|
232 |
+
context = get_context(query, sustainability=1)
|
233 |
+
# cities = get_cities(context)
|
234 |
+
# print(cities)
|
235 |
+
except FileNotFoundError as e:
|
236 |
+
try:
|
237 |
+
vectordb.create_wikivoyage_docs_db_and_add_data()
|
238 |
+
vectordb.create_wikivoyage_listings_db_and_add_data()
|
239 |
+
|
240 |
+
try:
|
241 |
+
context = get_context(query, sustainability=1)
|
242 |
+
# cities = get_cities(context)
|
243 |
+
# print(cities)
|
244 |
+
except Exception as e:
|
245 |
+
exc_type, exc_obj, exc_tb = sys.exc_info()
|
246 |
+
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
|
247 |
+
logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
|
248 |
+
|
249 |
+
except Exception as e:
|
250 |
+
logger.error(f"Error while creating DB: {e}")
|
251 |
+
|
252 |
+
except Exception as e:
|
253 |
+
exc_type, exc_obj, exc_tb = sys.exc_info()
|
254 |
+
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
|
255 |
+
logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
|
256 |
+
|
257 |
+
file_path = os.path.join(os.getcwd(), "test_results", "test_result.json")
|
258 |
+
with open(file_path, 'w') as file:
|
259 |
+
json.dump(context, file)
|
260 |
+
|
261 |
+
return context
|
262 |
+
|
263 |
+
|
264 |
+
if __name__ == "__main__":
|
265 |
+
context = test()
|
266 |
+
|
267 |
+
print(context)
|
src/pipeline.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Main file to execute the TRS Pipeline.
|
3 |
+
"""
|
4 |
+
import sys
|
5 |
+
from augmentation import prompt_generation as pg
|
6 |
+
from information_retrieval import info_retrieval as ir
|
7 |
+
from text_generation.models import (
|
8 |
+
Llama3,
|
9 |
+
Mistral,
|
10 |
+
Gemma2,
|
11 |
+
Llama3Point1,
|
12 |
+
Llama3Instruct,
|
13 |
+
MistralInstruct,
|
14 |
+
Llama3Point1Instruct,
|
15 |
+
Phi3SmallInstruct,
|
16 |
+
GPT4,
|
17 |
+
Gemini,
|
18 |
+
)
|
19 |
+
from text_generation import text_generation as tg
|
20 |
+
import logging
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
24 |
+
|
25 |
+
TEST_DIR = "../tests/"
|
26 |
+
MODELS = {
|
27 |
+
'GPT-4': GPT4,
|
28 |
+
'Llama3': Llama3,
|
29 |
+
'Mistral': Mistral,
|
30 |
+
'Gemma2': Gemma2,
|
31 |
+
'Llama3.1': Llama3Point1,
|
32 |
+
'Llama3-Instruct': Llama3Instruct,
|
33 |
+
'Mistral-Instruct': MistralInstruct,
|
34 |
+
'Llama3.1-Instruct': Llama3Point1Instruct,
|
35 |
+
'Phi3-Instruct': Phi3SmallInstruct,
|
36 |
+
'Gemini-1.0-pro': Gemini,
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
def pipeline(query: str, model_name: str, test: int = 0, **params):
|
41 |
+
"""
|
42 |
+
|
43 |
+
Executes the entire RAG pipeline, provided the query and model class name.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
- query: str
|
47 |
+
- model_name: string, one of the following: Llama3, Mistral, Gemma2, Llama3Point1
|
48 |
+
- test: whether the pipeline is running a test
|
49 |
+
- params:
|
50 |
+
- limit (number of results to be retained)
|
51 |
+
- reranking (binary, whether to rerank results using ColBERT or not)
|
52 |
+
- sustainability
|
53 |
+
|
54 |
+
|
55 |
+
"""
|
56 |
+
|
57 |
+
model = MODELS[model_name]
|
58 |
+
|
59 |
+
context_params = {
|
60 |
+
'limit': 5,
|
61 |
+
'reranking': 0,
|
62 |
+
'sustainability': 0,
|
63 |
+
}
|
64 |
+
|
65 |
+
if 'limit' in params:
|
66 |
+
context_params['limit'] = params['limit']
|
67 |
+
|
68 |
+
if 'reranking' in params:
|
69 |
+
context_params['reranking'] = params['reranking']
|
70 |
+
|
71 |
+
if 'sustainability' in params:
|
72 |
+
context_params['sustainability'] = params['sustainability']
|
73 |
+
|
74 |
+
logger.info("Retrieving context..")
|
75 |
+
try:
|
76 |
+
context = ir.get_context(query=query, **context_params)
|
77 |
+
if test:
|
78 |
+
retrieved_cities = ir.get_cities(context)
|
79 |
+
else:
|
80 |
+
retrieved_cities = None
|
81 |
+
except Exception as e:
|
82 |
+
exc_type, exc_obj, exc_tb = sys.exc_info()
|
83 |
+
logger.error(f"Error at line {exc_tb.tb_lineno} while trying to get context: {e}")
|
84 |
+
return None
|
85 |
+
|
86 |
+
logger.info("Retrieved context, augmenting prompt..")
|
87 |
+
try:
|
88 |
+
prompt = pg.augment_prompt(
|
89 |
+
query=query,
|
90 |
+
context=context,
|
91 |
+
params=context_params
|
92 |
+
)
|
93 |
+
except Exception as e:
|
94 |
+
exc_type, exc_obj, exc_tb = sys.exc_info()
|
95 |
+
logger.error(f"Error at line {exc_tb.tb_lineno} while trying to augment prompt: {e}")
|
96 |
+
return None
|
97 |
+
|
98 |
+
# return prompt
|
99 |
+
|
100 |
+
logger.info(f"Augmented prompt, initializing {model} and generating response..")
|
101 |
+
try:
|
102 |
+
response = tg.generate_response(model, prompt)
|
103 |
+
except Exception as e:
|
104 |
+
exc_type, exc_obj, exc_tb = sys.exc_info()
|
105 |
+
logger.info(f"Error at line {exc_tb.tb_lineno} while generating response: {e}")
|
106 |
+
return None
|
107 |
+
|
108 |
+
if test:
|
109 |
+
return retrieved_cities, prompt[1]['content'], response
|
110 |
+
|
111 |
+
else:
|
112 |
+
return response
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
# sample_query = "I'm planning a trip in the summer and I love art, history, and visiting museums. Can you
|
117 |
+
# suggest " \ "some " \ "European cities? "
|
118 |
+
sample_query = "I'm planning a trip in July and enjoy beaches, nightlife, and vibrant cities. Recommend some " \
|
119 |
+
"cities. "
|
120 |
+
model = "GPT-4"
|
121 |
+
|
122 |
+
pipeline_response = pipeline(
|
123 |
+
query=sample_query,
|
124 |
+
model_name=model,
|
125 |
+
sustainability=1
|
126 |
+
)
|
127 |
+
|
128 |
+
print(pipeline_response)
|
src/sustainability/__init__.py
ADDED
File without changes
|
src/sustainability/s_fairness.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import logging
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
9 |
+
|
10 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
11 |
+
sys.path.append(os.path.dirname(SCRIPT_DIR))
|
12 |
+
|
13 |
+
from data_directories import *
|
14 |
+
|
15 |
+
|
16 |
+
def get_popularity(destination):
|
17 |
+
"""
|
18 |
+
|
19 |
+
Returns the popularity score for a particular destination.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
- destination: str
|
23 |
+
|
24 |
+
"""
|
25 |
+
|
26 |
+
parent_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
|
27 |
+
|
28 |
+
if "src" in os.getcwd() and os.path.exists(os.path.join(parent_path, "european-city-data")):
|
29 |
+
popularity_path = popularity_dir.replace("../../", "../")
|
30 |
+
else:
|
31 |
+
popularity_path = popularity_dir
|
32 |
+
|
33 |
+
popularity_df = pd.read_csv(popularity_path + "popularity_scores.csv")
|
34 |
+
|
35 |
+
if not len(popularity_df[popularity_df['city'] == destination]):
|
36 |
+
print(f"{destination} does not have popularity data")
|
37 |
+
return None
|
38 |
+
|
39 |
+
return popularity_df[popularity_df['city'] == destination]['weighted_pop_score'].item()
|
40 |
+
|
41 |
+
|
42 |
+
def get_seasonality(destination, month=None):
|
43 |
+
"""
|
44 |
+
|
45 |
+
Returns the seasonality score for a particular destination for a particular month. If no month is provided then
|
46 |
+
the best month, i.e. month of lowest seasonality is returned.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
- destination: str
|
50 |
+
- month: str (default: None)
|
51 |
+
|
52 |
+
"""
|
53 |
+
parent_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
|
54 |
+
|
55 |
+
if "src" in os.getcwd() and os.path.exists(os.path.join(parent_path, "european-city-data")):
|
56 |
+
seasonality_path = seasonality_dir.replace("../../", "../")
|
57 |
+
else:
|
58 |
+
seasonality_path = seasonality_dir
|
59 |
+
seasonality_df = pd.read_csv(seasonality_path + "seasonality_scores.csv")
|
60 |
+
|
61 |
+
# Check if city is present in dataframe
|
62 |
+
if not len(seasonality_df[seasonality_df['city'] == destination]):
|
63 |
+
logger.info(f"{destination} does not have seasonality data for {month}")
|
64 |
+
return None, None
|
65 |
+
|
66 |
+
if month:
|
67 |
+
m = month.capitalize()[:3]
|
68 |
+
else:
|
69 |
+
seasonality_df['lowest_col'] = seasonality_df.loc[:, seasonality_df.columns != 'city'].idxmin(axis="columns")
|
70 |
+
m = seasonality_df[seasonality_df['city'] == destination]['lowest_col'].item()
|
71 |
+
|
72 |
+
# print(destination, m, seasonality_df[seasonality_df['city'] == destination][m])
|
73 |
+
|
74 |
+
return m, seasonality_df[seasonality_df['city'] == destination][m].item()
|
75 |
+
|
76 |
+
|
77 |
+
def compute_sfairness_score(destination, month=None):
|
78 |
+
"""
|
79 |
+
|
80 |
+
Returns the s-fairness score for a particular destination city and (optional) month. If the destination doesn't
|
81 |
+
have popularity or seasonality scores, then the function returns None.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
- destination: str
|
85 |
+
- month: str (default: None)
|
86 |
+
|
87 |
+
"""
|
88 |
+
seasonality = get_seasonality(destination, month)
|
89 |
+
month = seasonality[0]
|
90 |
+
popularity = get_popularity(destination)
|
91 |
+
emissions = 0
|
92 |
+
|
93 |
+
# RECHECK
|
94 |
+
if seasonality[1] is not None and popularity is not None:
|
95 |
+
s_fairness = round(0.281 * emissions + 0.334 * popularity + 0.385 * seasonality[1], 3)
|
96 |
+
return {
|
97 |
+
'month': month,
|
98 |
+
's-fairness': s_fairness
|
99 |
+
}
|
100 |
+
# elif popularity is not None: # => seasonality is None
|
101 |
+
# s_fairness = 0.281 * emissions + 0.334 * popularity
|
102 |
+
# elif seasonality[1] is not None: # => popularity is None
|
103 |
+
# s_fairness = 0.281 * emissions + 0.385 * seasonality[1]
|
104 |
+
# else: # => both are non
|
105 |
+
# s_fairness = 0.281 * emissions
|
106 |
+
else:
|
107 |
+
return {
|
108 |
+
'month': None,
|
109 |
+
's-fairness': None
|
110 |
+
}
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == "__main__":
|
114 |
+
print(compute_sfairness_score("Paris", "Oct"))
|
src/text_generation/__init__.py
ADDED
File without changes
|
src/text_generation/model_init.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
3 |
+
from vertexai.generative_models import GenerativeModel
|
4 |
+
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from anthropic import AnthropicVertex
|
7 |
+
import os
|
8 |
+
from openai import OpenAI
|
9 |
+
from src.text_generation.vertexai_setup import initialize_vertexai_params
|
10 |
+
|
11 |
+
load_dotenv()
|
12 |
+
if "OPENAI_API_KEY" in os.environ:
|
13 |
+
OAI_API_KEY = os.environ["OPENAI_API_KEY"]
|
14 |
+
if "VERTEXAI_PROJECTID" in os.environ:
|
15 |
+
VERTEXAI_PROJECT = os.environ["VERTEXAI_PROJECTID"]
|
16 |
+
|
17 |
+
|
18 |
+
class LLMBaseClass:
|
19 |
+
"""
|
20 |
+
Base Class for text generation - user needs to provide the HF model ID while instantiating the class after which
|
21 |
+
the generate method can be called to generate responses
|
22 |
+
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, model_id) -> None:
|
26 |
+
|
27 |
+
match (model_id[0].lower()):
|
28 |
+
case "gpt-4o-mini": # for open AI models
|
29 |
+
self.api_key = OAI_API_KEY
|
30 |
+
self.model = OpenAI(api_key=self.api_key)
|
31 |
+
case "claude-3-5-sonnet@20240620": # for Claude through vertexAI
|
32 |
+
self.api_key = None
|
33 |
+
self.model = AnthropicVertex(region="europe-west1", project_id=VERTEXAI_PROJECT)
|
34 |
+
case "gemini-1.0-pro":
|
35 |
+
self.api_key = None
|
36 |
+
self.model = GenerativeModel(model_id[0].lower())
|
37 |
+
case _: # for HF models
|
38 |
+
self.api_key = None
|
39 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
40 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
41 |
+
|
42 |
+
self.tokenizer.chat_template = "{%- for message in messages %}{%- if message['role'] == 'user' %}{{- " \
|
43 |
+
"bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }}{%- " \
|
44 |
+
"elif " \
|
45 |
+
"message['role'] == 'system' %}{{- '<<SYS>>\\n' + message[" \
|
46 |
+
"'content'].strip() + " \
|
47 |
+
"'\\n<</SYS>>\\n\\n' }}{%- elif message['role'] == 'assistant' %}{{- '[" \
|
48 |
+
"ASST] ' " \
|
49 |
+
"+ message['content'] + ' [/ASST]' + eos_token }}{%- endif %}{%- " \
|
50 |
+
"endfor %} "
|
51 |
+
# Initialize quantization to use less GPU
|
52 |
+
if torch.cuda.is_available():
|
53 |
+
bnb_config = BitsAndBytesConfig(
|
54 |
+
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4",
|
55 |
+
bnb_4bit_compute_dtype=torch.bfloat16
|
56 |
+
)
|
57 |
+
else:
|
58 |
+
bnb_config = None
|
59 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
60 |
+
model_id,
|
61 |
+
torch_dtype=torch.bfloat16,
|
62 |
+
device_map="auto",
|
63 |
+
quantization_config=bnb_config,
|
64 |
+
)
|
65 |
+
|
66 |
+
self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id
|
67 |
+
|
68 |
+
self.terminators = [
|
69 |
+
self.tokenizer.eos_token_id,
|
70 |
+
self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
71 |
+
]
|
72 |
+
|
73 |
+
def generate(self, messages):
|
74 |
+
match (self.model_id[0].lower()):
|
75 |
+
case "gpt-4o-mini":
|
76 |
+
completion = self.model.chat.completions.create(
|
77 |
+
model=self.model_id[0],
|
78 |
+
messages=messages,
|
79 |
+
temperature=0.6,
|
80 |
+
top_p=0.9,
|
81 |
+
)
|
82 |
+
# Return the generated content from the API response
|
83 |
+
return completion.choices[0].message.content
|
84 |
+
case "claude-3-5-sonnet@20240620" | "gemini-1.0-pro":
|
85 |
+
initialize_vertexai_params()
|
86 |
+
if "claude" in self.model_id[0].lower():
|
87 |
+
message = self.model.messages.create(
|
88 |
+
max_tokens=1024,
|
89 |
+
model=self.model_id[0],
|
90 |
+
messages=[
|
91 |
+
{
|
92 |
+
"role": "user",
|
93 |
+
"content": messages[0]["content"],
|
94 |
+
}
|
95 |
+
],
|
96 |
+
)
|
97 |
+
return message.content[0].text
|
98 |
+
else:
|
99 |
+
response = self.model.generate_content(messages[0]["content"])
|
100 |
+
return response
|
101 |
+
case _:
|
102 |
+
input_ids = self.tokenizer.apply_chat_template(
|
103 |
+
conversation=messages,
|
104 |
+
add_generation_prompt=True,
|
105 |
+
return_tensors="pt",
|
106 |
+
padding=True
|
107 |
+
).to(self.model.device)
|
108 |
+
|
109 |
+
outputs = self.model.generate(
|
110 |
+
input_ids,
|
111 |
+
max_new_tokens=1024,
|
112 |
+
# eos_token_id=self.terminators,
|
113 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
114 |
+
do_sample=True,
|
115 |
+
temperature=0.6,
|
116 |
+
top_p=0.9,
|
117 |
+
)
|
118 |
+
response = outputs[0][input_ids.shape[-1]:]
|
119 |
+
|
120 |
+
return self.tokenizer.decode(response, skip_special_tokens=True)
|
121 |
+
|
122 |
+
|
123 |
+
# database/wikivoyage/wikivoyage_listings.lance/data/e2940f51-d754-4b54-a688-004bdb8e7aa2.lance
|
src/text_generation/models.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.text_generation.model_init import LLMBaseClass
|
2 |
+
|
3 |
+
|
4 |
+
class Llama3(LLMBaseClass):
|
5 |
+
"""
|
6 |
+
Initializes a Llama3 model object
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __init__(self) -> None:
|
10 |
+
self.model_id = "meta-llama/Meta-Llama-3-8B"
|
11 |
+
|
12 |
+
super().__init__(self.model_id)
|
13 |
+
|
14 |
+
|
15 |
+
class Mistral(LLMBaseClass):
|
16 |
+
"""
|
17 |
+
Initializes a Mistral model object
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self) -> None:
|
21 |
+
self.model_id = "mistralai/Mistral-7B-v0.3"
|
22 |
+
super().__init__(self.model_id)
|
23 |
+
|
24 |
+
|
25 |
+
class Gemma2(LLMBaseClass):
|
26 |
+
"""
|
27 |
+
Initializes a Gemma2 model object
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self) -> None:
|
31 |
+
self.model_id = "google/gemma-2-9b"
|
32 |
+
super().__init__(self.model_id)
|
33 |
+
|
34 |
+
|
35 |
+
class Llama3Point1(LLMBaseClass):
|
36 |
+
"""
|
37 |
+
Initializes a Llama 3.1 object
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self) -> None:
|
41 |
+
self.model_id = "meta-llama/Meta-Llama-3.1-8B"
|
42 |
+
super().__init__(self.model_id)
|
43 |
+
|
44 |
+
|
45 |
+
class Llama3Instruct(LLMBaseClass):
|
46 |
+
"""
|
47 |
+
Initializes a Llama 3 Instruct object
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(self) -> None:
|
51 |
+
self.model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
52 |
+
super().__init__(self.model_id)
|
53 |
+
|
54 |
+
|
55 |
+
class MistralInstruct(LLMBaseClass):
|
56 |
+
"""
|
57 |
+
Initializes a Mistral Instruct object
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(self) -> None:
|
61 |
+
self.model_id = "mistralai/Mistral-7B-Instruct-v0.1"
|
62 |
+
super().__init__(self.model_id)
|
63 |
+
|
64 |
+
|
65 |
+
class Llama3Point1Instruct(LLMBaseClass):
|
66 |
+
"""
|
67 |
+
Initializes a Llama 3.1 Instruct object
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self) -> None:
|
71 |
+
self.model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
72 |
+
super().__init__(self.model_id)
|
73 |
+
|
74 |
+
|
75 |
+
class Phi3SmallInstruct(LLMBaseClass):
|
76 |
+
"""
|
77 |
+
Initializes a Phi3-Small-Instruct object
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(self) -> None:
|
81 |
+
self.model_id = "microsoft/Phi-3-small-128k-instruct"
|
82 |
+
super().__init__(self.model_id)
|
83 |
+
|
84 |
+
|
85 |
+
class GPT4(LLMBaseClass):
|
86 |
+
"""
|
87 |
+
Initializes a GPT-4 Instruct object
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(self) -> None:
|
91 |
+
self.model_id = "gpt-4o-mini",
|
92 |
+
super().__init__(self.model_id)
|
93 |
+
|
94 |
+
|
95 |
+
class Claude3Point5Sonnet(LLMBaseClass):
|
96 |
+
"""
|
97 |
+
Initializes a Claude3.5 Instruct object
|
98 |
+
"""
|
99 |
+
|
100 |
+
def __init__(self) -> None:
|
101 |
+
self.model_id = "claude-3-5-sonnet@20240620",
|
102 |
+
super().__init__(self.model_id)
|
103 |
+
|
104 |
+
|
105 |
+
class Gemini(LLMBaseClass):
|
106 |
+
"""
|
107 |
+
Initializes a Gemini Instruct object
|
108 |
+
"""
|
109 |
+
|
110 |
+
def __init__(self) -> None:
|
111 |
+
self.model_id = "gemini-1.0-pro",
|
112 |
+
super().__init__(self.model_id)
|
src/text_generation/text_generation.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from augmentation import prompt_generation as pg
|
2 |
+
from information_retrieval import info_retrieval as ir
|
3 |
+
from src.text_generation.models import (
|
4 |
+
Llama3,
|
5 |
+
Mistral,
|
6 |
+
Gemma2,
|
7 |
+
Llama3Point1,
|
8 |
+
GPT4,
|
9 |
+
Claude3Point5Sonnet,
|
10 |
+
)
|
11 |
+
import logging
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
15 |
+
|
16 |
+
|
17 |
+
def generate_response(model, prompt):
|
18 |
+
"""
|
19 |
+
|
20 |
+
Function that initializes the LLM class and calls the generate function.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
- messages: list; contains the system and user prompt
|
24 |
+
- model: class; the class of the llm to be initialized
|
25 |
+
|
26 |
+
"""
|
27 |
+
|
28 |
+
logger.info(f"Initializing LLM configuration for {model}")
|
29 |
+
llm = model()
|
30 |
+
|
31 |
+
logger.info("Generating response")
|
32 |
+
try:
|
33 |
+
response = llm.generate(prompt)
|
34 |
+
except Exception as e:
|
35 |
+
logger.error(f"Error while generating response for {model}: {e}")
|
36 |
+
response = 'ERROR'
|
37 |
+
|
38 |
+
return response
|
39 |
+
|
40 |
+
|
41 |
+
def test(model):
|
42 |
+
context_params = {
|
43 |
+
'limit': 3,
|
44 |
+
'reranking': 0
|
45 |
+
}
|
46 |
+
# model = Llama3Point1
|
47 |
+
|
48 |
+
query = "Suggest some places to visit during winter. I like hiking, nature and the mountains and I enjoy skiing " \
|
49 |
+
"in winter. "
|
50 |
+
|
51 |
+
# without sustainability
|
52 |
+
logger.info("Retrieving context..")
|
53 |
+
try:
|
54 |
+
context = ir.get_context(query=query, **context_params)
|
55 |
+
except Exception as e:
|
56 |
+
logger.error(f"Error while trying to get context: {e}")
|
57 |
+
return None
|
58 |
+
|
59 |
+
logger.info("Retrieved context, augmenting prompt (without sustainability)..")
|
60 |
+
try:
|
61 |
+
without_sfairness = pg.augment_prompt(
|
62 |
+
query=query,
|
63 |
+
context=context,
|
64 |
+
sustainability=0,
|
65 |
+
params=context_params
|
66 |
+
)
|
67 |
+
except Exception as e:
|
68 |
+
logger.error(f"Error while trying to augment prompt: {e}")
|
69 |
+
return None
|
70 |
+
|
71 |
+
# return without_sfairness
|
72 |
+
|
73 |
+
logger.info(f"Augmented prompt, initializing {model} and generating response..")
|
74 |
+
try:
|
75 |
+
response = generate_response(model, without_sfairness)
|
76 |
+
except Exception as e:
|
77 |
+
logger.info(f"Error while generating response: {e}")
|
78 |
+
return None
|
79 |
+
|
80 |
+
return response
|
81 |
+
|
82 |
+
|
83 |
+
if __name__ == "__main__":
|
84 |
+
response = test(Claude3Point5Sonnet)
|
85 |
+
print(response)
|
src/text_generation/vertexai_setup.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
from google.oauth2 import service_account
|
4 |
+
import vertexai
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
import base64
|
8 |
+
|
9 |
+
load_dotenv()
|
10 |
+
if "VERTEXAI_PROJECTID" in os.environ:
|
11 |
+
VERTEXAI_PROJECT = os.environ["VERTEXAI_PROJECTID"]
|
12 |
+
|
13 |
+
|
14 |
+
def decode_service_key():
|
15 |
+
encoded_key = os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
|
16 |
+
original_service_key = json.loads(base64.b64decode(encoded_key).decode('utf-8'))
|
17 |
+
if original_service_key:
|
18 |
+
return original_service_key
|
19 |
+
return None
|
20 |
+
|
21 |
+
|
22 |
+
def initialize_vertexai_params(location: Optional[str] = "us-central1"):
|
23 |
+
|
24 |
+
creds_file_name = os.getcwd() + "/.config/application_default_credentials.json"
|
25 |
+
print(creds_file_name)
|
26 |
+
if not os.path.exists(os.path.dirname(creds_file_name)):
|
27 |
+
credentials = decode_service_key()
|
28 |
+
with open(creds_file_name, 'w', encoding='utf-8') as file:
|
29 |
+
json.dump(credentials, file, ensure_ascii=False, indent=4)
|
30 |
+
|
31 |
+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = creds_file_name
|
32 |
+
|
33 |
+
service_account.Credentials.from_service_account_file(
|
34 |
+
filename=os.environ["GOOGLE_APPLICATION_CREDENTIALS"],
|
35 |
+
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
36 |
+
)
|
37 |
+
vertexai.init(project=VERTEXAI_PROJECT, location=location)
|
src/vectordb/__init__.py
ADDED
File without changes
|
src/vectordb/create_db.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.vectordb.vectordb import *
|
2 |
+
import logging
|
3 |
+
|
4 |
+
logger = logging.getLogger(__name__)
|
5 |
+
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
6 |
+
|
7 |
+
|
8 |
+
def run():
|
9 |
+
logging.info("Creating database for Wikivoyage Documents")
|
10 |
+
try:
|
11 |
+
create_wikivoyage_docs_db_and_add_data()
|
12 |
+
except Exception as e:
|
13 |
+
logger.error(f"Error for Wikivoyage Documents: {e}")
|
14 |
+
|
15 |
+
logging.info("Creating database for Wikivoyage Listings")
|
16 |
+
try:
|
17 |
+
create_wikivoyage_listings_db_and_add_data()
|
18 |
+
print("Completed")
|
19 |
+
except Exception as e:
|
20 |
+
logger.error(f"Error for Wikivoyage Listings: {e}")
|
21 |
+
|
22 |
+
|
23 |
+
if __name__ == "__main__":
|
24 |
+
run()
|
src/vectordb/helpers.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
+
import sys
|
6 |
+
|
7 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
8 |
+
sys.path.append(os.path.dirname(SCRIPT_DIR))
|
9 |
+
|
10 |
+
from data_directories import *
|
11 |
+
|
12 |
+
|
13 |
+
def create_chunks(city, country, text):
|
14 |
+
"""
|
15 |
+
|
16 |
+
Helper function that creates chunks given paragraph(s) of text based on implicit sections in the text.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
- city: str
|
20 |
+
- country: str
|
21 |
+
- text: str; document that needs to be chunked
|
22 |
+
|
23 |
+
"""
|
24 |
+
|
25 |
+
for i, line in enumerate(text):
|
26 |
+
if line[0] == "\n":
|
27 |
+
del text[i]
|
28 |
+
|
29 |
+
index = 0
|
30 |
+
chunks = []
|
31 |
+
pattern = re.compile("==")
|
32 |
+
ignore = re.compile("===")
|
33 |
+
section = 'Introduction'
|
34 |
+
for i, line in enumerate(text):
|
35 |
+
if pattern.search(line) and not ignore.search(line):
|
36 |
+
chunk = ''.join(text[index:i])
|
37 |
+
chunks.append({
|
38 |
+
'city': city,
|
39 |
+
'country': country,
|
40 |
+
'section': section,
|
41 |
+
'text': chunk,
|
42 |
+
# 'vector': f'city: {city}, country: {country}, section: {section}, text: {chunk}'
|
43 |
+
})
|
44 |
+
index = i + 1
|
45 |
+
section = re.sub(pattern, '', line).strip()
|
46 |
+
|
47 |
+
df = pd.DataFrame(chunks)
|
48 |
+
return df
|
49 |
+
|
50 |
+
|
51 |
+
def read_docs():
|
52 |
+
"""
|
53 |
+
|
54 |
+
Helper function that reads all of the Wikivoyage documents containing information about the city.
|
55 |
+
|
56 |
+
"""
|
57 |
+
|
58 |
+
df = pd.DataFrame()
|
59 |
+
cities = pd.read_csv(cities_csv)
|
60 |
+
for file_name in os.listdir(wikivoyage_docs_dir + "cleaned/"):
|
61 |
+
city = file_name.split(".")[0]
|
62 |
+
# print(city)
|
63 |
+
country = cities[cities['city'] == city]['country'].item()
|
64 |
+
with open(wikivoyage_docs_dir + "cleaned/" + file_name) as file:
|
65 |
+
text = file.readlines()
|
66 |
+
chunk_df = create_chunks(city, country, text)
|
67 |
+
df = pd.concat([df, chunk_df])
|
68 |
+
|
69 |
+
return df
|
70 |
+
|
71 |
+
|
72 |
+
def read_listings():
|
73 |
+
"""
|
74 |
+
|
75 |
+
Helper function that reads the Wikivoyage listings csv containing tabular information about 144 cities.
|
76 |
+
|
77 |
+
"""
|
78 |
+
df = pd.read_csv(wikivoyage_listings_dir + "wikivoyage-listings-cleaned.csv")
|
79 |
+
cities = pd.read_csv(cities_csv)
|
80 |
+
|
81 |
+
def find_country(city):
|
82 |
+
return cities[cities['city'] == city]['country'].values[0]
|
83 |
+
|
84 |
+
df['country'] = df['city'].apply(find_country)
|
85 |
+
|
86 |
+
return df
|
87 |
+
|
88 |
+
|
89 |
+
def preprocess_df(df):
|
90 |
+
"""
|
91 |
+
|
92 |
+
Helper function that preprocesses the dataframe containing chunks of text and removes hyperlinks and strips the \n from the text.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
- df: dataframe
|
96 |
+
|
97 |
+
"""
|
98 |
+
section_counts = df['section'].value_counts()
|
99 |
+
sections_to_keep = section_counts[section_counts > 150].index
|
100 |
+
filtered_df = df[df['section'].isin(sections_to_keep)]
|
101 |
+
|
102 |
+
def preprocess_text(s):
|
103 |
+
s = re.sub(r'http\S+', '', s)
|
104 |
+
s = re.sub(r'=+', '', s)
|
105 |
+
s = s.strip()
|
106 |
+
return s
|
107 |
+
|
108 |
+
filtered_df['text'] = filtered_df['text'].apply(preprocess_text)
|
109 |
+
|
110 |
+
return filtered_df
|
111 |
+
|
112 |
+
|
113 |
+
def compute_wv_docs_embeddings(df):
|
114 |
+
"""
|
115 |
+
|
116 |
+
Helper function that computes embeddings for the text. The all-MiniLM-L6-v2 embedding model is used.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
- df: dataframe
|
120 |
+
|
121 |
+
"""
|
122 |
+
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
123 |
+
vector_dimension = model.get_sentence_embedding_dimension()
|
124 |
+
|
125 |
+
print("Computing embeddings")
|
126 |
+
embeddings = []
|
127 |
+
for i, row in df.iterrows():
|
128 |
+
emb = model.encode(row['combined'], show_progress_bar=True).tolist()
|
129 |
+
embeddings.append(emb)
|
130 |
+
|
131 |
+
print("Finished computing embeddings for wikivoyage documents.")
|
132 |
+
df['vector'] = embeddings
|
133 |
+
# df.to_csv(wv_embeddings + "wikivoyage-listings-embeddings.csv")
|
134 |
+
# print("Finished saving file.")
|
135 |
+
return df
|
136 |
+
|
137 |
+
|
138 |
+
def embed_query(query):
|
139 |
+
"""
|
140 |
+
|
141 |
+
Helper function that returns the embedded query.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
- query: str
|
145 |
+
|
146 |
+
"""
|
147 |
+
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
148 |
+
# vector_dimension = model.get_sentence_embedding_dimension()
|
149 |
+
embedding = model.encode(query).tolist()
|
150 |
+
return embedding
|
src/vectordb/lancedb_init.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from lancedb.embeddings import get_registry
|
4 |
+
from lancedb.pydantic import LanceModel, Vector
|
5 |
+
|
6 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
7 |
+
sys.path.append(os.path.dirname(SCRIPT_DIR))
|
8 |
+
|
9 |
+
|
10 |
+
model = get_registry().get("sentence-transformers").create()
|
11 |
+
|
12 |
+
|
13 |
+
class WikivoyageDocuments(LanceModel):
|
14 |
+
"""
|
15 |
+
|
16 |
+
Schema definition for the Wikivoyage Documents table.
|
17 |
+
|
18 |
+
"""
|
19 |
+
city: str = model.SourceField()
|
20 |
+
country: str = model.SourceField()
|
21 |
+
section: str = model.SourceField()
|
22 |
+
text: str = model.SourceField()
|
23 |
+
vector: Vector(model.ndims()) = model.VectorField()
|
24 |
+
|
25 |
+
|
26 |
+
class WikivoyageListings(LanceModel):
|
27 |
+
"""
|
28 |
+
|
29 |
+
Schema definition for the Wikivoyage Listings table.
|
30 |
+
|
31 |
+
"""
|
32 |
+
city: str = model.SourceField()
|
33 |
+
type: str = model.SourceField()
|
34 |
+
title: str = model.SourceField()
|
35 |
+
description: str = model.SourceField()
|
36 |
+
country: str = model.SourceField()
|
37 |
+
vector: Vector(model.ndims()) = model.VectorField()
|
src/vectordb/vectordb.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from src import *
|
2 |
+
from src.vectordb.helpers import *
|
3 |
+
from src.vectordb.lancedb_init import *
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import lancedb
|
7 |
+
from lancedb.rerankers import ColbertReranker
|
8 |
+
|
9 |
+
import sys
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
from typing import Optional
|
12 |
+
|
13 |
+
# db = lancedb.connect("/tmp/db")
|
14 |
+
|
15 |
+
def create_wikivoyage_docs_db_and_add_data():
|
16 |
+
"""
|
17 |
+
|
18 |
+
Creates wikivoyage documents table and ingests data
|
19 |
+
|
20 |
+
"""
|
21 |
+
uri = database_dir
|
22 |
+
current_dir = os.path.split(os.getcwd())[1]
|
23 |
+
|
24 |
+
if "src" or "tests" in current_dir: # hacky way to get the correct path
|
25 |
+
uri = uri.replace("../../", "../")
|
26 |
+
|
27 |
+
db = lancedb.connect(uri)
|
28 |
+
logger.info("Connected to DB. Reading data now...")
|
29 |
+
df = read_docs()
|
30 |
+
filtered_df = preprocess_df(df)
|
31 |
+
logger.info("Finished reading data, attempting to create table and ingest the data...")
|
32 |
+
|
33 |
+
db.drop_table("wikivoyage_documents", ignore_missing=True)
|
34 |
+
table = db.create_table("wikivoyage_documents", schema=WikivoyageDocuments)
|
35 |
+
|
36 |
+
table.add(filtered_df.to_dict('records'))
|
37 |
+
logger.info("Completed ingestion.")
|
38 |
+
|
39 |
+
|
40 |
+
def create_wikivoyage_listings_db_and_add_data():
|
41 |
+
"""
|
42 |
+
|
43 |
+
Creates wikivoyage listings table and ingests data
|
44 |
+
|
45 |
+
"""
|
46 |
+
uri = database_dir
|
47 |
+
current_dir = os.path.split(os.getcwd())[1]
|
48 |
+
|
49 |
+
if "src" or "tests" in current_dir: # hacky way to get the correct path
|
50 |
+
uri = uri.replace("../../", "../")
|
51 |
+
|
52 |
+
db = lancedb.connect(uri)
|
53 |
+
logger.info("Connected to DB. Reading data now...")
|
54 |
+
df = read_listings()
|
55 |
+
logger.info("Finished reading data, attempting to create table and ingest the data...")
|
56 |
+
# filtered_df = preprocess_df(df)
|
57 |
+
|
58 |
+
db.drop_table("wikivoyage_listings", ignore_missing=True)
|
59 |
+
table = db.create_table("wikivoyage_listings", schema=WikivoyageListings)
|
60 |
+
|
61 |
+
table.add(df.astype('str').to_dict('records'))
|
62 |
+
logger.info("Completed ingestion.")
|
63 |
+
|
64 |
+
|
65 |
+
def search_wikivoyage_docs(query: str, limit: int = 10, reranking: int = 0, run_local: Optional[bool] = False):
|
66 |
+
"""
|
67 |
+
|
68 |
+
Function to search the wikivoyage database an return most relevant chunked docs.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
- query: str
|
72 |
+
- limit: number of results (default is 10)
|
73 |
+
- reranking: bool (0 or 1), if activated, CrossEncoderReranker is used.
|
74 |
+
|
75 |
+
"""
|
76 |
+
if run_local:
|
77 |
+
uri = database_dir
|
78 |
+
current_dir = os.path.split(os.getcwd())[1]
|
79 |
+
|
80 |
+
if "src" or "tests" in current_dir: # hacky way to get the correct path
|
81 |
+
uri = uri.replace("../../", "../")
|
82 |
+
else:
|
83 |
+
uri = os.environ["BUCKET_NAME"]
|
84 |
+
# print(uri)
|
85 |
+
try:
|
86 |
+
db = lancedb.connect(uri)
|
87 |
+
except Exception as e:
|
88 |
+
logger.error(f"Error while connecting to DB: {e}")
|
89 |
+
|
90 |
+
logger.info("Connected to Wikivoyage DB.")
|
91 |
+
|
92 |
+
# query_embedding = embed_query(query)
|
93 |
+
table = db.open_table("wikivoyage_documents")
|
94 |
+
|
95 |
+
if reranking:
|
96 |
+
try:
|
97 |
+
reranker = ColbertReranker(column='text')
|
98 |
+
results = table.search(query) \
|
99 |
+
.metric('cosine') \
|
100 |
+
.rerank(reranker=reranker) \
|
101 |
+
.limit(limit) \
|
102 |
+
.to_list()
|
103 |
+
except Exception as e:
|
104 |
+
exc_type, exc_obj, exc_tb = sys.exc_info()
|
105 |
+
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
|
106 |
+
logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
|
107 |
+
|
108 |
+
else:
|
109 |
+
try:
|
110 |
+
results = table.search(query) \
|
111 |
+
.limit(limit) \
|
112 |
+
.metric('cosine') \
|
113 |
+
.to_list()
|
114 |
+
except Exception as e:
|
115 |
+
exc_type, exc_obj, exc_tb = sys.exc_info()
|
116 |
+
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
|
117 |
+
logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
|
118 |
+
|
119 |
+
logger.info("Found the most relevant documents.")
|
120 |
+
city_lists = [{"city": r['city'], "country": r['country'], "section": r['section'], "text": r['text']} for r in
|
121 |
+
results]
|
122 |
+
|
123 |
+
# context = [f"city: {r['city']}, country: {r['country']}, name: {r['title']}, description: {r['description']}"
|
124 |
+
# for r in results]
|
125 |
+
|
126 |
+
return city_lists
|
127 |
+
|
128 |
+
|
129 |
+
def search_wikivoyage_listings(query, cities, limit=10, reranking=0):
|
130 |
+
"""
|
131 |
+
|
132 |
+
Function to search the wikivoyage database an return most relevant listings, post-filtered by the recommended
|
133 |
+
cities provided by wikivoyage_documents table.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
- query: str
|
137 |
+
- cities: list
|
138 |
+
- limit: number of results (default is 10)
|
139 |
+
- reranking: bool (0 or 1), if activated, CrossEncoderReranker is used.
|
140 |
+
|
141 |
+
"""
|
142 |
+
uri = database_dir
|
143 |
+
current_dir = os.path.split(os.getcwd())[1]
|
144 |
+
|
145 |
+
if "src" or "tests" in current_dir: # hacky way to get the correct path
|
146 |
+
uri = uri.replace("../../", "../")
|
147 |
+
|
148 |
+
db = lancedb.connect(uri)
|
149 |
+
logger.info("Connected to Wikivoyage Listings DB.")
|
150 |
+
|
151 |
+
table = db.open_table("wikivoyage_listings")
|
152 |
+
|
153 |
+
cities_filter = f"city IN {tuple(cities)}"
|
154 |
+
|
155 |
+
if reranking:
|
156 |
+
try:
|
157 |
+
reranker = ColbertReranker(column='description')
|
158 |
+
results = table.search(query) \
|
159 |
+
.where(cities_filter) \
|
160 |
+
.metric('cosine') \
|
161 |
+
.rerank(reranker=reranker) \
|
162 |
+
.limit(limit) \
|
163 |
+
.to_list()
|
164 |
+
|
165 |
+
except Exception as e:
|
166 |
+
exc_type, exc_obj, exc_tb = sys.exc_info()
|
167 |
+
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
|
168 |
+
logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
|
169 |
+
|
170 |
+
else:
|
171 |
+
try:
|
172 |
+
results = table.search(query) \
|
173 |
+
.where(cities_filter) \
|
174 |
+
.metric('cosine') \
|
175 |
+
.limit(limit) \
|
176 |
+
.to_list()
|
177 |
+
except Exception as e:
|
178 |
+
exc_type, exc_obj, exc_tb = sys.exc_info()
|
179 |
+
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
|
180 |
+
logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
|
181 |
+
|
182 |
+
logger.info("Found the most relevant documents.")
|
183 |
+
city_listings = [{"city": r['city'], "country": r['country'], "type": r['type'], "title": r['title'],
|
184 |
+
"description": r['description']} for r in results]
|
185 |
+
|
186 |
+
return city_listings
|