File size: 4,625 Bytes
3e0718c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1206897
37dc901
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e0718c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import json
from copy import deepcopy
from typing import Any, Dict, List

from flow_modules.aiflows.ChatFlowModule import ChatAtomicFlow

from dataclasses import dataclass


@dataclass
class Command:
    name: str
    description: str
    input_args: List[str]

class CodeWriterCtrlFlow(ChatAtomicFlow):
    """refer to https://huggingface.co/aiflows/JarvisFlowModule/blob/main/Controller_JarvisFlow.py

    This class controls the execution of the CodeWriterFlow.

    *Input Interface Non Initialized*:
    - `goal`

    *Input Interface Initialized*:
    - `goal`
    - `code`
    - `feedback`

    *Output Interface*:
    - `command`
    - `command_args`

    *Config Parameters*:
    - `backend`: the backend used to call the LLM.
    - `commands`: a list of commands that the controller can use.
    - `system_message_prompt_template`: the template of the system message prompt.
    - `init_human_message_prompt_template`: the template of the init human (user) message prompt.
    - `human_message_prompt_template`: the template of the human (user) message prompt.
    - `previous_messages`: the sliding window of previous messages.

    """
    def __init__(
            self,
            commands: List[Command],
            **kwargs):
        super().__init__(**kwargs)
        self.system_message_prompt_template = self.system_message_prompt_template.partial(
            commands=self._build_commands_manual(commands),
        )
        self.hint_for_model = """
        Make sure your response is in the following format:
              Response Format:
              {
              "command": "call code writer, the tester, or to finish",
              "command_args": {
                  "arg name": "value"
                  }
              }
        """

    @staticmethod
    def _build_commands_manual(commands: List[Command]) -> str:
        ret = ""
        for i, command in enumerate(commands):
            command_input_json_schema = json.dumps(
                {input_arg: f"YOUR_{input_arg.upper()}" for input_arg in command.input_args})
            ret += f"{i + 1}. {command.name}: {command.description} Input arguments (given in the JSON schema): {command_input_json_schema}\n"
        return ret

    @classmethod
    def instantiate_from_config(cls, config):
        flow_config = deepcopy(config)

        kwargs = {"flow_config": flow_config}

        # ~~~ Set up prompts ~~~
        kwargs.update(cls._set_up_prompts(flow_config))

        # ~~~Set up backend ~~~
        kwargs.update(cls._set_up_backend(flow_config))

        # ~~~ Set up commands ~~~
        commands = flow_config["commands"]
        commands = [
            Command(name, command_conf["description"], command_conf["input_args"]) for name, command_conf in
            commands.items()
        ]
        kwargs.update({"commands": commands})

        # ~~~ Instantiate flow ~~~
        return cls(**kwargs)

    def _update_prompts_and_input(self, input_data: Dict[str, Any]):
        if 'goal' in input_data:
            input_data['goal'] += self.hint_for_model
        if 'feedback' in input_data:
            input_data['feedback'] += self.hint_for_model

    def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
        self._update_prompts_and_input(input_data)

        # ~~~when conversation is initialized, append the updated system prompts to the chat history ~~~
        if self._is_conversation_initialized():
            updated_system_message_content = self._get_message(self.system_message_prompt_template, input_data)
            self._state_update_add_chat_message(content=updated_system_message_content,
                                                role=self.flow_config["system_name"])

        while True:
            api_output = super().run(input_data)["api_output"].strip()
            try:
                response = json.loads(api_output)
                return response
            except (json.decoder.JSONDecodeError, json.JSONDecodeError):
                updated_system_message_content = self._get_message(self.system_message_prompt_template, input_data)
                self._state_update_add_chat_message(content=updated_system_message_content,
                                                    role=self.flow_config["system_name"])
                new_goal = "The previous respond cannot be parsed with json.loads. Next time, do not provide any comments or code blocks. Make sure your next response is purely json parsable."
                new_input_data = input_data.copy()
                new_input_data['feedback'] = new_goal
                input_data = new_input_data