File size: 8,640 Bytes
d8d694f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import json
import logging

import streamlit as st
import torch
from lagent.actions import ActionExecutor
from lagent.agents.internlm2_agent import Internlm2Protocol
from lagent.schema import ActionReturn, AgentReturn
from lmdeploy import GenerationConfig

from utils.digital_human.digital_human_worker import gen_digital_human_video_in_spinner
from utils.rag.rag_worker import build_rag_prompt
from utils.tts.tts_worker import gen_tts_in_spinner



def prepare_generation_config(skip_special_tokens=True):

    gen_config = GenerationConfig(
        top_p=0.8,
        temperature=0.7,
        repetition_penalty=1.005,
        skip_special_tokens=skip_special_tokens,
    )  # top_k=40, min_new_tokens=200
    return gen_config


def combine_history(prompt, meta_instruction, history_msg=None, first_input_str=""):
    total_prompt = [{"role": "system", "content": meta_instruction}]

    if first_input_str != "":
        total_prompt.append({"role": "user", "content": first_input_str})

    if history_msg is not None:
        for message in history_msg:
            total_prompt.append({"role": message["role"], "content": message["content"]})

    total_prompt.append({"role": "user", "content": prompt})
    return [total_prompt]

'''
@st.cache_resource
# def init_handlers(departure_place, delivery_company_name):
def init_handlers():
    # from utils.agent.delivery_time_query import DeliveryTimeQueryAction  # isort:skip

    META_CN = "当开启工具以及代码时,根据需求选择合适的工具进行调用"

    INTERPRETER_CN = (
        "你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。"
        "当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。"
        "这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),"
        "复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),"
        "文本处理和分析(比如文本解析和自然语言处理),"
        "机器学习和数据科学(用于展示模型训练和数据可视化),"
        "以及文件操作和数据导入(处理CSV、JSON等格式的文件)。"
    )

    PLUGIN_CN = (
        "你可以使用如下工具:"
        "\n{prompt}\n"
        "如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! "
        "同时注意你可以使用的工具,不要随意捏造!"
    )

    protocol_handler = Internlm2Protocol(
        meta_prompt=META_CN,
        interpreter_prompt=INTERPRETER_CN,
        plugin_prompt=PLUGIN_CN,
        tool=dict(
            begin="{start_token}{name}\n",
            start_token="<|action_start|>",
            name_map=dict(plugin="<|plugin|>", interpreter="<|interpreter|>"),
            belong="assistant",
            end="<|action_end|>\n",
        ),
    )
    
    action_list = [
        DeliveryTimeQueryAction(
            departure_place=departure_place,
            delivery_company_name=delivery_company_name,
        ),
    ]
    
    # plugin_map = {action.name: action for action in action_list}
    # plugin_name = [action.name for action in action_list]
    # plugin_action = [plugin_map[name] for name in plugin_name]
    # action_executor = ActionExecutor(actions=plugin_action)

    # return action_executor, protocol_handler
'''

# def get_agent_result(model_pipe, prompt_input, departure_place, delivery_company_name):
def get_agent_result(model_pipe, prompt_input):

    # action_executor, protocol_handler = init_handlers(departure_place, delivery_company_name)

    inner_history = [{"role": "user", "content": prompt_input}]  # NOTE TEST !!!
    interpreter_executor = None
    max_turn = 7
    for _ in range(max_turn):

        prompt = protocol_handler.format(  # 生成 agent prompt
            inner_step=inner_history,
            plugin_executor=action_executor,
            interpreter_executor=interpreter_executor,
        )
        cur_response = ""

        agent_return = AgentReturn()
        for item in model_pipe.stream_infer(prompt, gen_config=prepare_generation_config(skip_special_tokens=False)):
            if "~" in item.text:
                item.text = item.text.replace("~", "。").replace("。。", "。")

            cur_response += item.text

            name, language, action = protocol_handler.parse(
                message=cur_response,
                plugin_executor=action_executor,
                interpreter_executor=interpreter_executor,
            )
            if name:  # "plugin"
                if name == "plugin":
                    if action_executor:
                        executor = action_executor
                    else:
                        logging.info(msg="No plugin is instantiated!")
                        continue
                    try:
                        action = json.loads(action)
                    except Exception as e:
                        logging.info(msg=f"Invaild action {e}")
                        continue
                elif name == "interpreter":
                    if interpreter_executor:
                        executor = interpreter_executor
                    else:
                        logging.info(msg="No interpreter is instantiated!")
                        continue
                agent_return.response = action

        print(f"Agent response: {cur_response}")

        if name:
            print(f"Agent action: {action}")
            action_return: ActionReturn = executor(action["name"], action["parameters"])

            try:
                return_str = action_return.result[0]["content"]
                return return_str
            except Exception as e:
                return ""

        if not name:
            agent_return.response = language
            break

    return ""


def get_turbomind_response(
    prompt,
    meta_instruction,
    user_avator,
    robot_avator,
    model_pipe,
    session_messages,
    add_session_msg=True,
    first_input_str="",
    rag_retriever=None,
    product_name="",
    enable_agent=True,
    # departure_place=None,
    # delivery_company_name=None,
):

    # ====================== Agent ======================
    agent_response = ""
    if enable_agent:
        GENERATE_AGENT_TEMPLATE = (
            "这是网上获取到的信息:“{}”\n 客户的问题:“{}” \n 请认真阅读信息并运用你的性格进行解答。"  # RAG prompt 模板
        )
        # agent_response = get_agent_result(model_pipe, prompt, departure_place, delivery_company_name)
        agent_response = get_agent_result(model_pipe, prompt)
        if agent_response != "":
            agent_response = GENERATE_AGENT_TEMPLATE.format(agent_response, prompt)
            print(f"Agent response: {agent_response}")
    prompt_pro = agent_response

    # ====================== RAG ======================
    if rag_retriever is not None and prompt_pro == "":
        # 如果 Agent 没有执行,则使用 RAG 查询数据库
        prompt_pro = build_rag_prompt(rag_retriever, product_name, prompt)

    # ====================== 加上历史信息 ======================
    real_prompt = combine_history(
        prompt_pro if prompt_pro != "" else prompt,
        meta_instruction,
        history_msg=session_messages,
        first_input_str=first_input_str,
    )  # 是否加上历史对话记录

    print(real_prompt)

    # Add user message to chat history
    if add_session_msg:
        session_messages.append({"role": "user", "content": prompt, "avatar": user_avator})

    with st.chat_message("assistant", avatar=robot_avator):
        message_placeholder = st.empty()
        cur_response = ""
        for item in model_pipe.stream_infer(real_prompt, gen_config=prepare_generation_config()):

            if "~" in item.text:
                item.text = item.text.replace("~", "。").replace("。。", "。")

            cur_response += item.text
            message_placeholder.markdown(cur_response + "▌")
        message_placeholder.markdown(cur_response)

        tts_save_path = gen_tts_in_spinner(cur_response)  # 一整句生成
        gen_digital_human_video_in_spinner(tts_save_path)

        # Add robot response to chat history
        session_messages.append(
            {
                "role": "assistant",
                "content": cur_response,  # pylint: disable=undefined-loop-variable
                "avatar": robot_avator,
                "wav": tts_save_path,
            }
        )
    torch.cuda.empty_cache()