Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_openai import ChatOpenAI | |
from langgraph.graph import MessagesState | |
from langgraph.graph import START, StateGraph | |
from langgraph.prebuilt import tools_condition | |
from langgraph.prebuilt import ToolNode | |
from langchain_core.messages import HumanMessage, SystemMessage | |
import tempfile | |
# ------------------- Environment Variable Setup ------------------- | |
# Fetch API keys from environment variables | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
tavily_api_key = os.getenv("TAVILY_API_KEY") | |
# Verify if API keys are set | |
if not openai_api_key: | |
raise ValueError("Missing required environment variable: OPENAI_API_KEY") | |
if not tavily_api_key: | |
raise ValueError("Missing required environment variable: TAVILY_API_KEY") | |
# ------------------- Tool Definitions ------------------- | |
# Tavily Search Tool | |
tavily_tool = TavilySearchResults(max_results=5) | |
def multiply(a: int, b: int) -> int: | |
"""Multiply two numbers.""" | |
return a * b | |
def add(a: int, b: int) -> int: | |
"""Add two numbers.""" | |
return a + b | |
def divide(a: int, b: int) -> float: | |
"""Divide two numbers.""" | |
if b == 0: | |
raise ValueError("Division by zero is not allowed.") | |
return a / b | |
# Combine tools | |
tools = [add, multiply, divide, tavily_tool] | |
# ------------------- LLM and System Message Setup ------------------- | |
llm = ChatOpenAI(model="gpt-4o-mini") | |
llm_with_tools = llm.bind_tools(tools, parallel_tool_calls=False) | |
sys_msg = SystemMessage(content="You are a helpful assistant tasked with performing arithmetic and search on a set of inputs.") | |
# ------------------- LangGraph Workflow ------------------- | |
def assistant(state: MessagesState): | |
"""Assistant node to invoke LLM with tools.""" | |
return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]} | |
# Define the graph | |
app_graph = StateGraph(MessagesState) | |
app_graph.add_node("assistant", assistant) | |
app_graph.add_node("tools", ToolNode(tools)) | |
app_graph.add_edge(START, "assistant") | |
app_graph.add_conditional_edges("assistant", tools_condition) | |
app_graph.add_edge("tools", "assistant") | |
react_graph = app_graph.compile() | |
# Save graph visualization as an image | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile: | |
graph = react_graph.get_graph(xray=True) | |
tmpfile.write(graph.draw_mermaid_png()) # Write binary image data to file | |
graph_image_path = tmpfile.name | |
# ------------------- Streamlit Interface ------------------- | |
st.title("ReAct Agent for Arithmetic Ops & Web Search") | |
# Display the workflow graph | |
#st.header("LangGraph Workflow Visualization") | |
st.image(graph_image_path, caption="Workflow Visualization") | |
# Prompt user for inputs | |
user_question = st.text_area("Enter your question:", | |
placeholder="Example: 'Add 3 and 4. Multiply the result by 2. Divide it by 5.'") | |
if st.button("Submit"): | |
if not user_question.strip(): | |
st.error("Please enter a valid question.") | |
st.stop() | |
st.info("Processing your question...") | |
messages = [HumanMessage(content=user_question)] | |
response = react_graph.invoke({"messages": messages}) | |
# Display results step-by-step | |
st.subheader("Response:") | |
for m in response['messages']: | |
if hasattr(m, "content") and m.content: # Display human and assistant messages | |
st.write("**AI Message:**", m.content) | |
if hasattr(m, "tool_calls") and m.tool_calls: # Display tool call steps | |
for tool_call in m.tool_calls: | |
st.write(f"**Tool Call:** `{tool_call['name']}`") | |
st.json(tool_call['args']) # Display tool arguments in JSON | |
if "output" in tool_call: # Handle tool outputs if available | |
st.write("**Tool Output:**", tool_call['output']) | |
st.success("Processing complete!") | |
# Example Placeholder Suggestions | |
st.sidebar.subheader("Example Questions") | |
st.sidebar.write("- Add 3 and 4. Multiply the result by 2. Divide it by 5.") | |
st.sidebar.write("- Tell me how many centuries Virat Kohli scored.") | |
st.sidebar.write("- Search for the tallest building in the world.") | |
st.sidebar.title("References") | |
st.sidebar.markdown("1. [LangGraph ReAct Agents](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_9_ReAct_Agents.ipynb)") |