|
import json |
|
import mimetypes |
|
import os |
|
import sys |
|
import tempfile |
|
|
|
import gradio as gr |
|
import requests |
|
|
|
|
|
|
|
import schemdraw |
|
from frontend.gradio_agentchatbot.agentchatbot import AgentChatbot |
|
from frontend.gradio_agentchatbot.utils import ChatFileMessage, ChatMessage, ThoughtMetadata |
|
from lagent.schema import AgentStatusCode |
|
from schemdraw import flow |
|
|
|
|
|
import os |
|
|
|
|
|
|
|
|
|
os.system("pip show starlette") |
|
|
|
os.system("pip install tenacity") |
|
os.system("python -m mindsearch.app --lang en --model_format gpt4 --search_engine DuckDuckGoSearch &") |
|
|
|
|
|
print('MindSearch is running on http://') |
|
|
|
PLANNER_HISTORY = [] |
|
SEARCHER_HISTORY = [] |
|
user = os.environ.get('USERNAME') |
|
pwd = os.environ.get('PASSWORD') |
|
|
|
def create_search_graph(adjacency_list: dict): |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
with schemdraw.Drawing(fontsize=10, unit=1) as graph: |
|
node_pos, nodes, edges = {}, {}, [] |
|
if "root" in adjacency_list: |
|
queue, layer, response_level = ["root"], 0, 0 |
|
while queue: |
|
layer_len = len(queue) |
|
for i in range(layer_len): |
|
node_name = queue.pop(0) |
|
node_pos[node_name] = (layer * 5, -i * 3) |
|
for item in adjacency_list[node_name]: |
|
if item["name"] == "response": |
|
response_level = max(response_level, (layer + 1) * 5) |
|
else: |
|
queue.append(item["name"]) |
|
edges.append((node_name, item["name"])) |
|
layer += 1 |
|
for node_name, (x, y) in node_pos.items(): |
|
if node_name == "root": |
|
node = flow.Terminal().label(node_name).at((x, y)).color("pink") |
|
else: |
|
node = flow.RoundBox(w=3.5, h=1.75).label(node_name).at((x, y)).color("teal") |
|
nodes[node_name] = node |
|
if response_level: |
|
response_node = ( |
|
flow.Terminal().label("response").at((response_level, 0)).color("orange") |
|
) |
|
nodes["response"] = response_node |
|
for start, end in edges: |
|
flow.Arc3(arrow="->").linestyle("--" if end == "response" else "-").at( |
|
nodes[start].E |
|
).to(nodes[end].W).color("grey" if end == "response" else "lightblue") |
|
return graph |
|
|
|
|
|
def draw_search_graph(adjacency_list: dict, suffix=".png", dpi=360) -> str: |
|
g = create_search_graph(adjacency_list) |
|
path = tempfile.mktemp(suffix=suffix) |
|
g.save(path, dpi=dpi) |
|
return path |
|
|
|
|
|
def rst_mem(): |
|
"""Reset the chatbot memory.""" |
|
if PLANNER_HISTORY: |
|
PLANNER_HISTORY.clear() |
|
return [], [], 0 |
|
|
|
|
|
def format_response(gr_history, message, response, idx=-1): |
|
if idx < 0: |
|
idx = len(gr_history) + idx |
|
if message["stream_state"] == AgentStatusCode.STREAM_ING: |
|
gr_history[idx].content = response |
|
elif message["stream_state"] == AgentStatusCode.CODING: |
|
if gr_history[idx].thought_metadata.tool_name is None: |
|
gr_history[idx].content = gr_history[idx].content.split("<|action_start|>")[0] |
|
gr_history.insert( |
|
idx + 1, |
|
ChatMessage( |
|
role="assistant", |
|
content=response, |
|
thought_metadata=ThoughtMetadata(tool_name="🖥️ Code Interpreter"), |
|
), |
|
) |
|
else: |
|
gr_history[idx].content = response |
|
elif message["stream_state"] == AgentStatusCode.PLUGIN_START: |
|
if isinstance(response, dict): |
|
response = json.dumps(response, ensure_ascii=False, indent=4) |
|
if gr_history[idx].thought_metadata.tool_name is None: |
|
gr_history[idx].content = gr_history[idx].content.split("<|action_start|>")[0] |
|
gr_history.insert( |
|
idx + 1, |
|
ChatMessage( |
|
role="assistant", |
|
content="```json\n" + response, |
|
thought_metadata=ThoughtMetadata(tool_name="🌐 Web Browser"), |
|
), |
|
) |
|
else: |
|
gr_history[idx].content = "```json\n" + response |
|
elif message["stream_state"] == AgentStatusCode.PLUGIN_END and isinstance(response, dict): |
|
gr_history[idx].content = ( |
|
f"```json\n{json.dumps(response, ensure_ascii=False, indent=4)}\n```" |
|
) |
|
elif message["stream_state"] in [AgentStatusCode.CODE_RETURN, AgentStatusCode.PLUGIN_RETURN]: |
|
try: |
|
content = json.loads(message["content"]) |
|
except json.decoder.JSONDecodeError: |
|
content = message["content"] |
|
if gr_history[idx].thought_metadata.tool_name: |
|
gr_history.insert( |
|
idx + 1, |
|
ChatMessage( |
|
role="assistant", |
|
content=( |
|
content |
|
if isinstance(content, str) |
|
else f"\n```json\n{json.dumps(content, ensure_ascii=False, indent=4)}\n```\n" |
|
), |
|
thought_metadata=ThoughtMetadata(tool_name="Execution"), |
|
), |
|
) |
|
gr_history.insert(idx + 2, ChatMessage(role="assistant", content="")) |
|
|
|
|
|
def predict(history_planner, history_searcher, node_cnt): |
|
|
|
def streaming(raw_response): |
|
for chunk in raw_response.iter_lines( |
|
chunk_size=8192, decode_unicode=False, delimiter=b"\n" |
|
): |
|
if chunk: |
|
decoded = chunk.decode("utf-8") |
|
if decoded == "\r": |
|
continue |
|
if decoded[:6] == "data: ": |
|
decoded = decoded[6:] |
|
elif decoded.startswith(": ping - "): |
|
continue |
|
response = json.loads(decoded) |
|
yield ( |
|
response["current_node"], |
|
( |
|
response["response"]["formatted"]["node"][response["current_node"]] |
|
if response["current_node"] |
|
else response["response"] |
|
), |
|
response["response"]["formatted"]["adjacency_list"], |
|
) |
|
|
|
global PLANNER_HISTORY |
|
PLANNER_HISTORY.extend(history_planner[-3:]) |
|
search_graph_msg = history_planner[-1] |
|
|
|
url = "http://localhost:8002/solve" |
|
data = {"inputs": PLANNER_HISTORY[-3].content} |
|
raw_response = requests.post(url, json=data, timeout=60, stream=True) |
|
|
|
node_id2msg_idx = {} |
|
for resp in streaming(raw_response): |
|
node_name, agent_message, adjacency_list = resp |
|
dedup_nodes = set(adjacency_list) | { |
|
val["name"] for vals in adjacency_list.values() for val in vals |
|
} |
|
if dedup_nodes and len(dedup_nodes) != node_cnt: |
|
node_cnt = len(dedup_nodes) |
|
graph_path = draw_search_graph(adjacency_list) |
|
search_graph_msg.file.path = graph_path |
|
search_graph_msg.file.mime_type = mimetypes.guess_type(graph_path)[0] |
|
if node_name: |
|
if node_name in ["root", "response"]: |
|
continue |
|
node_id = f'【{node_name}】{agent_message["content"]}' |
|
agent_message = agent_message["response"] |
|
response = ( |
|
agent_message["formatted"]["action"] |
|
if agent_message["stream_state"] |
|
in [AgentStatusCode.PLUGIN_START, AgentStatusCode.PLUGIN_END] |
|
else agent_message["formatted"] and agent_message["formatted"].get("thought") |
|
) |
|
if node_id not in node_id2msg_idx: |
|
node_id2msg_idx[node_id] = len(history_searcher) + 1 |
|
history_searcher.append(ChatMessage(role="user", content=node_id)) |
|
history_searcher.append(ChatMessage(role="assistant", content="")) |
|
offset = len(history_searcher) |
|
format_response(history_searcher, agent_message, response, node_id2msg_idx[node_id]) |
|
flag, incr = False, len(history_searcher) - offset |
|
for key, value in node_id2msg_idx.items(): |
|
if flag or key == node_id: |
|
node_id2msg_idx[key] = value + incr |
|
if not flag: |
|
flag = True |
|
yield history_planner, history_searcher, node_cnt |
|
else: |
|
response = ( |
|
agent_message["formatted"]["action"] |
|
if agent_message["stream_state"] |
|
in [AgentStatusCode.CODING, AgentStatusCode.CODE_END] |
|
else agent_message["formatted"] and agent_message["formatted"].get("thought") |
|
) |
|
format_response(history_planner, agent_message, response, -2) |
|
if agent_message["stream_state"] == AgentStatusCode.END: |
|
PLANNER_HISTORY = history_planner |
|
yield history_planner, history_searcher, node_cnt |
|
return history_planner, history_searcher, node_cnt |
|
|
|
|
|
with gr.Blocks(css=os.path.join(os.path.dirname(__file__), "css", "gradio_front.css")) as demo: |
|
with gr.Column(elem_classes="chat-box"): |
|
gr.HTML("""<h1 align="center">Talk to me, Jack</h1>""") |
|
gr.HTML( |
|
"""<p style="text-align: center; font-family: Arial, sans-serif;"> |
|
Please be explicit on the request. It's not a perfect world, just yet. Be patient for the time being.</p> """ |
|
) |
|
|
|
gr.HTML( |
|
""" |
|
<h1 align='center'><img |
|
src= |
|
'https://img.freepik.com/premium-vector/secretary-cartoon-character-white-background_1639-28887.jpg' |
|
alt='Maria's Waiting for You' class="logo" width="150"></h1> """ |
|
) |
|
node_count = gr.State(0) |
|
with gr.Row(): |
|
planner = AgentChatbot( |
|
label="planner", |
|
height=600, |
|
show_label=True, |
|
show_copy_button=True, |
|
bubble_full_width=False, |
|
render_markdown=True, |
|
elem_classes="chatbot-container", |
|
) |
|
searcher = AgentChatbot( |
|
label="searcher", |
|
height=600, |
|
show_label=True, |
|
show_copy_button=True, |
|
bubble_full_width=False, |
|
render_markdown=True, |
|
elem_classes="chatbot-container", |
|
) |
|
with gr.Row(elem_classes="chat-box"): |
|
|
|
user_input = gr.Textbox( |
|
show_label=False, |
|
placeholder="Type your message...", |
|
lines=1, |
|
container=False, |
|
elem_classes="editor", |
|
scale=4, |
|
) |
|
|
|
submitBtn = gr.Button("submit", variant="primary", elem_classes="toolbarButton", scale=1) |
|
clearBtn = gr.Button("clear", variant="secondary", elem_classes="toolbarButton", scale=1) |
|
with gr.Row(elem_classes="examples-container"): |
|
examples_component = gr.Examples( |
|
[ |
|
["Help me find a portable battery bank for air travel. I'm looking for one made or designed by an american company, which provides the largest capacity and it has the most safety features. Please provide a recommendation that meets the requirements above."], |
|
], |
|
inputs=user_input, |
|
label="Try these examples:", |
|
) |
|
|
|
def user(query, history): |
|
history.append(ChatMessage(role="user", content=query)) |
|
history.append(ChatMessage(role="assistant", content="")) |
|
graph_path = draw_search_graph({"root": []}) |
|
history.append( |
|
ChatFileMessage( |
|
role="assistant", |
|
file=gr.FileData(path=graph_path, mime_type=mimetypes.guess_type(graph_path)[0]), |
|
) |
|
) |
|
return "", history |
|
|
|
submitBtn.click(user, [user_input, planner], [user_input, planner], queue=False).then( |
|
predict, |
|
[planner, searcher, node_count], |
|
[planner, searcher, node_count], |
|
) |
|
clearBtn.click(rst_mem, None, [planner, searcher, node_count], queue=False) |
|
|
|
demo.queue() |
|
|
|
def same_auth(username, password): |
|
return password == pwd |
|
|
|
authtuple = same_auth |
|
demo.launch(share=True,auth=authtuple) |
|
|
|
|
|
|