Spaces:
Running
Running
refactor to basic ReAct pattern
Browse files- app.py +16 -28
- docs/graph.png +0 -0
- scripts/generate_graph_image.py +1 -4
app.py
CHANGED
@@ -8,6 +8,8 @@ from langgraph.graph.message import add_messages
|
|
8 |
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
9 |
from langchain.schema.runnable.config import RunnableConfig
|
10 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
|
|
|
|
11 |
|
12 |
import chainlit as cl
|
13 |
from rag import create_rag_pipeline, add_urls_to_vectorstore
|
@@ -30,41 +32,29 @@ class AgentState(TypedDict):
|
|
30 |
messages: Annotated[list, add_messages]
|
31 |
context: list # Store retrieved context
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
tavily_tool = TavilySearchResults(max_results=5)
|
34 |
-
tool_belt = [tavily_tool]
|
35 |
|
36 |
-
|
37 |
-
model =
|
38 |
|
39 |
# Define system prompt
|
40 |
SYSTEM_PROMPT = SystemMessage(content="""
|
41 |
You are a helpful AI assistant that answers questions clearly and concisely.
|
42 |
If you don't know something, simply say you don't know.
|
43 |
Be engaging and professional in your responses.
|
44 |
-
Use the
|
|
|
45 |
""")
|
46 |
|
47 |
-
def retrieve(state: AgentState):
|
48 |
-
"""Retrieve relevant context from the vector store"""
|
49 |
-
# Get the last message's content
|
50 |
-
last_message = state["messages"][-1]
|
51 |
-
if isinstance(last_message, HumanMessage):
|
52 |
-
# Get relevant documents
|
53 |
-
docs = rag_components["retriever"].get_relevant_documents(last_message.content)
|
54 |
-
# Extract the content from documents
|
55 |
-
context = [doc.page_content for doc in docs]
|
56 |
-
return {"context": context}
|
57 |
-
return {"context": []}
|
58 |
-
|
59 |
def call_model(state: AgentState):
|
60 |
messages = [SYSTEM_PROMPT] + state["messages"]
|
61 |
-
|
62 |
-
# Add context to system message if available
|
63 |
-
if state.get("context"):
|
64 |
-
context_str = "\n".join(state["context"])
|
65 |
-
context_message = SystemMessage(content=f"Context:\n{context_str}")
|
66 |
-
messages = [messages[0], context_message] + messages[1:]
|
67 |
-
|
68 |
response = model.invoke(messages)
|
69 |
return {"messages": [response]}
|
70 |
|
@@ -82,14 +72,12 @@ def should_continue(state):
|
|
82 |
# Create the graph
|
83 |
builder = StateGraph(AgentState)
|
84 |
|
85 |
-
#
|
86 |
-
builder.add_node("retrieve", retrieve)
|
87 |
builder.add_node("agent", call_model)
|
88 |
builder.add_node("action", tool_node)
|
89 |
|
90 |
-
#
|
91 |
-
builder.set_entry_point("
|
92 |
-
builder.add_edge("retrieve", "agent")
|
93 |
builder.add_conditional_edges(
|
94 |
"agent",
|
95 |
should_continue,
|
|
|
8 |
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
9 |
from langchain.schema.runnable.config import RunnableConfig
|
10 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
11 |
+
from langchain.tools import Tool
|
12 |
+
from langchain_core.tools import tool
|
13 |
|
14 |
import chainlit as cl
|
15 |
from rag import create_rag_pipeline, add_urls_to_vectorstore
|
|
|
32 |
messages: Annotated[list, add_messages]
|
33 |
context: list # Store retrieved context
|
34 |
|
35 |
+
# Create a retrieve tool
|
36 |
+
@tool
|
37 |
+
def retrieve_context(query: str) -> list[str]:
|
38 |
+
"""Searches the knowledge base for relevant information about events and activities. Use this when you need specific details about events."""
|
39 |
+
return [doc.page_content for doc in rag_components["retriever"].get_relevant_documents(query)]
|
40 |
+
|
41 |
tavily_tool = TavilySearchResults(max_results=5)
|
42 |
+
tool_belt = [tavily_tool, retrieve_context]
|
43 |
|
44 |
+
llm = ChatOpenAI(model="gpt-4o", temperature=0)
|
45 |
+
model = llm.bind_tools(tool_belt)
|
46 |
|
47 |
# Define system prompt
|
48 |
SYSTEM_PROMPT = SystemMessage(content="""
|
49 |
You are a helpful AI assistant that answers questions clearly and concisely.
|
50 |
If you don't know something, simply say you don't know.
|
51 |
Be engaging and professional in your responses.
|
52 |
+
Use the retrieve_context tool when you need specific information about events and activities.
|
53 |
+
Use the tavily_search tool for general web searches.
|
54 |
""")
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
def call_model(state: AgentState):
|
57 |
messages = [SYSTEM_PROMPT] + state["messages"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
response = model.invoke(messages)
|
59 |
return {"messages": [response]}
|
60 |
|
|
|
72 |
# Create the graph
|
73 |
builder = StateGraph(AgentState)
|
74 |
|
75 |
+
# Remove retrieve node and modify graph structure
|
|
|
76 |
builder.add_node("agent", call_model)
|
77 |
builder.add_node("action", tool_node)
|
78 |
|
79 |
+
# Update edges
|
80 |
+
builder.set_entry_point("agent")
|
|
|
81 |
builder.add_conditional_edges(
|
82 |
"agent",
|
83 |
should_continue,
|
docs/graph.png
CHANGED
![]() |
![]() |
scripts/generate_graph_image.py
CHANGED
@@ -4,20 +4,17 @@ import sys
|
|
4 |
# Add the parent directory to the Python path
|
5 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
6 |
|
7 |
-
from app import graph
|
8 |
from langchain_core.runnables.graph import MermaidDrawMethod
|
9 |
|
10 |
-
# Create docs directory if it doesn't exist
|
11 |
docs_dir = "docs"
|
12 |
if not os.path.exists(docs_dir):
|
13 |
os.makedirs(docs_dir)
|
14 |
|
15 |
-
# Generate and save the graph visualization
|
16 |
graph_image = graph.get_graph().draw_mermaid_png(
|
17 |
draw_method=MermaidDrawMethod.API,
|
18 |
)
|
19 |
|
20 |
-
# Save the image
|
21 |
with open(os.path.join(docs_dir, "graph.png"), "wb") as f:
|
22 |
f.write(graph_image)
|
23 |
|
|
|
4 |
# Add the parent directory to the Python path
|
5 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
6 |
|
7 |
+
from app import graph
|
8 |
from langchain_core.runnables.graph import MermaidDrawMethod
|
9 |
|
|
|
10 |
docs_dir = "docs"
|
11 |
if not os.path.exists(docs_dir):
|
12 |
os.makedirs(docs_dir)
|
13 |
|
|
|
14 |
graph_image = graph.get_graph().draw_mermaid_png(
|
15 |
draw_method=MermaidDrawMethod.API,
|
16 |
)
|
17 |
|
|
|
18 |
with open(os.path.join(docs_dir, "graph.png"), "wb") as f:
|
19 |
f.write(graph_image)
|
20 |
|