hari-huynh
Update KG Search Tool
be52c8f
raw
history blame
4.63 kB
import os
import yaml
from dotenv import load_dotenv
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from langchain.schema.output_parser import StrOutputParser
from langchain.tools import BaseTool, StructuredTool, tool
from langchain_community.graphs import Neo4jGraph
# from utils import utils
# Question-Cypher pair examples
with open("prompts/cypher_examples.yaml", "r") as f:
example_pairs = yaml.safe_load(f)
examples = example_pairs["examples"]
# LLM for choose the best similar examples
load_dotenv()
os.environ["GOOGLE_API_KEY"] = os.getenv("GEMINI_API_KEY")
embedding_model = GoogleGenerativeAIEmbeddings(
model= "models/text-embedding-004"
)
example_selector = SemanticSimilarityExampleSelector.from_examples(
examples = examples,
embeddings = embedding_model,
vectorstore_cls = FAISS,
k = 1
)
# Load schema, prefix, suffix
with open("prompts/schema.txt", "r") as file:
schema = file.read()
with open("prompts/cypher_instruct.yaml", "r") as file:
instruct = yaml.safe_load(file)
example_prompt = PromptTemplate(
input_variables = ["question_example", "cypher_example"],
template = instruct["example_template"]
)
dynamic_prompt = FewShotPromptTemplate(
example_selector = example_selector,
example_prompt = example_prompt,
prefix = instruct["prefix"],
suffix = instruct["suffix"].format(schema=schema),
input_variables = ["question"]
)
def generate_cypher(question: str) -> str:
"""Make Cypher query from given question."""
load_dotenv()
# Set up Neo4J & Gemini API
os.environ["NEO4J_URI"] = os.getenv("NEO4J_URI")
os.environ["NEO4J_USERNAME"] = os.getenv("NEO4J_USERNAME")
os.environ["NEO4J_PASSWORD"] = os.getenv("NEO4J_PASSWORD")
os.environ["GOOGLE_API_KEY"] = os.getenv("GEMINI_API_KEY")
gemini_chat = ChatGoogleGenerativeAI(
model= "gemini-1.5-flash-latest"
)
chat_messages = [
SystemMessage(content= dynamic_prompt.format(question=question)),
]
output_parser = StrOutputParser()
cypher_statement = []
chain = dynamic_prompt | gemini_chat | output_parser
cypher_statement = chain.invoke({"question": question})
cypher_statement = cypher_statement.replace("```", "").replace("cypher", "").strip()
return cypher_statement
def run_cypher(question, cypher_statement: str) -> str:
"""Return result of Cypher query from Knowledge Graph."""
knowledge_graph = Neo4jGraph()
result = knowledge_graph.query(cypher_statement)
print(f"\nCypher Result:\n{result}")
gemini_chat = ChatGoogleGenerativeAI(
model= "gemini-1.5-flash-latest"
)
answer_prompt = f"""
Generate a concise and informative summary of the results in a polite and easy-to-understand manner based on question and Cypher query response.
Question: {question}
Response: {str(result)}
Avoid repeat information.
If response is empty, you should answer "Knowledge graph doesn't have enough information".
Answer:
"""
sys_answer_prompt = [
SystemMessage(content= answer_prompt),
HumanMessage(content="Provide information about question from knowledge graph")
]
response = gemini_chat.invoke(sys_answer_prompt)
answer = response.content
return answer
def lookup_kg(question: str) -> str:
"""Based on question, make and run Cypher statements.
question: str
Raw question from user input
"""
cypher_statement = generate_cypher(question)
cypher_statement = cypher_statement.replace("cypher", "").replace("```", "").strip()
print(f"\nQuery:\n {cypher_statement}")
try:
answer = run_cypher(question, cypher_statement)
except:
answer = "Knowledge graph doesn't have enough information\n"
return answer
if __name__ == "__main__":
question = "Have any company is recruiting Machine Learning jobs?"
# Test few-shot template
# print(dynamic_prompt.format(question = "What does the Software Engineer job usually require?"))
# # Test generate Cypher
# result = generate_cypher(question)
# # Test return information from Cypher
# final_result = run_cypher(result)
# print(final_result)
# Test lookup_kg tool
kg_info = lookup_kg(question)
print(kg_info)