File size: 3,143 Bytes
2a51e7d
 
 
 
 
 
 
 
 
0a8b0d4
2a51e7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from langgraph.checkpoint.memory import MemorySaver
from langgraph.store.memory import InMemoryStore
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import ToolNode
from langchain_core.runnables import Runnable
from langchain_core.messages import AIMessage, ToolMessage
from langgraph.prebuilt import tools_condition

from langchain_groq import ChatGroq
from apps.agent.tools import tool_weweb, tool_xano    
from apps.agent.state import State, RequestAssistance
from apps.agent.constant import PROMPT


class Agent:
    def __init__(self, llm: ChatGroq, memory=MemorySaver(), store=InMemoryStore() , prompt=PROMPT):
        self.llm = llm
        self.memory = memory
        self.store = store
        self.tools = [tool_xano, tool_weweb]
        llm_with_tools = prompt | self.llm.bind_tools(self.tools + [RequestAssistance])

        builder = StateGraph(State)
        builder.add_node("chatbot", Assistant(llm_with_tools))
        builder.add_node("tools", ToolNode(self.tools))
        builder.add_node("human", self._human_node)
        builder.add_conditional_edges(
            "chatbot",
            tools_condition,
            {"human": "human", "tools": "tools", END: END},
        )

        builder.add_edge("tools", "chatbot")
        builder.add_edge("human", "chatbot")
        builder.add_edge(START, "chatbot")
            
        self.graph = builder.compile(
            checkpointer=self.memory,
            store=self.store,
            interrupt_after=["human"]
        )
    
    def _create_response(self, response: str, ai_message: AIMessage):
        return ToolMessage(
            content=response,
            tool_call_id=ai_message.tool_calls[0]["id"],
        )
    
    def _human_node(self, state: State):
        new_messages = []
        if not isinstance(state["messages"][-1], ToolMessage):
            # Typically, the user will have updated the state during the interrupt.
            # If they choose not to, we will include a placeholder ToolMessage to
            # let the LLM continue.
            new_messages.append(
                self._create_response("No response from human.", state["messages"][-1])
            )
        return {
            # Append the new messages
            "messages": new_messages,
            # Unset the flag
            "ask_human": False,
        }
    

    def _select_next_node(self, state: State):
        if state["ask_human"]:
            return "human"
            # Otherwise, we can route as before
        return tools_condition(state)


class Assistant:
    def __init__(self, runnable: Runnable):
        self.runnable = runnable

    def __call__(self, state):
        while True:
            response = self.runnable.invoke(state)
            # If the LLM happens to return an empty response, we will re-prompt it
            # for an actual response.
            ask_human = False

            if (
                response.tool_calls and response.tool_calls[0]["name"] == RequestAssistance.__name__
            ):
                ask_human = True
            return {"messages": [response], "ask_human": ask_human}