|
import json |
|
from typing import Callable, Dict, List, Union |
|
|
|
from pydantic import BaseModel, Field |
|
|
|
from lagent.actions import ActionExecutor, AsyncActionExecutor, BaseAction |
|
from lagent.agents.agent import Agent, AsyncAgent |
|
from lagent.agents.aggregator import DefaultAggregator |
|
from lagent.hooks import ActionPreprocessor |
|
from lagent.llms import BaseLLM |
|
from lagent.memory import Memory |
|
from lagent.prompts.parsers.json_parser import JSONParser |
|
from lagent.prompts.prompt_template import PromptTemplate |
|
from lagent.schema import AgentMessage |
|
from lagent.utils import create_object |
|
|
|
select_action_template = """你是一个可以调用外部工具的助手,可以使用的工具包括: |
|
{action_info} |
|
{output_format} |
|
开始!""" |
|
|
|
output_format_template = """如果使用工具请遵循以下格式回复: |
|
{function_format} |
|
|
|
如果你已经知道了答案,或者你不需要工具,请遵循以下格式回复 |
|
{finish_format}""" |
|
|
|
|
|
class ReAct(Agent): |
|
|
|
def __init__(self, |
|
llm: Union[BaseLLM, Dict], |
|
actions: Union[BaseAction, List[BaseAction]], |
|
template: Union[PromptTemplate, str] = None, |
|
memory: Dict = dict(type=Memory), |
|
output_format: Dict = dict(type=JSONParser), |
|
aggregator: Dict = dict(type=DefaultAggregator), |
|
hooks: List = [dict(type=ActionPreprocessor)], |
|
finish_condition: Callable[[AgentMessage], bool] = lambda m: |
|
'conclusion' in m.content or 'conclusion' in m.formatted, |
|
max_turn: int = 5, |
|
**kwargs): |
|
self.max_turn = max_turn |
|
self.finish_condition = finish_condition |
|
actions = dict( |
|
type=ActionExecutor, |
|
actions=actions, |
|
hooks=hooks, |
|
) |
|
self.actions: ActionExecutor = create_object(actions) |
|
select_agent = dict( |
|
type=Agent, |
|
llm=llm, |
|
template=template.format( |
|
action_info=json.dumps(self.actions.description()), |
|
output_format=output_format.format_instruction()), |
|
output_format=output_format, |
|
memory=memory, |
|
aggregator=aggregator, |
|
hooks=hooks, |
|
) |
|
self.select_agent = create_object(select_agent) |
|
super().__init__(**kwargs) |
|
|
|
def forward(self, message: AgentMessage, **kwargs) -> AgentMessage: |
|
for _ in range(self.max_turn): |
|
message = self.select_agent(message) |
|
if self.finish_condition(message): |
|
return message |
|
message = self.actions(message) |
|
return message |
|
|
|
|
|
class AsyncReAct(AsyncAgent): |
|
|
|
def __init__(self, |
|
llm: Union[BaseLLM, Dict], |
|
actions: Union[BaseAction, List[BaseAction]], |
|
template: Union[PromptTemplate, str] = None, |
|
memory: Dict = dict(type=Memory), |
|
output_format: Dict = dict(type=JSONParser), |
|
aggregator: Dict = dict(type=DefaultAggregator), |
|
hooks: List = [dict(type=ActionPreprocessor)], |
|
finish_condition: Callable[[AgentMessage], bool] = lambda m: |
|
'conclusion' in m.content or 'conclusion' in m.formatted, |
|
max_turn: int = 5, |
|
**kwargs): |
|
self.max_turn = max_turn |
|
self.finish_condition = finish_condition |
|
actions = dict( |
|
type=AsyncActionExecutor, |
|
actions=actions, |
|
hooks=hooks, |
|
) |
|
self.actions: AsyncActionExecutor = create_object(actions) |
|
select_agent = dict( |
|
type=AsyncAgent, |
|
llm=llm, |
|
template=template.format( |
|
action_info=json.dumps(self.actions.description()), |
|
output_format=output_format.format_instruction()), |
|
output_format=output_format, |
|
memory=memory, |
|
aggregator=aggregator, |
|
hooks=hooks, |
|
) |
|
self.select_agent = create_object(select_agent) |
|
super().__init__(**kwargs) |
|
|
|
async def forward(self, message: AgentMessage, **kwargs) -> AgentMessage: |
|
for _ in range(self.max_turn): |
|
message = await self.select_agent(message) |
|
if self.finish_condition(message): |
|
return message |
|
message = await self.actions(message) |
|
return message |
|
|
|
|
|
if __name__ == '__main__': |
|
from lagent.llms import GPTAPI |
|
|
|
class ActionCall(BaseModel): |
|
name: str = Field(description='调用的函数名称') |
|
parameters: Dict = Field(description='调用函数的参数') |
|
|
|
class ActionFormat(BaseModel): |
|
thought_process: str = Field( |
|
description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。') |
|
action: ActionCall = Field(description='当前步骤需要执行的操作,包括函数名称和参数。') |
|
|
|
class FinishFormat(BaseModel): |
|
thought_process: str = Field( |
|
description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。') |
|
conclusion: str = Field(description='总结当前的搜索结果,回答问题。') |
|
|
|
prompt_template = PromptTemplate(select_action_template) |
|
output_format = JSONParser( |
|
output_format_template, |
|
function_format=ActionFormat, |
|
finish_format=FinishFormat) |
|
|
|
llm = dict( |
|
type=GPTAPI, |
|
model_type='gpt-4o-2024-05-13', |
|
key=None, |
|
max_new_tokens=4096, |
|
proxies=dict(), |
|
retry=1000) |
|
|
|
agent = ReAct( |
|
llm=llm, |
|
template=prompt_template, |
|
output_format=output_format, |
|
aggregator=dict(type='DefaultAggregator'), |
|
actions=[dict(type='PythonInterpreter')], |
|
) |
|
response = agent( |
|
AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5')) |
|
print(response) |
|
response = agent(AgentMessage(sender='user', content=' 2 ** 5 呢')) |
|
print(response) |
|
|