Inference failed with error "Setting cache_implementation to 'sliding_window' requires the model config supporting sliding window attention..."
Hello,
I'm facing a strange issue when using "gemma-2-2b-it" model to do model inference, the error is:
"
Setting cache_implementation to 'sliding_window' requires the model config supporting sliding window attention, please check if there is a sliding_window field in the model config and it's not set to None.
"
My inference code is as same as which comes from HuggingFace web site:
pip install accelerate
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2-2b-it",
device_map="auto",
torch_dtype=torch.bfloat16,
)
input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=32)
print(tokenizer.decode(outputs[0]))
"
How to fix it ?
Thanks
What version of transformers are you using? I don't see this issue when I use 4.44.2
It looks that the issue comes from "config.json" file, it misses parameter "sliding_window",
The correct content is:
{
"architectures": [
"Gemma2ForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"attn_logit_softcapping": 50.0,
"bos_token_id": 2,
"cache_implementation": "hybrid",
"eos_token_id": [
1,
107
],
"final_logit_softcapping": 30.0,
"head_dim": 256,
"hidden_act": "gelu_pytorch_tanh",
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 2304,
"initializer_range": 0.02,
"intermediate_size": 9216,
"max_position_embeddings": 8192,
"model_type": "gemma2",
"num_attention_heads": 8,
"num_hidden_layers": 26,
"num_key_value_heads": 4,
"pad_token_id": 0,
"query_pre_attn_scalar": 256,
"rms_norm_eps": 1e-06,
"rope_theta": 10000.0,
"sliding_window": 4096,
"torch_dtype": "bfloat16",
"transformers_version": "4.42.4",
"use_cache": true,
"vocab_size": 256000
}
Close this as this is not a "true" issue:)