Ashmi Banerjee commited on
Commit
adbebe0
1 Parent(s): bfcfb0e

update sustainability prompt and post processing

Browse files
README.md CHANGED
@@ -19,7 +19,7 @@ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-
19
 
20
  - [x] Sustainability - database paths - move to HF
21
 
22
- - [ ] Fix it for the new models e.g. Llama and others
23
 
24
  - [ ] Add the space secrets to have it running online
25
  - [ ] Fix the google application json file
@@ -28,6 +28,8 @@ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-
28
 
29
  - [x] Add emissions calculation and starting point
30
  - [x] Add more cities to starting point
31
- - [ ] Experiment with the sustainability & without sustainability prompt
32
  - [ ] Adapt the gradio examples to the right format
33
- - [x] UI refactoring
 
 
 
19
 
20
  - [x] Sustainability - database paths - move to HF
21
 
22
+ - [x] Fix it for the new models e.g. Llama and others
23
 
24
  - [ ] Add the space secrets to have it running online
25
  - [ ] Fix the google application json file
 
28
 
29
  - [x] Add emissions calculation and starting point
30
  - [x] Add more cities to starting point
31
+ - [x] Experiment with the sustainability & without sustainability prompt
32
  - [ ] Adapt the gradio examples to the right format
33
+ - [x] UI refactoring
34
+ - [ ] Log the different queries and the results
35
+ - [ ] Clear does not clear the country
src/augmentation/prompt_generation.py CHANGED
@@ -3,7 +3,7 @@ import logging
3
 
4
  logger = logging.getLogger(__name__)
5
  logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
6
-
7
 
8
  def generate_prompt(query, context, template=None):
9
  """
@@ -20,22 +20,7 @@ def generate_prompt(query, context, template=None):
20
  if template:
21
  SYS_PROMPT = template
22
  else:
23
- SYS_PROMPT = """You are an AI recommendation system. Your task is to recommend cities in Europe for travel
24
- based on the user's question. You should use the provided contexts to suggest a list of the 3 best cities
25
- that are best suited to the user's question, as well as the month of travel. If the user has already provided
26
- the month of travel in the question, use the same month; otherwise, provide the ideal month of travel. Each
27
- recommendation should also contain an explanation of why it is being recommended, based on the context. Your
28
- answer must begin with "I recommend " followed by the city name and why you recommended it. Your answers are
29
- correct, high-quality, and written by a domain expert. If the provided context does not contain the answer,
30
- simply state, "The provided context does not have the answer." """
31
-
32
- USER_PROMPT = """ Question: {} Which city do you recommend and why?
33
-
34
- Context: Here are the options: {}
35
-
36
- Answer:
37
-
38
- """
39
 
40
  formatted_prompt = f"{USER_PROMPT.format(query, context)}"
41
  messages = [
@@ -108,7 +93,7 @@ def format_context(context):
108
  return formatted_context
109
 
110
 
111
- def augment_prompt(query, context, **params):
112
  """
113
  Function that accepts the user query as input, obtains relevant documents and augments the prompt with the
114
  retrieved context, which can be passed to the LLM.
@@ -121,27 +106,16 @@ def augment_prompt(query, context, **params):
121
  """
122
 
123
  # what about the cities without s-fairness scores? i.e. they don't have seasonality data
124
-
125
- prompt_with_sustainability = """You are an AI recommendation system. Your task is to recommend cities in Europe
126
- for travel based on the user's question. You should use the provided contexts to suggest the city that is best
127
- suited to the user's question. You recommend a list of the top 3 most sustainable cities to the user, as well as
128
- the best month of travel. Each recommendation should also contain an explanation of why it is being recommended,
129
- on sustainability grounds based on the context. The context contains a sustainability score for each city,
130
- also known as the s-fairness score, along with the ideal month of travel. A lower s-fairness score indicates that
131
- the city is a better destination for the month provided. A city without a sustainability score should not be
132
- considered. You should only consider the s-fairness score while choosing the best city. However, your answer
133
- should not contain the numeric score itself or any mention of the sustainability score. Your answer must begin
134
- with "I recommend " followed by the city names and why you recommended it. Your answers are correct, high-quality,
135
- and written by a domain expert. If the provided context does not contain the answer, simply state, "The provided
136
- context does not have the answer. """
137
 
138
  # format context
139
  formatted_context = format_context(context)
140
 
141
  if "sustainability" in params["params"] and params["params"]["sustainability"]:
142
- prompt = generate_prompt(query, formatted_context, prompt_with_sustainability)
143
  else:
144
- prompt = generate_prompt(query, formatted_context)
145
 
146
  return prompt
147
 
 
3
 
4
  logger = logging.getLogger(__name__)
5
  logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
6
+ from src.augmentation.prompts import SYSTEM_PROMPT, SUSTAINABILITY_PROMPT, USER_PROMPT
7
 
8
  def generate_prompt(query, context, template=None):
9
  """
 
20
  if template:
21
  SYS_PROMPT = template
22
  else:
23
+ SYS_PROMPT = SYSTEM_PROMPT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  formatted_prompt = f"{USER_PROMPT.format(query, context)}"
26
  messages = [
 
93
  return formatted_context
94
 
95
 
96
+ def augment_prompt(query: str, starting_point: str, context: dict, **params: dict):
97
  """
98
  Function that accepts the user query as input, obtains relevant documents and augments the prompt with the
99
  retrieved context, which can be passed to the LLM.
 
106
  """
107
 
108
  # what about the cities without s-fairness scores? i.e. they don't have seasonality data
109
+ updated_query = f"With {starting_point} as the starting point, {query}"
110
+ prompt_with_sustainability = SUSTAINABILITY_PROMPT
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  # format context
113
  formatted_context = format_context(context)
114
 
115
  if "sustainability" in params["params"] and params["params"]["sustainability"]:
116
+ prompt = generate_prompt(updated_query, formatted_context, prompt_with_sustainability)
117
  else:
118
+ prompt = generate_prompt(updated_query, formatted_context)
119
 
120
  return prompt
121
 
src/augmentation/prompts.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ USER_PROMPT = \
2
+ """ Question: {} Which city do you recommend and why?
3
+
4
+ Context: Here are the options: {}
5
+
6
+ Answer:
7
+
8
+ """
9
+ SYSTEM_PROMPT =\
10
+ """You are an AI recommendation system focused on sustainable travel.
11
+ Your task is to recommend European cities for travel based on the user's query and starting point.
12
+ Using the provided context, suggest the top 3 cities that are the most sustainable, along with the best month to visit for the user with respect to their starting point.
13
+ Each recommendation should:
14
+ 1. Include the ideal mode of travel from the user's starting location which has the lowest levels of emissions.
15
+ 2. Offer an explanation as to why the city is recommended, focusing on sustainability factors such as popularity, emissions and seasonal footfall.
16
+ Your answer must begin with "I recommend " followed by the city names and why you recommended it.
17
+ Your answers are correct, high-quality, and written by a domain expert.
18
+ If the provided context does not contain the answer, simply state,
19
+ "The provided context does not have the answer. """
20
+ SUSTAINABILITY_PROMPT =\
21
+ """You are an AI recommendation system focused on sustainable travel.
22
+ Your task is to recommend European cities for travel based on the user's query and starting point.
23
+ Using the provided context, suggest the top 3 cities that are the most sustainable, along with the best month to visit for the user with respect to their starting point.
24
+ Each recommendation should:
25
+ 1. Be based on the value of the s-fairness score provided in the context. A lower s-fairness score indicates that the city is a better destination for the month provided. A city without a sustainability score should not be considered.
26
+ 2. The system should discourage travel during peak seasons and promote travel during off and shoulder seasons.
27
+ 3. It should recommend hidden gems or off-beat destinations compared to the most popular ones.
28
+ 4. Include the ideal mode of travel from the user's starting location which has the lowest levels of emissions.
29
+ 5. Offer an explanation as to why the city is recommended, focusing on sustainability factors such as popularity, emissions and seasonal footfall.
30
+ You should only consider the s-fairness score while choosing the best city.
31
+ However, your answer should not contain the numeric score itself or any mention of the sustainability score.
32
+ Your answer must begin with "I recommend " followed by the city names and why you recommended it.
33
+ Your answers are correct, high-quality, and written by a domain expert.
34
+ If the provided context does not contain the answer, simply state,
35
+ "The provided context does not have the answer. """
src/pipeline.py CHANGED
@@ -23,6 +23,7 @@ import logging
23
  logger = logging.getLogger(__name__)
24
  logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
25
  from src.text_generation.mapper import MODEL_MAPPER
 
26
 
27
  TEST_DIR = "../tests/"
28
 
@@ -96,6 +97,7 @@ def pipeline(starting_point: str,
96
  try:
97
  prompt = pg.augment_prompt(
98
  query=query,
 
99
  context=context,
100
  params=context_params
101
  )
@@ -114,11 +116,21 @@ def pipeline(starting_point: str,
114
  logger.info(f"Error at line {exc_tb.tb_lineno} while generating response: {e}")
115
  return None
116
 
 
 
 
 
 
 
 
 
 
 
117
  if test:
118
- return retrieved_cities, prompt[1]['content'], response
119
 
120
  else:
121
- return response
122
 
123
 
124
  if __name__ == "__main__":
 
23
  logger = logging.getLogger(__name__)
24
  logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
25
  from src.text_generation.mapper import MODEL_MAPPER
26
+ from src.post_processing.post_process import post_process_output
27
 
28
  TEST_DIR = "../tests/"
29
 
 
97
  try:
98
  prompt = pg.augment_prompt(
99
  query=query,
100
+ starting_point=starting_point,
101
  context=context,
102
  params=context_params
103
  )
 
116
  logger.info(f"Error at line {exc_tb.tb_lineno} while generating response: {e}")
117
  return None
118
 
119
+ try:
120
+ model_params = {"max_tokens": params["max_tokens"], "temperature": params["temperature"]}
121
+ post_processed_response = post_process_output(
122
+ model_id=model_id, user_query=query,
123
+ starting_point=starting_point,
124
+ context=context, response=response, **model_params)
125
+ except Exception as e:
126
+ exc_type, exc_obj, exc_tb = sys.exc_info()
127
+ logger.info(f"Error at line {exc_tb.tb_lineno} while generating response: {e}")
128
+ return None
129
  if test:
130
+ return retrieved_cities, prompt[1]['content'], post_processed_response
131
 
132
  else:
133
+ return post_processed_response
134
 
135
 
136
  if __name__ == "__main__":
src/post_processing/post_process.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.text_generation.text_generation import generate_response
2
+
3
+
4
+ def post_process_output(model_id: str, user_query: str, starting_point: str, response: str, context: dict, **params: dict) -> str | None:
5
+ if "s-fairness" in response.lower() or "score" in response.lower():
6
+ formatted_response = \
7
+ f"You are an AI recommendation system focused on sustainable travel. Rewrite this response without " \
8
+ f"mentioning the sustainability score or s-fairness.\n {response} "
9
+ elif "the provided context does not have the answer" \
10
+ in response.lower():
11
+ formatted_response = \
12
+ f'You are an AI recommendation system focused on sustainable travel. Rewrite this response using the information from the context: \n Starting point: {starting_point}\n Query: {user_query}\n Context: {context}'
13
+ else:
14
+ formatted_response = response
15
+
16
+ final_prompt = [{"role": "system", "content": formatted_response}]
17
+ generated_response = generate_response(model_id,
18
+ final_prompt, **params)
19
+ return generated_response
src/text_generation/mapper.py CHANGED
@@ -1,9 +1,9 @@
1
  MODEL_MAPPER = {
2
- 'Gemma-2-9B-it': "google/gemma-2-9b-it",
3
- 'Gemma-2-2B-it': "google/gemma-2-2b-it",
4
  "Gemini-1.0-pro": "gemini-1.0-pro",
5
  "Gemini-1.5-Flash": "gemini-1.5-flash-001",
6
  "Gemini-1.5-Pro": "gemini-1.5-pro-001",
 
 
7
  "Claude3.5-sonnet": "claude-3-5-sonnet@20240620",
8
  'GPT-4': "gpt-4o-mini",
9
  'Llama3': "meta-llama/Meta-Llama-3-8B",
 
1
  MODEL_MAPPER = {
 
 
2
  "Gemini-1.0-pro": "gemini-1.0-pro",
3
  "Gemini-1.5-Flash": "gemini-1.5-flash-001",
4
  "Gemini-1.5-Pro": "gemini-1.5-pro-001",
5
+ 'Gemma-2-9B-it': "google/gemma-2-9b-it",
6
+ 'Gemma-2-2B-it': "google/gemma-2-2b-it",
7
  "Claude3.5-sonnet": "claude-3-5-sonnet@20240620",
8
  'GPT-4': "gpt-4o-mini",
9
  'Llama3': "meta-llama/Meta-Llama-3-8B",
src/text_generation/model_init.py CHANGED
@@ -134,6 +134,7 @@ class LLMBaseClass:
134
  """
135
  Generate responses using Hugging Face models.
136
  """
 
137
  response = self.model.chat_completion(
138
  messages=[{"role": "user", "content": messages[0]["content"] + messages[1]["content"]}],
139
  max_tokens=self.tokens, temperature=self.temp)
 
134
  """
135
  Generate responses using Hugging Face models.
136
  """
137
+ content = " ".join([message["content"] for message in messages])
138
  response = self.model.chat_completion(
139
  messages=[{"role": "user", "content": messages[0]["content"] + messages[1]["content"]}],
140
  max_tokens=self.tokens, temperature=self.temp)
src/ui/components/inputs.py CHANGED
@@ -57,7 +57,7 @@ def main_component() -> Tuple[gr.Dropdown, gr.Dropdown, gr.Textbox, gr.Checkbox,
57
  "your starting location, and month of travel?"
58
  )
59
 
60
- models = list(MODEL_MAPPER.keys())[:5]
61
  # Model selection dropdown
62
  model = gr.Dropdown(
63
  choices=models,
 
57
  "your starting location, and month of travel?"
58
  )
59
 
60
+ models = list(MODEL_MAPPER.keys())[:3]
61
  # Model selection dropdown
62
  model = gr.Dropdown(
63
  choices=models,