model.generate cannot handle past_key_values correctly

#13
by Zhuangl - opened

Here is my script.

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextStreamer,
)

model_name_or_path='./glm-4-9b-chat'
device = "cuda"

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True
).to(device).eval()

streamer = TextStreamer(tokenizer, skip_prompt=True, decode_kwargs=dict(skip_special_tokens=True))
gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1, "streamer":streamer, "return_dict_in_generate": True}
past_key_values = None
inputs = None

system_message = "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。"
history = [{"role": "system", "content": system_message}]

past_key_values = None

while True:
    query = input("Human: ")
    if len(query.strip()) == 0:
        history = [{"role": "system", "content": system_message}]

        continue
        
    history.append({
        'role': "user",
        "content": query
    })
    inputs = tokenizer.apply_chat_template(history,
                add_generation_prompt=True,
                tokenize=True,
                return_tensors="pt",
                return_dict=True
                )
    inputs = inputs.to(device)
    print(inputs['input_ids'].shape)

    with torch.no_grad():
        print("Assistant:")
        outputs = model.generate(**inputs, **gen_kwargs, past_key_values=past_key_values)
        past_key_values = outputs['past_key_values']
        
        outputs = outputs['sequences'][:, inputs['input_ids'].shape[1]:]
        outputs = tokenizer.decode(outputs[0], skip_special_tokens=True)
        history.append({
            'role': "assistant",
            "content": outputs
        })

The first round chat works fine, but the second round (which has past_key_values passed into model.generate) runs into error:

Traceback (most recent call last):
  File "/home/ubuntu/llm/generation.txt", line 52, in <module>
    outputs = model.generate(**inputs, **gen_kwargs, past_key_values=past_key_values)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/generation/utils.py", line 1914, in generate
    result = self._sample(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/generation/utils.py", line 2651, in _sample
    outputs = self(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 997, in forward
    transformer_outputs = self.transformer(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 882, in forward
    full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
  File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 783, in get_masks
    full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
RuntimeError: The size of tensor a (130) must match the size of tensor b (69) at non-singleton dimension 2

I tested this script on LLama model and it works well. So maybe you guys could share some suggestions based on ChatGLM modeling implementation.

Knowledge Engineering Group (KEG) & Data Mining at Tsinghua University org

past_key_values = outputs['past_key_values'] 这个已经不能用了,transofmers 4.42升级了写法,不是这么写的,可以看一下我们最新模型实现部分 把这个部分替换掉了

@zRzRzRzRzRzRzR 感谢回复。
我用的transformer==4.42.3,llama3上跑是没问题的。另外您说的最新模型实现 可以提供下link么,我现在用的还是这个repo里的modeling_chatglm.py

Knowledge Engineering Group (KEG) & Data Mining at Tsinghua University org

版本是这个,权重没有变换,用的最新的这个repo就行(需要pull 其他配置文件,然后你再试一下,

full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) 这行代码应该是没有问题的

Sign up or log in to comment