hari-huynh commited on
Commit
be52c8f
1 Parent(s): 1b8f0b5

Update KG Search Tool

Browse files
Files changed (1) hide show
  1. tools/kg_search.py +11 -6
tools/kg_search.py CHANGED
@@ -37,12 +37,12 @@ example_selector = SemanticSimilarityExampleSelector.from_examples(
37
  # Load schema, prefix, suffix
38
  with open("prompts/schema.txt", "r") as file:
39
  schema = file.read()
40
-
41
  with open("prompts/cypher_instruct.yaml", "r") as file:
42
  instruct = yaml.safe_load(file)
43
 
44
  example_prompt = PromptTemplate(
45
- input_variables = ["question", "cypher"],
46
  template = instruct["example_template"]
47
  )
48
 
@@ -54,6 +54,7 @@ dynamic_prompt = FewShotPromptTemplate(
54
  input_variables = ["question"]
55
  )
56
 
 
57
  def generate_cypher(question: str) -> str:
58
  """Make Cypher query from given question."""
59
  load_dotenv()
@@ -69,12 +70,14 @@ def generate_cypher(question: str) -> str:
69
  )
70
 
71
  chat_messages = [
72
- SystemMessage(content= dynamic_prompt.format(question=question)),
73
  ]
74
 
 
75
  output_parser = StrOutputParser()
 
76
  chain = dynamic_prompt | gemini_chat | output_parser
77
- cypher_statement = chain.invoke(question)
78
  cypher_statement = cypher_statement.replace("```", "").replace("cypher", "").strip()
79
 
80
  return cypher_statement
@@ -83,6 +86,7 @@ def run_cypher(question, cypher_statement: str) -> str:
83
  """Return result of Cypher query from Knowledge Graph."""
84
  knowledge_graph = Neo4jGraph()
85
  result = knowledge_graph.query(cypher_statement)
 
86
 
87
  gemini_chat = ChatGoogleGenerativeAI(
88
  model= "gemini-1.5-flash-latest"
@@ -114,11 +118,12 @@ def lookup_kg(question: str) -> str:
114
  """
115
  cypher_statement = generate_cypher(question)
116
  cypher_statement = cypher_statement.replace("cypher", "").replace("```", "").strip()
 
117
 
118
  try:
119
  answer = run_cypher(question, cypher_statement)
120
  except:
121
- answer = "Knowledge graph doesn't have enough information"
122
 
123
  return answer
124
 
@@ -137,5 +142,5 @@ if __name__ == "__main__":
137
  # print(final_result)
138
 
139
  # Test lookup_kg tool
140
- kg_info = lookup_kg.invoke(question)
141
  print(kg_info)
 
37
  # Load schema, prefix, suffix
38
  with open("prompts/schema.txt", "r") as file:
39
  schema = file.read()
40
+
41
  with open("prompts/cypher_instruct.yaml", "r") as file:
42
  instruct = yaml.safe_load(file)
43
 
44
  example_prompt = PromptTemplate(
45
+ input_variables = ["question_example", "cypher_example"],
46
  template = instruct["example_template"]
47
  )
48
 
 
54
  input_variables = ["question"]
55
  )
56
 
57
+
58
  def generate_cypher(question: str) -> str:
59
  """Make Cypher query from given question."""
60
  load_dotenv()
 
70
  )
71
 
72
  chat_messages = [
73
+ SystemMessage(content= dynamic_prompt.format(question=question)),
74
  ]
75
 
76
+
77
  output_parser = StrOutputParser()
78
+ cypher_statement = []
79
  chain = dynamic_prompt | gemini_chat | output_parser
80
+ cypher_statement = chain.invoke({"question": question})
81
  cypher_statement = cypher_statement.replace("```", "").replace("cypher", "").strip()
82
 
83
  return cypher_statement
 
86
  """Return result of Cypher query from Knowledge Graph."""
87
  knowledge_graph = Neo4jGraph()
88
  result = knowledge_graph.query(cypher_statement)
89
+ print(f"\nCypher Result:\n{result}")
90
 
91
  gemini_chat = ChatGoogleGenerativeAI(
92
  model= "gemini-1.5-flash-latest"
 
118
  """
119
  cypher_statement = generate_cypher(question)
120
  cypher_statement = cypher_statement.replace("cypher", "").replace("```", "").strip()
121
+ print(f"\nQuery:\n {cypher_statement}")
122
 
123
  try:
124
  answer = run_cypher(question, cypher_statement)
125
  except:
126
+ answer = "Knowledge graph doesn't have enough information\n"
127
 
128
  return answer
129
 
 
142
  # print(final_result)
143
 
144
  # Test lookup_kg tool
145
+ kg_info = lookup_kg(question)
146
  print(kg_info)