buzzCraft commited on
Commit
291bc70
1 Parent(s): 0e37947

Adding ChainLit demo

Browse files
Files changed (13) hide show
  1. .env_demo +5 -3
  2. .gitignore +11 -0
  3. README.md +16 -2
  4. app.py +176 -0
  5. chainlit.md +14 -0
  6. extractor.py +560 -0
  7. main.py +9 -2
  8. main_cli.py +27 -0
  9. media/chainlit.png +0 -0
  10. requirements.txt +3 -1
  11. src/database.py +45 -38
  12. src/extractor.py +103 -52
  13. src/sql_chain.py +40 -15
.env_demo CHANGED
@@ -1,4 +1,6 @@
1
- OPENAI_API_KEY=API_KEY_HERE
 
 
2
  LANGSMITH = False
3
- LANGSMITH_API_KEY=API_KEY_HERE -NOT NEEDED IF LANGSMITH IS FALSE
4
- ```
 
1
+ OPENAI_API_KEY=OPENAI_API_KEY
2
+ OPENAI_MODEL = gpt-3.5-turbo-0125
3
+ DATABASE_PATH = data/gamess.db
4
  LANGSMITH = False
5
+ LANGSMITH_API_KEY=
6
+ LANGSMITH_PROJECT=SoccerRag
.gitignore CHANGED
@@ -1,2 +1,13 @@
1
 
2
  *.pyc
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  *.pyc
3
+ .env
4
+ .chainlit/config.toml
5
+ .chainlit/translations/en-US.json
6
+ .idea/inspectionProfiles/profiles_settings.xml
7
+ .idea/inspectionProfiles/Project_Default.xml
8
+ .idea/misc.xml
9
+ .idea/modules.xml
10
+ .idea/soccer-rag.iml
11
+ .idea/vcs.xml
12
+ extractor.log
13
+ data/games.db
README.md CHANGED
@@ -32,12 +32,25 @@ python src/database.py
32
  ````
33
  Adjust the path to the data in the database.py file as needed.
34
 
35
- ## Running the code
36
  To run the code, execute the following command:
37
  ````bash
 
 
38
  python main.py
39
  ````
40
- The code will prompt you to enter a natural language query.
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  ### Example query
43
  ````angular2html
@@ -48,6 +61,7 @@ Lionel Messi has scored the following number of goals each season:
48
  - 2016-2017: 31 goals
49
  ````
50
 
 
51
  ## Results
52
  ![result-table.png](media%2Fresult-table.png)
53
 
 
32
  ````
33
  Adjust the path to the data in the database.py file as needed.
34
 
35
+ ## Running the code in command line
36
  To run the code, execute the following command:
37
  ````bash
38
+ The code will prompt you to enter a natural language query.
39
+
40
  python main.py
41
  ````
42
+ You can also call main_cli.py with a query as an argument:
43
+ ````bash
44
+ python main_cli.py -q "How many goals has Messi scored each season?"
45
+ ````
46
+
47
+ ## Running the code in ChainLit (GUI)
48
+ To run the code in ChainLit, execute the following command:
49
+ ````bash
50
+ chainlit run app.py
51
+ ````
52
+ This will open up a browser window with the GUI.
53
+ ![ChainLit](media/chainlit.png)
54
 
55
  ### Example query
56
  ````angular2html
 
61
  - 2016-2017: 31 goals
62
  ````
63
 
64
+
65
  ## Results
66
  ![result-table.png](media%2Fresult-table.png)
67
 
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from src.extractor import create_extractor
3
+ from src.sql_chain import create_agent
4
+ from dotenv import load_dotenv
5
+ import chainlit as cl
6
+ import json
7
+ # Loading the environment variables
8
+ load_dotenv(".env")
9
+ # Create the extractor and agent
10
+
11
+ model = os.getenv('OPENAI_MODEL')
12
+ # Check if model exists, if not, set it to default
13
+ # if not model:
14
+ # model = "gpt-3.5-turbo-0125"
15
+ ex = create_extractor()
16
+ ag = create_agent(llm_model=model)
17
+ # ag = create_agent(llm_model = "gpt-4-0125-preview")
18
+ openai_api_key = os.getenv('OPENAI_API_KEY')
19
+
20
+
21
+
22
+
23
+ def extract_func(user_prompt: str):
24
+ """
25
+
26
+ Parameters
27
+ ----------
28
+ user_prompt: str
29
+
30
+ Returns
31
+ -------
32
+ A dictionary of extracted properties
33
+ """
34
+ extracted = ex.extract_chainlit(user_prompt)
35
+ return extracted
36
+ def validate_func(properties:dict): # Auto validate as much as possible
37
+ """
38
+ Parameters
39
+ ----------
40
+ extracted properties: dict
41
+
42
+ Returns
43
+ -------
44
+ Two dictionaries:
45
+ 1. validated: The validated properties
46
+ 2. need_input: Properties that need human validation
47
+ """
48
+ validated, need_input = ex.validate_chainlit(properties)
49
+ return validated, need_input
50
+
51
+ def human_validate_func(human, validated, user_prompt):
52
+ """
53
+
54
+ Parameters
55
+ ----------
56
+ human - Human validated properties in the form of a list of dictionaries
57
+ validated - Validated properties in the form of a dictionary
58
+ user_prompt - The user prompt
59
+
60
+ Returns
61
+ -------
62
+ The cleaned prompt with updated values
63
+ """
64
+ for item in human:
65
+ # Iterate through key-value pairs in the current dictionary
66
+ for key, value in item.items():
67
+ if value == "":
68
+ continue
69
+ # Check if the key exists in the validated dictionary
70
+ if key in validated:
71
+ # Append the value to the existing list
72
+ validated[key].append(value)
73
+ else:
74
+ # Create a new key with the value as a new list
75
+ validated[key] = [value]
76
+ val_list = [validated]
77
+
78
+ return ex.build_prompt_chainlit(val_list, user_prompt)
79
+
80
+ def no_human(validated, user_prompt):
81
+ """
82
+ In case there is no need for human validation, this function will be called
83
+ Parameters
84
+ ----------
85
+ validated
86
+ user_prompt
87
+
88
+ Returns
89
+ -------
90
+ Updated prompt
91
+ """
92
+ return ex.build_prompt_chainlit([validated], user_prompt)
93
+
94
+
95
+ def ask(text):
96
+ """
97
+ Calls the SQL Agent to get the final answer
98
+ Parameters
99
+ ----------
100
+ text
101
+
102
+ Returns
103
+ -------
104
+ The final answer
105
+ """
106
+ ans, const = ag.ask(text)
107
+ return {"output": ans["output"]}, 12
108
+
109
+
110
+ @cl.step
111
+ async def Cleaner(text): # just for printing
112
+ return text
113
+
114
+
115
+ @cl.step
116
+ async def LLM(cleaned_prompt): # just for printing
117
+ ans, const = ask(cleaned_prompt)
118
+ return ans, const
119
+
120
+
121
+ @cl.step
122
+ async def Choice(text):
123
+ return text
124
+
125
+ @cl.step
126
+ async def Extractor(user_prompt):
127
+ extracted_values = extract_func(user_prompt)
128
+ return extracted_values
129
+
130
+
131
+ @cl.on_message # this function will be called every time a user inputs a message in the UI
132
+ async def main(message: cl.Message):
133
+ user_prompt = message.content # Get the user prompt
134
+ # extracted_values = extract_func(user_prompt)
135
+ #
136
+ # json_formatted = json.dumps(extracted_values, indent=4)
137
+ extracted_values = await Extractor(user_prompt)
138
+ json_formatted = json.dumps(extracted_values, indent=4)
139
+ # Print the extracted values in json format
140
+ await cl.Message(author="Extractor", content=f"Extracted properties:\n```json\n{json_formatted}\n```").send()
141
+ # Try to validate everything
142
+ validated, need_input = validate_func(extracted_values)
143
+ await cl.Message(author="Validator", content=f"Extracted properties will now be validated against the database.").send()
144
+ if need_input:
145
+ # If we need validation, we will ask the user to select the correct value
146
+ for element in need_input:
147
+ key = next(iter(element)) # Get the first key in the dictionary
148
+ # Present user with options to choose from
149
+ actions = [
150
+ cl.Action(name=value, value=value, description=str(value))
151
+ for value in element['top_matches']
152
+ ]
153
+ actions.append(cl.Action(name="No Update", value="", description="No Update"))
154
+ # Add a "No Update" option
155
+ res = await cl.AskActionMessage(
156
+ author="Validator",
157
+ content=f"Select the correct value for {element[key]}",
158
+ actions=actions
159
+ ).send()
160
+ selected_value = res.get("value", "") if res else ""
161
+ element[key] = selected_value
162
+ element.pop("top_matches")
163
+ await Choice(selected_value) # Logging choice
164
+ # Get the cleaned prompt
165
+ cleaned_prompt = human_validate_func(need_input, validated, user_prompt)
166
+ else:
167
+ cleaned_prompt = no_human(validated, user_prompt)
168
+ # Print the cleaned prompt
169
+ cleaner_message = cl.Message(author="Cleaner", content=f"New prompt is as follows:\n{cleaned_prompt}")
170
+ await cleaner_message.send()
171
+
172
+ # Call the SQL agent to get the final answer
173
+ # ans, const = ask(cleaned_prompt) # Get the final answer from some function
174
+ await cl.Message(content=f"I will now query the database for information.").send()
175
+ ans, const = await LLM(cleaned_prompt)
176
+ await cl.Message(content=f"This is the final answer: \n\n{ans['output']}").send()
chainlit.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Welcome to Chainlit! 🚀🤖
2
+
3
+ Hi there, Developer! 👋 We're excited to have you on board. Chainlit is a powerful tool designed to help you prototype, debug and share applications built on top of LLMs.
4
+
5
+ ## Useful Links 🔗
6
+
7
+ - **Documentation:** Get started with our comprehensive [Chainlit Documentation](https://docs.chainlit.io) 📚
8
+ - **Discord Community:** Join our friendly [Chainlit Discord](https://discord.gg/k73SQ3FyUh) to ask questions, share your projects, and connect with other developers! 💬
9
+
10
+ We can't wait to see what you create with Chainlit! Happy coding! 💻😊
11
+
12
+ ## Welcome screen
13
+
14
+ To modify the welcome screen, edit the `chainlit.md` file at the root of your project. If you do not want a welcome screen, just leave this file empty.
extractor.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from langchain.chains import create_extraction_chain_pydantic
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ from langchain.chains import create_extraction_chain
6
+ from copy import deepcopy
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain_community.utilities import SQLDatabase
9
+ import os
10
+ import difflib
11
+ import ast
12
+ import json
13
+ import re
14
+ from thefuzz import process
15
+ # Set up logging
16
+ import logging
17
+
18
+ from dotenv import load_dotenv
19
+
20
+ load_dotenv(".env")
21
+
22
+ logging.basicConfig(level=logging.INFO)
23
+ # Save the log to a file
24
+ handler = logging.FileHandler('extractor.log')
25
+ logger = logging.getLogger(__name__)
26
+
27
+ os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
28
+ # os.environ["ANTHROPIC_API_KEY"] = os.getenv('ANTHROPIC_API_KEY')
29
+
30
+ if os.getenv('LANGSMITH'):
31
+ os.environ['LANGCHAIN_TRACING_V2'] = 'true'
32
+ os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
33
+ os.environ[
34
+ 'LANGCHAIN_API_KEY'] = os.getenv("LANGSMITH_API_KEY")
35
+ os.environ['LANGCHAIN_PROJECT'] = os.getenv('LANGSMITH_PROJECT')
36
+ db_uri = os.getenv('DATABASE_PATH')
37
+ db_uri = f"sqlite:///{db_uri}"
38
+ db = SQLDatabase.from_uri(db_uri)
39
+
40
+ # from langchain_anthropic import ChatAnthropic
41
+ class Extractor():
42
+ # llm = ChatOpenAI(model_name="gpt-4-0125-preview", temperature=0)
43
+ #gpt-3.5-turbo
44
+ def __init__(self, model="gpt-3.5-turbo-0125", schema_config=None, custom_extractor_prompt=None):
45
+ # model = "gpt-4-0125-preview"
46
+ if custom_extractor_prompt:
47
+ cust_promt = ChatPromptTemplate.from_template(custom_extractor_prompt)
48
+
49
+ self.llm = ChatOpenAI(model=model, temperature=0)
50
+ # self.llm = ChatAnthropic(model="claude-3-opus-20240229", temperature=0)
51
+ self.schema = schema_config or {}
52
+ self.chain = create_extraction_chain(self.schema, self.llm, prompt=cust_promt)
53
+
54
+ def extract(self, query):
55
+ return self.chain.invoke(query)
56
+
57
+
58
+ class Retriever():
59
+ def __init__(self, db, config):
60
+ self.db = db
61
+ self.config = config
62
+ self.table = config.get('db_table')
63
+ self.column = config.get('db_column')
64
+ self.pk_column = config.get('pk_column')
65
+ self.numeric = config.get('numeric', False)
66
+ self.response = []
67
+ self.query = f"SELECT {self.column} FROM {self.table}"
68
+ self.augmented_table = config.get('augmented_table', None)
69
+ self.augmented_column = config.get('augmented_column', None)
70
+ self.augmented_fk = config.get('augmented_fk', None)
71
+
72
+ def query_as_list(self):
73
+ # Execute the query
74
+ response = self.db.run(self.query)
75
+ response = [el for sub in ast.literal_eval(response) for el in sub if el]
76
+ if not self.numeric:
77
+ response = [re.sub(r"\b\d+\b", "", string).strip() for string in response]
78
+ self.response = list(set(response))
79
+ # print(self.response)
80
+ return self.response
81
+
82
+ def get_augmented_items(self, prompt):
83
+ if self.augmented_table is None:
84
+ return None
85
+ else:
86
+ # Construct the query to search for the prompt in the augmented table
87
+ query = f"SELECT {self.augmented_fk} FROM {self.augmented_table} WHERE LOWER({self.augmented_column}) = LOWER('{prompt}')"
88
+
89
+ # Execute the query
90
+ fk_response = self.db.run(query)
91
+ if fk_response:
92
+ # Extract the FK value
93
+ fk_response = ast.literal_eval(fk_response)
94
+ fk_value = fk_response[0][0]
95
+ query = f"SELECT {self.column} FROM {self.table} WHERE {self.pk_column} = {fk_value}"
96
+ # Execute the query
97
+ matching_response = self.db.run(query)
98
+ # Extract the matching response
99
+ matching_response = ast.literal_eval(matching_response)
100
+ matching_response = matching_response[0][0]
101
+ return matching_response
102
+ else:
103
+ return None
104
+
105
+ def find_close_matches(self, target_string, n=3, method="difflib", threshold=70):
106
+ """
107
+ Find and return the top n close matches to target_string in the database query results.
108
+
109
+ Args:
110
+ - target_string (str): The string to match against the database results.
111
+ - n (int): Number of top matches to return.
112
+
113
+ Returns:
114
+ - list of tuples: Each tuple contains a match and its score.
115
+ """
116
+ # Ensure we have the response list populated
117
+ if not self.response:
118
+ self.query_as_list()
119
+
120
+ # Find top n close matches
121
+ if method == "fuzzy":
122
+ # Use the fuzzy_string method to get matches and their scores
123
+ # If the threshold is met, return the best match; otherwise, return all matches meeting the threshold
124
+ top_matches = self.fuzzy_string(target_string, limit=n, threshold=threshold)
125
+
126
+
127
+ else:
128
+ # Use difflib's get_close_matches to get the top n matches
129
+ top_matches = difflib.get_close_matches(target_string, self.response, n=n, cutoff=0.2)
130
+
131
+ return top_matches
132
+
133
+ def fuzzy_string(self, prompt, limit, threshold=80, low_threshold=30):
134
+
135
+ # Get matches and their scores, limited by the specified 'limit'
136
+ matches = process.extract(prompt, self.response, limit=limit)
137
+
138
+
139
+ filtered_matches = [match for match in matches if match[1] >= threshold]
140
+
141
+ # If no matches meet the threshold, return the list of all matches' strings
142
+ if not filtered_matches:
143
+ # Return matches above the low_threshold
144
+ # Fix for wrong properties being returned
145
+ return [match[0] for match in matches if match[1] >= low_threshold]
146
+
147
+
148
+ # If there's only one match meeting the threshold, return it as a string
149
+ if len(filtered_matches) == 1:
150
+ return filtered_matches[0][0] # Return the matched string directly
151
+
152
+ # If there's more than one match meeting the threshold or ties, return the list of matches' strings
153
+ highest_score = filtered_matches[0][1]
154
+ ties = [match for match in filtered_matches if match[1] == highest_score]
155
+
156
+ # Return the strings of tied matches directly, ignoring the scores
157
+ m = [match[0] for match in ties]
158
+ if len(m) == 1:
159
+ return m[0]
160
+ return [match[0] for match in ties]
161
+
162
+ def fetch_pk(self, property_name, property_value):
163
+ # Some properties do not have a primary key
164
+ # Return the property value if no primary key is specified
165
+ pk_list = []
166
+
167
+ # Check if the property_value is a list; if not, make it a list for uniform processing
168
+ if not isinstance(property_value, list):
169
+ property_value = [property_value]
170
+
171
+ # Some properties do not have a primary key
172
+ # Return None for each property_value if no primary key is specified
173
+ if self.pk_column is None:
174
+ return [None for _ in property_value]
175
+
176
+ for value in property_value:
177
+ query = f"SELECT {self.pk_column} FROM {self.table} WHERE {self.column} = '{value}' LIMIT 1"
178
+ response = self.db.run(query)
179
+
180
+ # Append the response (PK or None) to the pk_list
181
+ pk_list.append(response)
182
+
183
+ return pk_list
184
+
185
+
186
+ def setup_retrievers(db, schema_config):
187
+ # retrievers = {}
188
+ # for prop, config in schema_config["properties"].items():
189
+ # retrievers[prop] = Retriever(db=db, config=config)
190
+ # return retrievers
191
+
192
+ retrievers = {}
193
+ # Iterate over each property in the schema_config's properties
194
+ for prop, config in schema_config["properties"].items():
195
+ # Access the 'items' dictionary for the configuration of the array's elements
196
+ item_config = config['items']
197
+ # Create a Retriever instance using the item_config
198
+ retrievers[prop] = Retriever(db=db, config=item_config)
199
+ return retrievers
200
+
201
+
202
+ def extract_properties(prompt, schema_config, custom_extractor_prompt=None):
203
+ """Extract properties from the prompt."""
204
+ # modify schema_conf to only include the required properties
205
+ schema_stripped = {'properties': {}}
206
+ for key, value in schema_config['properties'].items():
207
+ schema_stripped['properties'][key] = {
208
+ 'type': value['type'],
209
+ 'items': {'type': value['items']['type']}
210
+ }
211
+
212
+ extractor = Extractor(schema_config=schema_stripped, custom_extractor_prompt=custom_extractor_prompt)
213
+ extraction_result = extractor.extract(prompt)
214
+ # print("Extraction Result:", extraction_result)
215
+
216
+ if 'text' in extraction_result and extraction_result['text']:
217
+ properties = extraction_result['text']
218
+ return properties
219
+ else:
220
+ print("No properties extracted.")
221
+ return None
222
+
223
+
224
+ def recheck_property_value(properties, property_name, retrievers, input_func):
225
+ while True:
226
+ new_value = input_func(f"Enter new value for {property_name} or type 'quit' to stop: ")
227
+ if new_value.lower() == 'quit':
228
+ break # Exit the loop and do not update the property
229
+
230
+ new_top_matches = retrievers[property_name].find_close_matches(new_value, n=3)
231
+ if new_top_matches:
232
+ # Display new top matches and ask for confirmation or re-entry
233
+ print("\nNew close matches found:")
234
+ for i, match in enumerate(new_top_matches, start=1):
235
+ print(f"[{i}] {match}")
236
+ print("[4] Re-enter value")
237
+ print("[5] Quit without updating")
238
+
239
+ selection = input_func("Select the best match (1-3), choose 4 to re-enter value, or 5 to quit: ")
240
+ if selection in ['1', '2', '3']:
241
+ selected_match = new_top_matches[int(selection) - 1]
242
+ properties[property_name] = selected_match # Update the dictionary directly
243
+ print(f"Updated {property_name} to {selected_match}")
244
+ break # Successfully updated, exit the loop
245
+ elif selection == '5':
246
+ break # Quit without updating
247
+ # Loop will continue if user selects 4 or inputs invalid selection
248
+ else:
249
+ print("No close matches found. Please try again or type 'quit' to stop.")
250
+
251
+
252
+ def check_and_update_properties(properties_list, retrievers, method="fuzzy", input_func=input):
253
+ """
254
+ Checks and updates the properties in the properties list based on close matches found in the database.
255
+ The function iterates through each property in each property dictionary within the list,
256
+ finds close matches for it in the database using the retrievers, and updates the property
257
+ value based on user selection.
258
+
259
+ Args:
260
+ properties_list (list of dict): A list of dictionaries, where each dictionary contains properties
261
+ to check and potentially update based on database matches.
262
+ retrievers (dict): A dictionary of Retriever objects keyed by property name, used to find close matches in the database.
263
+ input_func (function, optional): A function to capture user input. Defaults to the built-in input function.
264
+
265
+ The function updates the properties_list in place based on user choices for updating property values
266
+ with close matches found by the retrievers.
267
+ """
268
+
269
+ for index, properties in enumerate(properties_list):
270
+ for property_name, retriever in retrievers.items(): # Iterate using items to get both key and value
271
+ property_values = properties.get(property_name, [])
272
+ if not property_values: # Skip if the property is not present or is an empty list
273
+ continue
274
+
275
+ updated_property_values = [] # To store updated list of values
276
+
277
+ for value in property_values:
278
+ if retriever.augmented_table:
279
+ augmented_value = retriever.get_augmented_items(value)
280
+ if augmented_value:
281
+ updated_property_values.append(augmented_value)
282
+ continue
283
+ # Since property_value is now expected to be a list, we handle each value individually
284
+ top_matches = retriever.find_close_matches(value, method=method, n=3)
285
+
286
+ # Check if the closest match is the same as the current value
287
+ if top_matches and top_matches[0] == value:
288
+ updated_property_values.append(value)
289
+ continue
290
+
291
+ if not top_matches:
292
+ updated_property_values.append(value) # Keep the original value if no matches found
293
+ continue
294
+
295
+ if type(top_matches) == str and method == "fuzzy":
296
+ # If the top_matches is a string, it means that the threshold was met and only one item was returned
297
+ # In this case, we can directly update the property with the top match
298
+ updated_property_values.append(top_matches)
299
+ properties[property_name] = updated_property_values
300
+ continue
301
+
302
+ print(f"\nCurrent {property_name}: {value}")
303
+ for i, match in enumerate(top_matches, start=1):
304
+ print(f"[{i}] {match}")
305
+ print("[4] Enter new value")
306
+
307
+ # hmm = input_func(f"Fix for Pycharm, press enter to continue")
308
+
309
+ choice = input_func(f"Select the best match for {property_name} (1-4): ")
310
+ if choice in ['1', '2', '3']:
311
+ selected_match = top_matches[int(choice) - 1]
312
+ updated_property_values.append(selected_match) # Update with the selected match
313
+ print(f"Updated {property_name} to {selected_match}")
314
+ elif choice == '4':
315
+ # Allow re-entry of value for this specific item
316
+ recheck_property_value(properties, property_name, value, retrievers, input_func)
317
+ # Note: Implement recheck_property_value to handle individual value updates within the list
318
+ else:
319
+ print("Invalid selection. Property not updated.")
320
+ updated_property_values.append(value) # Keep the original value
321
+
322
+ # Update the entire list for the property after processing all values
323
+ properties[property_name] = updated_property_values
324
+
325
+
326
+ # Function to remove duplicates
327
+ def remove_duplicates(dicts):
328
+ seen = {} # Dictionary to keep track of seen values for each key
329
+ for d in dicts:
330
+ for key in list(d.keys()): # Use list to avoid RuntimeError for changing dict size during iteration
331
+ value = d[key]
332
+ if key in seen and value == seen[key]:
333
+ del d[key] # Remove key-value pair if duplicate is found
334
+ else:
335
+ seen[key] = value # Update seen values for this key
336
+ return dicts
337
+
338
+
339
+ def fetch_pks(properties_list, retrievers):
340
+ all_pk_attributes = [] # Initialize a list to store dictionaries of _pk attributes for each item in properties_list
341
+
342
+ # Iterate through each properties dictionary in the list
343
+ for properties in properties_list:
344
+ pk_attributes = {} # Initialize a dictionary for the current set of properties
345
+ for property_name, property_value in properties.items():
346
+ if property_name in retrievers:
347
+ # Fetch the primary key using the retriever for the current property
348
+ pk = retrievers[property_name].fetch_pk(property_name, property_value)
349
+ # Store it in the dictionary with a modified key name
350
+ pk_attributes[f"{property_name}_pk"] = pk
351
+
352
+ # Add the dictionary of _pk attributes for the current set of properties to the list
353
+ all_pk_attributes.append(pk_attributes)
354
+
355
+ # Return a list of dictionaries, where each dictionary contains _pk attributes for a set of properties
356
+ return all_pk_attributes
357
+
358
+
359
+ def update_prompt(prompt, properties, pk, properties_original):
360
+ # Replace the original prompt with the updated properties and pk
361
+ prompt = prompt.replace("{{properties}}", str(properties))
362
+ prompt = prompt.replace("{{pk}}", str(pk))
363
+ return prompt
364
+
365
+
366
+ def update_prompt_enhanced(prompt, properties, pk, properties_original):
367
+ updated_info = ""
368
+ for prop, pk_info, prop_orig in zip(properties, pk, properties_original):
369
+ for key in prop.keys():
370
+ # Extract original and updated values
371
+ orig_values = prop_orig.get(key, [])
372
+ updated_values = prop.get(key, [])
373
+
374
+ # Ensure both original and updated values are lists for uniform processing
375
+ if not isinstance(orig_values, list):
376
+ orig_values = [orig_values]
377
+ if not isinstance(updated_values, list):
378
+ updated_values = [updated_values]
379
+
380
+ # Extract primary key detail for this key, handling various pk formats carefully
381
+ pk_key = f"{key}_pk" # Construct pk key name based on the property key
382
+ pk_details = pk_info.get(pk_key, [])
383
+ if not isinstance(pk_details, list):
384
+ pk_details = [pk_details]
385
+
386
+ for orig_value, updated_value, pk_detail in zip(orig_values, updated_values, pk_details):
387
+ pk_value = None
388
+ if isinstance(pk_detail, str):
389
+ pk_value = pk_detail.strip("[]()").split(",")[0].replace("'", "").replace('"', '')
390
+
391
+ update_statement = ""
392
+ # Skip updating if there's no change in value to avoid redundant info
393
+ if orig_value != updated_value and pk_value:
394
+ update_statement = f"\n- {orig_value} (now referred to as {updated_value}) has a primary key: {pk_value}."
395
+ elif orig_value != updated_value:
396
+ update_statement = f"\n- {orig_value} (now referred to as {updated_value})."
397
+ elif pk_value:
398
+ update_statement = f"\n- {orig_value} has a primary key: {pk_value}."
399
+
400
+ updated_info += update_statement
401
+
402
+ if updated_info:
403
+ prompt += "\nUpdated Information:" + updated_info
404
+
405
+ return prompt
406
+
407
+
408
+ def prompt_cleaner(prompt, db, schema_config):
409
+ """Main function to clean the prompt."""
410
+
411
+ retrievers = setup_retrievers(db, schema_config)
412
+
413
+ properties = extract_properties(prompt, schema_config)
414
+ # Keep original properties for later use
415
+ properties_original = deepcopy(properties)
416
+ # Remove duplicates - Happens when there are more than one player or team in the prompt
417
+ properties = remove_duplicates(properties)
418
+ if properties:
419
+ check_and_update_properties(properties, retrievers)
420
+
421
+ pk = fetch_pks(properties, retrievers)
422
+ properties = update_prompt_enhanced(prompt, properties, pk, properties_original)
423
+
424
+ return properties, pk
425
+
426
+
427
+ class PromptCleaner:
428
+ """
429
+ A class designed to clean and process prompts by extracting properties, removing duplicates,
430
+ and updating these properties based on a predefined schema configuration and database interactions.
431
+
432
+ Attributes:
433
+ db: A database connection object used to execute queries and fetch data.
434
+ schema_config: A dictionary defining the schema configuration for the extraction process.
435
+ schema_config = {
436
+ "properties": {
437
+ # Property name
438
+ "person_name": {"type": "string", "db_table": "players", "db_column": "name", "pk_column": "hash",
439
+ # if mostly numeric, such as 2015-2016 set true
440
+ "numeric": False},
441
+ "team_name": {"type": "string", "db_table": "teams", "db_column": "name", "pk_column": "id",
442
+ "numeric": False},
443
+ # Add more as needed
444
+ },
445
+ # Parameter to extractor, if person_name is required, add it here and the extractor will
446
+ # return an error if it is not found
447
+ "required": [],
448
+ }
449
+
450
+ Methods:
451
+ clean(prompt): Cleans the given prompt by extracting and updating properties based on the database.
452
+ Returns a tuple containing the updated properties and their primary keys.
453
+ """
454
+
455
+ def __init__(self, db=db, schema_config=None, custom_extractor_prompt=None):
456
+ """
457
+ Initializes the PromptCleaner with a database connection and a schema configuration.
458
+
459
+ Args:
460
+ db: The database connection object to be used for querying. (if none, it will use the default db)
461
+ schema_config: A dictionary defining properties and their database mappings for extraction and updating.
462
+ """
463
+ self.db = db
464
+ self.schema_config = schema_config
465
+ self.retrievers = setup_retrievers(self.db, self.schema_config)
466
+ self.cust_extractor_prompt = custom_extractor_prompt
467
+
468
+ def clean(self, prompt, return_pk=False, test=False, verbose = False):
469
+ """
470
+ Processes the given prompt to extract properties, remove duplicates, update the properties
471
+ based on close matches within the database, and fetch primary keys for these properties.
472
+
473
+ The method first extracts properties from the prompt using the schema configuration,
474
+ then checks these properties against the database to find and update close matches.
475
+ It also fetches primary keys for the updated properties where applicable.
476
+
477
+ Args:
478
+ prompt (str): The prompt text to be cleaned and processed.
479
+ return_pk (bool): A flag to indicate whether to return primary keys along with the properties.
480
+ test (bool): A flag to indicate whether to return the original properties for testing purposes.
481
+ verbose (bool): A flag to indicate whether to return the original properties for debugging.
482
+
483
+ Returns:
484
+ tuple: A tuple containing two elements:
485
+ - The first element is the original prompt, with updated information that excist in the db.
486
+ - The second element is a list of dictionaries, each containing primary keys for the properties,
487
+ where applicable.
488
+
489
+ """
490
+ if self.cust_extractor_prompt:
491
+
492
+ properties = extract_properties(prompt, self.schema_config, self.cust_extractor_prompt)
493
+
494
+ else:
495
+ properties = extract_properties(prompt, self.schema_config)
496
+ # Keep original properties for later use
497
+ properties_original = deepcopy(properties)
498
+ if test:
499
+ return properties_original
500
+ # Remove duplicates - Happens when there are more than one player or team in the prompt
501
+ # properties = remove_duplicates(properties)
502
+ pk = None
503
+ if properties:
504
+ check_and_update_properties(properties, self.retrievers)
505
+ pk = fetch_pks(properties, self.retrievers)
506
+ properties = update_prompt_enhanced(prompt, properties, pk, properties_original)
507
+
508
+
509
+
510
+ if return_pk:
511
+ return properties, pk
512
+ elif verbose:
513
+ return properties, properties_original
514
+ else:
515
+ return properties
516
+
517
+
518
+ def load_json(file_path: str) -> dict:
519
+ with open(file_path, 'r') as file:
520
+ return json.load(file)
521
+
522
+
523
+ def create_extractor(schema: str = "src/conf/schema.json", db: SQLDatabase = db_uri):
524
+ schema_config = load_json(schema)
525
+ db = SQLDatabase.from_uri(db)
526
+ pre_prompt = """Extract and save the relevant entities mentioned \
527
+ in the following passage together with their properties.
528
+
529
+ Only extract the properties mentioned in the 'information_extraction' function.
530
+
531
+ The questions are soccer related. game_event are things like yellow cards, goals, assists, freekick ect.
532
+ Generic properties like, "description", "home team", "away team", "game" ect should NOT be extracted.
533
+
534
+ If a property is not present and is not required in the function parameters, do not include it in the output.
535
+ If no properties are found, return an empty list.
536
+
537
+ Here are some exampels:
538
+ 'How many goals did Henry score for Arsnl in the 2015 season?'
539
+ person_name': ['Henry'], 'team_name': [Arsnl],'year_season': ['2015'],
540
+
541
+ Passage:
542
+ {input}
543
+ """
544
+
545
+ return PromptCleaner(db, schema_config, custom_extractor_prompt=pre_prompt)
546
+
547
+
548
+ if __name__ == "__main__":
549
+
550
+
551
+ schema_config = load_json("src/conf/schema.json")
552
+ # Add game and league to the schema_config
553
+
554
+ # prompter = PromptCleaner(db, schema_config, custom_extractor_prompt=extract_prompt)
555
+ prompter = create_extractor("src/conf/schema.json", "sqlite:///data/games.db")
556
+ prompt= prompter.clean("Give me goals, shots on target, shots off target and corners from the game between ManU and Swansa")
557
+
558
+
559
+ print(prompt)
560
+
main.py CHANGED
@@ -1,8 +1,15 @@
1
  from src.extractor import create_extractor
2
  from src.sql_chain import create_agent
 
 
 
 
 
 
 
 
3
  ex = create_extractor()
4
- ag = create_agent(llm_model="gpt-3.5-turbo-0125", verbose=False)
5
- # ag = create_agent(llm_model = "gpt-4-0125-preview")
6
 
7
  def query(prompt):
8
  clean = ex.clean(prompt)
 
1
  from src.extractor import create_extractor
2
  from src.sql_chain import create_agent
3
+ import os
4
+ from dotenv import load_dotenv
5
+
6
+ ex = create_extractor()
7
+ load_dotenv(".env")
8
+
9
+ model = os.getenv('OPENAI_MODEL')
10
+
11
  ex = create_extractor()
12
+ ag = create_agent(llm_model=model)
 
13
 
14
  def query(prompt):
15
  clean = ex.clean(prompt)
main_cli.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.extractor import create_extractor
2
+ from src.sql_chain import create_agent
3
+ import os
4
+ from dotenv import load_dotenv
5
+
6
+ load_dotenv(".env")
7
+
8
+ model = os.getenv('OPENAI_MODEL')
9
+
10
+ ex = create_extractor()
11
+ ag = create_agent(llm_model=model)
12
+
13
+
14
+ def query(prompt):
15
+ clean, ver = ex.clean(prompt, verbose=True)
16
+ ans, ver = ag.ask(clean)
17
+ return ans
18
+
19
+ if __name__ == '__main__':
20
+ import argparse
21
+
22
+ parser = argparse.ArgumentParser(description="Process a user query.")
23
+ parser.add_argument('-q', '--query', type=str, required=True, help='A query string to process')
24
+
25
+ args = parser.parse_args()
26
+ ans = query(args.query)
27
+ print(ans["output"])
media/chainlit.png ADDED
requirements.txt CHANGED
@@ -10,5 +10,7 @@ rapidfuzz==3.6.1
10
  thefuzz==0.22.1
11
  faiss-cpu
12
  Levenshtein==0.25.0
13
- langsmith~=0.1.54
14
  python-dotenv==1.0.1
 
 
 
10
  thefuzz==0.22.1
11
  faiss-cpu
12
  Levenshtein==0.25.0
13
+ langsmith~=0.0.92
14
  python-dotenv==1.0.1
15
+ chainlit~=1.0.506
16
+ pandas
src/database.py CHANGED
@@ -4,7 +4,7 @@ import pandas as pd
4
  import os
5
  import json
6
 
7
- engine = create_engine('sqlite:///../../data/games.db', echo=False)
8
  Base = declarative_base()
9
 
10
 
@@ -25,6 +25,7 @@ class Game(Base):
25
  season = Column(String)
26
  league_id = Column(Integer, ForeignKey('leagues.id'))
27
 
 
28
  class GameLineup(Base):
29
  __tablename__ = 'game_lineup'
30
  id = Column(Integer, primary_key=True)
@@ -46,6 +47,7 @@ class Team(Base):
46
  id = Column(Integer, primary_key=True)
47
  name = Column(String)
48
 
 
49
  class Player(Base):
50
  __tablename__ = 'players'
51
  hash = Column(String, primary_key=True)
@@ -75,11 +77,13 @@ class Commentary(Base):
75
  event_time_end = Column(Float)
76
  description = Column(Text)
77
 
 
78
  class League(Base):
79
  __tablename__ = 'leagues'
80
  id = Column(Integer, primary_key=True)
81
  name = Column(String)
82
 
 
83
  class Event(Base):
84
  __tablename__ = 'events'
85
  id = Column(Integer, primary_key=True)
@@ -92,36 +96,36 @@ class Event(Base):
92
  label = Column(String)
93
  visibility = Column(Boolean)
94
 
 
95
  class Augmented_Team(Base):
96
  __tablename__ = 'augmented_teams'
97
  id = Column(Integer, primary_key=True)
98
  team_id = Column(Integer, ForeignKey('teams.id'))
99
  augmented_name = Column(String)
100
 
 
101
  class Augmented_League(Base):
102
  __tablename__ = 'augmented_leagues'
103
  id = Column(Integer, primary_key=True)
104
  league_id = Column(Integer, ForeignKey('leagues.id'))
105
  augmented_name = Column(String)
106
 
 
107
  class Player_Event_Label(Base):
108
  __tablename__ = 'player_event_labels'
109
  id = Column(Integer, primary_key=True)
110
  label = Column(String)
111
 
 
112
  class Player_Event(Base):
113
  __tablename__ = 'player_events'
114
  id = Column(Integer, primary_key=True)
115
  game_id = Column(Integer, ForeignKey('games.id'))
116
  player_id = Column(Integer, ForeignKey('players.hash'))
117
- time = Column(String) # Time in minutes of the game
118
  type = Column(Integer, ForeignKey('player_event_labels.id'))
119
- linked_player = Column(Integer, ForeignKey('players.hash')) # If the event is linked to another player, for example a substitution
120
-
121
-
122
-
123
-
124
-
125
 
126
 
127
  # Create Tables
@@ -130,11 +134,13 @@ Base.metadata.create_all(engine)
130
  # Session setup
131
  Session = sessionmaker(bind=engine)
132
 
133
- def extract_time_from_player_event(time:str)->str:
 
134
  # Extract the time from the string
135
- time = time.split("'")[0] # Need to keep it str because of overtime eg. (45+2)
136
  return time
137
 
 
138
  def get_or_create(session, model, **kwargs):
139
  instance = session.query(model).filter_by(**kwargs).first()
140
  if instance:
@@ -145,7 +151,8 @@ def get_or_create(session, model, **kwargs):
145
  session.commit()
146
  return instance
147
 
148
- def process_game_data(data,data2, league, season):
 
149
  session = Session()
150
  # Caption = d and v2 = d2
151
  home_team = data["gameHomeTeam"]
@@ -169,7 +176,8 @@ def process_game_data(data,data2, league, season):
169
  # Check if league exists
170
  league = get_or_create(session, League, name=league)
171
  if not game:
172
- game = Game(timestamp=timestamp, score=score, goal_home=home_score, goal_away=away_score, round=round_, home_team_id=home_team.id, away_team_id=away_team.id,
 
173
  venue=venue, date=date, attendance=attendance, season=season, league_id=league.id, referee=referee)
174
  session.add(game)
175
  session.commit()
@@ -187,22 +195,19 @@ def process_game_data(data,data2, league, season):
187
  for player_data in team_lineup["players"]:
188
  player_hash = player_data["hash"]
189
  name = player_data["long_name"]
190
- if " " not in name: # Since some players are missing their first name, do this to help with the search
191
  name = "NULL " + name
192
  number = player_data["shirt_number"]
193
  captain = player_data["captain"] == "(C)"
194
  starting = player_data["starting"]
195
  country = player_data["country"]
196
  position = player_data["lineup"]
197
- facts = player_data.get("facts", None) # Facts might be empty
198
-
199
-
200
-
201
-
202
 
203
  player = get_or_create(session, Player, hash=player_hash, name=name, country=country)
204
  game_lineup = GameLineup(game_id=game.id, team_id=team_id, player_id=player.hash,
205
- shirt_number=number, position=position, starting=starting, captain=captain, coach=False, tactics=tactic)
 
206
  if facts:
207
  for fact in facts:
208
  type = fact["type"]
@@ -210,7 +215,8 @@ def process_game_data(data,data2, league, season):
210
  event = get_or_create(session, Player_Event_Label, id=int(type))
211
  linked_player = fact.get("linked_player_hash", None)
212
 
213
- player_event = Player_Event(game_id=game.id, player_id=player.hash, time=time, type=event.id, linked_player=linked_player)
 
214
  session.add(player_event)
215
  session.add(game_lineup)
216
 
@@ -223,7 +229,8 @@ def process_game_data(data,data2, league, season):
223
  coach_country = coach["country"]
224
  coach_player = get_or_create(session, Player, hash=coach_hash, name=coach_name, country=coach_country)
225
  game_lineup = GameLineup(game_id=game.id, team_id=team_id, player_id=coach_player.hash,
226
- shirt_number=None, position=None, starting=None, captain=False, coach=True, tactics=tactic)
 
227
  session.add(game_lineup)
228
 
229
  # Commit all changes at once
@@ -241,7 +248,7 @@ def process_game_data(data,data2, league, season):
241
  label = "yellow card"
242
  elif label == "r-card":
243
  label = "red card"
244
-
245
  description = event["description"]
246
  important = event["important"] == "true"
247
  visible = event["visibility"]
@@ -257,9 +264,11 @@ def process_game_data(data,data2, league, season):
257
 
258
  return game.id, home_team.id, away_team.id
259
 
 
260
  def process_player_data(data):
261
  pass
262
 
 
263
  def process_ASR_data(data, game_id, period):
264
  session = Session()
265
  seg = data["segments"]
@@ -277,6 +286,7 @@ def process_ASR_data(data, game_id, period):
277
  session.commit()
278
  session.close()
279
 
 
280
  def convert_to_seconds(time_str):
281
  # Split the string into its components
282
  period, time = time_str.split(" - ")
@@ -321,17 +331,14 @@ def parse_labels_v2(data, session, home_team_id, away_team_id, game_id):
321
  game_time=game_time, # Already in seconds
322
  frame_stamp=position, # Make sure this is an integer or None
323
  team_id=team_id, # Integer ID of the team
324
- visibility=visibility, # Boolean
325
- label=label # String with information
326
  )
327
  session.add(annotation_entry)
328
 
329
  session.commit()
330
 
331
 
332
-
333
-
334
-
335
  def process_json_files(directory):
336
  session = Session()
337
  fill_player_events(session)
@@ -355,7 +362,7 @@ def process_json_files(directory):
355
  lb_cap = json.load(f)
356
  with open(os.path.join(root, "Labels-v2.json"), 'r') as f:
357
  lb_v2 = json.load(f)
358
- game_id, home_team_id, away_team_id = process_game_data(lb_cap,lb_v2, league, season)
359
 
360
  for file in asr_files:
361
  with open(os.path.join(root, file), 'r') as f:
@@ -368,19 +375,18 @@ def process_json_files(directory):
368
  elif '1_half-ASR' in file:
369
  period = 1
370
  # Parse and commit the data
371
- process_ASR_data(data=asr, game_id = game_id, period=period)
372
 
373
  elif '2_half-ASR' in file:
374
  period = 2
375
  # Parse and commit the data
376
- process_ASR_data(data=asr, game_id = game_id, period=period)
377
-
378
 
379
  session.commit()
380
  session.close()
381
 
382
- def fill_player_events(session):
383
 
 
384
  fact_id2label = {
385
  "1": "Yellow card",
386
  # Example: "time": "71' Ivanovic B. (Unsportsmanlike conduct)", "description": "Yellow Card"
@@ -397,9 +403,7 @@ def fill_player_events(session):
397
  session.commit()
398
 
399
 
400
-
401
  def fill_Augmented_Team(file_path):
402
-
403
  df = pd.read_csv(file_path)
404
  # the df should have two columns, team_name and augmented_name
405
 
@@ -417,6 +421,7 @@ def fill_Augmented_Team(file_path):
417
  session.commit()
418
  session.close()
419
 
 
420
  def fill_Augmented_League(file_path):
421
  # Read the csv file
422
  df = pd.read_csv(file_path)
@@ -432,14 +437,16 @@ def fill_Augmented_League(file_path):
432
  augmented_name = augmented_name.strip()
433
  league = session.query(League).filter_by(name=league_name).first()
434
  if league:
435
- augmented_league = get_or_create(session, Augmented_League, league_id=league.id, augmented_name=augmented_name)
 
436
  session.commit()
437
  session.close()
438
 
 
439
  if __name__ == "__main__":
440
  # Example directory path
441
- process_json_files('../data/Dataset/SoccerNet/')
442
- fill_Augmented_Team('../data/dataset/augmented_teams.csv')
443
- fill_Augmented_League('../data/dataset/augmented_leagues.csv')
444
  # Rename the event/annotation table to something more descriptive. Events are fucking everything else over
445
 
 
4
  import os
5
  import json
6
 
7
+ engine = create_engine('sqlite:///../data/games.db', echo=False)
8
  Base = declarative_base()
9
 
10
 
 
25
  season = Column(String)
26
  league_id = Column(Integer, ForeignKey('leagues.id'))
27
 
28
+
29
  class GameLineup(Base):
30
  __tablename__ = 'game_lineup'
31
  id = Column(Integer, primary_key=True)
 
47
  id = Column(Integer, primary_key=True)
48
  name = Column(String)
49
 
50
+
51
  class Player(Base):
52
  __tablename__ = 'players'
53
  hash = Column(String, primary_key=True)
 
77
  event_time_end = Column(Float)
78
  description = Column(Text)
79
 
80
+
81
  class League(Base):
82
  __tablename__ = 'leagues'
83
  id = Column(Integer, primary_key=True)
84
  name = Column(String)
85
 
86
+
87
  class Event(Base):
88
  __tablename__ = 'events'
89
  id = Column(Integer, primary_key=True)
 
96
  label = Column(String)
97
  visibility = Column(Boolean)
98
 
99
+
100
  class Augmented_Team(Base):
101
  __tablename__ = 'augmented_teams'
102
  id = Column(Integer, primary_key=True)
103
  team_id = Column(Integer, ForeignKey('teams.id'))
104
  augmented_name = Column(String)
105
 
106
+
107
  class Augmented_League(Base):
108
  __tablename__ = 'augmented_leagues'
109
  id = Column(Integer, primary_key=True)
110
  league_id = Column(Integer, ForeignKey('leagues.id'))
111
  augmented_name = Column(String)
112
 
113
+
114
  class Player_Event_Label(Base):
115
  __tablename__ = 'player_event_labels'
116
  id = Column(Integer, primary_key=True)
117
  label = Column(String)
118
 
119
+
120
  class Player_Event(Base):
121
  __tablename__ = 'player_events'
122
  id = Column(Integer, primary_key=True)
123
  game_id = Column(Integer, ForeignKey('games.id'))
124
  player_id = Column(Integer, ForeignKey('players.hash'))
125
+ time = Column(String) # Time in minutes of the game
126
  type = Column(Integer, ForeignKey('player_event_labels.id'))
127
+ linked_player = Column(Integer, ForeignKey(
128
+ 'players.hash')) # If the event is linked to another player, for example a substitution
 
 
 
 
129
 
130
 
131
  # Create Tables
 
134
  # Session setup
135
  Session = sessionmaker(bind=engine)
136
 
137
+
138
+ def extract_time_from_player_event(time: str) -> str:
139
  # Extract the time from the string
140
+ time = time.split("'")[0] # Need to keep it str because of overtime eg. (45+2)
141
  return time
142
 
143
+
144
  def get_or_create(session, model, **kwargs):
145
  instance = session.query(model).filter_by(**kwargs).first()
146
  if instance:
 
151
  session.commit()
152
  return instance
153
 
154
+
155
+ def process_game_data(data, data2, league, season):
156
  session = Session()
157
  # Caption = d and v2 = d2
158
  home_team = data["gameHomeTeam"]
 
176
  # Check if league exists
177
  league = get_or_create(session, League, name=league)
178
  if not game:
179
+ game = Game(timestamp=timestamp, score=score, goal_home=home_score, goal_away=away_score, round=round_,
180
+ home_team_id=home_team.id, away_team_id=away_team.id,
181
  venue=venue, date=date, attendance=attendance, season=season, league_id=league.id, referee=referee)
182
  session.add(game)
183
  session.commit()
 
195
  for player_data in team_lineup["players"]:
196
  player_hash = player_data["hash"]
197
  name = player_data["long_name"]
198
+ if " " not in name: # Since some players are missing their first name, do this to help with the search
199
  name = "NULL " + name
200
  number = player_data["shirt_number"]
201
  captain = player_data["captain"] == "(C)"
202
  starting = player_data["starting"]
203
  country = player_data["country"]
204
  position = player_data["lineup"]
205
+ facts = player_data.get("facts", None) # Facts might be empty
 
 
 
 
206
 
207
  player = get_or_create(session, Player, hash=player_hash, name=name, country=country)
208
  game_lineup = GameLineup(game_id=game.id, team_id=team_id, player_id=player.hash,
209
+ shirt_number=number, position=position, starting=starting, captain=captain,
210
+ coach=False, tactics=tactic)
211
  if facts:
212
  for fact in facts:
213
  type = fact["type"]
 
215
  event = get_or_create(session, Player_Event_Label, id=int(type))
216
  linked_player = fact.get("linked_player_hash", None)
217
 
218
+ player_event = Player_Event(game_id=game.id, player_id=player.hash, time=time, type=event.id,
219
+ linked_player=linked_player)
220
  session.add(player_event)
221
  session.add(game_lineup)
222
 
 
229
  coach_country = coach["country"]
230
  coach_player = get_or_create(session, Player, hash=coach_hash, name=coach_name, country=coach_country)
231
  game_lineup = GameLineup(game_id=game.id, team_id=team_id, player_id=coach_player.hash,
232
+ shirt_number=None, position=None, starting=None, captain=False, coach=True,
233
+ tactics=tactic)
234
  session.add(game_lineup)
235
 
236
  # Commit all changes at once
 
248
  label = "yellow card"
249
  elif label == "r-card":
250
  label = "red card"
251
+
252
  description = event["description"]
253
  important = event["important"] == "true"
254
  visible = event["visibility"]
 
264
 
265
  return game.id, home_team.id, away_team.id
266
 
267
+
268
  def process_player_data(data):
269
  pass
270
 
271
+
272
  def process_ASR_data(data, game_id, period):
273
  session = Session()
274
  seg = data["segments"]
 
286
  session.commit()
287
  session.close()
288
 
289
+
290
  def convert_to_seconds(time_str):
291
  # Split the string into its components
292
  period, time = time_str.split(" - ")
 
331
  game_time=game_time, # Already in seconds
332
  frame_stamp=position, # Make sure this is an integer or None
333
  team_id=team_id, # Integer ID of the team
334
+ visibility=visibility, # Boolean
335
+ label=label # String with information
336
  )
337
  session.add(annotation_entry)
338
 
339
  session.commit()
340
 
341
 
 
 
 
342
  def process_json_files(directory):
343
  session = Session()
344
  fill_player_events(session)
 
362
  lb_cap = json.load(f)
363
  with open(os.path.join(root, "Labels-v2.json"), 'r') as f:
364
  lb_v2 = json.load(f)
365
+ game_id, home_team_id, away_team_id = process_game_data(lb_cap, lb_v2, league, season)
366
 
367
  for file in asr_files:
368
  with open(os.path.join(root, file), 'r') as f:
 
375
  elif '1_half-ASR' in file:
376
  period = 1
377
  # Parse and commit the data
378
+ process_ASR_data(data=asr, game_id=game_id, period=period)
379
 
380
  elif '2_half-ASR' in file:
381
  period = 2
382
  # Parse and commit the data
383
+ process_ASR_data(data=asr, game_id=game_id, period=period)
 
384
 
385
  session.commit()
386
  session.close()
387
 
 
388
 
389
+ def fill_player_events(session):
390
  fact_id2label = {
391
  "1": "Yellow card",
392
  # Example: "time": "71' Ivanovic B. (Unsportsmanlike conduct)", "description": "Yellow Card"
 
403
  session.commit()
404
 
405
 
 
406
  def fill_Augmented_Team(file_path):
 
407
  df = pd.read_csv(file_path)
408
  # the df should have two columns, team_name and augmented_name
409
 
 
421
  session.commit()
422
  session.close()
423
 
424
+
425
  def fill_Augmented_League(file_path):
426
  # Read the csv file
427
  df = pd.read_csv(file_path)
 
437
  augmented_name = augmented_name.strip()
438
  league = session.query(League).filter_by(name=league_name).first()
439
  if league:
440
+ augmented_league = get_or_create(session, Augmented_League, league_id=league.id,
441
+ augmented_name=augmented_name)
442
  session.commit()
443
  session.close()
444
 
445
+
446
  if __name__ == "__main__":
447
  # Example directory path
448
+ process_json_files('../data/Dataset/SN-ASR_captions_and_actions/')
449
+ fill_Augmented_Team('../data/Dataset/augmented_teams.csv')
450
+ fill_Augmented_League('../data/Dataset/augmented_leagues.csv')
451
  # Rename the event/annotation table to something more descriptive. Events are fucking everything else over
452
 
src/extractor.py CHANGED
@@ -32,13 +32,16 @@ if os.getenv('LANGSMITH'):
32
  os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
33
  os.environ[
34
  'LANGCHAIN_API_KEY'] = os.getenv("LANGSMITH_API_KEY")
35
- os.environ['LANGCHAIN_PROJECT'] = 'master-theses'
36
- db = SQLDatabase.from_uri("sqlite:///data/games.db")
 
 
 
37
 
38
  # from langchain_anthropic import ChatAnthropic
39
  class Extractor():
40
  # llm = ChatOpenAI(model_name="gpt-4-0125-preview", temperature=0)
41
- #gpt-3.5-turbo
42
  def __init__(self, model="gpt-3.5-turbo-0125", schema_config=None, custom_extractor_prompt=None):
43
  # model = "gpt-4-0125-preview"
44
  if custom_extractor_prompt:
@@ -133,7 +136,6 @@ class Retriever():
133
  # Get matches and their scores, limited by the specified 'limit'
134
  matches = process.extract(prompt, self.response, limit=limit)
135
 
136
-
137
  filtered_matches = [match for match in matches if match[1] >= threshold]
138
 
139
  # If no matches meet the threshold, return the list of all matches' strings
@@ -142,7 +144,6 @@ class Retriever():
142
  # Fix for wrong properties being returned
143
  return [match[0] for match in matches if match[1] >= low_threshold]
144
 
145
-
146
  # If there's only one match meeting the threshold, return it as a string
147
  if len(filtered_matches) == 1:
148
  return filtered_matches[0][0] # Return the matched string directly
@@ -247,7 +248,7 @@ def recheck_property_value(properties, property_name, retrievers, input_func):
247
  print("No close matches found. Please try again or type 'quit' to stop.")
248
 
249
 
250
- def check_and_update_properties(properties_list, retrievers, method="fuzzy", input_func=input):
251
  """
252
  Checks and updates the properties in the properties list based on close matches found in the database.
253
  The function iterates through each property in each property dictionary within the list,
@@ -263,7 +264,7 @@ def check_and_update_properties(properties_list, retrievers, method="fuzzy", inp
263
  The function updates the properties_list in place based on user choices for updating property values
264
  with close matches found by the retrievers.
265
  """
266
-
267
  for index, properties in enumerate(properties_list):
268
  for property_name, retriever in retrievers.items(): # Iterate using items to get both key and value
269
  property_values = properties.get(property_name, [])
@@ -279,7 +280,11 @@ def check_and_update_properties(properties_list, retrievers, method="fuzzy", inp
279
  updated_property_values.append(augmented_value)
280
  continue
281
  # Since property_value is now expected to be a list, we handle each value individually
282
- top_matches = retriever.find_close_matches(value, method=method, n=3)
 
 
 
 
283
 
284
  # Check if the closest match is the same as the current value
285
  if top_matches and top_matches[0] == value:
@@ -296,30 +301,38 @@ def check_and_update_properties(properties_list, retrievers, method="fuzzy", inp
296
  updated_property_values.append(top_matches)
297
  properties[property_name] = updated_property_values
298
  continue
299
-
300
- print(f"\nCurrent {property_name}: {value}")
301
- for i, match in enumerate(top_matches, start=1):
302
- print(f"[{i}] {match}")
303
- print("[4] Enter new value")
304
-
305
- # hmm = input_func(f"Fix for Pycharm, press enter to continue")
306
-
307
- choice = input_func(f"Select the best match for {property_name} (1-4): ")
308
- if choice in ['1', '2', '3']:
309
- selected_match = top_matches[int(choice) - 1]
310
- updated_property_values.append(selected_match) # Update with the selected match
311
- print(f"Updated {property_name} to {selected_match}")
312
- elif choice == '4':
313
- # Allow re-entry of value for this specific item
314
- recheck_property_value(properties, property_name, value, retrievers, input_func)
315
- # Note: Implement recheck_property_value to handle individual value updates within the list
316
- else:
317
- print("Invalid selection. Property not updated.")
318
- updated_property_values.append(value) # Keep the original value
 
 
 
319
 
320
  # Update the entire list for the property after processing all values
321
  properties[property_name] = updated_property_values
322
 
 
 
 
 
 
323
 
324
  # Function to remove duplicates
325
  def remove_duplicates(dicts):
@@ -354,18 +367,21 @@ def fetch_pks(properties_list, retrievers):
354
  return all_pk_attributes
355
 
356
 
357
- def update_prompt(prompt, properties, pk, properties_original):
358
- # Replace the original prompt with the updated properties and pk
359
- prompt = prompt.replace("{{properties}}", str(properties))
360
- prompt = prompt.replace("{{pk}}", str(pk))
361
- return prompt
362
 
363
 
364
- def update_prompt_enhanced(prompt, properties, pk, properties_original):
365
  updated_info = ""
366
  for prop, pk_info, prop_orig in zip(properties, pk, properties_original):
367
  for key in prop.keys():
368
  # Extract original and updated values
 
 
 
369
  orig_values = prop_orig.get(key, [])
370
  updated_values = prop.get(key, [])
371
 
@@ -391,9 +407,13 @@ def update_prompt_enhanced(prompt, properties, pk, properties_original):
391
  if orig_value != updated_value and pk_value:
392
  update_statement = f"\n- {orig_value} (now referred to as {updated_value}) has a primary key: {pk_value}."
393
  elif orig_value != updated_value:
394
- update_statement = f"\n- {orig_value} (now referred to as {updated_value})."
395
  elif pk_value:
396
  update_statement = f"\n- {orig_value} has a primary key: {pk_value}."
 
 
 
 
397
 
398
  updated_info += update_statement
399
 
@@ -417,7 +437,7 @@ def prompt_cleaner(prompt, db, schema_config):
417
  check_and_update_properties(properties, retrievers)
418
 
419
  pk = fetch_pks(properties, retrievers)
420
- properties = update_prompt_enhanced(prompt, properties, pk, properties_original)
421
 
422
  return properties, pk
423
 
@@ -462,8 +482,9 @@ class PromptCleaner:
462
  self.schema_config = schema_config
463
  self.retrievers = setup_retrievers(self.db, self.schema_config)
464
  self.cust_extractor_prompt = custom_extractor_prompt
 
465
 
466
- def clean(self, prompt, return_pk=False, test=False, verbose = False):
467
  """
468
  Processes the given prompt to extract properties, remove duplicates, update the properties
469
  based on close matches within the database, and fetch primary keys for these properties.
@@ -493,24 +514,50 @@ class PromptCleaner:
493
  properties = extract_properties(prompt, self.schema_config)
494
  # Keep original properties for later use
495
  properties_original = deepcopy(properties)
 
496
  if test:
497
  return properties_original
498
  # Remove duplicates - Happens when there are more than one player or team in the prompt
499
  # properties = remove_duplicates(properties)
500
  pk = None
 
501
  if properties:
502
  check_and_update_properties(properties, self.retrievers)
503
  pk = fetch_pks(properties, self.retrievers)
504
- properties = update_prompt_enhanced(prompt, properties, pk, properties_original)
505
-
506
 
507
-
508
- if return_pk:
 
 
509
  return properties, pk
510
  elif verbose:
511
  return properties, properties_original
 
 
 
 
 
 
 
 
512
  else:
513
- return properties
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
 
515
 
516
  def load_json(file_path: str) -> dict:
@@ -518,24 +565,24 @@ def load_json(file_path: str) -> dict:
518
  return json.load(file)
519
 
520
 
521
- def create_extractor(schema: str = "src/conf/schema.json", db: SQLDatabase = "sqlite:///data/games.db", ):
522
  schema_config = load_json(schema)
523
  db = SQLDatabase.from_uri(db)
524
  pre_prompt = """Extract and save the relevant entities mentioned \
525
  in the following passage together with their properties.
526
-
527
  Only extract the properties mentioned in the 'information_extraction' function.
528
-
529
  The questions are soccer related. game_event are things like yellow cards, goals, assists, freekick ect.
530
  Generic properties like, "description", "home team", "away team", "game" ect should NOT be extracted.
531
-
532
  If a property is not present and is not required in the function parameters, do not include it in the output.
533
  If no properties are found, return an empty list.
534
-
535
  Here are some exampels:
536
  'How many goals did Henry score for Arsnl in the 2015 season?'
537
  person_name': ['Henry'], 'team_name': [Arsnl],'year_season': ['2015'],
538
-
539
  Passage:
540
  {input}
541
  """
@@ -544,15 +591,19 @@ def create_extractor(schema: str = "src/conf/schema.json", db: SQLDatabase = "sq
544
 
545
 
546
  if __name__ == "__main__":
547
-
548
-
549
  schema_config = load_json("src/conf/schema.json")
550
  # Add game and league to the schema_config
551
 
552
  # prompter = PromptCleaner(db, schema_config, custom_extractor_prompt=extract_prompt)
553
  prompter = create_extractor("src/conf/schema.json", "sqlite:///data/games.db")
554
- prompt= prompter.clean("Give me goals, shots on target, shots off target and corners from the game between ManU and Swansa")
555
-
556
 
557
  print(prompt)
 
 
 
 
 
 
558
 
 
32
  os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
33
  os.environ[
34
  'LANGCHAIN_API_KEY'] = os.getenv("LANGSMITH_API_KEY")
35
+ os.environ['LANGCHAIN_PROJECT'] = os.getenv('LANGSMITH_PROJECT')
36
+ db_uri = os.getenv('DATABASE_PATH')
37
+ db_uri = f"sqlite:///{db_uri}"
38
+ db = SQLDatabase.from_uri(db_uri)
39
+
40
 
41
  # from langchain_anthropic import ChatAnthropic
42
  class Extractor():
43
  # llm = ChatOpenAI(model_name="gpt-4-0125-preview", temperature=0)
44
+ # gpt-3.5-turbo
45
  def __init__(self, model="gpt-3.5-turbo-0125", schema_config=None, custom_extractor_prompt=None):
46
  # model = "gpt-4-0125-preview"
47
  if custom_extractor_prompt:
 
136
  # Get matches and their scores, limited by the specified 'limit'
137
  matches = process.extract(prompt, self.response, limit=limit)
138
 
 
139
  filtered_matches = [match for match in matches if match[1] >= threshold]
140
 
141
  # If no matches meet the threshold, return the list of all matches' strings
 
144
  # Fix for wrong properties being returned
145
  return [match[0] for match in matches if match[1] >= low_threshold]
146
 
 
147
  # If there's only one match meeting the threshold, return it as a string
148
  if len(filtered_matches) == 1:
149
  return filtered_matches[0][0] # Return the matched string directly
 
248
  print("No close matches found. Please try again or type 'quit' to stop.")
249
 
250
 
251
+ def check_and_update_properties(properties_list, retrievers, method="fuzzy", input_func="input"):
252
  """
253
  Checks and updates the properties in the properties list based on close matches found in the database.
254
  The function iterates through each property in each property dictionary within the list,
 
264
  The function updates the properties_list in place based on user choices for updating property values
265
  with close matches found by the retrievers.
266
  """
267
+ return_list = []
268
  for index, properties in enumerate(properties_list):
269
  for property_name, retriever in retrievers.items(): # Iterate using items to get both key and value
270
  property_values = properties.get(property_name, [])
 
280
  updated_property_values.append(augmented_value)
281
  continue
282
  # Since property_value is now expected to be a list, we handle each value individually
283
+ if input_func == "chainlit":
284
+ n = 5
285
+ else:
286
+ n = 3
287
+ top_matches = retriever.find_close_matches(value, method=method, n=n)
288
 
289
  # Check if the closest match is the same as the current value
290
  if top_matches and top_matches[0] == value:
 
301
  updated_property_values.append(top_matches)
302
  properties[property_name] = updated_property_values
303
  continue
304
+ if input_func == "input":
305
+ print(f"\nCurrent {property_name}: {value}")
306
+ for i, match in enumerate(top_matches, start=1):
307
+ print(f"[{i}] {match}")
308
+ print("[4] Enter new value")
309
+
310
+ # hmm = input(f"Fix for Pycharm, press enter to continue")
311
+
312
+ choice = input(f"Select the best match for {property_name} (1-4): ")
313
+ if choice in ['1', '2', '3']:
314
+ selected_match = top_matches[int(choice) - 1]
315
+ updated_property_values.append(selected_match) # Update with the selected match
316
+ print(f"Updated {property_name} to {selected_match}")
317
+ elif choice == '4':
318
+ # Allow re-entry of value for this specific item
319
+ recheck_property_value(properties, property_name, value, retrievers, input_func)
320
+ # Note: Implement recheck_property_value to handle individual value updates within the list
321
+ else:
322
+ print("Invalid selection. Property not updated.")
323
+ updated_property_values.append(value) # Keep the original value
324
+ elif input_func == "chainlit": # If we use UI, just return the list of top matches, and then let the user select
325
+ options = {property_name: value, "top_matches": top_matches}
326
+ return_list.append(options)
327
 
328
  # Update the entire list for the property after processing all values
329
  properties[property_name] = updated_property_values
330
 
331
+ if input_func == "chainlit":
332
+ return properties, return_list
333
+ else:
334
+ return properties
335
+
336
 
337
  # Function to remove duplicates
338
  def remove_duplicates(dicts):
 
367
  return all_pk_attributes
368
 
369
 
370
+ # def update_prompt(prompt, properties, pk, properties_original):
371
+ # # Replace the original prompt with the updated properties and pk
372
+ # prompt = prompt.replace("{{properties}}", str(properties))
373
+ # prompt = prompt.replace("{{pk}}", str(pk))
374
+ # return prompt
375
 
376
 
377
+ def update_prompt(prompt, properties, pk, properties_original, retrievers):
378
  updated_info = ""
379
  for prop, pk_info, prop_orig in zip(properties, pk, properties_original):
380
  for key in prop.keys():
381
  # Extract original and updated values
382
+ if key in retrievers:
383
+ # Fetch the primary key using the retriever for the current property
384
+ table = retrievers[key].table
385
  orig_values = prop_orig.get(key, [])
386
  updated_values = prop.get(key, [])
387
 
 
407
  if orig_value != updated_value and pk_value:
408
  update_statement = f"\n- {orig_value} (now referred to as {updated_value}) has a primary key: {pk_value}."
409
  elif orig_value != updated_value:
410
+ update_statement = f"\n- {orig_value} (now referred to as {updated_value}."
411
  elif pk_value:
412
  update_statement = f"\n- {orig_value} has a primary key: {pk_value}."
413
+ elif orig_value == updated_value and pk_value:
414
+ update_statement = f"\n- {orig_value} has a primary key: {pk_value}."
415
+ elif orig_value == updated_value:
416
+ update_statement = f"\n- {orig_value}."
417
 
418
  updated_info += update_statement
419
 
 
437
  check_and_update_properties(properties, retrievers)
438
 
439
  pk = fetch_pks(properties, retrievers)
440
+ properties = update_prompt(prompt, properties, pk, properties_original)
441
 
442
  return properties, pk
443
 
 
482
  self.schema_config = schema_config
483
  self.retrievers = setup_retrievers(self.db, self.schema_config)
484
  self.cust_extractor_prompt = custom_extractor_prompt
485
+ self.properties_original = None
486
 
487
+ def clean(self, prompt, return_pk=False, test=False, verbose=False):
488
  """
489
  Processes the given prompt to extract properties, remove duplicates, update the properties
490
  based on close matches within the database, and fetch primary keys for these properties.
 
514
  properties = extract_properties(prompt, self.schema_config)
515
  # Keep original properties for later use
516
  properties_original = deepcopy(properties)
517
+
518
  if test:
519
  return properties_original
520
  # Remove duplicates - Happens when there are more than one player or team in the prompt
521
  # properties = remove_duplicates(properties)
522
  pk = None
523
+ # VALIDATE PROPERTIES
524
  if properties:
525
  check_and_update_properties(properties, self.retrievers)
526
  pk = fetch_pks(properties, self.retrievers)
527
+ properties = update_prompt(prompt=prompt, properties=properties, pk=pk, properties_original=properties_original,
528
+ retrievers=self.retrievers)
529
 
530
+ # Prepare additional data if requested
531
+ if return_pk and verbose:
532
+ return (properties, pk), (properties, properties_original)
533
+ elif return_pk:
534
  return properties, pk
535
  elif verbose:
536
  return properties, properties_original
537
+
538
+ return properties
539
+
540
+ def extract_chainlit(self, prompt):
541
+ if self.cust_extractor_prompt:
542
+
543
+ properties = extract_properties(prompt, self.schema_config, self.cust_extractor_prompt)
544
+
545
  else:
546
+ properties = extract_properties(prompt, self.schema_config)
547
+ self.properties_original = deepcopy(properties)
548
+ return properties
549
+
550
+ def validate_chainlit(self, properties):
551
+ properties, need_val = check_and_update_properties(properties, self.retrievers, input_func="chainlit")
552
+ return properties, need_val
553
+
554
+ def build_prompt_chainlit(self, properties, prompt):
555
+ pk = None
556
+ # self.properties_original= deepcopy(properties)
557
+ if properties:
558
+ pk = fetch_pks(properties, self.retrievers)
559
+ prompt_new = update_prompt(prompt, properties, pk, self.properties_original, self.retrievers)
560
+ return prompt_new
561
 
562
 
563
  def load_json(file_path: str) -> dict:
 
565
  return json.load(file)
566
 
567
 
568
+ def create_extractor(schema: str = "src/conf/schema.json", db: SQLDatabase = db_uri):
569
  schema_config = load_json(schema)
570
  db = SQLDatabase.from_uri(db)
571
  pre_prompt = """Extract and save the relevant entities mentioned \
572
  in the following passage together with their properties.
573
+
574
  Only extract the properties mentioned in the 'information_extraction' function.
575
+
576
  The questions are soccer related. game_event are things like yellow cards, goals, assists, freekick ect.
577
  Generic properties like, "description", "home team", "away team", "game" ect should NOT be extracted.
578
+
579
  If a property is not present and is not required in the function parameters, do not include it in the output.
580
  If no properties are found, return an empty list.
581
+
582
  Here are some exampels:
583
  'How many goals did Henry score for Arsnl in the 2015 season?'
584
  person_name': ['Henry'], 'team_name': [Arsnl],'year_season': ['2015'],
585
+
586
  Passage:
587
  {input}
588
  """
 
591
 
592
 
593
  if __name__ == "__main__":
 
 
594
  schema_config = load_json("src/conf/schema.json")
595
  # Add game and league to the schema_config
596
 
597
  # prompter = PromptCleaner(db, schema_config, custom_extractor_prompt=extract_prompt)
598
  prompter = create_extractor("src/conf/schema.json", "sqlite:///data/games.db")
599
+ prompt = prompter.clean(
600
+ "Give me goals, shots on target, shots off target and corners from the game between ManU and Swansa and Manchester City")
601
 
602
  print(prompt)
603
+ # ex = create_extractor()
604
+ #
605
+ # val_list = [{'person_name': ['Cristiano Ronaldo'], 'team_name': ['Manchester City']}]
606
+ # user_prompt = "Did ronaldo play for city?"
607
+ # p = ex.build_prompt_chainlit(val_list, user_prompt)
608
+ # print(p)
609
 
src/sql_chain.py CHANGED
@@ -29,7 +29,7 @@ if os.getenv('LANGSMITH'):
29
  os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
30
  os.environ[
31
  'LANGCHAIN_API_KEY'] = os.getenv("LANGSMITH_API_KEY")
32
- os.environ['LANGCHAIN_PROJECT'] = 'master-theses'
33
 
34
 
35
  def load_json(file_path: str) -> dict:
@@ -38,7 +38,8 @@ def load_json(file_path: str) -> dict:
38
 
39
 
40
  class SqlChain:
41
- def __init__(self, few_shot_prompts: str, llm_model="gpt-3.5-turbo", db_uri="sqlite:///data/games.db", few_shot_k=2, verbose=True):
 
42
  self.llm = ChatOpenAI(model=llm_model, temperature=0)
43
  self.db = SQLDatabase.from_uri(db_uri)
44
  self.few_shot_k = few_shot_k
@@ -50,13 +51,12 @@ class SqlChain:
50
  db=self.db,
51
  prompt=self.full_prompt,
52
  max_iterations=10,
53
- verbose=verbose,
54
  agent_type="openai-tools",
55
  # Default to 10 examples - Can be overwritten with the prompt
56
  top_k=30,
57
  )
58
 
59
-
60
  def _set_up_few_shot_prompts(self, few_shot_prompts: dict) -> None:
61
  few_shots = SemanticSimilarityExampleSelector.from_examples(
62
  few_shot_prompts,
@@ -68,6 +68,7 @@ class SqlChain:
68
  return few_shots
69
 
70
  def few_prompt_construct(self, query: str, top_k=5, dialect="SQLite") -> str:
 
71
  system_prefix = """You are an agent designed to interact with a SQL database.
72
  Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
73
  ALWAYS query the database before returning an answer.
@@ -77,7 +78,7 @@ class SqlChain:
77
  You have access to tools for interacting with the database.
78
  Only use the given tools. Only use the information returned by the tools to construct your final answer.
79
  You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
80
-
81
  DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
82
 
83
  If the question does not seem related to the database, just return 'I don't know' as the answer.
@@ -86,10 +87,17 @@ class SqlChain:
86
  Here are some examples of user inputs and their corresponding SQL queries. They are tested and works.
87
  Use them as a guide when creating your own queries:"""
88
 
 
 
 
 
 
 
89
  SUFFIX = """Begin!
90
 
91
  Question: {input}
92
- Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.
 
93
  I will not stop until I query the database and return the answer.
94
  {agent_scratchpad}"""
95
 
@@ -117,6 +125,7 @@ class SqlChain:
117
  "agent_scratchpad": [],
118
  }
119
  )
 
120
  def prompt_no_few_shot(self, query: str, dialect="SQLite") -> str:
121
  system_prefix = """You are an agent designed to interact with a SQL database.
122
  Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
@@ -134,10 +143,22 @@ class SqlChain:
134
 
135
  return f"{system_prefix}\n{query}"
136
 
137
-
138
-
139
-
140
- def ask(self, query: str, few_prompt:bool=True) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
141
  if few_prompt:
142
  self.few_prompt_construct(query)
143
  return self.agent.invoke({"input": self.full_prompt}), self.full_prompt
@@ -146,15 +167,19 @@ class SqlChain:
146
  return self.agent.invoke(self.prompt_no_few_shot(query)), self.prompt_no_few_shot(query)
147
 
148
 
149
-
150
-
151
  def create_agent(few_shot_prompts: str = "src/conf/sqls.json", llm_model="gpt-3.5-turbo-0125",
152
- db_uri="sqlite:///data/games.db", few_shot_k=2, verbose=True):
153
  """ Create an agent with the given few_shot_prompts, llm_model and db_uri
154
  Call it with agent.ask(prompt)"""
155
- return SqlChain(few_shot_prompts, llm_model, db_uri, few_shot_k, verbose)
 
 
 
 
 
 
156
 
157
 
158
  if __name__ == "__main__":
159
  chain = SqlChain("src/conf/sqls.json")
160
- chain.ask("Is Manchester United in the database?", False)
 
29
  os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
30
  os.environ[
31
  'LANGCHAIN_API_KEY'] = os.getenv("LANGSMITH_API_KEY")
32
+ os.environ['LANGCHAIN_PROJECT'] = os.getenv('LANGSMITH_PROJECT')
33
 
34
 
35
  def load_json(file_path: str) -> dict:
 
38
 
39
 
40
  class SqlChain:
41
+ def __init__(self, few_shot_prompts: str, llm_model="gpt-3.5-turbo", db_uri="sqlite:///data/games.db",
42
+ few_shot_k=2):
43
  self.llm = ChatOpenAI(model=llm_model, temperature=0)
44
  self.db = SQLDatabase.from_uri(db_uri)
45
  self.few_shot_k = few_shot_k
 
51
  db=self.db,
52
  prompt=self.full_prompt,
53
  max_iterations=10,
54
+ verbose=True,
55
  agent_type="openai-tools",
56
  # Default to 10 examples - Can be overwritten with the prompt
57
  top_k=30,
58
  )
59
 
 
60
  def _set_up_few_shot_prompts(self, few_shot_prompts: dict) -> None:
61
  few_shots = SemanticSimilarityExampleSelector.from_examples(
62
  few_shot_prompts,
 
68
  return few_shots
69
 
70
  def few_prompt_construct(self, query: str, top_k=5, dialect="SQLite") -> str:
71
+
72
  system_prefix = """You are an agent designed to interact with a SQL database.
73
  Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
74
  ALWAYS query the database before returning an answer.
 
78
  You have access to tools for interacting with the database.
79
  Only use the given tools. Only use the information returned by the tools to construct your final answer.
80
  You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
81
+
82
  DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
83
 
84
  If the question does not seem related to the database, just return 'I don't know' as the answer.
 
87
  Here are some examples of user inputs and their corresponding SQL queries. They are tested and works.
88
  Use them as a guide when creating your own queries:"""
89
 
90
+ # SUFFIX = """Begin!
91
+ #
92
+ # Question: {input}
93
+ # Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.
94
+ # I will not stop until I query the database and return the answer.
95
+ # {agent_scratchpad}"""
96
  SUFFIX = """Begin!
97
 
98
  Question: {input}
99
+ Thought: I should look at the examples provided and see if I can use them to identify tables and how to build the query.
100
+ Then I should query the schema of the most relevant tables.
101
  I will not stop until I query the database and return the answer.
102
  {agent_scratchpad}"""
103
 
 
125
  "agent_scratchpad": [],
126
  }
127
  )
128
+
129
  def prompt_no_few_shot(self, query: str, dialect="SQLite") -> str:
130
  system_prefix = """You are an agent designed to interact with a SQL database.
131
  Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
 
143
 
144
  return f"{system_prefix}\n{query}"
145
 
146
+ def ask(self, query: str, few_prompt: bool = True, rag_test=False) -> str:
147
+ if rag_test:
148
+ self.few_prompt_construct(query)
149
+ # Alter the self.full_prompt to only include whats added by the RAG system
150
+ # Get content in self.full_prompt[messages][0][content]
151
+ prompt = self.full_prompt.messages
152
+ prompt = prompt[0].content
153
+
154
+ prompt = prompt.split("Use them as a guide when creating your own queries:\n\n")[1]
155
+ # Then remove everything after \n\nBegin!\n\n
156
+ prompt = prompt.split("\n\nBegin!\n\n")[0]
157
+ # Lets split it to a list. One element for each "User input: {input}\nSQL query: {query}"
158
+ prompt = prompt.split("User input: ")
159
+ # Then remove the first element
160
+ prompt = prompt[1:]
161
+ return prompt
162
  if few_prompt:
163
  self.few_prompt_construct(query)
164
  return self.agent.invoke({"input": self.full_prompt}), self.full_prompt
 
167
  return self.agent.invoke(self.prompt_no_few_shot(query)), self.prompt_no_few_shot(query)
168
 
169
 
 
 
170
  def create_agent(few_shot_prompts: str = "src/conf/sqls.json", llm_model="gpt-3.5-turbo-0125",
171
+ db_uri="config", few_shot_k=2):
172
  """ Create an agent with the given few_shot_prompts, llm_model and db_uri
173
  Call it with agent.ask(prompt)"""
174
+ if db_uri == "config":
175
+ db_uri = os.getenv('DATABASE_PATH')
176
+ db_uri = f"sqlite:///{db_uri}"
177
+ # print(db_uri)
178
+ # print("sqlite:///data/games.db")
179
+ # exit(0)
180
+ return SqlChain(few_shot_prompts, llm_model, db_uri, few_shot_k)
181
 
182
 
183
  if __name__ == "__main__":
184
  chain = SqlChain("src/conf/sqls.json")
185
+ chain.ask("Is Manchester United in the database?", rag_test=True)