File size: 572 Bytes
5c0d7ef 5edaefc 5c0d7ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
from transformers import LlamaConfig as OrigLlamaConfig
class LlamaConfig(OrigLlamaConfig):
model_type = "llama_aqlm"
def __init__(
self,
nbits_per_codebook: int = 16,
num_codebooks: int = 1,
out_group_size: int = 1,
in_group_size: int = 8,
**kwargs,
):
super().__init__(**kwargs)
self.aqlm = {
"nbits_per_codebook": nbits_per_codebook,
"num_codebooks": num_codebooks,
"out_group_size": out_group_size,
"in_group_size": in_group_size,
}
|