Ashmi Banerjee
refactored the vectordb
89cd5d5
raw
history blame
5.28 kB
from typing import Optional
import gradio as gr
import sys
sys.path.append("./src")
from src.pipeline import pipeline
from src.helpers.data_loaders import load_places
def clear():
return None, None, None
# Function to update the list of cities based on the selected country
def update_cities(selected_country, df):
filtered_cities = df[df['country'] == selected_country]['city'].tolist()
return gr.Dropdown(choices=filtered_cities, interactive=True) # Make it interactive as it is not by default
def generate_text(query_text, model_name: Optional[str], is_sustainable: Optional[bool], tokens: Optional[int] = 1024,
temp: Optional[float] = 0.49, starting_point: Optional[str] = "Munich"):
pipeline_response = pipeline(
query=query_text,
model_name=model_name,
sustainability=is_sustainable,
starting_point=starting_point,
)
return pipeline_response
def create_ui():
data_file = "cities/eu_200_cities.csv"
df = load_places(data_file)
df = df.sort_values(by=['country', 'city'])
examples = [
["I'm planning a vacation to France. Can you suggest a one-week itinerary including must-visit places and "
"local cuisines to try?", "GPT-4"],
["I want to explore off-the-beaten-path destinations in Europe, any suggestions?", "Gemini-1.0-pro"],
["Suggest some cities that can be visited from London and are very rich in history and culture.",
"Gemini-1.0-pro"],
]
with gr.Blocks() as app:
gr.HTML(
"<center><h1 style='font-size:xx-large; font-color: green'>πŸ€ Green City Finder πŸ€</h1><h3>AI Sprint 2024 submissions by Ashmi Banerjee. </h3></center> <br><p>We're testing the "
"compatibility of"
"Retrieval Augmented Generation (RAG) implementations with Google's <b>Gemma-2b-it</b> & <b>Gemini 1.0 "
"Pro</b> \n "
"models through HuggingFace and VertexAI, respectively, to generate sustainable travel recommendations.\n "
"We use the Wikivoyage dataset to provide city recommendations based on user queries. The vector "
"embeddings are stored in a VectorDB (LanceDB) hosted in Google Cloud.\n "
"<p>Sustainability is calculated based on the work by <a href=https://arxiv.org/abs/2403.18604>Banerjee "
"et al.</a></p>\n "
" </p> <br>Google Cloud credits are provided for this project. </p>\n"
" ")
with gr.Group():
countries = gr.Dropdown(choices=list(df.country.unique()), multiselect=False, label="Country")
starting_point = gr.Dropdown(choices=[], multiselect=False,
label="Select your starting point for the trip!")
countries.select(fn=lambda selected_country:
update_cities(selected_country, df),
inputs=countries, outputs=starting_point)
query = gr.Textbox(label="Query", placeholder="Ask for your city recommendation here!")
sustainable = gr.Checkbox(label="Sustainable", info="Do you want your recommendations to be sustainable "
"with regards to the environment, your starting "
"location and month of travel?")
# TODO: Add model options, month and starting point
model = gr.Dropdown(
["GPT-4", "Gemini-1.0-pro"], label="Model", info="Select your model. Will add more "
"models "
"later!",
)
output = gr.Textbox(label="Generated Results", lines=4)
with gr.Accordion("Settings", open=False):
max_new_tokens = gr.Slider(label="Max new tokens", value=1024, minimum=0, maximum=8192, step=64,
interactive=True,
visible=True, info="The maximum number of output tokens")
temperature = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.49,
interactive=True,
visible=True, info="The value used to module the logits distribution")
with gr.Group():
with gr.Row():
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
cancel_btn = gr.Button("Cancel", variant="stop")
submit_btn.click(generate_text, inputs=[query, model, sustainable, starting_point], outputs=[output])
clear_btn.click(clear, inputs=[], outputs=[query, model, output])
cancel_btn.click(clear, inputs=[], outputs=[query, model, output])
gr.Markdown("## Examples")
# gr.Examples(
# examples, inputs=[query, model], label="Examples", fn=generate_text, outputs=[output],
# cache_examples=True,
# )
return app
if __name__ == "__main__":
app = create_ui()
app.launch(show_api=False)