Ending early when tool calling (due to no thoughts?)

#28
by curltron - opened

What I've noticed is that both the official mistral-common prompt and the vllm prompts all assume that an Assistant message can only contain either tool calls or content, but not both. Of course, the prompt provided here doesn't even allow tool calling for some reason.

The result is that because my chat template (custom for tool calling based on vLLM's) allows for Assistant message having thoughts before calling tools, it will end the prompt early without a tool call, pre-emptively ending simple ReAct agents.

I think this is ultimately a big mistake given I've observed allowing the model to "think" before calling tools is beneficial. It also coincides with the new "chain of thought" feature in reasoning models.

It's a shame IMO because this model is really great at tool calling otherwise and is particularly fast and performant (running on a 3090 TI at 4bit). I think fine-tuning on some instruct sequences that have both thoughts AND tool calls would be hugely beneficial. The vLLM parsing will still work in that case as well (though their chat template will need to be tuned, maybe as I do below).

My single tool chat template for reference (based on vllm) that works well other than the early stopping:

{%- if messages[0]["role"] == "system" %}
    {%- set system_message = messages[0]["content"] %}
    {%- set loop_messages = messages[1:] %}
{%- else %}
    {%- set loop_messages = messages %}
{%- endif %}
{%- if not tools is defined %}
    {%- set tools = none %}
{%- elif tools is not none %}
    {%- set parallel_tool_prompt = "You are a helpful assistant that can call tools, one at a time. Use the format: `[TOOL_CALLS] [{\"name\": tool call name, \"arguments\": tool call arguments}]`. Do not attempt to interpret them or otherwise provide a response until you receive a tool call result that you can interpret for the user." %}
    {%- if system_message is defined %}
        {%- set system_message = parallel_tool_prompt + "\n\n" + system_message %}
    {%- else %}
        {%- set system_message = parallel_tool_prompt %}
    {%- endif %}
{%- endif %}
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}

{{- bos_token }}
{%- for message in loop_messages %}
    {%- if message["role"] == "user" %}
        {%- if tools is not none and (message == user_messages[-1]) %}
            {{- "[AVAILABLE_TOOLS] [" }}
            {%- for tool in tools %}
                {%- set tool = tool.function %}
                {{- '{"type": "function", "function": {' }}
                {%- for key, val in tool.items() if key != "return" %}
                    {%- if val is string %}
                        {{- '"' + key + '": "' + val + '"' }}
                    {%- else %}
                        {{- '"' + key + '": ' + val|tojson }}
                    {%- endif %}
                    {%- if not loop.last %}
                        {{- ", " }}
                    {%- endif %}
                {%- endfor %}
                {{- "}}" }}
                {%- if not loop.last %}
                    {{- ", " }}
                {%- else %}
                    {{- "]" }}
                {%- endif %}
            {%- endfor %}
            {{- "[/AVAILABLE_TOOLS]\n\n" }}
        {%- endif %}
        {{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]\n\n" }}
    {%- elif message["role"] == "tool_calls" or message.tool_calls is defined %}
        {%- if message.tool_calls is defined %}
            {%- set tool_calls = message.tool_calls %}
        {%- else %}
            {%- set tool_calls = message.content %}
        {%- endif %}
        {%- if message["role"] == "assistant" %}
            {{- message["content"] }}
        {%- endif %}
        {{- "\n[TOOL_CALLS][" }}
        {%- for tool_call in tool_calls %}
            {%- set out = tool_call.function|tojson %}
            {{- out[:-1] }}
            {%- if not tool_call.id is defined or tool_call.id|length < 9 %}
                {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }}
            {%- endif %}
            {{- ', "id": "' + tool_call.id[-9:] + '"}' }}
            {%- if not loop.last %}
                {{- ", " }}
            {%- else %}
                {{- "]" + eos_token }}
            {%- endif %}
        {%- endfor %}
    {%- elif message["role"] == "assistant" %}
        {{- " " + message["content"] + eos_token }}
    {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}
        {%- if message.content is defined and message.content.content is defined %}
            {%- set content = message.content.content %}
        {%- else %}
            {%- set content = message.content %}
        {%- endif %}
        {{- '\n[TOOL_RESULTS] {"content": ' + content|string + ", " }}
        {%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %}
            {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }}
        {%- endif %}
        {{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }}
    {%- else %}
        {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
    {%- endif %}
{%- endfor %}

Sign up or log in to comment