Spaces:
Running
Running
Ashmi Banerjee
commited on
Commit
β’
8842640
1
Parent(s):
89cd5d5
added more models, refactored a lot of stuff
Browse files- README.md +4 -3
- app.py +27 -92
- src/augmentation/prompt_generation.py +1 -1
- src/pipeline.py +18 -9
- src/text_generation/mapper.py +16 -0
- src/text_generation/model_init.py +125 -108
- src/text_generation/text_generation.py +8 -12
- src/text_generation/vertexai_setup.py +19 -2
- src/ui/__init__.py +0 -0
- src/ui/components/actions.py +30 -0
- src/ui/components/inputs.py +69 -0
- src/ui/components/static.py +139 -0
- src/ui/setup.py +19 -0
- src/ui/templates/intro.html +15 -0
README.md
CHANGED
@@ -15,18 +15,19 @@ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-
|
|
15 |
|
16 |
### TODOs
|
17 |
|
18 |
-
- [
|
19 |
|
20 |
- [x] Sustainability - database paths - move to HF
|
21 |
|
22 |
- [ ] Fix it for the new models e.g. Llama and others
|
23 |
|
24 |
- [ ] Add the space secrets to have it running online
|
25 |
-
|
26 |
- [ ] Fix the google application json file
|
27 |
|
28 |
- [ ] Make the space public
|
29 |
|
30 |
- [x] Add emissions calculation and starting point
|
31 |
- [x] Add more cities to starting point
|
32 |
-
- [ ] Experiment with the sustainability & without sustainability prompt
|
|
|
|
|
|
15 |
|
16 |
### TODOs
|
17 |
|
18 |
+
- [x] Refactor the vectordb.py - remove code duplication
|
19 |
|
20 |
- [x] Sustainability - database paths - move to HF
|
21 |
|
22 |
- [ ] Fix it for the new models e.g. Llama and others
|
23 |
|
24 |
- [ ] Add the space secrets to have it running online
|
|
|
25 |
- [ ] Fix the google application json file
|
26 |
|
27 |
- [ ] Make the space public
|
28 |
|
29 |
- [x] Add emissions calculation and starting point
|
30 |
- [x] Add more cities to starting point
|
31 |
+
- [ ] Experiment with the sustainability & without sustainability prompt
|
32 |
+
- [ ] Adapt the gradio examples to the right format
|
33 |
+
- [x] UI refactoring
|
app.py
CHANGED
@@ -1,102 +1,37 @@
|
|
1 |
-
from typing import Optional
|
2 |
import gradio as gr
|
3 |
import sys
|
4 |
|
|
|
|
|
|
|
|
|
|
|
5 |
sys.path.append("./src")
|
6 |
-
from src.pipeline import pipeline
|
7 |
-
from src.helpers.data_loaders import load_places
|
8 |
-
|
9 |
-
|
10 |
-
def clear():
|
11 |
-
return None, None, None
|
12 |
-
|
13 |
-
|
14 |
-
# Function to update the list of cities based on the selected country
|
15 |
-
def update_cities(selected_country, df):
|
16 |
-
filtered_cities = df[df['country'] == selected_country]['city'].tolist()
|
17 |
-
return gr.Dropdown(choices=filtered_cities, interactive=True) # Make it interactive as it is not by default
|
18 |
-
|
19 |
-
|
20 |
-
def generate_text(query_text, model_name: Optional[str], is_sustainable: Optional[bool], tokens: Optional[int] = 1024,
|
21 |
-
temp: Optional[float] = 0.49, starting_point: Optional[str] = "Munich"):
|
22 |
-
pipeline_response = pipeline(
|
23 |
-
query=query_text,
|
24 |
-
model_name=model_name,
|
25 |
-
sustainability=is_sustainable,
|
26 |
-
starting_point=starting_point,
|
27 |
-
)
|
28 |
-
return pipeline_response
|
29 |
|
30 |
|
31 |
def create_ui():
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
"et al.</a></p>\n "
|
55 |
-
" </p> <br>Google Cloud credits are provided for this project. </p>\n"
|
56 |
-
" ")
|
57 |
-
|
58 |
-
with gr.Group():
|
59 |
-
countries = gr.Dropdown(choices=list(df.country.unique()), multiselect=False, label="Country")
|
60 |
-
starting_point = gr.Dropdown(choices=[], multiselect=False,
|
61 |
-
label="Select your starting point for the trip!")
|
62 |
-
|
63 |
-
countries.select(fn=lambda selected_country:
|
64 |
-
update_cities(selected_country, df),
|
65 |
-
inputs=countries, outputs=starting_point)
|
66 |
-
|
67 |
-
query = gr.Textbox(label="Query", placeholder="Ask for your city recommendation here!")
|
68 |
-
sustainable = gr.Checkbox(label="Sustainable", info="Do you want your recommendations to be sustainable "
|
69 |
-
"with regards to the environment, your starting "
|
70 |
-
"location and month of travel?")
|
71 |
-
# TODO: Add model options, month and starting point
|
72 |
-
model = gr.Dropdown(
|
73 |
-
["GPT-4", "Gemini-1.0-pro"], label="Model", info="Select your model. Will add more "
|
74 |
-
"models "
|
75 |
-
"later!",
|
76 |
-
)
|
77 |
-
output = gr.Textbox(label="Generated Results", lines=4)
|
78 |
-
|
79 |
-
with gr.Accordion("Settings", open=False):
|
80 |
-
max_new_tokens = gr.Slider(label="Max new tokens", value=1024, minimum=0, maximum=8192, step=64,
|
81 |
-
interactive=True,
|
82 |
-
visible=True, info="The maximum number of output tokens")
|
83 |
-
temperature = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.49,
|
84 |
-
interactive=True,
|
85 |
-
visible=True, info="The value used to module the logits distribution")
|
86 |
-
with gr.Group():
|
87 |
-
with gr.Row():
|
88 |
-
submit_btn = gr.Button("Submit", variant="primary")
|
89 |
-
clear_btn = gr.Button("Clear", variant="secondary")
|
90 |
-
cancel_btn = gr.Button("Cancel", variant="stop")
|
91 |
-
submit_btn.click(generate_text, inputs=[query, model, sustainable, starting_point], outputs=[output])
|
92 |
-
clear_btn.click(clear, inputs=[], outputs=[query, model, output])
|
93 |
-
cancel_btn.click(clear, inputs=[], outputs=[query, model, output])
|
94 |
-
|
95 |
-
gr.Markdown("## Examples")
|
96 |
-
# gr.Examples(
|
97 |
-
# examples, inputs=[query, model], label="Examples", fn=generate_text, outputs=[output],
|
98 |
-
# cache_examples=True,
|
99 |
-
# )
|
100 |
return app
|
101 |
|
102 |
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import sys
|
3 |
|
4 |
+
from src.ui.components.actions import generate_text, clear
|
5 |
+
from src.ui.components.inputs import main_component
|
6 |
+
from src.ui.components.static import load_buttons, load_examples, model_settings
|
7 |
+
from src.ui.setup import load_html_from_file
|
8 |
+
from src.text_generation.vertexai_setup import initialize_vertexai_params
|
9 |
sys.path.append("./src")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
def create_ui():
|
13 |
+
initialize_vertexai_params()
|
14 |
+
# Path to HTML file
|
15 |
+
html_file_path = 'src/ui/templates/intro.html'
|
16 |
+
|
17 |
+
# Create the Gradio HTML component
|
18 |
+
html_content = load_html_from_file(html_file_path)
|
19 |
+
with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.green, secondary_hue=gr.themes.colors.blue)) as app:
|
20 |
+
|
21 |
+
gr.HTML(html_content)
|
22 |
+
country, starting_point, query, sustainable, model = main_component()
|
23 |
+
output = gr.Textbox(label="Generated Results", lines=4)
|
24 |
+
max_new_tokens, temperature = model_settings()
|
25 |
+
|
26 |
+
# Load the buttons for the interface
|
27 |
+
load_buttons(query, model,
|
28 |
+
sustainable, starting_point,
|
29 |
+
max_new_tokens, temperature,
|
30 |
+
output,
|
31 |
+
generate_text_fn=generate_text,
|
32 |
+
clear_fn=clear)
|
33 |
+
# Load the examples for the interface
|
34 |
+
# load_examples(country, starting_point, query, model, sustainable, output, generate_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
return app
|
36 |
|
37 |
|
src/augmentation/prompt_generation.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from information_retrieval import info_retrieval as ir
|
2 |
import logging
|
3 |
|
4 |
logger = logging.getLogger(__name__)
|
|
|
1 |
+
from src.information_retrieval import info_retrieval as ir
|
2 |
import logging
|
3 |
|
4 |
logger = logging.getLogger(__name__)
|
src/pipeline.py
CHANGED
@@ -15,14 +15,17 @@ from text_generation.models import (
|
|
15 |
Phi3SmallInstruct,
|
16 |
GPT4,
|
17 |
Gemini,
|
|
|
18 |
)
|
19 |
from text_generation import text_generation as tg
|
20 |
import logging
|
21 |
|
22 |
logger = logging.getLogger(__name__)
|
23 |
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
|
|
24 |
|
25 |
TEST_DIR = "../tests/"
|
|
|
26 |
MODELS = {
|
27 |
'GPT-4': GPT4,
|
28 |
'Llama3': Llama3,
|
@@ -33,11 +36,15 @@ MODELS = {
|
|
33 |
'Mistral-Instruct': MistralInstruct,
|
34 |
'Llama3.1-Instruct': Llama3Point1Instruct,
|
35 |
'Phi3-Instruct': Phi3SmallInstruct,
|
36 |
-
|
|
|
37 |
}
|
38 |
|
39 |
|
40 |
-
def pipeline(starting_point: str,
|
|
|
|
|
|
|
41 |
"""
|
42 |
|
43 |
Executes the entire RAG pipeline, provided the query and model class name.
|
@@ -53,9 +60,11 @@ def pipeline(starting_point: str, query: str, model_name: str, test: int = 0, **
|
|
53 |
|
54 |
|
55 |
"""
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
59 |
context_params = {
|
60 |
'limit': 5,
|
61 |
'reranking': 0,
|
@@ -97,9 +106,9 @@ def pipeline(starting_point: str, query: str, model_name: str, test: int = 0, **
|
|
97 |
|
98 |
# return prompt
|
99 |
|
100 |
-
logger.info(f"Augmented prompt, initializing {
|
101 |
try:
|
102 |
-
response = tg.generate_response(
|
103 |
except Exception as e:
|
104 |
exc_type, exc_obj, exc_tb = sys.exc_info()
|
105 |
logger.info(f"Error at line {exc_tb.tb_lineno} while generating response: {e}")
|
@@ -117,11 +126,11 @@ if __name__ == "__main__":
|
|
117 |
# suggest " \ "some " \ "European cities? "
|
118 |
sample_query = "I'm planning a trip in July and enjoy beaches, nightlife, and vibrant cities. Recommend some " \
|
119 |
"cities. "
|
120 |
-
|
121 |
|
122 |
pipeline_response = pipeline(
|
123 |
query=sample_query,
|
124 |
-
model_name=
|
125 |
sustainability=1
|
126 |
)
|
127 |
|
|
|
15 |
Phi3SmallInstruct,
|
16 |
GPT4,
|
17 |
Gemini,
|
18 |
+
Claude3Point5Sonnet,
|
19 |
)
|
20 |
from text_generation import text_generation as tg
|
21 |
import logging
|
22 |
|
23 |
logger = logging.getLogger(__name__)
|
24 |
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
25 |
+
from src.text_generation.mapper import MODEL_MAPPER
|
26 |
|
27 |
TEST_DIR = "../tests/"
|
28 |
+
|
29 |
MODELS = {
|
30 |
'GPT-4': GPT4,
|
31 |
'Llama3': Llama3,
|
|
|
36 |
'Mistral-Instruct': MistralInstruct,
|
37 |
'Llama3.1-Instruct': Llama3Point1Instruct,
|
38 |
'Phi3-Instruct': Phi3SmallInstruct,
|
39 |
+
"Gemini-1.0-pro": Gemini,
|
40 |
+
"Claude3.5-sonnet": Claude3Point5Sonnet,
|
41 |
}
|
42 |
|
43 |
|
44 |
+
def pipeline(starting_point: str,
|
45 |
+
query: str,
|
46 |
+
model_name: str,
|
47 |
+
test: int = 0, **params):
|
48 |
"""
|
49 |
|
50 |
Executes the entire RAG pipeline, provided the query and model class name.
|
|
|
60 |
|
61 |
|
62 |
"""
|
63 |
+
try:
|
64 |
+
model_id = MODEL_MAPPER[model_name]
|
65 |
+
except KeyError:
|
66 |
+
logger.error(f"Model {model_name} not found in the model mapper.")
|
67 |
+
model_id = MODEL_MAPPER['Gemini-1.0-pro']
|
68 |
context_params = {
|
69 |
'limit': 5,
|
70 |
'reranking': 0,
|
|
|
106 |
|
107 |
# return prompt
|
108 |
|
109 |
+
logger.info(f"Augmented prompt, initializing {model_name} and generating response..")
|
110 |
try:
|
111 |
+
response = tg.generate_response(model_id, prompt, **params)
|
112 |
except Exception as e:
|
113 |
exc_type, exc_obj, exc_tb = sys.exc_info()
|
114 |
logger.info(f"Error at line {exc_tb.tb_lineno} while generating response: {e}")
|
|
|
126 |
# suggest " \ "some " \ "European cities? "
|
127 |
sample_query = "I'm planning a trip in July and enjoy beaches, nightlife, and vibrant cities. Recommend some " \
|
128 |
"cities. "
|
129 |
+
model_name = "GPT-4"
|
130 |
|
131 |
pipeline_response = pipeline(
|
132 |
query=sample_query,
|
133 |
+
model_name=model_name,
|
134 |
sustainability=1
|
135 |
)
|
136 |
|
src/text_generation/mapper.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL_MAPPER = {
|
2 |
+
'Gemma-2-9B-it': "google/gemma-2-9b-it",
|
3 |
+
'Gemma-2-2B-it': "google/gemma-2-2b-it",
|
4 |
+
"Gemini-1.0-pro": "gemini-1.0-pro",
|
5 |
+
"Gemini-1.5-Flash": "gemini-1.5-flash-001",
|
6 |
+
"Gemini-1.5-Pro": "gemini-1.5-pro-001",
|
7 |
+
"Claude3.5-sonnet": "claude-3-5-sonnet@20240620",
|
8 |
+
'GPT-4': "gpt-4o-mini",
|
9 |
+
'Llama3': "meta-llama/Meta-Llama-3-8B",
|
10 |
+
'Mistral': "mistralai/Mistral-7B",
|
11 |
+
'Llama3.1': "meta-llama/Meta-Llama-3.1-8B",
|
12 |
+
'Llama3-Instruct': "meta-llama/Meta-Llama-3-8B-Instruct",
|
13 |
+
'Mistral-Instruct': "mistralai/Mistral-7B-Instruct-v0.1",
|
14 |
+
'Llama3.1-Instruct': "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
15 |
+
'Phi3-Instruct': "microsoft/Phi-3-small-128k-instruct",
|
16 |
+
}
|
src/text_generation/model_init.py
CHANGED
@@ -1,123 +1,140 @@
|
|
1 |
-
import torch
|
2 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
3 |
from vertexai.generative_models import GenerativeModel
|
4 |
-
|
5 |
from dotenv import load_dotenv
|
6 |
from anthropic import AnthropicVertex
|
7 |
import os
|
8 |
from openai import OpenAI
|
9 |
-
from src.text_generation.vertexai_setup import initialize_vertexai_params
|
|
|
10 |
|
|
|
11 |
load_dotenv()
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
class LLMBaseClass:
|
19 |
"""
|
20 |
-
Base
|
21 |
-
the generate method
|
22 |
-
|
23 |
"""
|
24 |
|
25 |
-
def __init__(self, model_id) -> None:
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
72 |
|
73 |
def generate(self, messages):
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
# database/wikivoyage/wikivoyage_listings.lance/data/e2940f51-d754-4b54-a688-004bdb8e7aa2.lance
|
|
|
|
|
|
|
1 |
from vertexai.generative_models import GenerativeModel
|
|
|
2 |
from dotenv import load_dotenv
|
3 |
from anthropic import AnthropicVertex
|
4 |
import os
|
5 |
from openai import OpenAI
|
6 |
+
from src.text_generation.vertexai_setup import initialize_vertexai_params, get_default_config
|
7 |
+
from huggingface_hub import InferenceClient
|
8 |
|
9 |
+
# Load environment variables
|
10 |
load_dotenv()
|
11 |
+
OAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
12 |
+
|
13 |
+
|
14 |
+
def _validate_tokens(max_tokens: int) -> int:
|
15 |
+
"""
|
16 |
+
Validates the max_tokens parameter. Ensures it's within a valid range (1 to 8192).
|
17 |
+
If invalid, defaults to 8192.
|
18 |
+
"""
|
19 |
+
if 1 <= max_tokens <= 8192:
|
20 |
+
return max_tokens
|
21 |
+
return 8192
|
22 |
+
|
23 |
+
|
24 |
+
def _validate_temperature(temp: float) -> float:
|
25 |
+
"""
|
26 |
+
Validates the temperature parameter. Ensures it's within a valid range (0 to 1).
|
27 |
+
If invalid, defaults to 0.49.
|
28 |
+
"""
|
29 |
+
if 0 <= temp <= 1:
|
30 |
+
return temp
|
31 |
+
return 0.49
|
32 |
|
33 |
|
34 |
class LLMBaseClass:
|
35 |
"""
|
36 |
+
Base class for text generation. Users provide the HF model ID or other model identifiers
|
37 |
+
and can call the generate method to get responses.
|
|
|
38 |
"""
|
39 |
|
40 |
+
def __init__(self, model_id: str, max_tokens: int, temp: float) -> None:
|
41 |
+
self.model_id = model_id
|
42 |
+
self.api_key = None
|
43 |
+
self.temp = _validate_temperature(temp)
|
44 |
+
self.tokens = _validate_tokens(max_tokens)
|
45 |
+
self.model = self._initialize_model()
|
46 |
+
|
47 |
+
def _initialize_model(self):
|
48 |
+
"""
|
49 |
+
Initialize the model based on the provided model ID.
|
50 |
+
"""
|
51 |
+
if self.model_id == "gpt-4o-mini":
|
52 |
+
return self._initialize_openai_model()
|
53 |
+
elif self.model_id == "claude-3-5-sonnet@20240620":
|
54 |
+
return self._initialize_claude_model()
|
55 |
+
elif self.model_id in ["claude-3-5-sonnet@20240620",
|
56 |
+
"gemini-1.0-pro", "gemini-1.5-flash-001", "gemini-1.5-pro-001"]:
|
57 |
+
return self._initialize_vertexai_model()
|
58 |
+
else:
|
59 |
+
return self._initialize_hf_model()
|
60 |
+
|
61 |
+
def _initialize_openai_model(self):
|
62 |
+
"""
|
63 |
+
Initialize OpenAI model.
|
64 |
+
"""
|
65 |
+
self.api_key = OAI_API_KEY
|
66 |
+
return OpenAI(api_key=self.api_key)
|
67 |
+
|
68 |
+
def _initialize_claude_model(self):
|
69 |
+
"""
|
70 |
+
Initialize Claude model using Anthropic via Vertex AI.
|
71 |
+
"""
|
72 |
+
self.api_key = os.getenv("VERTEXAI_PROJECTID")
|
73 |
+
return AnthropicVertex(region="europe-west1", project_id=self.api_key)
|
74 |
+
|
75 |
+
def _initialize_vertexai_model(self):
|
76 |
+
"""
|
77 |
+
Initialize Google Gemini model using Vertex AI.
|
78 |
+
"""
|
79 |
+
default_gen_config, default_safe_settings = get_default_config()
|
80 |
+
gen_config = {
|
81 |
+
"temperature": self.temp,
|
82 |
+
"max_output_tokens": self.tokens,
|
83 |
+
}
|
84 |
+
return GenerativeModel(self.model_id,
|
85 |
+
generation_config=default_gen_config if gen_config is None else gen_config,
|
86 |
+
safety_settings=default_safe_settings)
|
87 |
+
|
88 |
+
def _initialize_hf_model(self):
|
89 |
+
self.api_key = os.getenv("HF_TOKEN")
|
90 |
+
return InferenceClient(token=self.api_key, model=self.model_id)
|
91 |
|
92 |
def generate(self, messages):
|
93 |
+
"""
|
94 |
+
Generate responses based on the model type and provided messages.
|
95 |
+
"""
|
96 |
+
if self.model_id == "gpt-4o-mini":
|
97 |
+
return self._generate_openai(messages)
|
98 |
+
elif self.model_id in ["claude-3-5-sonnet@20240620",
|
99 |
+
"gemini-1.0-pro", "gemini-1.5-flash-001", "gemini-1.5-pro-001"]:
|
100 |
+
return self._generate_vertexai(messages)
|
101 |
+
else:
|
102 |
+
return self._generate_hf(messages)
|
103 |
+
|
104 |
+
def _generate_openai(self, messages):
|
105 |
+
"""
|
106 |
+
Generate responses using OpenAI model.
|
107 |
+
"""
|
108 |
+
completion = self.model.chat.completions.create(
|
109 |
+
model=self.model_id,
|
110 |
+
messages=messages,
|
111 |
+
temperature=self.temp,
|
112 |
+
max_tokens=self.tokens,
|
113 |
+
)
|
114 |
+
return completion.choices[0].message.content
|
115 |
+
|
116 |
+
def _generate_vertexai(self, messages):
|
117 |
+
"""
|
118 |
+
Generate responses using Claude or Gemini models via Vertex AI.
|
119 |
+
"""
|
120 |
+
initialize_vertexai_params()
|
121 |
+
content = " ".join([message["content"] for message in messages])
|
122 |
+
if "claude" in self.model_id:
|
123 |
+
message = self.model.messages.create(
|
124 |
+
max_tokens=self.tokens,
|
125 |
+
model=self.model_id,
|
126 |
+
messages=[{"role": "user", "content": content}],
|
127 |
+
)
|
128 |
+
return message.content[0].text
|
129 |
+
else:
|
130 |
+
response = self.model.generate_content(content)
|
131 |
+
return response.text
|
132 |
+
|
133 |
+
def _generate_hf(self, messages):
|
134 |
+
"""
|
135 |
+
Generate responses using Hugging Face models.
|
136 |
+
"""
|
137 |
+
response = self.model.chat_completion(
|
138 |
+
messages=[{"role": "user", "content": messages[0]["content"] + messages[1]["content"]}],
|
139 |
+
max_tokens=self.tokens, temperature=self.temp)
|
140 |
+
return response.choices[0].message.content
|
|
|
|
src/text_generation/text_generation.py
CHANGED
@@ -1,20 +1,13 @@
|
|
1 |
-
from augmentation import prompt_generation as pg
|
2 |
-
from information_retrieval import info_retrieval as ir
|
3 |
-
from src.text_generation.
|
4 |
-
Llama3,
|
5 |
-
Mistral,
|
6 |
-
Gemma2,
|
7 |
-
Llama3Point1,
|
8 |
-
GPT4,
|
9 |
-
Claude3Point5Sonnet,
|
10 |
-
)
|
11 |
import logging
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
15 |
|
16 |
|
17 |
-
def generate_response(model, prompt):
|
18 |
"""
|
19 |
|
20 |
Function that initializes the LLM class and calls the generate function.
|
@@ -26,7 +19,9 @@ def generate_response(model, prompt):
|
|
26 |
"""
|
27 |
|
28 |
logger.info(f"Initializing LLM configuration for {model}")
|
29 |
-
llm = model
|
|
|
|
|
30 |
|
31 |
logger.info("Generating response")
|
32 |
try:
|
@@ -73,6 +68,7 @@ def test(model):
|
|
73 |
logger.info(f"Augmented prompt, initializing {model} and generating response..")
|
74 |
try:
|
75 |
response = generate_response(model, without_sfairness)
|
|
|
76 |
except Exception as e:
|
77 |
logger.info(f"Error while generating response: {e}")
|
78 |
return None
|
|
|
1 |
+
from src.augmentation import prompt_generation as pg
|
2 |
+
from src.information_retrieval import info_retrieval as ir
|
3 |
+
from src.text_generation.model_init import LLMBaseClass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import logging
|
5 |
|
6 |
logger = logging.getLogger(__name__)
|
7 |
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
8 |
|
9 |
|
10 |
+
def generate_response(model, prompt: list, **params):
|
11 |
"""
|
12 |
|
13 |
Function that initializes the LLM class and calls the generate function.
|
|
|
19 |
"""
|
20 |
|
21 |
logger.info(f"Initializing LLM configuration for {model}")
|
22 |
+
llm = LLMBaseClass(model_id=model,
|
23 |
+
max_tokens=params['max_tokens'],
|
24 |
+
temp=params['temperature'])
|
25 |
|
26 |
logger.info("Generating response")
|
27 |
try:
|
|
|
68 |
logger.info(f"Augmented prompt, initializing {model} and generating response..")
|
69 |
try:
|
70 |
response = generate_response(model, without_sfairness)
|
71 |
+
print(response)
|
72 |
except Exception as e:
|
73 |
logger.info(f"Error while generating response: {e}")
|
74 |
return None
|
src/text_generation/vertexai_setup.py
CHANGED
@@ -5,6 +5,11 @@ import vertexai
|
|
5 |
import os
|
6 |
import json
|
7 |
import base64
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
load_dotenv()
|
10 |
if "VERTEXAI_PROJECTID" in os.environ:
|
@@ -12,7 +17,7 @@ if "VERTEXAI_PROJECTID" in os.environ:
|
|
12 |
|
13 |
|
14 |
def decode_service_key():
|
15 |
-
encoded_key = os.environ["
|
16 |
original_service_key = json.loads(base64.b64decode(encoded_key).decode('utf-8'))
|
17 |
if original_service_key:
|
18 |
return original_service_key
|
@@ -20,7 +25,6 @@ def decode_service_key():
|
|
20 |
|
21 |
|
22 |
def initialize_vertexai_params(location: Optional[str] = "us-central1"):
|
23 |
-
|
24 |
creds_file_name = os.getcwd() + "/.config/gcp_default_credentials.json"
|
25 |
print(creds_file_name)
|
26 |
if not os.path.exists(os.path.dirname(creds_file_name)):
|
@@ -35,3 +39,16 @@ def initialize_vertexai_params(location: Optional[str] = "us-central1"):
|
|
35 |
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
36 |
)
|
37 |
vertexai.init(project=VERTEXAI_PROJECT, location=location)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import os
|
6 |
import json
|
7 |
import base64
|
8 |
+
import logging
|
9 |
+
from vertexai import generative_models
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
13 |
|
14 |
load_dotenv()
|
15 |
if "VERTEXAI_PROJECTID" in os.environ:
|
|
|
17 |
|
18 |
|
19 |
def decode_service_key():
|
20 |
+
encoded_key = os.environ["GOOGLE_CREDENTIALS"]
|
21 |
original_service_key = json.loads(base64.b64decode(encoded_key).decode('utf-8'))
|
22 |
if original_service_key:
|
23 |
return original_service_key
|
|
|
25 |
|
26 |
|
27 |
def initialize_vertexai_params(location: Optional[str] = "us-central1"):
|
|
|
28 |
creds_file_name = os.getcwd() + "/.config/gcp_default_credentials.json"
|
29 |
print(creds_file_name)
|
30 |
if not os.path.exists(os.path.dirname(creds_file_name)):
|
|
|
39 |
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
40 |
)
|
41 |
vertexai.init(project=VERTEXAI_PROJECT, location=location)
|
42 |
+
logger.info("Vertex AI initialized")
|
43 |
+
|
44 |
+
|
45 |
+
def get_default_config() -> tuple[dict, dict]:
|
46 |
+
default_gen_config = {
|
47 |
+
"temperature": 0.49,
|
48 |
+
"max_output_tokens": 1024,
|
49 |
+
}
|
50 |
+
default_safety_settings = {
|
51 |
+
generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
52 |
+
generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
53 |
+
}
|
54 |
+
return default_gen_config, default_safety_settings
|
src/ui/__init__.py
ADDED
File without changes
|
src/ui/components/actions.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from src.pipeline import pipeline
|
4 |
+
|
5 |
+
|
6 |
+
def clear():
|
7 |
+
return None, None, None, None, None
|
8 |
+
|
9 |
+
|
10 |
+
def generate_text(query_text: str,
|
11 |
+
model_name: Optional[str],
|
12 |
+
is_sustainable: Optional[bool],
|
13 |
+
starting_point: Optional[str],
|
14 |
+
max_tokens: Optional[int] = 1024,
|
15 |
+
temp: Optional[float] = 0.49,
|
16 |
+
):
|
17 |
+
model_params = {
|
18 |
+
'max_tokens': max_tokens,
|
19 |
+
'temperature': temp
|
20 |
+
}
|
21 |
+
pipeline_response = pipeline(
|
22 |
+
query=query_text,
|
23 |
+
model_name=model_name,
|
24 |
+
sustainability=is_sustainable,
|
25 |
+
starting_point=starting_point,
|
26 |
+
**model_params
|
27 |
+
)
|
28 |
+
if pipeline_response is None:
|
29 |
+
return "Error while generating response! Please try again."
|
30 |
+
return pipeline_response
|
src/ui/components/inputs.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
from src.helpers.data_loaders import load_places
|
6 |
+
from src.text_generation.mapper import MODEL_MAPPER
|
7 |
+
|
8 |
+
|
9 |
+
def get_places():
|
10 |
+
data_file = "cities/eu_200_cities.csv"
|
11 |
+
df = load_places(data_file)
|
12 |
+
df = df.sort_values(by=['country', 'city'])
|
13 |
+
return df
|
14 |
+
|
15 |
+
|
16 |
+
# Function to update the list of cities based on the selected country
|
17 |
+
def update_cities(selected_country, df):
|
18 |
+
filtered_cities = df[df['country'] == selected_country]['city'].tolist()
|
19 |
+
return gr.Dropdown(choices=filtered_cities, interactive=True) # Make it interactive as it is not by default
|
20 |
+
|
21 |
+
|
22 |
+
def main_component() -> Tuple[gr.Dropdown, gr.Dropdown, gr.Textbox, gr.Checkbox, gr.Dropdown]:
|
23 |
+
"""
|
24 |
+
Creates the main Gradio interface components and returns them.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
Tuple containing:
|
28 |
+
- countries: Dropdown for selecting the country.
|
29 |
+
- starting_point: Dropdown for selecting the starting point.
|
30 |
+
- query: Textbox for entering the user query.
|
31 |
+
- sustainable: Checkbox for sustainable travel.
|
32 |
+
- model: Dropdown for selecting the model.
|
33 |
+
"""
|
34 |
+
df = get_places()
|
35 |
+
country_names = list(df.country.unique())
|
36 |
+
with gr.Group():
|
37 |
+
# Country selection dropdown
|
38 |
+
countries = gr.Dropdown(choices=country_names, multiselect=False, label="Country")
|
39 |
+
|
40 |
+
# Starting point selection dropdown
|
41 |
+
starting_point = gr.Dropdown(choices=[], multiselect=False, label="Select your starting point for the trip!")
|
42 |
+
|
43 |
+
# When a country is selected, update the starting point options
|
44 |
+
countries.select(
|
45 |
+
fn=lambda selected_country: update_cities(selected_country, df),
|
46 |
+
inputs=countries,
|
47 |
+
outputs=starting_point
|
48 |
+
)
|
49 |
+
|
50 |
+
# User query input
|
51 |
+
query = gr.Textbox(label="Query", placeholder="Ask for your city recommendation here!")
|
52 |
+
|
53 |
+
# Checkbox for sustainable travel option
|
54 |
+
sustainable = gr.Checkbox(
|
55 |
+
label="Sustainable",
|
56 |
+
info="Do you want your recommendations to be sustainable with regards to the environment, "
|
57 |
+
"your starting location, and month of travel?"
|
58 |
+
)
|
59 |
+
|
60 |
+
models = list(MODEL_MAPPER.keys())[:5]
|
61 |
+
# Model selection dropdown
|
62 |
+
model = gr.Dropdown(
|
63 |
+
choices=models,
|
64 |
+
label="Model",
|
65 |
+
info="Select your model. The model will generate the recommendations based on your query."
|
66 |
+
)
|
67 |
+
|
68 |
+
# Return all the components individually
|
69 |
+
return countries, starting_point, query, sustainable, model
|
src/ui/components/static.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Tuple, Optional
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
from src.ui.components.inputs import get_places, update_cities, main_component
|
6 |
+
|
7 |
+
|
8 |
+
def load_examples(
|
9 |
+
country: gr.components.Component,
|
10 |
+
starting_point: gr.components.Component,
|
11 |
+
query: gr.components.Component,
|
12 |
+
model: gr.components.Component,
|
13 |
+
is_sustainable: gr.components.Component,
|
14 |
+
output: gr.components.Component,
|
15 |
+
generate_text_fn: Callable[
|
16 |
+
[str, Optional[str], Optional[bool], Optional[int], Optional[float], Optional[str]],
|
17 |
+
str],
|
18 |
+
) -> gr.Examples:
|
19 |
+
|
20 |
+
df = get_places()
|
21 |
+
country_names = list(df.country.unique())
|
22 |
+
|
23 |
+
# Example data
|
24 |
+
country_example = "USA"
|
25 |
+
starting_point_example = "New York"
|
26 |
+
query_example = "What are some top tourist attractions?"
|
27 |
+
sustainable_example = True
|
28 |
+
model_example = "GPT-4"
|
29 |
+
|
30 |
+
# Update the starting point options based on the example country
|
31 |
+
starting_point_choices = update_cities(country_example, df)
|
32 |
+
|
33 |
+
# Set the choices and default value for the starting point dropdown
|
34 |
+
starting_point.choices = starting_point_choices
|
35 |
+
starting_point.value = starting_point_example
|
36 |
+
|
37 |
+
# Provide examples
|
38 |
+
examples = [
|
39 |
+
[country_example, starting_point_example, query_example, sustainable_example, model_example]
|
40 |
+
]
|
41 |
+
|
42 |
+
# Create a Gradio Examples component
|
43 |
+
gr.Examples(
|
44 |
+
examples=examples,
|
45 |
+
inputs=[country, starting_point, query, is_sustainable, model],
|
46 |
+
outputs=[]
|
47 |
+
)
|
48 |
+
|
49 |
+
|
50 |
+
def load_buttons(
|
51 |
+
query: gr.components.Component,
|
52 |
+
model: gr.components.Component,
|
53 |
+
sustainable: gr.components.Component,
|
54 |
+
starting_point: gr.components.Component,
|
55 |
+
max_new_tokens: gr.components.Component,
|
56 |
+
temperature: gr.components.Component,
|
57 |
+
output: gr.components.Component,
|
58 |
+
generate_text_fn: Callable[
|
59 |
+
[str, Optional[str], Optional[bool], Optional[int], Optional[float], Optional[str]],
|
60 |
+
str],
|
61 |
+
clear_fn: Callable[[], None]
|
62 |
+
) -> gr.Group:
|
63 |
+
"""
|
64 |
+
Load and return buttons for the Gradio interface.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
query: The input component for user queries.
|
68 |
+
model: The input component for selecting the model.
|
69 |
+
sustainable: The input component for sustainable travel options.
|
70 |
+
starting_point: The input component for the user's starting point.
|
71 |
+
output: The output component for displaying the generated text.
|
72 |
+
generate_text_fn: The function to be called on submit to generate text.
|
73 |
+
clear_fn: The function to clear the input fields and output.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
Gradio Group component containing the buttons.
|
77 |
+
"""
|
78 |
+
with gr.Group() as btns:
|
79 |
+
with gr.Row():
|
80 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
81 |
+
clear_btn = gr.Button("Clear", variant="secondary")
|
82 |
+
cancel_btn = gr.Button("Cancel", variant="stop")
|
83 |
+
|
84 |
+
# Bind actions to the buttons
|
85 |
+
submit_btn.click(
|
86 |
+
fn=generate_text_fn, # Function to generate text
|
87 |
+
inputs=[query, model, sustainable,
|
88 |
+
starting_point, max_new_tokens,
|
89 |
+
temperature], # Input components for generation
|
90 |
+
outputs=[output] # Output component
|
91 |
+
)
|
92 |
+
clear_btn.click(
|
93 |
+
fn=clear_fn, # Function to clear inputs
|
94 |
+
inputs=[], # No inputs for clearing
|
95 |
+
outputs=[query, model, sustainable, starting_point, output] # Clear all inputs and output
|
96 |
+
)
|
97 |
+
cancel_btn.click(
|
98 |
+
fn=clear_fn, # Function to cancel and clear inputs
|
99 |
+
inputs=[], # No inputs for cancel
|
100 |
+
outputs=[query, model, sustainable, starting_point, output] # Clear all inputs and output
|
101 |
+
)
|
102 |
+
return btns
|
103 |
+
|
104 |
+
|
105 |
+
def model_settings() -> Tuple[gr.Slider, gr.Slider]:
|
106 |
+
"""
|
107 |
+
Creates the model settings components and returns them.
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
Tuple containing:
|
111 |
+
- max_new_tokens: Slider for setting the maximum number of new tokens.
|
112 |
+
- temperature: Slider for setting the temperature.
|
113 |
+
"""
|
114 |
+
with gr.Accordion("Settings", open=False):
|
115 |
+
# Slider for maximum number of new tokens
|
116 |
+
max_new_tokens = gr.Slider(
|
117 |
+
label="Max new tokens",
|
118 |
+
value=1024,
|
119 |
+
minimum=0,
|
120 |
+
maximum=8192,
|
121 |
+
step=64,
|
122 |
+
interactive=True,
|
123 |
+
visible=True,
|
124 |
+
info="The maximum number of output tokens"
|
125 |
+
)
|
126 |
+
|
127 |
+
# Slider for temperature
|
128 |
+
temperature = gr.Slider(
|
129 |
+
label="Temperature",
|
130 |
+
step=0.01,
|
131 |
+
minimum=0.01,
|
132 |
+
maximum=1.0,
|
133 |
+
value=0.49,
|
134 |
+
interactive=True,
|
135 |
+
visible=True,
|
136 |
+
info="The value used to modulate the logits distribution"
|
137 |
+
)
|
138 |
+
|
139 |
+
return max_new_tokens, temperature
|
src/ui/setup.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.helpers.data_loaders import load_places
|
2 |
+
|
3 |
+
|
4 |
+
# TODO add the gcp application json file loading
|
5 |
+
# TODO examples
|
6 |
+
|
7 |
+
def load_html_from_file(file_path: str) -> str:
|
8 |
+
"""
|
9 |
+
Reads the HTML content from a file and returns it as a string.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
file_path (str): The path to the HTML file.
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
str: The HTML content of the file.
|
16 |
+
"""
|
17 |
+
with open(file_path, 'r') as file:
|
18 |
+
return file.read()
|
19 |
+
|
src/ui/templates/intro.html
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!-- intro.html -->
|
2 |
+
<body>
|
3 |
+
<h1 style='font-size:xx-large; color: green; text-align: center'>π Green City Finder π</h1>
|
4 |
+
<h3 style="text-align: center">AI Sprint 2024 submissions by Ashmi Banerjee.</h3>
|
5 |
+
<br>
|
6 |
+
<p style="text-align: justify">We're testing the compatibility of
|
7 |
+
Retrieval Augmented Generation (RAG) implementations with Google's <b>Gemma-2b-it</b> & <b>Gemini 1.0 Pro</b>
|
8 |
+
models through HuggingFace and VertexAI, respectively, to generate sustainable travel recommendations.
|
9 |
+
We use the Wikivoyage dataset to provide city recommendations based on user queries. The vector embeddings are
|
10 |
+
stored in a VectorDB (LanceDB) hosted in Google Cloud.
|
11 |
+
</p>
|
12 |
+
<p style="text-align: justify">Sustainability is calculated based on the work by <a href="https://arxiv.org/abs/2403.18604">Banerjee et al.</a></p>
|
13 |
+
<br>
|
14 |
+
<p style="text-align: justify">Google Cloud credits are provided for this project.</p>
|
15 |
+
</body>
|