Spaces:
Build error
Build error
hari-huynh
commited on
Commit
•
be52c8f
1
Parent(s):
1b8f0b5
Update KG Search Tool
Browse files- tools/kg_search.py +11 -6
tools/kg_search.py
CHANGED
@@ -37,12 +37,12 @@ example_selector = SemanticSimilarityExampleSelector.from_examples(
|
|
37 |
# Load schema, prefix, suffix
|
38 |
with open("prompts/schema.txt", "r") as file:
|
39 |
schema = file.read()
|
40 |
-
|
41 |
with open("prompts/cypher_instruct.yaml", "r") as file:
|
42 |
instruct = yaml.safe_load(file)
|
43 |
|
44 |
example_prompt = PromptTemplate(
|
45 |
-
input_variables = ["
|
46 |
template = instruct["example_template"]
|
47 |
)
|
48 |
|
@@ -54,6 +54,7 @@ dynamic_prompt = FewShotPromptTemplate(
|
|
54 |
input_variables = ["question"]
|
55 |
)
|
56 |
|
|
|
57 |
def generate_cypher(question: str) -> str:
|
58 |
"""Make Cypher query from given question."""
|
59 |
load_dotenv()
|
@@ -69,12 +70,14 @@ def generate_cypher(question: str) -> str:
|
|
69 |
)
|
70 |
|
71 |
chat_messages = [
|
72 |
-
|
73 |
]
|
74 |
|
|
|
75 |
output_parser = StrOutputParser()
|
|
|
76 |
chain = dynamic_prompt | gemini_chat | output_parser
|
77 |
-
cypher_statement = chain.invoke(question)
|
78 |
cypher_statement = cypher_statement.replace("```", "").replace("cypher", "").strip()
|
79 |
|
80 |
return cypher_statement
|
@@ -83,6 +86,7 @@ def run_cypher(question, cypher_statement: str) -> str:
|
|
83 |
"""Return result of Cypher query from Knowledge Graph."""
|
84 |
knowledge_graph = Neo4jGraph()
|
85 |
result = knowledge_graph.query(cypher_statement)
|
|
|
86 |
|
87 |
gemini_chat = ChatGoogleGenerativeAI(
|
88 |
model= "gemini-1.5-flash-latest"
|
@@ -114,11 +118,12 @@ def lookup_kg(question: str) -> str:
|
|
114 |
"""
|
115 |
cypher_statement = generate_cypher(question)
|
116 |
cypher_statement = cypher_statement.replace("cypher", "").replace("```", "").strip()
|
|
|
117 |
|
118 |
try:
|
119 |
answer = run_cypher(question, cypher_statement)
|
120 |
except:
|
121 |
-
answer = "Knowledge graph doesn't have enough information"
|
122 |
|
123 |
return answer
|
124 |
|
@@ -137,5 +142,5 @@ if __name__ == "__main__":
|
|
137 |
# print(final_result)
|
138 |
|
139 |
# Test lookup_kg tool
|
140 |
-
kg_info = lookup_kg
|
141 |
print(kg_info)
|
|
|
37 |
# Load schema, prefix, suffix
|
38 |
with open("prompts/schema.txt", "r") as file:
|
39 |
schema = file.read()
|
40 |
+
|
41 |
with open("prompts/cypher_instruct.yaml", "r") as file:
|
42 |
instruct = yaml.safe_load(file)
|
43 |
|
44 |
example_prompt = PromptTemplate(
|
45 |
+
input_variables = ["question_example", "cypher_example"],
|
46 |
template = instruct["example_template"]
|
47 |
)
|
48 |
|
|
|
54 |
input_variables = ["question"]
|
55 |
)
|
56 |
|
57 |
+
|
58 |
def generate_cypher(question: str) -> str:
|
59 |
"""Make Cypher query from given question."""
|
60 |
load_dotenv()
|
|
|
70 |
)
|
71 |
|
72 |
chat_messages = [
|
73 |
+
SystemMessage(content= dynamic_prompt.format(question=question)),
|
74 |
]
|
75 |
|
76 |
+
|
77 |
output_parser = StrOutputParser()
|
78 |
+
cypher_statement = []
|
79 |
chain = dynamic_prompt | gemini_chat | output_parser
|
80 |
+
cypher_statement = chain.invoke({"question": question})
|
81 |
cypher_statement = cypher_statement.replace("```", "").replace("cypher", "").strip()
|
82 |
|
83 |
return cypher_statement
|
|
|
86 |
"""Return result of Cypher query from Knowledge Graph."""
|
87 |
knowledge_graph = Neo4jGraph()
|
88 |
result = knowledge_graph.query(cypher_statement)
|
89 |
+
print(f"\nCypher Result:\n{result}")
|
90 |
|
91 |
gemini_chat = ChatGoogleGenerativeAI(
|
92 |
model= "gemini-1.5-flash-latest"
|
|
|
118 |
"""
|
119 |
cypher_statement = generate_cypher(question)
|
120 |
cypher_statement = cypher_statement.replace("cypher", "").replace("```", "").strip()
|
121 |
+
print(f"\nQuery:\n {cypher_statement}")
|
122 |
|
123 |
try:
|
124 |
answer = run_cypher(question, cypher_statement)
|
125 |
except:
|
126 |
+
answer = "Knowledge graph doesn't have enough information\n"
|
127 |
|
128 |
return answer
|
129 |
|
|
|
142 |
# print(final_result)
|
143 |
|
144 |
# Test lookup_kg tool
|
145 |
+
kg_info = lookup_kg(question)
|
146 |
print(kg_info)
|