File size: 8,640 Bytes
d8d694f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
import json
import logging
import streamlit as st
import torch
from lagent.actions import ActionExecutor
from lagent.agents.internlm2_agent import Internlm2Protocol
from lagent.schema import ActionReturn, AgentReturn
from lmdeploy import GenerationConfig
from utils.digital_human.digital_human_worker import gen_digital_human_video_in_spinner
from utils.rag.rag_worker import build_rag_prompt
from utils.tts.tts_worker import gen_tts_in_spinner
def prepare_generation_config(skip_special_tokens=True):
gen_config = GenerationConfig(
top_p=0.8,
temperature=0.7,
repetition_penalty=1.005,
skip_special_tokens=skip_special_tokens,
) # top_k=40, min_new_tokens=200
return gen_config
def combine_history(prompt, meta_instruction, history_msg=None, first_input_str=""):
total_prompt = [{"role": "system", "content": meta_instruction}]
if first_input_str != "":
total_prompt.append({"role": "user", "content": first_input_str})
if history_msg is not None:
for message in history_msg:
total_prompt.append({"role": message["role"], "content": message["content"]})
total_prompt.append({"role": "user", "content": prompt})
return [total_prompt]
'''
@st.cache_resource
# def init_handlers(departure_place, delivery_company_name):
def init_handlers():
# from utils.agent.delivery_time_query import DeliveryTimeQueryAction # isort:skip
META_CN = "当开启工具以及代码时,根据需求选择合适的工具进行调用"
INTERPRETER_CN = (
"你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。"
"当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。"
"这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),"
"复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),"
"文本处理和分析(比如文本解析和自然语言处理),"
"机器学习和数据科学(用于展示模型训练和数据可视化),"
"以及文件操作和数据导入(处理CSV、JSON等格式的文件)。"
)
PLUGIN_CN = (
"你可以使用如下工具:"
"\n{prompt}\n"
"如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! "
"同时注意你可以使用的工具,不要随意捏造!"
)
protocol_handler = Internlm2Protocol(
meta_prompt=META_CN,
interpreter_prompt=INTERPRETER_CN,
plugin_prompt=PLUGIN_CN,
tool=dict(
begin="{start_token}{name}\n",
start_token="<|action_start|>",
name_map=dict(plugin="<|plugin|>", interpreter="<|interpreter|>"),
belong="assistant",
end="<|action_end|>\n",
),
)
action_list = [
DeliveryTimeQueryAction(
departure_place=departure_place,
delivery_company_name=delivery_company_name,
),
]
# plugin_map = {action.name: action for action in action_list}
# plugin_name = [action.name for action in action_list]
# plugin_action = [plugin_map[name] for name in plugin_name]
# action_executor = ActionExecutor(actions=plugin_action)
# return action_executor, protocol_handler
'''
# def get_agent_result(model_pipe, prompt_input, departure_place, delivery_company_name):
def get_agent_result(model_pipe, prompt_input):
# action_executor, protocol_handler = init_handlers(departure_place, delivery_company_name)
inner_history = [{"role": "user", "content": prompt_input}] # NOTE TEST !!!
interpreter_executor = None
max_turn = 7
for _ in range(max_turn):
prompt = protocol_handler.format( # 生成 agent prompt
inner_step=inner_history,
plugin_executor=action_executor,
interpreter_executor=interpreter_executor,
)
cur_response = ""
agent_return = AgentReturn()
for item in model_pipe.stream_infer(prompt, gen_config=prepare_generation_config(skip_special_tokens=False)):
if "~" in item.text:
item.text = item.text.replace("~", "。").replace("。。", "。")
cur_response += item.text
name, language, action = protocol_handler.parse(
message=cur_response,
plugin_executor=action_executor,
interpreter_executor=interpreter_executor,
)
if name: # "plugin"
if name == "plugin":
if action_executor:
executor = action_executor
else:
logging.info(msg="No plugin is instantiated!")
continue
try:
action = json.loads(action)
except Exception as e:
logging.info(msg=f"Invaild action {e}")
continue
elif name == "interpreter":
if interpreter_executor:
executor = interpreter_executor
else:
logging.info(msg="No interpreter is instantiated!")
continue
agent_return.response = action
print(f"Agent response: {cur_response}")
if name:
print(f"Agent action: {action}")
action_return: ActionReturn = executor(action["name"], action["parameters"])
try:
return_str = action_return.result[0]["content"]
return return_str
except Exception as e:
return ""
if not name:
agent_return.response = language
break
return ""
def get_turbomind_response(
prompt,
meta_instruction,
user_avator,
robot_avator,
model_pipe,
session_messages,
add_session_msg=True,
first_input_str="",
rag_retriever=None,
product_name="",
enable_agent=True,
# departure_place=None,
# delivery_company_name=None,
):
# ====================== Agent ======================
agent_response = ""
if enable_agent:
GENERATE_AGENT_TEMPLATE = (
"这是网上获取到的信息:“{}”\n 客户的问题:“{}” \n 请认真阅读信息并运用你的性格进行解答。" # RAG prompt 模板
)
# agent_response = get_agent_result(model_pipe, prompt, departure_place, delivery_company_name)
agent_response = get_agent_result(model_pipe, prompt)
if agent_response != "":
agent_response = GENERATE_AGENT_TEMPLATE.format(agent_response, prompt)
print(f"Agent response: {agent_response}")
prompt_pro = agent_response
# ====================== RAG ======================
if rag_retriever is not None and prompt_pro == "":
# 如果 Agent 没有执行,则使用 RAG 查询数据库
prompt_pro = build_rag_prompt(rag_retriever, product_name, prompt)
# ====================== 加上历史信息 ======================
real_prompt = combine_history(
prompt_pro if prompt_pro != "" else prompt,
meta_instruction,
history_msg=session_messages,
first_input_str=first_input_str,
) # 是否加上历史对话记录
print(real_prompt)
# Add user message to chat history
if add_session_msg:
session_messages.append({"role": "user", "content": prompt, "avatar": user_avator})
with st.chat_message("assistant", avatar=robot_avator):
message_placeholder = st.empty()
cur_response = ""
for item in model_pipe.stream_infer(real_prompt, gen_config=prepare_generation_config()):
if "~" in item.text:
item.text = item.text.replace("~", "。").replace("。。", "。")
cur_response += item.text
message_placeholder.markdown(cur_response + "▌")
message_placeholder.markdown(cur_response)
tts_save_path = gen_tts_in_spinner(cur_response) # 一整句生成
gen_digital_human_video_in_spinner(tts_save_path)
# Add robot response to chat history
session_messages.append(
{
"role": "assistant",
"content": cur_response, # pylint: disable=undefined-loop-variable
"avatar": robot_avator,
"wav": tts_save_path,
}
)
torch.cuda.empty_cache()
|