import os import yaml from dotenv import load_dotenv from langchain_core.example_selectors import SemanticSimilarityExampleSelector from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai import GoogleGenerativeAIEmbeddings from langchain_community.vectorstores import FAISS from langchain.schema import AIMessage, HumanMessage, SystemMessage from langchain.schema.output_parser import StrOutputParser from langchain.tools import BaseTool, StructuredTool, tool from langchain_community.graphs import Neo4jGraph # from utils import utils # Question-Cypher pair examples with open("prompts/cypher_examples.yaml", "r") as f: example_pairs = yaml.safe_load(f) examples = example_pairs["examples"] # LLM for choose the best similar examples load_dotenv() os.environ["GOOGLE_API_KEY"] = os.getenv("GEMINI_API_KEY") embedding_model = GoogleGenerativeAIEmbeddings( model= "models/text-embedding-004" ) example_selector = SemanticSimilarityExampleSelector.from_examples( examples = examples, embeddings = embedding_model, vectorstore_cls = FAISS, k = 1 ) # Load schema, prefix, suffix with open("prompts/schema.txt", "r") as file: schema = file.read() with open("prompts/cypher_instruct.yaml", "r") as file: instruct = yaml.safe_load(file) example_prompt = PromptTemplate( input_variables = ["question", "cypher"], template = instruct["example_template"] ) dynamic_prompt = FewShotPromptTemplate( example_selector = example_selector, example_prompt = example_prompt, prefix = instruct["prefix"], suffix = instruct["suffix"].format(schema=schema), input_variables = ["question"] ) def generate_cypher(question: str) -> str: """Make Cypher query from given question.""" load_dotenv() # Set up Neo4J & Gemini API os.environ["NEO4J_URI"] = os.getenv("NEO4J_URI") os.environ["NEO4J_USERNAME"] = os.getenv("NEO4J_USERNAME") os.environ["NEO4J_PASSWORD"] = os.getenv("NEO4J_PASSWORD") os.environ["GOOGLE_API_KEY"] = os.getenv("GEMINI_API_KEY") gemini_chat = ChatGoogleGenerativeAI( model= "gemini-1.5-flash-latest" ) chat_messages = [ SystemMessage(content= dynamic_prompt.format(question=question)), ] output_parser = StrOutputParser() chain = dynamic_prompt | gemini_chat | output_parser cypher_statement = chain.invoke(question) cypher_statement = cypher_statement.replace("```", "").replace("cypher", "").strip() return cypher_statement def run_cypher(question, cypher_statement: str) -> str: """Return result of Cypher query from Knowledge Graph.""" knowledge_graph = Neo4jGraph() result = knowledge_graph.query(cypher_statement) gemini_chat = ChatGoogleGenerativeAI( model= "gemini-1.5-flash-latest" ) answer_prompt = f""" Generate a concise and informative summary of the results in a polite and easy-to-understand manner based on question and Cypher query response. Question: {question} Response: {str(result)} Avoid repeat information. If response is empty, you should answer "Knowledge graph doesn't have enough information". Answer: """ sys_answer_prompt = [ SystemMessage(content= answer_prompt), HumanMessage(content="Provide information about question from knowledge graph") ] response = gemini_chat.invoke(sys_answer_prompt) answer = response.content return answer def lookup_kg(question: str) -> str: """Based on question, make and run Cypher statements. question: str Raw question from user input """ cypher_statement = generate_cypher(question) cypher_statement = cypher_statement.replace("cypher", "").replace("```", "").strip() try: answer = run_cypher(question, cypher_statement) except: answer = "Knowledge graph doesn't have enough information" return answer if __name__ == "__main__": question = "Have any company is recruiting Machine Learning jobs?" # Test few-shot template # print(dynamic_prompt.format(question = "What does the Software Engineer job usually require?")) # # Test generate Cypher # result = generate_cypher(question) # # Test return information from Cypher # final_result = run_cypher(result) # print(final_result) # Test lookup_kg tool kg_info = lookup_kg.invoke(question) print(kg_info)