ric9176 commited on
Commit
7801115
·
1 Parent(s): 502c3a8

refactor to basic ReAct pattern

Browse files
Files changed (3) hide show
  1. app.py +16 -28
  2. docs/graph.png +0 -0
  3. scripts/generate_graph_image.py +1 -4
app.py CHANGED
@@ -8,6 +8,8 @@ from langgraph.graph.message import add_messages
8
  from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
9
  from langchain.schema.runnable.config import RunnableConfig
10
  from langchain_community.tools.tavily_search import TavilySearchResults
 
 
11
 
12
  import chainlit as cl
13
  from rag import create_rag_pipeline, add_urls_to_vectorstore
@@ -30,41 +32,29 @@ class AgentState(TypedDict):
30
  messages: Annotated[list, add_messages]
31
  context: list # Store retrieved context
32
 
 
 
 
 
 
 
33
  tavily_tool = TavilySearchResults(max_results=5)
34
- tool_belt = [tavily_tool]
35
 
36
- model = ChatOpenAI(model="gpt-4o", temperature=0)
37
- model = model.bind_tools(tool_belt)
38
 
39
  # Define system prompt
40
  SYSTEM_PROMPT = SystemMessage(content="""
41
  You are a helpful AI assistant that answers questions clearly and concisely.
42
  If you don't know something, simply say you don't know.
43
  Be engaging and professional in your responses.
44
- Use the provided context when available to give accurate information about events and activities.
 
45
  """)
46
 
47
- def retrieve(state: AgentState):
48
- """Retrieve relevant context from the vector store"""
49
- # Get the last message's content
50
- last_message = state["messages"][-1]
51
- if isinstance(last_message, HumanMessage):
52
- # Get relevant documents
53
- docs = rag_components["retriever"].get_relevant_documents(last_message.content)
54
- # Extract the content from documents
55
- context = [doc.page_content for doc in docs]
56
- return {"context": context}
57
- return {"context": []}
58
-
59
  def call_model(state: AgentState):
60
  messages = [SYSTEM_PROMPT] + state["messages"]
61
-
62
- # Add context to system message if available
63
- if state.get("context"):
64
- context_str = "\n".join(state["context"])
65
- context_message = SystemMessage(content=f"Context:\n{context_str}")
66
- messages = [messages[0], context_message] + messages[1:]
67
-
68
  response = model.invoke(messages)
69
  return {"messages": [response]}
70
 
@@ -82,14 +72,12 @@ def should_continue(state):
82
  # Create the graph
83
  builder = StateGraph(AgentState)
84
 
85
- # Add nodes
86
- builder.add_node("retrieve", retrieve)
87
  builder.add_node("agent", call_model)
88
  builder.add_node("action", tool_node)
89
 
90
- # Add edges
91
- builder.set_entry_point("retrieve")
92
- builder.add_edge("retrieve", "agent")
93
  builder.add_conditional_edges(
94
  "agent",
95
  should_continue,
 
8
  from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
9
  from langchain.schema.runnable.config import RunnableConfig
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
+ from langchain.tools import Tool
12
+ from langchain_core.tools import tool
13
 
14
  import chainlit as cl
15
  from rag import create_rag_pipeline, add_urls_to_vectorstore
 
32
  messages: Annotated[list, add_messages]
33
  context: list # Store retrieved context
34
 
35
+ # Create a retrieve tool
36
+ @tool
37
+ def retrieve_context(query: str) -> list[str]:
38
+ """Searches the knowledge base for relevant information about events and activities. Use this when you need specific details about events."""
39
+ return [doc.page_content for doc in rag_components["retriever"].get_relevant_documents(query)]
40
+
41
  tavily_tool = TavilySearchResults(max_results=5)
42
+ tool_belt = [tavily_tool, retrieve_context]
43
 
44
+ llm = ChatOpenAI(model="gpt-4o", temperature=0)
45
+ model = llm.bind_tools(tool_belt)
46
 
47
  # Define system prompt
48
  SYSTEM_PROMPT = SystemMessage(content="""
49
  You are a helpful AI assistant that answers questions clearly and concisely.
50
  If you don't know something, simply say you don't know.
51
  Be engaging and professional in your responses.
52
+ Use the retrieve_context tool when you need specific information about events and activities.
53
+ Use the tavily_search tool for general web searches.
54
  """)
55
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def call_model(state: AgentState):
57
  messages = [SYSTEM_PROMPT] + state["messages"]
 
 
 
 
 
 
 
58
  response = model.invoke(messages)
59
  return {"messages": [response]}
60
 
 
72
  # Create the graph
73
  builder = StateGraph(AgentState)
74
 
75
+ # Remove retrieve node and modify graph structure
 
76
  builder.add_node("agent", call_model)
77
  builder.add_node("action", tool_node)
78
 
79
+ # Update edges
80
+ builder.set_entry_point("agent")
 
81
  builder.add_conditional_edges(
82
  "agent",
83
  should_continue,
docs/graph.png CHANGED
scripts/generate_graph_image.py CHANGED
@@ -4,20 +4,17 @@ import sys
4
  # Add the parent directory to the Python path
5
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6
 
7
- from app import graph # This will import the compiled graph
8
  from langchain_core.runnables.graph import MermaidDrawMethod
9
 
10
- # Create docs directory if it doesn't exist
11
  docs_dir = "docs"
12
  if not os.path.exists(docs_dir):
13
  os.makedirs(docs_dir)
14
 
15
- # Generate and save the graph visualization
16
  graph_image = graph.get_graph().draw_mermaid_png(
17
  draw_method=MermaidDrawMethod.API,
18
  )
19
 
20
- # Save the image
21
  with open(os.path.join(docs_dir, "graph.png"), "wb") as f:
22
  f.write(graph_image)
23
 
 
4
  # Add the parent directory to the Python path
5
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6
 
7
+ from app import graph
8
  from langchain_core.runnables.graph import MermaidDrawMethod
9
 
 
10
  docs_dir = "docs"
11
  if not os.path.exists(docs_dir):
12
  os.makedirs(docs_dir)
13
 
 
14
  graph_image = graph.get_graph().draw_mermaid_png(
15
  draw_method=MermaidDrawMethod.API,
16
  )
17
 
 
18
  with open(os.path.join(docs_dir, "graph.png"), "wb") as f:
19
  f.write(graph_image)
20