|
import re |
|
from typing import Any, Dict, Union |
|
|
|
from pydantic import BaseModel, ValidationError |
|
|
|
from lagent.prompts.parsers.str_parser import StrParser |
|
|
|
|
|
class CustomFormatParser(StrParser): |
|
|
|
def _extract_fields_with_metadata( |
|
self, model: BaseModel) -> Dict[str, Dict[str, Any]]: |
|
fields_metadata = {} |
|
for field_name, field in model.model_fields.items(): |
|
fields_metadata[field_name] = { |
|
'annotation': field.annotation, |
|
'default': field.default |
|
if field.default is not None else '<required>', |
|
'comment': field.description if field.description else '' |
|
} |
|
return fields_metadata |
|
|
|
def format_to_string(self, format_model: BaseModel) -> str: |
|
fields = self._extract_fields_with_metadata(format_model) |
|
formatted_str = '' |
|
for field_name, metadata in fields.items(): |
|
comment = metadata.get('comment', '') |
|
field_annotation = metadata['annotation'].__name__ if metadata[ |
|
'annotation'] is not None else 'Any' |
|
if comment: |
|
formatted_str += f'<!-- {comment} -->\n' |
|
formatted_str += f'<{field_name} type="{field_annotation}">{metadata["default"] if metadata["default"] != "<required>" else ""}</{field_name}>\n' |
|
return formatted_str |
|
|
|
def parse_response(self, data: str) -> Union[dict, BaseModel]: |
|
pattern = re.compile(r'(<!--\s*(.*?)\s*-->)?\s*<(\w+)[^>]*>(.*?)</\3>', |
|
re.DOTALL) |
|
matches = pattern.findall(data) |
|
|
|
data_dict = {} |
|
for _, comment_text, key, value in matches: |
|
if comment_text: |
|
self.fields[key]['comment'] = comment_text.strip() |
|
data_dict[key] = value |
|
|
|
model = self.default_format |
|
if self.unknown_format and not self._is_valid_format( |
|
data_dict, self.default_format): |
|
model = self.unknown_format |
|
|
|
return model.model_validate(data_dict) |
|
|
|
def _is_valid_format(self, data: Dict, format_model: BaseModel) -> bool: |
|
try: |
|
format_model.model_validate(data) |
|
return True |
|
except ValidationError: |
|
return False |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
class DefaultFormat(BaseModel): |
|
name: str |
|
age: int |
|
|
|
class UnknownFormat(BaseModel): |
|
title: str |
|
year: int |
|
|
|
template = """如果了解该问题请按照一下格式回复 |
|
```html |
|
{format} |
|
``` |
|
否则请回复 |
|
```html |
|
{unknown_format} |
|
``` |
|
""" |
|
parser = CustomFormatParser( |
|
template, default_format=DefaultFormat, unknown_format=UnknownFormat) |
|
|
|
|
|
response = ''' |
|
<!-- User's full name --> |
|
<name type="str">John Doe</name> |
|
<!-- User's age --> |
|
<age type="int">30</age> |
|
''' |
|
|
|
result = parser.parse_response(response) |
|
print(result) |
|
|