from transformers import LlamaConfig as OrigLlamaConfig class LlamaConfig(OrigLlamaConfig): model_type = "llama_aqlm" def __init__( self, nbits_per_codebook: int = 16, num_codebooks: int = 1, out_group_size: int = 1, in_group_size: int = 8, **kwargs, ): super().__init__(**kwargs) self.aqlm = { "nbits_per_codebook": nbits_per_codebook, "num_codebooks": num_codebooks, "out_group_size": out_group_size, "in_group_size": in_group_size, }