Ashmi Banerjee commited on
Commit
8842640
β€’
1 Parent(s): 89cd5d5

added more models, refactored a lot of stuff

Browse files
README.md CHANGED
@@ -15,18 +15,19 @@ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-
15
 
16
  ### TODOs
17
 
18
- - [ ] Refactor the vectordb.py - remove code duplication
19
 
20
  - [x] Sustainability - database paths - move to HF
21
 
22
  - [ ] Fix it for the new models e.g. Llama and others
23
 
24
  - [ ] Add the space secrets to have it running online
25
-
26
  - [ ] Fix the google application json file
27
 
28
  - [ ] Make the space public
29
 
30
  - [x] Add emissions calculation and starting point
31
  - [x] Add more cities to starting point
32
- - [ ] Experiment with the sustainability & without sustainability prompt
 
 
 
15
 
16
  ### TODOs
17
 
18
+ - [x] Refactor the vectordb.py - remove code duplication
19
 
20
  - [x] Sustainability - database paths - move to HF
21
 
22
  - [ ] Fix it for the new models e.g. Llama and others
23
 
24
  - [ ] Add the space secrets to have it running online
 
25
  - [ ] Fix the google application json file
26
 
27
  - [ ] Make the space public
28
 
29
  - [x] Add emissions calculation and starting point
30
  - [x] Add more cities to starting point
31
+ - [ ] Experiment with the sustainability & without sustainability prompt
32
+ - [ ] Adapt the gradio examples to the right format
33
+ - [x] UI refactoring
app.py CHANGED
@@ -1,102 +1,37 @@
1
- from typing import Optional
2
  import gradio as gr
3
  import sys
4
 
 
 
 
 
 
5
  sys.path.append("./src")
6
- from src.pipeline import pipeline
7
- from src.helpers.data_loaders import load_places
8
-
9
-
10
- def clear():
11
- return None, None, None
12
-
13
-
14
- # Function to update the list of cities based on the selected country
15
- def update_cities(selected_country, df):
16
- filtered_cities = df[df['country'] == selected_country]['city'].tolist()
17
- return gr.Dropdown(choices=filtered_cities, interactive=True) # Make it interactive as it is not by default
18
-
19
-
20
- def generate_text(query_text, model_name: Optional[str], is_sustainable: Optional[bool], tokens: Optional[int] = 1024,
21
- temp: Optional[float] = 0.49, starting_point: Optional[str] = "Munich"):
22
- pipeline_response = pipeline(
23
- query=query_text,
24
- model_name=model_name,
25
- sustainability=is_sustainable,
26
- starting_point=starting_point,
27
- )
28
- return pipeline_response
29
 
30
 
31
  def create_ui():
32
- data_file = "cities/eu_200_cities.csv"
33
- df = load_places(data_file)
34
- df = df.sort_values(by=['country', 'city'])
35
-
36
- examples = [
37
- ["I'm planning a vacation to France. Can you suggest a one-week itinerary including must-visit places and "
38
- "local cuisines to try?", "GPT-4"],
39
- ["I want to explore off-the-beaten-path destinations in Europe, any suggestions?", "Gemini-1.0-pro"],
40
- ["Suggest some cities that can be visited from London and are very rich in history and culture.",
41
- "Gemini-1.0-pro"],
42
- ]
43
-
44
- with gr.Blocks() as app:
45
- gr.HTML(
46
- "<center><h1 style='font-size:xx-large; font-color: green'>πŸ€ Green City Finder πŸ€</h1><h3>AI Sprint 2024 submissions by Ashmi Banerjee. </h3></center> <br><p>We're testing the "
47
- "compatibility of"
48
- "Retrieval Augmented Generation (RAG) implementations with Google's <b>Gemma-2b-it</b> & <b>Gemini 1.0 "
49
- "Pro</b> \n "
50
- "models through HuggingFace and VertexAI, respectively, to generate sustainable travel recommendations.\n "
51
- "We use the Wikivoyage dataset to provide city recommendations based on user queries. The vector "
52
- "embeddings are stored in a VectorDB (LanceDB) hosted in Google Cloud.\n "
53
- "<p>Sustainability is calculated based on the work by <a href=https://arxiv.org/abs/2403.18604>Banerjee "
54
- "et al.</a></p>\n "
55
- " </p> <br>Google Cloud credits are provided for this project. </p>\n"
56
- " ")
57
-
58
- with gr.Group():
59
- countries = gr.Dropdown(choices=list(df.country.unique()), multiselect=False, label="Country")
60
- starting_point = gr.Dropdown(choices=[], multiselect=False,
61
- label="Select your starting point for the trip!")
62
-
63
- countries.select(fn=lambda selected_country:
64
- update_cities(selected_country, df),
65
- inputs=countries, outputs=starting_point)
66
-
67
- query = gr.Textbox(label="Query", placeholder="Ask for your city recommendation here!")
68
- sustainable = gr.Checkbox(label="Sustainable", info="Do you want your recommendations to be sustainable "
69
- "with regards to the environment, your starting "
70
- "location and month of travel?")
71
- # TODO: Add model options, month and starting point
72
- model = gr.Dropdown(
73
- ["GPT-4", "Gemini-1.0-pro"], label="Model", info="Select your model. Will add more "
74
- "models "
75
- "later!",
76
- )
77
- output = gr.Textbox(label="Generated Results", lines=4)
78
-
79
- with gr.Accordion("Settings", open=False):
80
- max_new_tokens = gr.Slider(label="Max new tokens", value=1024, minimum=0, maximum=8192, step=64,
81
- interactive=True,
82
- visible=True, info="The maximum number of output tokens")
83
- temperature = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.49,
84
- interactive=True,
85
- visible=True, info="The value used to module the logits distribution")
86
- with gr.Group():
87
- with gr.Row():
88
- submit_btn = gr.Button("Submit", variant="primary")
89
- clear_btn = gr.Button("Clear", variant="secondary")
90
- cancel_btn = gr.Button("Cancel", variant="stop")
91
- submit_btn.click(generate_text, inputs=[query, model, sustainable, starting_point], outputs=[output])
92
- clear_btn.click(clear, inputs=[], outputs=[query, model, output])
93
- cancel_btn.click(clear, inputs=[], outputs=[query, model, output])
94
-
95
- gr.Markdown("## Examples")
96
- # gr.Examples(
97
- # examples, inputs=[query, model], label="Examples", fn=generate_text, outputs=[output],
98
- # cache_examples=True,
99
- # )
100
  return app
101
 
102
 
 
 
1
  import gradio as gr
2
  import sys
3
 
4
+ from src.ui.components.actions import generate_text, clear
5
+ from src.ui.components.inputs import main_component
6
+ from src.ui.components.static import load_buttons, load_examples, model_settings
7
+ from src.ui.setup import load_html_from_file
8
+ from src.text_generation.vertexai_setup import initialize_vertexai_params
9
  sys.path.append("./src")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  def create_ui():
13
+ initialize_vertexai_params()
14
+ # Path to HTML file
15
+ html_file_path = 'src/ui/templates/intro.html'
16
+
17
+ # Create the Gradio HTML component
18
+ html_content = load_html_from_file(html_file_path)
19
+ with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.green, secondary_hue=gr.themes.colors.blue)) as app:
20
+
21
+ gr.HTML(html_content)
22
+ country, starting_point, query, sustainable, model = main_component()
23
+ output = gr.Textbox(label="Generated Results", lines=4)
24
+ max_new_tokens, temperature = model_settings()
25
+
26
+ # Load the buttons for the interface
27
+ load_buttons(query, model,
28
+ sustainable, starting_point,
29
+ max_new_tokens, temperature,
30
+ output,
31
+ generate_text_fn=generate_text,
32
+ clear_fn=clear)
33
+ # Load the examples for the interface
34
+ # load_examples(country, starting_point, query, model, sustainable, output, generate_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  return app
36
 
37
 
src/augmentation/prompt_generation.py CHANGED
@@ -1,4 +1,4 @@
1
- from information_retrieval import info_retrieval as ir
2
  import logging
3
 
4
  logger = logging.getLogger(__name__)
 
1
+ from src.information_retrieval import info_retrieval as ir
2
  import logging
3
 
4
  logger = logging.getLogger(__name__)
src/pipeline.py CHANGED
@@ -15,14 +15,17 @@ from text_generation.models import (
15
  Phi3SmallInstruct,
16
  GPT4,
17
  Gemini,
 
18
  )
19
  from text_generation import text_generation as tg
20
  import logging
21
 
22
  logger = logging.getLogger(__name__)
23
  logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
 
24
 
25
  TEST_DIR = "../tests/"
 
26
  MODELS = {
27
  'GPT-4': GPT4,
28
  'Llama3': Llama3,
@@ -33,11 +36,15 @@ MODELS = {
33
  'Mistral-Instruct': MistralInstruct,
34
  'Llama3.1-Instruct': Llama3Point1Instruct,
35
  'Phi3-Instruct': Phi3SmallInstruct,
36
- 'Gemini-1.0-pro': Gemini,
 
37
  }
38
 
39
 
40
- def pipeline(starting_point: str, query: str, model_name: str, test: int = 0, **params):
 
 
 
41
  """
42
 
43
  Executes the entire RAG pipeline, provided the query and model class name.
@@ -53,9 +60,11 @@ def pipeline(starting_point: str, query: str, model_name: str, test: int = 0, **
53
 
54
 
55
  """
56
-
57
- model = MODELS[model_name]
58
-
 
 
59
  context_params = {
60
  'limit': 5,
61
  'reranking': 0,
@@ -97,9 +106,9 @@ def pipeline(starting_point: str, query: str, model_name: str, test: int = 0, **
97
 
98
  # return prompt
99
 
100
- logger.info(f"Augmented prompt, initializing {model} and generating response..")
101
  try:
102
- response = tg.generate_response(model, prompt)
103
  except Exception as e:
104
  exc_type, exc_obj, exc_tb = sys.exc_info()
105
  logger.info(f"Error at line {exc_tb.tb_lineno} while generating response: {e}")
@@ -117,11 +126,11 @@ if __name__ == "__main__":
117
  # suggest " \ "some " \ "European cities? "
118
  sample_query = "I'm planning a trip in July and enjoy beaches, nightlife, and vibrant cities. Recommend some " \
119
  "cities. "
120
- model = "GPT-4"
121
 
122
  pipeline_response = pipeline(
123
  query=sample_query,
124
- model_name=model,
125
  sustainability=1
126
  )
127
 
 
15
  Phi3SmallInstruct,
16
  GPT4,
17
  Gemini,
18
+ Claude3Point5Sonnet,
19
  )
20
  from text_generation import text_generation as tg
21
  import logging
22
 
23
  logger = logging.getLogger(__name__)
24
  logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
25
+ from src.text_generation.mapper import MODEL_MAPPER
26
 
27
  TEST_DIR = "../tests/"
28
+
29
  MODELS = {
30
  'GPT-4': GPT4,
31
  'Llama3': Llama3,
 
36
  'Mistral-Instruct': MistralInstruct,
37
  'Llama3.1-Instruct': Llama3Point1Instruct,
38
  'Phi3-Instruct': Phi3SmallInstruct,
39
+ "Gemini-1.0-pro": Gemini,
40
+ "Claude3.5-sonnet": Claude3Point5Sonnet,
41
  }
42
 
43
 
44
+ def pipeline(starting_point: str,
45
+ query: str,
46
+ model_name: str,
47
+ test: int = 0, **params):
48
  """
49
 
50
  Executes the entire RAG pipeline, provided the query and model class name.
 
60
 
61
 
62
  """
63
+ try:
64
+ model_id = MODEL_MAPPER[model_name]
65
+ except KeyError:
66
+ logger.error(f"Model {model_name} not found in the model mapper.")
67
+ model_id = MODEL_MAPPER['Gemini-1.0-pro']
68
  context_params = {
69
  'limit': 5,
70
  'reranking': 0,
 
106
 
107
  # return prompt
108
 
109
+ logger.info(f"Augmented prompt, initializing {model_name} and generating response..")
110
  try:
111
+ response = tg.generate_response(model_id, prompt, **params)
112
  except Exception as e:
113
  exc_type, exc_obj, exc_tb = sys.exc_info()
114
  logger.info(f"Error at line {exc_tb.tb_lineno} while generating response: {e}")
 
126
  # suggest " \ "some " \ "European cities? "
127
  sample_query = "I'm planning a trip in July and enjoy beaches, nightlife, and vibrant cities. Recommend some " \
128
  "cities. "
129
+ model_name = "GPT-4"
130
 
131
  pipeline_response = pipeline(
132
  query=sample_query,
133
+ model_name=model_name,
134
  sustainability=1
135
  )
136
 
src/text_generation/mapper.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL_MAPPER = {
2
+ 'Gemma-2-9B-it': "google/gemma-2-9b-it",
3
+ 'Gemma-2-2B-it': "google/gemma-2-2b-it",
4
+ "Gemini-1.0-pro": "gemini-1.0-pro",
5
+ "Gemini-1.5-Flash": "gemini-1.5-flash-001",
6
+ "Gemini-1.5-Pro": "gemini-1.5-pro-001",
7
+ "Claude3.5-sonnet": "claude-3-5-sonnet@20240620",
8
+ 'GPT-4': "gpt-4o-mini",
9
+ 'Llama3': "meta-llama/Meta-Llama-3-8B",
10
+ 'Mistral': "mistralai/Mistral-7B",
11
+ 'Llama3.1': "meta-llama/Meta-Llama-3.1-8B",
12
+ 'Llama3-Instruct': "meta-llama/Meta-Llama-3-8B-Instruct",
13
+ 'Mistral-Instruct': "mistralai/Mistral-7B-Instruct-v0.1",
14
+ 'Llama3.1-Instruct': "meta-llama/Meta-Llama-3.1-8B-Instruct",
15
+ 'Phi3-Instruct': "microsoft/Phi-3-small-128k-instruct",
16
+ }
src/text_generation/model_init.py CHANGED
@@ -1,123 +1,140 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
  from vertexai.generative_models import GenerativeModel
4
-
5
  from dotenv import load_dotenv
6
  from anthropic import AnthropicVertex
7
  import os
8
  from openai import OpenAI
9
- from src.text_generation.vertexai_setup import initialize_vertexai_params
 
10
 
 
11
  load_dotenv()
12
- if "OPENAI_API_KEY" in os.environ:
13
- OAI_API_KEY = os.environ["OPENAI_API_KEY"]
14
- if "VERTEXAI_PROJECTID" in os.environ:
15
- VERTEXAI_PROJECT = os.environ["VERTEXAI_PROJECTID"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  class LLMBaseClass:
19
  """
20
- Base Class for text generation - user needs to provide the HF model ID while instantiating the class after which
21
- the generate method can be called to generate responses
22
-
23
  """
24
 
25
- def __init__(self, model_id) -> None:
26
-
27
- match (model_id[0].lower()):
28
- case "gpt-4o-mini": # for open AI models
29
- self.api_key = OAI_API_KEY
30
- self.model = OpenAI(api_key=self.api_key)
31
- case "claude-3-5-sonnet@20240620": # for Claude through vertexAI
32
- self.api_key = None
33
- self.model = AnthropicVertex(region="europe-west1", project_id=VERTEXAI_PROJECT)
34
- case "gemini-1.0-pro":
35
- self.api_key = None
36
- self.model = GenerativeModel(model_id[0].lower())
37
- case _: # for HF models
38
- self.api_key = None
39
- self.tokenizer = AutoTokenizer.from_pretrained(model_id)
40
- self.tokenizer.pad_token = self.tokenizer.eos_token
41
-
42
- self.tokenizer.chat_template = "{%- for message in messages %}{%- if message['role'] == 'user' %}{{- " \
43
- "bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }}{%- " \
44
- "elif " \
45
- "message['role'] == 'system' %}{{- '<<SYS>>\\n' + message[" \
46
- "'content'].strip() + " \
47
- "'\\n<</SYS>>\\n\\n' }}{%- elif message['role'] == 'assistant' %}{{- '[" \
48
- "ASST] ' " \
49
- "+ message['content'] + ' [/ASST]' + eos_token }}{%- endif %}{%- " \
50
- "endfor %} "
51
- # Initialize quantization to use less GPU
52
- if torch.cuda.is_available():
53
- bnb_config = BitsAndBytesConfig(
54
- load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4",
55
- bnb_4bit_compute_dtype=torch.bfloat16
56
- )
57
- else:
58
- bnb_config = None
59
- self.model = AutoModelForCausalLM.from_pretrained(
60
- model_id,
61
- torch_dtype=torch.bfloat16,
62
- device_map="auto",
63
- quantization_config=bnb_config,
64
- )
65
-
66
- self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id
67
-
68
- self.terminators = [
69
- self.tokenizer.eos_token_id,
70
- self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
71
- ]
 
 
 
 
72
 
73
  def generate(self, messages):
74
- match (self.model_id[0].lower()):
75
- case "gpt-4o-mini":
76
- completion = self.model.chat.completions.create(
77
- model=self.model_id[0],
78
- messages=messages,
79
- temperature=0.6,
80
- top_p=0.9,
81
- )
82
- # Return the generated content from the API response
83
- return completion.choices[0].message.content
84
- case "claude-3-5-sonnet@20240620" | "gemini-1.0-pro":
85
- initialize_vertexai_params()
86
- if "claude" in self.model_id[0].lower():
87
- message = self.model.messages.create(
88
- max_tokens=1024,
89
- model=self.model_id[0],
90
- messages=[
91
- {
92
- "role": "user",
93
- "content": messages[0]["content"],
94
- }
95
- ],
96
- )
97
- return message.content[0].text
98
- else:
99
- response = self.model.generate_content(messages[0]["content"])
100
- return response
101
- case _:
102
- input_ids = self.tokenizer.apply_chat_template(
103
- conversation=messages,
104
- add_generation_prompt=True,
105
- return_tensors="pt",
106
- padding=True
107
- ).to(self.model.device)
108
-
109
- outputs = self.model.generate(
110
- input_ids,
111
- max_new_tokens=1024,
112
- # eos_token_id=self.terminators,
113
- pad_token_id=self.tokenizer.eos_token_id,
114
- do_sample=True,
115
- temperature=0.6,
116
- top_p=0.9,
117
- )
118
- response = outputs[0][input_ids.shape[-1]:]
119
-
120
- return self.tokenizer.decode(response, skip_special_tokens=True)
121
-
122
-
123
- # database/wikivoyage/wikivoyage_listings.lance/data/e2940f51-d754-4b54-a688-004bdb8e7aa2.lance
 
 
 
1
  from vertexai.generative_models import GenerativeModel
 
2
  from dotenv import load_dotenv
3
  from anthropic import AnthropicVertex
4
  import os
5
  from openai import OpenAI
6
+ from src.text_generation.vertexai_setup import initialize_vertexai_params, get_default_config
7
+ from huggingface_hub import InferenceClient
8
 
9
+ # Load environment variables
10
  load_dotenv()
11
+ OAI_API_KEY = os.getenv("OPENAI_API_KEY")
12
+
13
+
14
+ def _validate_tokens(max_tokens: int) -> int:
15
+ """
16
+ Validates the max_tokens parameter. Ensures it's within a valid range (1 to 8192).
17
+ If invalid, defaults to 8192.
18
+ """
19
+ if 1 <= max_tokens <= 8192:
20
+ return max_tokens
21
+ return 8192
22
+
23
+
24
+ def _validate_temperature(temp: float) -> float:
25
+ """
26
+ Validates the temperature parameter. Ensures it's within a valid range (0 to 1).
27
+ If invalid, defaults to 0.49.
28
+ """
29
+ if 0 <= temp <= 1:
30
+ return temp
31
+ return 0.49
32
 
33
 
34
  class LLMBaseClass:
35
  """
36
+ Base class for text generation. Users provide the HF model ID or other model identifiers
37
+ and can call the generate method to get responses.
 
38
  """
39
 
40
+ def __init__(self, model_id: str, max_tokens: int, temp: float) -> None:
41
+ self.model_id = model_id
42
+ self.api_key = None
43
+ self.temp = _validate_temperature(temp)
44
+ self.tokens = _validate_tokens(max_tokens)
45
+ self.model = self._initialize_model()
46
+
47
+ def _initialize_model(self):
48
+ """
49
+ Initialize the model based on the provided model ID.
50
+ """
51
+ if self.model_id == "gpt-4o-mini":
52
+ return self._initialize_openai_model()
53
+ elif self.model_id == "claude-3-5-sonnet@20240620":
54
+ return self._initialize_claude_model()
55
+ elif self.model_id in ["claude-3-5-sonnet@20240620",
56
+ "gemini-1.0-pro", "gemini-1.5-flash-001", "gemini-1.5-pro-001"]:
57
+ return self._initialize_vertexai_model()
58
+ else:
59
+ return self._initialize_hf_model()
60
+
61
+ def _initialize_openai_model(self):
62
+ """
63
+ Initialize OpenAI model.
64
+ """
65
+ self.api_key = OAI_API_KEY
66
+ return OpenAI(api_key=self.api_key)
67
+
68
+ def _initialize_claude_model(self):
69
+ """
70
+ Initialize Claude model using Anthropic via Vertex AI.
71
+ """
72
+ self.api_key = os.getenv("VERTEXAI_PROJECTID")
73
+ return AnthropicVertex(region="europe-west1", project_id=self.api_key)
74
+
75
+ def _initialize_vertexai_model(self):
76
+ """
77
+ Initialize Google Gemini model using Vertex AI.
78
+ """
79
+ default_gen_config, default_safe_settings = get_default_config()
80
+ gen_config = {
81
+ "temperature": self.temp,
82
+ "max_output_tokens": self.tokens,
83
+ }
84
+ return GenerativeModel(self.model_id,
85
+ generation_config=default_gen_config if gen_config is None else gen_config,
86
+ safety_settings=default_safe_settings)
87
+
88
+ def _initialize_hf_model(self):
89
+ self.api_key = os.getenv("HF_TOKEN")
90
+ return InferenceClient(token=self.api_key, model=self.model_id)
91
 
92
  def generate(self, messages):
93
+ """
94
+ Generate responses based on the model type and provided messages.
95
+ """
96
+ if self.model_id == "gpt-4o-mini":
97
+ return self._generate_openai(messages)
98
+ elif self.model_id in ["claude-3-5-sonnet@20240620",
99
+ "gemini-1.0-pro", "gemini-1.5-flash-001", "gemini-1.5-pro-001"]:
100
+ return self._generate_vertexai(messages)
101
+ else:
102
+ return self._generate_hf(messages)
103
+
104
+ def _generate_openai(self, messages):
105
+ """
106
+ Generate responses using OpenAI model.
107
+ """
108
+ completion = self.model.chat.completions.create(
109
+ model=self.model_id,
110
+ messages=messages,
111
+ temperature=self.temp,
112
+ max_tokens=self.tokens,
113
+ )
114
+ return completion.choices[0].message.content
115
+
116
+ def _generate_vertexai(self, messages):
117
+ """
118
+ Generate responses using Claude or Gemini models via Vertex AI.
119
+ """
120
+ initialize_vertexai_params()
121
+ content = " ".join([message["content"] for message in messages])
122
+ if "claude" in self.model_id:
123
+ message = self.model.messages.create(
124
+ max_tokens=self.tokens,
125
+ model=self.model_id,
126
+ messages=[{"role": "user", "content": content}],
127
+ )
128
+ return message.content[0].text
129
+ else:
130
+ response = self.model.generate_content(content)
131
+ return response.text
132
+
133
+ def _generate_hf(self, messages):
134
+ """
135
+ Generate responses using Hugging Face models.
136
+ """
137
+ response = self.model.chat_completion(
138
+ messages=[{"role": "user", "content": messages[0]["content"] + messages[1]["content"]}],
139
+ max_tokens=self.tokens, temperature=self.temp)
140
+ return response.choices[0].message.content
 
 
src/text_generation/text_generation.py CHANGED
@@ -1,20 +1,13 @@
1
- from augmentation import prompt_generation as pg
2
- from information_retrieval import info_retrieval as ir
3
- from src.text_generation.models import (
4
- Llama3,
5
- Mistral,
6
- Gemma2,
7
- Llama3Point1,
8
- GPT4,
9
- Claude3Point5Sonnet,
10
- )
11
  import logging
12
 
13
  logger = logging.getLogger(__name__)
14
  logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
15
 
16
 
17
- def generate_response(model, prompt):
18
  """
19
 
20
  Function that initializes the LLM class and calls the generate function.
@@ -26,7 +19,9 @@ def generate_response(model, prompt):
26
  """
27
 
28
  logger.info(f"Initializing LLM configuration for {model}")
29
- llm = model()
 
 
30
 
31
  logger.info("Generating response")
32
  try:
@@ -73,6 +68,7 @@ def test(model):
73
  logger.info(f"Augmented prompt, initializing {model} and generating response..")
74
  try:
75
  response = generate_response(model, without_sfairness)
 
76
  except Exception as e:
77
  logger.info(f"Error while generating response: {e}")
78
  return None
 
1
+ from src.augmentation import prompt_generation as pg
2
+ from src.information_retrieval import info_retrieval as ir
3
+ from src.text_generation.model_init import LLMBaseClass
 
 
 
 
 
 
 
4
  import logging
5
 
6
  logger = logging.getLogger(__name__)
7
  logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
8
 
9
 
10
+ def generate_response(model, prompt: list, **params):
11
  """
12
 
13
  Function that initializes the LLM class and calls the generate function.
 
19
  """
20
 
21
  logger.info(f"Initializing LLM configuration for {model}")
22
+ llm = LLMBaseClass(model_id=model,
23
+ max_tokens=params['max_tokens'],
24
+ temp=params['temperature'])
25
 
26
  logger.info("Generating response")
27
  try:
 
68
  logger.info(f"Augmented prompt, initializing {model} and generating response..")
69
  try:
70
  response = generate_response(model, without_sfairness)
71
+ print(response)
72
  except Exception as e:
73
  logger.info(f"Error while generating response: {e}")
74
  return None
src/text_generation/vertexai_setup.py CHANGED
@@ -5,6 +5,11 @@ import vertexai
5
  import os
6
  import json
7
  import base64
 
 
 
 
 
8
 
9
  load_dotenv()
10
  if "VERTEXAI_PROJECTID" in os.environ:
@@ -12,7 +17,7 @@ if "VERTEXAI_PROJECTID" in os.environ:
12
 
13
 
14
  def decode_service_key():
15
- encoded_key = os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
16
  original_service_key = json.loads(base64.b64decode(encoded_key).decode('utf-8'))
17
  if original_service_key:
18
  return original_service_key
@@ -20,7 +25,6 @@ def decode_service_key():
20
 
21
 
22
  def initialize_vertexai_params(location: Optional[str] = "us-central1"):
23
-
24
  creds_file_name = os.getcwd() + "/.config/gcp_default_credentials.json"
25
  print(creds_file_name)
26
  if not os.path.exists(os.path.dirname(creds_file_name)):
@@ -35,3 +39,16 @@ def initialize_vertexai_params(location: Optional[str] = "us-central1"):
35
  scopes=["https://www.googleapis.com/auth/cloud-platform"],
36
  )
37
  vertexai.init(project=VERTEXAI_PROJECT, location=location)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import os
6
  import json
7
  import base64
8
+ import logging
9
+ from vertexai import generative_models
10
+
11
+ logger = logging.getLogger(__name__)
12
+ logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
13
 
14
  load_dotenv()
15
  if "VERTEXAI_PROJECTID" in os.environ:
 
17
 
18
 
19
  def decode_service_key():
20
+ encoded_key = os.environ["GOOGLE_CREDENTIALS"]
21
  original_service_key = json.loads(base64.b64decode(encoded_key).decode('utf-8'))
22
  if original_service_key:
23
  return original_service_key
 
25
 
26
 
27
  def initialize_vertexai_params(location: Optional[str] = "us-central1"):
 
28
  creds_file_name = os.getcwd() + "/.config/gcp_default_credentials.json"
29
  print(creds_file_name)
30
  if not os.path.exists(os.path.dirname(creds_file_name)):
 
39
  scopes=["https://www.googleapis.com/auth/cloud-platform"],
40
  )
41
  vertexai.init(project=VERTEXAI_PROJECT, location=location)
42
+ logger.info("Vertex AI initialized")
43
+
44
+
45
+ def get_default_config() -> tuple[dict, dict]:
46
+ default_gen_config = {
47
+ "temperature": 0.49,
48
+ "max_output_tokens": 1024,
49
+ }
50
+ default_safety_settings = {
51
+ generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
52
+ generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
53
+ }
54
+ return default_gen_config, default_safety_settings
src/ui/__init__.py ADDED
File without changes
src/ui/components/actions.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from src.pipeline import pipeline
4
+
5
+
6
+ def clear():
7
+ return None, None, None, None, None
8
+
9
+
10
+ def generate_text(query_text: str,
11
+ model_name: Optional[str],
12
+ is_sustainable: Optional[bool],
13
+ starting_point: Optional[str],
14
+ max_tokens: Optional[int] = 1024,
15
+ temp: Optional[float] = 0.49,
16
+ ):
17
+ model_params = {
18
+ 'max_tokens': max_tokens,
19
+ 'temperature': temp
20
+ }
21
+ pipeline_response = pipeline(
22
+ query=query_text,
23
+ model_name=model_name,
24
+ sustainability=is_sustainable,
25
+ starting_point=starting_point,
26
+ **model_params
27
+ )
28
+ if pipeline_response is None:
29
+ return "Error while generating response! Please try again."
30
+ return pipeline_response
src/ui/components/inputs.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import gradio as gr
4
+
5
+ from src.helpers.data_loaders import load_places
6
+ from src.text_generation.mapper import MODEL_MAPPER
7
+
8
+
9
+ def get_places():
10
+ data_file = "cities/eu_200_cities.csv"
11
+ df = load_places(data_file)
12
+ df = df.sort_values(by=['country', 'city'])
13
+ return df
14
+
15
+
16
+ # Function to update the list of cities based on the selected country
17
+ def update_cities(selected_country, df):
18
+ filtered_cities = df[df['country'] == selected_country]['city'].tolist()
19
+ return gr.Dropdown(choices=filtered_cities, interactive=True) # Make it interactive as it is not by default
20
+
21
+
22
+ def main_component() -> Tuple[gr.Dropdown, gr.Dropdown, gr.Textbox, gr.Checkbox, gr.Dropdown]:
23
+ """
24
+ Creates the main Gradio interface components and returns them.
25
+
26
+ Returns:
27
+ Tuple containing:
28
+ - countries: Dropdown for selecting the country.
29
+ - starting_point: Dropdown for selecting the starting point.
30
+ - query: Textbox for entering the user query.
31
+ - sustainable: Checkbox for sustainable travel.
32
+ - model: Dropdown for selecting the model.
33
+ """
34
+ df = get_places()
35
+ country_names = list(df.country.unique())
36
+ with gr.Group():
37
+ # Country selection dropdown
38
+ countries = gr.Dropdown(choices=country_names, multiselect=False, label="Country")
39
+
40
+ # Starting point selection dropdown
41
+ starting_point = gr.Dropdown(choices=[], multiselect=False, label="Select your starting point for the trip!")
42
+
43
+ # When a country is selected, update the starting point options
44
+ countries.select(
45
+ fn=lambda selected_country: update_cities(selected_country, df),
46
+ inputs=countries,
47
+ outputs=starting_point
48
+ )
49
+
50
+ # User query input
51
+ query = gr.Textbox(label="Query", placeholder="Ask for your city recommendation here!")
52
+
53
+ # Checkbox for sustainable travel option
54
+ sustainable = gr.Checkbox(
55
+ label="Sustainable",
56
+ info="Do you want your recommendations to be sustainable with regards to the environment, "
57
+ "your starting location, and month of travel?"
58
+ )
59
+
60
+ models = list(MODEL_MAPPER.keys())[:5]
61
+ # Model selection dropdown
62
+ model = gr.Dropdown(
63
+ choices=models,
64
+ label="Model",
65
+ info="Select your model. The model will generate the recommendations based on your query."
66
+ )
67
+
68
+ # Return all the components individually
69
+ return countries, starting_point, query, sustainable, model
src/ui/components/static.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Tuple, Optional
2
+
3
+ import gradio as gr
4
+
5
+ from src.ui.components.inputs import get_places, update_cities, main_component
6
+
7
+
8
+ def load_examples(
9
+ country: gr.components.Component,
10
+ starting_point: gr.components.Component,
11
+ query: gr.components.Component,
12
+ model: gr.components.Component,
13
+ is_sustainable: gr.components.Component,
14
+ output: gr.components.Component,
15
+ generate_text_fn: Callable[
16
+ [str, Optional[str], Optional[bool], Optional[int], Optional[float], Optional[str]],
17
+ str],
18
+ ) -> gr.Examples:
19
+
20
+ df = get_places()
21
+ country_names = list(df.country.unique())
22
+
23
+ # Example data
24
+ country_example = "USA"
25
+ starting_point_example = "New York"
26
+ query_example = "What are some top tourist attractions?"
27
+ sustainable_example = True
28
+ model_example = "GPT-4"
29
+
30
+ # Update the starting point options based on the example country
31
+ starting_point_choices = update_cities(country_example, df)
32
+
33
+ # Set the choices and default value for the starting point dropdown
34
+ starting_point.choices = starting_point_choices
35
+ starting_point.value = starting_point_example
36
+
37
+ # Provide examples
38
+ examples = [
39
+ [country_example, starting_point_example, query_example, sustainable_example, model_example]
40
+ ]
41
+
42
+ # Create a Gradio Examples component
43
+ gr.Examples(
44
+ examples=examples,
45
+ inputs=[country, starting_point, query, is_sustainable, model],
46
+ outputs=[]
47
+ )
48
+
49
+
50
+ def load_buttons(
51
+ query: gr.components.Component,
52
+ model: gr.components.Component,
53
+ sustainable: gr.components.Component,
54
+ starting_point: gr.components.Component,
55
+ max_new_tokens: gr.components.Component,
56
+ temperature: gr.components.Component,
57
+ output: gr.components.Component,
58
+ generate_text_fn: Callable[
59
+ [str, Optional[str], Optional[bool], Optional[int], Optional[float], Optional[str]],
60
+ str],
61
+ clear_fn: Callable[[], None]
62
+ ) -> gr.Group:
63
+ """
64
+ Load and return buttons for the Gradio interface.
65
+
66
+ Args:
67
+ query: The input component for user queries.
68
+ model: The input component for selecting the model.
69
+ sustainable: The input component for sustainable travel options.
70
+ starting_point: The input component for the user's starting point.
71
+ output: The output component for displaying the generated text.
72
+ generate_text_fn: The function to be called on submit to generate text.
73
+ clear_fn: The function to clear the input fields and output.
74
+
75
+ Returns:
76
+ Gradio Group component containing the buttons.
77
+ """
78
+ with gr.Group() as btns:
79
+ with gr.Row():
80
+ submit_btn = gr.Button("Submit", variant="primary")
81
+ clear_btn = gr.Button("Clear", variant="secondary")
82
+ cancel_btn = gr.Button("Cancel", variant="stop")
83
+
84
+ # Bind actions to the buttons
85
+ submit_btn.click(
86
+ fn=generate_text_fn, # Function to generate text
87
+ inputs=[query, model, sustainable,
88
+ starting_point, max_new_tokens,
89
+ temperature], # Input components for generation
90
+ outputs=[output] # Output component
91
+ )
92
+ clear_btn.click(
93
+ fn=clear_fn, # Function to clear inputs
94
+ inputs=[], # No inputs for clearing
95
+ outputs=[query, model, sustainable, starting_point, output] # Clear all inputs and output
96
+ )
97
+ cancel_btn.click(
98
+ fn=clear_fn, # Function to cancel and clear inputs
99
+ inputs=[], # No inputs for cancel
100
+ outputs=[query, model, sustainable, starting_point, output] # Clear all inputs and output
101
+ )
102
+ return btns
103
+
104
+
105
+ def model_settings() -> Tuple[gr.Slider, gr.Slider]:
106
+ """
107
+ Creates the model settings components and returns them.
108
+
109
+ Returns:
110
+ Tuple containing:
111
+ - max_new_tokens: Slider for setting the maximum number of new tokens.
112
+ - temperature: Slider for setting the temperature.
113
+ """
114
+ with gr.Accordion("Settings", open=False):
115
+ # Slider for maximum number of new tokens
116
+ max_new_tokens = gr.Slider(
117
+ label="Max new tokens",
118
+ value=1024,
119
+ minimum=0,
120
+ maximum=8192,
121
+ step=64,
122
+ interactive=True,
123
+ visible=True,
124
+ info="The maximum number of output tokens"
125
+ )
126
+
127
+ # Slider for temperature
128
+ temperature = gr.Slider(
129
+ label="Temperature",
130
+ step=0.01,
131
+ minimum=0.01,
132
+ maximum=1.0,
133
+ value=0.49,
134
+ interactive=True,
135
+ visible=True,
136
+ info="The value used to modulate the logits distribution"
137
+ )
138
+
139
+ return max_new_tokens, temperature
src/ui/setup.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.helpers.data_loaders import load_places
2
+
3
+
4
+ # TODO add the gcp application json file loading
5
+ # TODO examples
6
+
7
+ def load_html_from_file(file_path: str) -> str:
8
+ """
9
+ Reads the HTML content from a file and returns it as a string.
10
+
11
+ Args:
12
+ file_path (str): The path to the HTML file.
13
+
14
+ Returns:
15
+ str: The HTML content of the file.
16
+ """
17
+ with open(file_path, 'r') as file:
18
+ return file.read()
19
+
src/ui/templates/intro.html ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- intro.html -->
2
+ <body>
3
+ <h1 style='font-size:xx-large; color: green; text-align: center'>πŸ€ Green City Finder πŸ€</h1>
4
+ <h3 style="text-align: center">AI Sprint 2024 submissions by Ashmi Banerjee.</h3>
5
+ <br>
6
+ <p style="text-align: justify">We're testing the compatibility of
7
+ Retrieval Augmented Generation (RAG) implementations with Google's <b>Gemma-2b-it</b> & <b>Gemini 1.0 Pro</b>
8
+ models through HuggingFace and VertexAI, respectively, to generate sustainable travel recommendations.
9
+ We use the Wikivoyage dataset to provide city recommendations based on user queries. The vector embeddings are
10
+ stored in a VectorDB (LanceDB) hosted in Google Cloud.
11
+ </p>
12
+ <p style="text-align: justify">Sustainability is calculated based on the work by <a href="https://arxiv.org/abs/2403.18604">Banerjee et al.</a></p>
13
+ <br>
14
+ <p style="text-align: justify">Google Cloud credits are provided for this project.</p>
15
+ </body>