from typing import Optional from langchain.chains import create_extraction_chain_pydantic from langchain_core.prompts import ChatPromptTemplate from langchain.chains import create_extraction_chain from copy import deepcopy from langchain_openai import ChatOpenAI from langchain_community.utilities import SQLDatabase import os import difflib import ast import json import re from thefuzz import process # Set up logging import logging from dotenv import load_dotenv load_dotenv(".env") logging.basicConfig(level=logging.INFO) # Save the log to a file handler = logging.FileHandler('extractor.log') logger = logging.getLogger(__name__) os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY') # os.environ["ANTHROPIC_API_KEY"] = os.getenv('ANTHROPIC_API_KEY') if os.getenv('LANGSMITH'): os.environ['LANGCHAIN_TRACING_V2'] = 'false' os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com' os.environ[ 'LANGCHAIN_API_KEY'] = os.getenv("LANGSMITH_API_KEY") os.environ['LANGCHAIN_PROJECT'] = os.getenv('LANGSMITH_PROJECT') db_uri = os.getenv('DATABASE_PATH') db_uri = f"sqlite:///{db_uri}" db = SQLDatabase.from_uri(db_uri) few_shot_n = os.getenv('FEW_SHOT') few_shot_n = int(few_shot_n) # from langchain_anthropic import ChatAnthropic class Extractor(): # llm = ChatOpenAI(model_name="gpt-4-0125-preview", temperature=0) # gpt-3.5-turbo def __init__(self, model="gpt-3.5-turbo-0125", schema_config=None, custom_extractor_prompt=None): # model = "gpt-4-0125-preview" if custom_extractor_prompt: cust_promt = ChatPromptTemplate.from_template(custom_extractor_prompt) self.llm = ChatOpenAI(model=model, temperature=0) # self.llm = ChatAnthropic(model="claude-3-opus-20240229", temperature=0) self.schema = schema_config or {} self.chain = create_extraction_chain(self.schema, self.llm, prompt=cust_promt) def extract(self, query): return self.chain.invoke(query) class Retriever(): def __init__(self, db, config): self.db = db self.config = config self.table = config.get('db_table') self.column = config.get('db_column') self.pk_column = config.get('pk_column') self.numeric = config.get('numeric', False) self.response = [] self.query = f"SELECT {self.column} FROM {self.table}" self.augmented_table = config.get('augmented_table', None) self.augmented_column = config.get('augmented_column', None) self.augmented_fk = config.get('augmented_fk', None) def query_as_list(self): # Execute the query response = self.db.run(self.query) response = [el for sub in ast.literal_eval(response) for el in sub if el] if not self.numeric: response = [re.sub(r"\b\d+\b", "", string).strip() for string in response] self.response = list(set(response)) # print(self.response) return self.response def get_augmented_items(self, prompt): if self.augmented_table is None: return None else: # Construct the query to search for the prompt in the augmented table query = f"SELECT {self.augmented_fk} FROM {self.augmented_table} WHERE LOWER({self.augmented_column}) = LOWER('{prompt}')" # Execute the query fk_response = self.db.run(query) if fk_response: # Extract the FK value fk_response = ast.literal_eval(fk_response) fk_value = fk_response[0][0] query = f"SELECT {self.column} FROM {self.table} WHERE {self.pk_column} = {fk_value}" # Execute the query matching_response = self.db.run(query) # Extract the matching response matching_response = ast.literal_eval(matching_response) matching_response = matching_response[0][0] return matching_response else: return None def find_close_matches(self, target_string, n=3, method="difflib", threshold=70): """ Find and return the top n close matches to target_string in the database query results. Args: - target_string (str): The string to match against the database results. - n (int): Number of top matches to return. Returns: - list of tuples: Each tuple contains a match and its score. """ # Ensure we have the response list populated if not self.response: self.query_as_list() # Find top n close matches if method == "fuzzy": # Use the fuzzy_string method to get matches and their scores # If the threshold is met, return the best match; otherwise, return all matches meeting the threshold top_matches = self.fuzzy_string(target_string, limit=n, threshold=threshold) else: # Use difflib's get_close_matches to get the top n matches top_matches = difflib.get_close_matches(target_string, self.response, n=n, cutoff=0.2) return top_matches def fuzzy_string(self, prompt, limit, threshold=80, low_threshold=30): # Get matches and their scores, limited by the specified 'limit' matches = process.extract(prompt, self.response, limit=limit) filtered_matches = [match for match in matches if match[1] >= threshold] # If no matches meet the threshold, return the list of all matches' strings if not filtered_matches: # Return matches above the low_threshold # Fix for wrong properties being returned return [match[0] for match in matches if match[1] >= low_threshold] # If there's only one match meeting the threshold, return it as a string if len(filtered_matches) == 1: return filtered_matches[0][0] # Return the matched string directly # If there's more than one match meeting the threshold or ties, return the list of matches' strings highest_score = filtered_matches[0][1] ties = [match for match in filtered_matches if match[1] == highest_score] # Return the strings of tied matches directly, ignoring the scores m = [match[0] for match in ties] if len(m) == 1: return m[0] return [match[0] for match in ties] def fetch_pk(self, property_name, property_value): # Some properties do not have a primary key # Return the property value if no primary key is specified pk_list = [] # Check if the property_value is a list; if not, make it a list for uniform processing if not isinstance(property_value, list): property_value = [property_value] # Some properties do not have a primary key # Return None for each property_value if no primary key is specified if self.pk_column is None: return [None for _ in property_value] for value in property_value: query = f"SELECT {self.pk_column} FROM {self.table} WHERE {self.column} = '{value}' LIMIT 1" response = self.db.run(query) # Append the response (PK or None) to the pk_list pk_list.append(response) return pk_list def setup_retrievers(db, schema_config): # retrievers = {} # for prop, config in schema_config["properties"].items(): # retrievers[prop] = Retriever(db=db, config=config) # return retrievers retrievers = {} # Iterate over each property in the schema_config's properties for prop, config in schema_config["properties"].items(): # Access the 'items' dictionary for the configuration of the array's elements item_config = config['items'] # Create a Retriever instance using the item_config retrievers[prop] = Retriever(db=db, config=item_config) return retrievers def extract_properties(prompt, schema_config, custom_extractor_prompt=None): """Extract properties from the prompt.""" # modify schema_conf to only include the required properties schema_stripped = {'properties': {}} for key, value in schema_config['properties'].items(): schema_stripped['properties'][key] = { 'type': value['type'], 'items': {'type': value['items']['type']} } extractor = Extractor(schema_config=schema_stripped, custom_extractor_prompt=custom_extractor_prompt) extraction_result = extractor.extract(prompt) # print("Extraction Result:", extraction_result) if 'text' in extraction_result and extraction_result['text']: properties = extraction_result['text'] return properties else: print("No properties extracted.") return None def recheck_property_value(properties, property_name, value, retrievers): while True: print(property_name) new_value = input(f"Enter new value for {property_name} - {value} or type 'quit' to stop: ") if new_value.lower() == 'quit': break # Exit the loop and do not update the property new_top_matches = retrievers.find_close_matches(new_value, n=few_shot_n) if new_top_matches: # Display new top matches and ask for confirmation or re-entry print("\nNew close matches found:") for i, match in enumerate(new_top_matches, start=1): print(f"[{i}] {match}") print(f"[{i+1}] Re-enter value") print(f"[{i+2}] Quit without updating") selection = input(f"Select the best match (1-{i}), choose {i+1} to re-enter value, or {i+2} to quit: ") if selection in [str(i) for i in range(1, i + 1)]: selected_match = new_top_matches[int(selection) - 1] properties[property_name] = selected_match # Update the dictionary directly print(f"Updated {property_name} to {selected_match}") break # Successfully updated, exit the loop elif selection == f'{i+2}': break # Quit without updating # Loop will continue if user selects 4 or inputs invalid selection else: print("No close matches found. Please try again or type 'quit' to stop.") def check_and_update_properties(properties_list, retrievers, method="fuzzy", input_func="input"): """ Checks and updates the properties in the properties list based on close matches found in the database. The function iterates through each property in each property dictionary within the list, finds close matches for it in the database using the retrievers, and updates the property value based on user selection. Args: properties_list (list of dict): A list of dictionaries, where each dictionary contains properties to check and potentially update based on database matches. retrievers (dict): A dictionary of Retriever objects keyed by property name, used to find close matches in the database. input_func (function, optional): A function to capture user input. Defaults to the built-in input function. The function updates the properties_list in place based on user choices for updating property values with close matches found by the retrievers. """ return_list = [] for index, properties in enumerate(properties_list): for property_name, retriever in retrievers.items(): # Iterate using items to get both key and value property_values = properties.get(property_name, []) if not property_values: # Skip if the property is not present or is an empty list continue updated_property_values = [] # To store updated list of values for value in property_values: if retriever.augmented_table: augmented_value = retriever.get_augmented_items(value) if augmented_value: updated_property_values.append(augmented_value) continue # Since property_value is now expected to be a list, we handle each value individually n = few_shot_n # if input_func == "chainlit": # n = 5 # else: # n = 3 top_matches = retriever.find_close_matches(value, method=method, n=n) # Check if the closest match is the same as the current value if top_matches and top_matches[0] == value: updated_property_values.append(value) continue if not top_matches: updated_property_values.append(value) # Keep the original value if no matches found continue if type(top_matches) == str and method == "fuzzy": # If the top_matches is a string, it means that the threshold was met and only one item was returned # In this case, we can directly update the property with the top match updated_property_values.append(top_matches) properties[property_name] = updated_property_values continue if input_func == "input": print(f"\nCurrent {property_name}: {value}") for i, match in enumerate(top_matches, start=1): print(f"[{i}] {match}") print(f"[{i+1}] Enter new value") # hmm = input(f"Fix for Pycharm, press enter to continue") choice = input(f"Select the best match for {property_name} (1-{i+1}): ") # if choice == in range(1, i) if choice in [str(i) for i in range(1, i+1)]: selected_match = top_matches[int(choice) - 1] updated_property_values.append(selected_match) # Update with the selected match print(f"Updated {property_name} to {selected_match}") elif choice == f'{i+1}': # Allow re-entry of value for this specific item recheck_property_value(properties, property_name, value, retriever) # Note: Implement recheck_property_value to handle individual value updates within the list else: print("Invalid selection. Property not updated.") updated_property_values.append(value) # Keep the original value elif input_func == "chainlit": # If we use UI, just return the list of top matches, and then let the user select options = {property_name: value, "top_matches": top_matches} return_list.append(options) # Update the entire list for the property after processing all values properties[property_name] = updated_property_values if input_func == "chainlit": return properties, return_list else: return properties # Function to remove duplicates def remove_duplicates(dicts): seen = {} # Dictionary to keep track of seen values for each key for d in dicts: for key in list(d.keys()): # Use list to avoid RuntimeError for changing dict size during iteration value = d[key] if key in seen and value == seen[key]: del d[key] # Remove key-value pair if duplicate is found else: seen[key] = value # Update seen values for this key return dicts def fetch_pks(properties_list, retrievers): all_pk_attributes = [] # Initialize a list to store dictionaries of _pk attributes for each item in properties_list # Iterate through each properties dictionary in the list for properties in properties_list: pk_attributes = {} # Initialize a dictionary for the current set of properties for property_name, property_value in properties.items(): if property_name in retrievers: # Fetch the primary key using the retriever for the current property pk = retrievers[property_name].fetch_pk(property_name, property_value) # Store it in the dictionary with a modified key name pk_attributes[f"{property_name}_pk"] = pk # Add the dictionary of _pk attributes for the current set of properties to the list all_pk_attributes.append(pk_attributes) # Return a list of dictionaries, where each dictionary contains _pk attributes for a set of properties return all_pk_attributes # def update_prompt(prompt, properties, pk, properties_original): # # Replace the original prompt with the updated properties and pk # prompt = prompt.replace("{{properties}}", str(properties)) # prompt = prompt.replace("{{pk}}", str(pk)) # return prompt def update_prompt(prompt, properties, pk, properties_original, retrievers): updated_info = "" for prop, pk_info, prop_orig in zip(properties, pk, properties_original): for key in prop.keys(): # Extract original and updated values if key in retrievers: # Fetch the primary key using the retriever for the current property table = retrievers[key].table orig_values = prop_orig.get(key, []) updated_values = prop.get(key, []) # Ensure both original and updated values are lists for uniform processing if not isinstance(orig_values, list): orig_values = [orig_values] if not isinstance(updated_values, list): updated_values = [updated_values] # Extract primary key detail for this key, handling various pk formats carefully pk_key = f"{key}_pk" # Construct pk key name based on the property key pk_details = pk_info.get(pk_key, []) if not isinstance(pk_details, list): pk_details = [pk_details] for orig_value, updated_value, pk_detail in zip(orig_values, updated_values, pk_details): pk_value = None if isinstance(pk_detail, str): pk_value = pk_detail.strip("[]()").split(",")[0].replace("'", "").replace('"', '') update_statement = "" # Skip updating if there's no change in value to avoid redundant info if orig_value != updated_value and pk_value: update_statement = f"\n- {orig_value} (now referred to as {updated_value}) has a primary key: {pk_value}." elif orig_value != updated_value: update_statement = f"\n- {orig_value} (now referred to as {updated_value}.)" elif pk_value: update_statement = f"\n- {orig_value} has a primary key: {pk_value}." elif orig_value == updated_value and pk_value: update_statement = f"\n- {orig_value} has a primary key: {pk_value}." elif orig_value == updated_value: update_statement = f"\n- {orig_value}." updated_info += update_statement if updated_info: prompt += "\nUpdated Information:" + updated_info return prompt def prompt_cleaner(prompt, db, schema_config): """Main function to clean the prompt.""" retrievers = setup_retrievers(db, schema_config) properties = extract_properties(prompt, schema_config) # Keep original properties for later use properties_original = deepcopy(properties) # Remove duplicates - Happens when there are more than one player or team in the prompt properties = remove_duplicates(properties) if properties: check_and_update_properties(properties, retrievers) pk = fetch_pks(properties, retrievers) properties = update_prompt(prompt, properties, pk, properties_original) return properties, pk class PromptCleaner: """ A class designed to clean and process prompts by extracting properties, removing duplicates, and updating these properties based on a predefined schema configuration and database interactions. Attributes: db: A database connection object used to execute queries and fetch data. schema_config: A dictionary defining the schema configuration for the extraction process. schema_config = { "properties": { # Property name "person_name": {"type": "string", "db_table": "players", "db_column": "name", "pk_column": "hash", # if mostly numeric, such as 2015-2016 set true "numeric": False}, "team_name": {"type": "string", "db_table": "teams", "db_column": "name", "pk_column": "id", "numeric": False}, # Add more as needed }, # Parameter to extractor, if person_name is required, add it here and the extractor will # return an error if it is not found "required": [], } Methods: clean(prompt): Cleans the given prompt by extracting and updating properties based on the database. Returns a tuple containing the updated properties and their primary keys. """ def __init__(self, db=db, schema_config=None, custom_extractor_prompt=None): """ Initializes the PromptCleaner with a database connection and a schema configuration. Args: db: The database connection object to be used for querying. (if none, it will use the default db) schema_config: A dictionary defining properties and their database mappings for extraction and updating. """ self.db = db self.schema_config = schema_config self.retrievers = setup_retrievers(self.db, self.schema_config) self.cust_extractor_prompt = custom_extractor_prompt self.properties_original = None def clean(self, prompt, return_pk=False, test=False, verbose=False): """ Processes the given prompt to extract properties, remove duplicates, update the properties based on close matches within the database, and fetch primary keys for these properties. The method first extracts properties from the prompt using the schema configuration, then checks these properties against the database to find and update close matches. It also fetches primary keys for the updated properties where applicable. Args: prompt (str): The prompt text to be cleaned and processed. return_pk (bool): A flag to indicate whether to return primary keys along with the properties. test (bool): A flag to indicate whether to return the original properties for testing purposes. verbose (bool): A flag to indicate whether to return the original properties for debugging. Returns: tuple: A tuple containing two elements: - The first element is the original prompt, with updated information that excist in the db. - The second element is a list of dictionaries, each containing primary keys for the properties, where applicable. """ if self.cust_extractor_prompt: properties = extract_properties(prompt, self.schema_config, self.cust_extractor_prompt) else: properties = extract_properties(prompt, self.schema_config) # Keep original properties for later use properties_original = deepcopy(properties) if test: return properties_original # Remove duplicates - Happens when there are more than one player or team in the prompt # properties = remove_duplicates(properties) pk = None # VALIDATE PROPERTIES if properties: check_and_update_properties(properties, self.retrievers) pk = fetch_pks(properties, self.retrievers) properties = update_prompt(prompt=prompt, properties=properties, pk=pk, properties_original=properties_original, retrievers=self.retrievers) # Prepare additional data if requested if return_pk and verbose: return (properties, pk), (properties, properties_original) elif return_pk: return properties, pk elif verbose: return properties, properties_original return properties def extract_chainlit(self, prompt): if self.cust_extractor_prompt: properties = extract_properties(prompt, self.schema_config, self.cust_extractor_prompt) else: properties = extract_properties(prompt, self.schema_config) self.properties_original = deepcopy(properties) return properties def validate_chainlit(self, properties): properties, need_val = check_and_update_properties(properties, self.retrievers, input_func="chainlit") return properties, need_val def build_prompt_chainlit(self, properties, prompt): pk = None # self.properties_original= deepcopy(properties) if properties: pk = fetch_pks(properties, self.retrievers) prompt_new = update_prompt(prompt, properties, pk, self.properties_original, self.retrievers) return prompt_new def load_json(file_path: str) -> dict: with open(file_path, 'r') as file: return json.load(file) def create_extractor(schema: str = "src/conf/schema.json", db: SQLDatabase = db_uri): schema_config = load_json(schema) db = SQLDatabase.from_uri(db) pre_prompt = """Extract and save the relevant entities mentioned \ in the following passage together with their properties. Only extract the properties mentioned in the 'information_extraction' function. The questions are soccer related. game_event are things like yellow cards, goals, assists, freekick ect. Generic properties like, "description", "home team", "away team", "game" ect should NOT be extracted. If a property is not present and is not required in the function parameters, do not include it in the output. If no properties are found, return an empty list. Here are some exampels: 'How many goals did Henry score for Arsnl in the 2015 season?' person_name': ['Henry'], 'team_name': [Arsnl],'year_season': ['2015'], Passage: {input} """ return PromptCleaner(db, schema_config, custom_extractor_prompt=pre_prompt) if __name__ == "__main__": schema_config = load_json("src/conf/schema.json") # Add game and league to the schema_config # prompter = PromptCleaner(db, schema_config, custom_extractor_prompt=extract_prompt) prompter = create_extractor("src/conf/schema.json", "sqlite:///data/games.db") prompt = prompter.clean( "Give me goals, shots on target, shots off target and corners from the game between ManU and Swansa and Manchester City") print(prompt) # ex = create_extractor() # # val_list = [{'person_name': ['Cristiano Ronaldo'], 'team_name': ['Manchester City']}] # user_prompt = "Did ronaldo play for city?" # p = ex.build_prompt_chainlit(val_list, user_prompt) # print(p)