|
import json |
|
from enum import IntEnum |
|
|
|
|
|
from typing import Any, Callable, List, Optional |
|
|
|
from lagent.prompts.parsers import StrParser |
|
from lagent.utils import create_object, load_class_from_string |
|
|
|
|
|
def default_plugin_validate(plugin: str): |
|
plugin = plugin.strip() |
|
if not (plugin.startswith('{') and plugin.endswith("}")): |
|
raise json.decoder.JSONDecodeError |
|
return json.loads(plugin) |
|
|
|
|
|
class ToolStatusCode(IntEnum): |
|
NO_TOOL = 0 |
|
VALID_TOOL = 1 |
|
PARSING_ERROR = -1 |
|
|
|
|
|
class ToolParser(StrParser): |
|
|
|
def __init__(self, |
|
tool_type: str, |
|
template: str = '', |
|
begin: str = '<tool>\n', |
|
end: str = '</tool>\n', |
|
validate: Callable[[str], Any] = None, |
|
**kwargs): |
|
super().__init__(template, begin=begin, end=end, **kwargs) |
|
self.template = template |
|
self.tool_type = tool_type |
|
|
|
|
|
|
|
self.validate = load_class_from_string(validate) if isinstance( |
|
validate, str) else validate |
|
|
|
def parse_response(self, data: str) -> dict: |
|
if self.format_field['begin'] not in data: |
|
return dict( |
|
tool_type=None, |
|
thought=data, |
|
action=None, |
|
status=ToolStatusCode.NO_TOOL) |
|
thought, action, *_ = data.split(self.format_field["begin"]) |
|
action = action.split(self.format_field['end'])[0] |
|
status = ToolStatusCode.VALID_TOOL |
|
if self.validate: |
|
try: |
|
action = self.validate(action) |
|
except Exception: |
|
status = ToolStatusCode.PARSING_ERROR |
|
return dict( |
|
tool_type=self.tool_type, |
|
thought=thought, |
|
action=action, |
|
status=status) |
|
|
|
def format_response(self, parsed: dict) -> str: |
|
if parsed['action'] is None: |
|
return parsed['thought'] |
|
assert parsed['tool_type'] == self.tool_type |
|
if isinstance(parsed['action'], dict): |
|
action = json.dumps(parsed['action'], ensure_ascii=False) |
|
else: |
|
action = str(parsed['action']) |
|
return parsed['thought'] + self.format_field[ |
|
'begin'] + action + self.format_field['end'] |
|
|
|
|
|
class InterpreterParser(ToolParser): |
|
|
|
def __init__(self, |
|
tool_type: str = 'interpreter', |
|
template: str = '', |
|
begin: str = '<|action_start|><|interpreter|>\n', |
|
end: str = '<|action_end|>\n', |
|
validate: Callable[[str], Any] = None, |
|
**kwargs): |
|
super().__init__(tool_type, template, begin, end, validate, **kwargs) |
|
|
|
|
|
class PluginParser(ToolParser): |
|
|
|
def __init__(self, |
|
tool_type: str = 'plugin', |
|
template: str = '', |
|
begin: str = '<|action_start|><|plugin|>\n', |
|
end: str = '<|action_end|>\n', |
|
validate: Callable[[str], Any] = default_plugin_validate, |
|
**kwargs): |
|
super().__init__(tool_type, template, begin, end, validate, **kwargs) |
|
|
|
|
|
class MixedToolParser(StrParser): |
|
|
|
def __init__(self, |
|
tool_type: Optional[str] = None, |
|
template='', |
|
parsers: List[ToolParser] = None, |
|
**format_field): |
|
self.parsers = {} |
|
self.tool_type = tool_type |
|
for parser in parsers or []: |
|
parser = create_object(parser) |
|
self.parsers[parser.tool_type] = parser |
|
super().__init__(template, **format_field) |
|
|
|
def format_instruction(self) -> List[dict]: |
|
inst = [] |
|
content = super().format_instruction() |
|
if content.strip(): |
|
msg = dict(role='system', content=content) |
|
if self.tool_type: |
|
msg['name'] = self.tool_type |
|
inst.append(msg) |
|
for name, parser in self.parsers.items(): |
|
content = parser.format_instruction() |
|
if content.strip(): |
|
inst.append(dict(role='system', content=content, name=name)) |
|
return inst |
|
|
|
def parse_response(self, data: str) -> dict: |
|
res = dict( |
|
tool_type=None, |
|
thought=data, |
|
action=None, |
|
status=ToolStatusCode.NO_TOOL) |
|
for name, parser in self.parsers.items(): |
|
res = parser.parse_response(data) |
|
if res['tool_type'] == name: |
|
break |
|
return res |
|
|
|
def format_response(self, parsed: dict) -> str: |
|
if parsed['action'] is None: |
|
return parsed['thought'] |
|
assert parsed['tool_type'] in self.parsers |
|
return self.parsers[parsed['tool_type']].format_response(parsed) |
|
|