Spaces:
Running
Running
Ashmi Banerjee
commited on
Commit
•
adbebe0
1
Parent(s):
bfcfb0e
update sustainability prompt and post processing
Browse files- README.md +5 -3
- src/augmentation/prompt_generation.py +7 -33
- src/augmentation/prompts.py +35 -0
- src/pipeline.py +14 -2
- src/post_processing/post_process.py +19 -0
- src/text_generation/mapper.py +2 -2
- src/text_generation/model_init.py +1 -0
- src/ui/components/inputs.py +1 -1
README.md
CHANGED
@@ -19,7 +19,7 @@ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-
|
|
19 |
|
20 |
- [x] Sustainability - database paths - move to HF
|
21 |
|
22 |
-
- [
|
23 |
|
24 |
- [ ] Add the space secrets to have it running online
|
25 |
- [ ] Fix the google application json file
|
@@ -28,6 +28,8 @@ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-
|
|
28 |
|
29 |
- [x] Add emissions calculation and starting point
|
30 |
- [x] Add more cities to starting point
|
31 |
-
- [
|
32 |
- [ ] Adapt the gradio examples to the right format
|
33 |
-
- [x] UI refactoring
|
|
|
|
|
|
19 |
|
20 |
- [x] Sustainability - database paths - move to HF
|
21 |
|
22 |
+
- [x] Fix it for the new models e.g. Llama and others
|
23 |
|
24 |
- [ ] Add the space secrets to have it running online
|
25 |
- [ ] Fix the google application json file
|
|
|
28 |
|
29 |
- [x] Add emissions calculation and starting point
|
30 |
- [x] Add more cities to starting point
|
31 |
+
- [x] Experiment with the sustainability & without sustainability prompt
|
32 |
- [ ] Adapt the gradio examples to the right format
|
33 |
+
- [x] UI refactoring
|
34 |
+
- [ ] Log the different queries and the results
|
35 |
+
- [ ] Clear does not clear the country
|
src/augmentation/prompt_generation.py
CHANGED
@@ -3,7 +3,7 @@ import logging
|
|
3 |
|
4 |
logger = logging.getLogger(__name__)
|
5 |
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
6 |
-
|
7 |
|
8 |
def generate_prompt(query, context, template=None):
|
9 |
"""
|
@@ -20,22 +20,7 @@ def generate_prompt(query, context, template=None):
|
|
20 |
if template:
|
21 |
SYS_PROMPT = template
|
22 |
else:
|
23 |
-
SYS_PROMPT =
|
24 |
-
based on the user's question. You should use the provided contexts to suggest a list of the 3 best cities
|
25 |
-
that are best suited to the user's question, as well as the month of travel. If the user has already provided
|
26 |
-
the month of travel in the question, use the same month; otherwise, provide the ideal month of travel. Each
|
27 |
-
recommendation should also contain an explanation of why it is being recommended, based on the context. Your
|
28 |
-
answer must begin with "I recommend " followed by the city name and why you recommended it. Your answers are
|
29 |
-
correct, high-quality, and written by a domain expert. If the provided context does not contain the answer,
|
30 |
-
simply state, "The provided context does not have the answer." """
|
31 |
-
|
32 |
-
USER_PROMPT = """ Question: {} Which city do you recommend and why?
|
33 |
-
|
34 |
-
Context: Here are the options: {}
|
35 |
-
|
36 |
-
Answer:
|
37 |
-
|
38 |
-
"""
|
39 |
|
40 |
formatted_prompt = f"{USER_PROMPT.format(query, context)}"
|
41 |
messages = [
|
@@ -108,7 +93,7 @@ def format_context(context):
|
|
108 |
return formatted_context
|
109 |
|
110 |
|
111 |
-
def augment_prompt(query, context, **params):
|
112 |
"""
|
113 |
Function that accepts the user query as input, obtains relevant documents and augments the prompt with the
|
114 |
retrieved context, which can be passed to the LLM.
|
@@ -121,27 +106,16 @@ def augment_prompt(query, context, **params):
|
|
121 |
"""
|
122 |
|
123 |
# what about the cities without s-fairness scores? i.e. they don't have seasonality data
|
124 |
-
|
125 |
-
prompt_with_sustainability =
|
126 |
-
for travel based on the user's question. You should use the provided contexts to suggest the city that is best
|
127 |
-
suited to the user's question. You recommend a list of the top 3 most sustainable cities to the user, as well as
|
128 |
-
the best month of travel. Each recommendation should also contain an explanation of why it is being recommended,
|
129 |
-
on sustainability grounds based on the context. The context contains a sustainability score for each city,
|
130 |
-
also known as the s-fairness score, along with the ideal month of travel. A lower s-fairness score indicates that
|
131 |
-
the city is a better destination for the month provided. A city without a sustainability score should not be
|
132 |
-
considered. You should only consider the s-fairness score while choosing the best city. However, your answer
|
133 |
-
should not contain the numeric score itself or any mention of the sustainability score. Your answer must begin
|
134 |
-
with "I recommend " followed by the city names and why you recommended it. Your answers are correct, high-quality,
|
135 |
-
and written by a domain expert. If the provided context does not contain the answer, simply state, "The provided
|
136 |
-
context does not have the answer. """
|
137 |
|
138 |
# format context
|
139 |
formatted_context = format_context(context)
|
140 |
|
141 |
if "sustainability" in params["params"] and params["params"]["sustainability"]:
|
142 |
-
prompt = generate_prompt(
|
143 |
else:
|
144 |
-
prompt = generate_prompt(
|
145 |
|
146 |
return prompt
|
147 |
|
|
|
3 |
|
4 |
logger = logging.getLogger(__name__)
|
5 |
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
6 |
+
from src.augmentation.prompts import SYSTEM_PROMPT, SUSTAINABILITY_PROMPT, USER_PROMPT
|
7 |
|
8 |
def generate_prompt(query, context, template=None):
|
9 |
"""
|
|
|
20 |
if template:
|
21 |
SYS_PROMPT = template
|
22 |
else:
|
23 |
+
SYS_PROMPT = SYSTEM_PROMPT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
formatted_prompt = f"{USER_PROMPT.format(query, context)}"
|
26 |
messages = [
|
|
|
93 |
return formatted_context
|
94 |
|
95 |
|
96 |
+
def augment_prompt(query: str, starting_point: str, context: dict, **params: dict):
|
97 |
"""
|
98 |
Function that accepts the user query as input, obtains relevant documents and augments the prompt with the
|
99 |
retrieved context, which can be passed to the LLM.
|
|
|
106 |
"""
|
107 |
|
108 |
# what about the cities without s-fairness scores? i.e. they don't have seasonality data
|
109 |
+
updated_query = f"With {starting_point} as the starting point, {query}"
|
110 |
+
prompt_with_sustainability = SUSTAINABILITY_PROMPT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
# format context
|
113 |
formatted_context = format_context(context)
|
114 |
|
115 |
if "sustainability" in params["params"] and params["params"]["sustainability"]:
|
116 |
+
prompt = generate_prompt(updated_query, formatted_context, prompt_with_sustainability)
|
117 |
else:
|
118 |
+
prompt = generate_prompt(updated_query, formatted_context)
|
119 |
|
120 |
return prompt
|
121 |
|
src/augmentation/prompts.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
USER_PROMPT = \
|
2 |
+
""" Question: {} Which city do you recommend and why?
|
3 |
+
|
4 |
+
Context: Here are the options: {}
|
5 |
+
|
6 |
+
Answer:
|
7 |
+
|
8 |
+
"""
|
9 |
+
SYSTEM_PROMPT =\
|
10 |
+
"""You are an AI recommendation system focused on sustainable travel.
|
11 |
+
Your task is to recommend European cities for travel based on the user's query and starting point.
|
12 |
+
Using the provided context, suggest the top 3 cities that are the most sustainable, along with the best month to visit for the user with respect to their starting point.
|
13 |
+
Each recommendation should:
|
14 |
+
1. Include the ideal mode of travel from the user's starting location which has the lowest levels of emissions.
|
15 |
+
2. Offer an explanation as to why the city is recommended, focusing on sustainability factors such as popularity, emissions and seasonal footfall.
|
16 |
+
Your answer must begin with "I recommend " followed by the city names and why you recommended it.
|
17 |
+
Your answers are correct, high-quality, and written by a domain expert.
|
18 |
+
If the provided context does not contain the answer, simply state,
|
19 |
+
"The provided context does not have the answer. """
|
20 |
+
SUSTAINABILITY_PROMPT =\
|
21 |
+
"""You are an AI recommendation system focused on sustainable travel.
|
22 |
+
Your task is to recommend European cities for travel based on the user's query and starting point.
|
23 |
+
Using the provided context, suggest the top 3 cities that are the most sustainable, along with the best month to visit for the user with respect to their starting point.
|
24 |
+
Each recommendation should:
|
25 |
+
1. Be based on the value of the s-fairness score provided in the context. A lower s-fairness score indicates that the city is a better destination for the month provided. A city without a sustainability score should not be considered.
|
26 |
+
2. The system should discourage travel during peak seasons and promote travel during off and shoulder seasons.
|
27 |
+
3. It should recommend hidden gems or off-beat destinations compared to the most popular ones.
|
28 |
+
4. Include the ideal mode of travel from the user's starting location which has the lowest levels of emissions.
|
29 |
+
5. Offer an explanation as to why the city is recommended, focusing on sustainability factors such as popularity, emissions and seasonal footfall.
|
30 |
+
You should only consider the s-fairness score while choosing the best city.
|
31 |
+
However, your answer should not contain the numeric score itself or any mention of the sustainability score.
|
32 |
+
Your answer must begin with "I recommend " followed by the city names and why you recommended it.
|
33 |
+
Your answers are correct, high-quality, and written by a domain expert.
|
34 |
+
If the provided context does not contain the answer, simply state,
|
35 |
+
"The provided context does not have the answer. """
|
src/pipeline.py
CHANGED
@@ -23,6 +23,7 @@ import logging
|
|
23 |
logger = logging.getLogger(__name__)
|
24 |
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
25 |
from src.text_generation.mapper import MODEL_MAPPER
|
|
|
26 |
|
27 |
TEST_DIR = "../tests/"
|
28 |
|
@@ -96,6 +97,7 @@ def pipeline(starting_point: str,
|
|
96 |
try:
|
97 |
prompt = pg.augment_prompt(
|
98 |
query=query,
|
|
|
99 |
context=context,
|
100 |
params=context_params
|
101 |
)
|
@@ -114,11 +116,21 @@ def pipeline(starting_point: str,
|
|
114 |
logger.info(f"Error at line {exc_tb.tb_lineno} while generating response: {e}")
|
115 |
return None
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
if test:
|
118 |
-
return retrieved_cities, prompt[1]['content'],
|
119 |
|
120 |
else:
|
121 |
-
return
|
122 |
|
123 |
|
124 |
if __name__ == "__main__":
|
|
|
23 |
logger = logging.getLogger(__name__)
|
24 |
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
25 |
from src.text_generation.mapper import MODEL_MAPPER
|
26 |
+
from src.post_processing.post_process import post_process_output
|
27 |
|
28 |
TEST_DIR = "../tests/"
|
29 |
|
|
|
97 |
try:
|
98 |
prompt = pg.augment_prompt(
|
99 |
query=query,
|
100 |
+
starting_point=starting_point,
|
101 |
context=context,
|
102 |
params=context_params
|
103 |
)
|
|
|
116 |
logger.info(f"Error at line {exc_tb.tb_lineno} while generating response: {e}")
|
117 |
return None
|
118 |
|
119 |
+
try:
|
120 |
+
model_params = {"max_tokens": params["max_tokens"], "temperature": params["temperature"]}
|
121 |
+
post_processed_response = post_process_output(
|
122 |
+
model_id=model_id, user_query=query,
|
123 |
+
starting_point=starting_point,
|
124 |
+
context=context, response=response, **model_params)
|
125 |
+
except Exception as e:
|
126 |
+
exc_type, exc_obj, exc_tb = sys.exc_info()
|
127 |
+
logger.info(f"Error at line {exc_tb.tb_lineno} while generating response: {e}")
|
128 |
+
return None
|
129 |
if test:
|
130 |
+
return retrieved_cities, prompt[1]['content'], post_processed_response
|
131 |
|
132 |
else:
|
133 |
+
return post_processed_response
|
134 |
|
135 |
|
136 |
if __name__ == "__main__":
|
src/post_processing/post_process.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.text_generation.text_generation import generate_response
|
2 |
+
|
3 |
+
|
4 |
+
def post_process_output(model_id: str, user_query: str, starting_point: str, response: str, context: dict, **params: dict) -> str | None:
|
5 |
+
if "s-fairness" in response.lower() or "score" in response.lower():
|
6 |
+
formatted_response = \
|
7 |
+
f"You are an AI recommendation system focused on sustainable travel. Rewrite this response without " \
|
8 |
+
f"mentioning the sustainability score or s-fairness.\n {response} "
|
9 |
+
elif "the provided context does not have the answer" \
|
10 |
+
in response.lower():
|
11 |
+
formatted_response = \
|
12 |
+
f'You are an AI recommendation system focused on sustainable travel. Rewrite this response using the information from the context: \n Starting point: {starting_point}\n Query: {user_query}\n Context: {context}'
|
13 |
+
else:
|
14 |
+
formatted_response = response
|
15 |
+
|
16 |
+
final_prompt = [{"role": "system", "content": formatted_response}]
|
17 |
+
generated_response = generate_response(model_id,
|
18 |
+
final_prompt, **params)
|
19 |
+
return generated_response
|
src/text_generation/mapper.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
MODEL_MAPPER = {
|
2 |
-
'Gemma-2-9B-it': "google/gemma-2-9b-it",
|
3 |
-
'Gemma-2-2B-it': "google/gemma-2-2b-it",
|
4 |
"Gemini-1.0-pro": "gemini-1.0-pro",
|
5 |
"Gemini-1.5-Flash": "gemini-1.5-flash-001",
|
6 |
"Gemini-1.5-Pro": "gemini-1.5-pro-001",
|
|
|
|
|
7 |
"Claude3.5-sonnet": "claude-3-5-sonnet@20240620",
|
8 |
'GPT-4': "gpt-4o-mini",
|
9 |
'Llama3': "meta-llama/Meta-Llama-3-8B",
|
|
|
1 |
MODEL_MAPPER = {
|
|
|
|
|
2 |
"Gemini-1.0-pro": "gemini-1.0-pro",
|
3 |
"Gemini-1.5-Flash": "gemini-1.5-flash-001",
|
4 |
"Gemini-1.5-Pro": "gemini-1.5-pro-001",
|
5 |
+
'Gemma-2-9B-it': "google/gemma-2-9b-it",
|
6 |
+
'Gemma-2-2B-it': "google/gemma-2-2b-it",
|
7 |
"Claude3.5-sonnet": "claude-3-5-sonnet@20240620",
|
8 |
'GPT-4': "gpt-4o-mini",
|
9 |
'Llama3': "meta-llama/Meta-Llama-3-8B",
|
src/text_generation/model_init.py
CHANGED
@@ -134,6 +134,7 @@ class LLMBaseClass:
|
|
134 |
"""
|
135 |
Generate responses using Hugging Face models.
|
136 |
"""
|
|
|
137 |
response = self.model.chat_completion(
|
138 |
messages=[{"role": "user", "content": messages[0]["content"] + messages[1]["content"]}],
|
139 |
max_tokens=self.tokens, temperature=self.temp)
|
|
|
134 |
"""
|
135 |
Generate responses using Hugging Face models.
|
136 |
"""
|
137 |
+
content = " ".join([message["content"] for message in messages])
|
138 |
response = self.model.chat_completion(
|
139 |
messages=[{"role": "user", "content": messages[0]["content"] + messages[1]["content"]}],
|
140 |
max_tokens=self.tokens, temperature=self.temp)
|
src/ui/components/inputs.py
CHANGED
@@ -57,7 +57,7 @@ def main_component() -> Tuple[gr.Dropdown, gr.Dropdown, gr.Textbox, gr.Checkbox,
|
|
57 |
"your starting location, and month of travel?"
|
58 |
)
|
59 |
|
60 |
-
models = list(MODEL_MAPPER.keys())[:
|
61 |
# Model selection dropdown
|
62 |
model = gr.Dropdown(
|
63 |
choices=models,
|
|
|
57 |
"your starting location, and month of travel?"
|
58 |
)
|
59 |
|
60 |
+
models = list(MODEL_MAPPER.keys())[:3]
|
61 |
# Model selection dropdown
|
62 |
model = gr.Dropdown(
|
63 |
choices=models,
|