import math import numpy as np import torch import transformers from auto_gptq.nn_modules.triton_utils.mixin import TritonModuleMixin from torch import nn def weight_quant(weight, num_bits=1): dtype = weight.dtype weight = weight.float() s = 1 / weight.abs().mean().clamp(min=1e-5) result = (weight * s).round().clamp(-1, 1) / s return result.type(dtype) def weight_quant2(weight, num_bits=1): dtype = weight.dtype weight = weight.float() s = 1 / weight.abs().mean().clamp(min=1e-5) result = (weight * s).round().clamp(-1, 1) return result.type(dtype), 1 / s def activation_quant(x, num_bits=8): dtype = x.dtype x = x.float() Qn = -(2 ** (num_bits - 1)) Qp = 2 ** (num_bits - 1) - 1 s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) result = (x * s).round().clamp(Qn, Qp) / s return result.type(dtype) def optimized_linear(quant_input, weight, scale): # 重み行列のスパース性を利用して、乗算の代わりに加算と減算を行う pos_mask = weight == 1 neg_mask = weight == -1 # 加算と減算を行い、結果を集約する pos_sum = torch.matmul(quant_input, pos_mask.to(quant_input.dtype)) neg_sum = torch.matmul(quant_input, neg_mask.to(quant_input.dtype)) result = pos_sum - neg_sum # scaleをかける result *= scale return result class BitLinear(nn.Linear): def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs): super(BitLinear, self).__init__(*kargs, **kwargs) """ RMSNorm is placed outside BitLinear """ self.weight_bits = weight_bits self.input_bits = input_bits self.quant_initialized = False self.quant_scale = None def quantize(self): if not self.quant_initialized: quant_weight, quant_scale = weight_quant2(self.weight, self.weight_bits) quant_weight = self.weight + (quant_weight - self.weight).detach() # print(f'Quantized weight: {quant_weight}') # print(f'Quantized scale: {quant_scale}') self.weight.data = quant_weight self.quant_scale = quant_scale self.quant_initialized = True def forward(self, input): if not self.quant_initialized: self.quantize() quant_input = ( input + (activation_quant(input, self.input_bits) - input).detach() ) out = nn.functional.linear(quant_input, self.weight) * self.quant_scale if self.bias is not None: out += self.bias.view(1, -1).expand_as(out) return out # Original code from https://github.com/AutoGPTQ/AutoGPTQ # MIT License try: from auto_gptq.nn_modules.triton_utils.kernels import ( QuantLinearFunction, QuantLinearInferenceOnlyFunction, quant_matmul_248, quant_matmul_inference_only_248, transpose_quant_matmul_248, ) except ImportError as e: triton_import_exception = e def error_raiser_triton(*args, **kwargs): raise ValueError( f"Trying to use the triton backend, but could not import triton dependencies with the following error: {triton_import_exception}" ) class FakeTriton: def __getattr__(self, name): raise ImportError( f"Trying to use the triton backend, but could not import triton dependencies with the following error: {triton_import_exception}" ) quant_matmul_248 = error_raiser_triton transpose_quant_matmul_248 = error_raiser_triton quant_matmul_inference_only_248 = error_raiser_triton QuantLinearFunction = FakeTriton QuantLinearInferenceOnlyFunction = FakeTriton class QuantizedBitLinear(nn.Module, TritonModuleMixin): QUANT_TYPE = "triton" def __init__( self, infeatures, outfeatures, bias, weight_bits=1, input_bits=8, quant_bits=2, group_size=128, trainable=False, **kwargs, ): super().__init__() if quant_bits not in [2, 4, 8]: raise NotImplementedError("Only 2,4,8 bits are supported.") if infeatures % 32 != 0 or outfeatures % 32 != 0: raise NotImplementedError( "in_feature and out_feature must be divisible by 32." ) self.infeatures = infeatures self.outfeatures = outfeatures self.weight_bits = weight_bits self.input_bits = input_bits self.quant_bits = quant_bits self.group_size = group_size if group_size != -1 else infeatures self.maxq = 2**self.quant_bits - 1 self.register_buffer( "qweight", torch.zeros( (infeatures // 32 * self.quant_bits, outfeatures), dtype=torch.int32 ), ) self.register_buffer( "qzeros", torch.zeros( ( math.ceil(infeatures / self.group_size), outfeatures // 32 * self.quant_bits, ), dtype=torch.int32, ), ) self.register_buffer( "scales", torch.zeros( (math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16, ), ) self.register_buffer( "g_idx", torch.tensor( [i // self.group_size for i in range(infeatures)], dtype=torch.int32 ), ) if bias: self.register_buffer( "bias", torch.zeros((outfeatures), dtype=torch.float16) ) else: self.bias = None self.register_buffer("scale", torch.tensor(1.0, dtype=torch.float16)) self.trainable = trainable def post_init(self): pass def pack(self, bitlinear: BitLinear): device = bitlinear.weight.device bitlinear = bitlinear.cpu() W = bitlinear.weight.data.clone() if isinstance(bitlinear, nn.Conv2d): W = W.flatten(1) if isinstance(bitlinear, transformers.pytorch_utils.Conv1D): W = W.t() self.scale = torch.tensor(bitlinear.quant_scale, dtype=torch.float16) # self.scales.fill_(self.scale).half() # self.scales.fill_(1).half() scales = torch.ones( self.outfeatures, math.ceil(self.infeatures / self.group_size), ) zero = 1 zeros = torch.zeros( self.outfeatures, math.ceil(self.infeatures / self.group_size), ) zeros.fill_(zero) scales = scales.t().contiguous() zeros = zeros.t().contiguous() scale_zeros = zeros * scales self.scales = scales.clone().half() if bitlinear.bias is not None: self.bias = bitlinear.bias.clone().half() intweight = [] for idx in range(self.infeatures): intweight.append( torch.round( (W[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]] ).to(torch.int)[:, None] ) intweight = torch.cat(intweight, dim=1) intweight = intweight.t().contiguous() intweight = intweight.numpy().astype(np.uint32) print(f"Int weight: {intweight}") i = 0 row = 0 qweight = np.zeros( (intweight.shape[0] // 32 * self.quant_bits, intweight.shape[1]), dtype=np.uint32, ) while row < qweight.shape[0]: if self.quant_bits in [2, 4, 8]: for j in range(i, i + (32 // self.quant_bits)): qweight[row] |= intweight[j] << (self.quant_bits * (j - i)) i += 32 // self.quant_bits row += 1 else: raise NotImplementedError("Only 2,4,8 bits are supported.") qweight = qweight.astype(np.int32) self.qweight = torch.from_numpy(qweight) print(f"Quantized weight: {self.qweight}") zeros -= 1 zeros = zeros.numpy().astype(np.uint32) qzeros = np.zeros( (zeros.shape[0], zeros.shape[1] // 32 * self.quant_bits), dtype=np.uint32 ) # math.ceil(infeatures / self.group_size), outfeatures // 32 * self.quant_bits, i = 0 col = 0 while col < qzeros.shape[1]: if self.quant_bits in [2, 4, 8]: for j in range(i, i + (32 // self.quant_bits)): qzeros[:, col] |= zeros[:, j] << (self.quant_bits * (j - i)) i += 32 // self.quant_bits col += 1 else: raise NotImplementedError("Only 2,4,8 bits are supported.") qzeros = qzeros.astype(np.int32) self.qzeros = torch.from_numpy(qzeros) self.to(device) # zeros -= 1 # zeros = zeros.numpy().astype(np.uint32) # qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.quant_bits), dtype=np.uint32) # i = 0 # col = 0 # while col < qzeros.shape[1]: # if self.quant_bits in [2, 4, 8]: # for j in range(i, i + (32 // self.quant_bits)): # qzeros[:, col] |= zeros[:, j] << (self.quant_bits * (j - i)) # i += 32 // self.quant_bits # col += 1 # else: # raise NotImplementedError("Only 2,4,8 bits are supported.") # qzeros = qzeros.astype(np.int32) # self.qzeros = torch.from_numpy(qzeros) def forward(self, x): # out_shape = x.shape[:-1] + (self.outfeatures,) # quant_linear_fn = QuantLinearFunction if self.trainable else QuantLinearInferenceOnlyFunction # out = quant_linear_fn.apply( # x.reshape(-1, x.shape[-1]), # self.qweight, # self.scales, # self.qzeros, # self.g_idx, # self.quant_bits, # self.maxq, # ) # out = out.half().reshape(out_shape) # out = out + self.bias if self.bias is not None else out # return out x = x + (activation_quant(x, self.input_bits) - x).detach() out_shape = x.shape[:-1] + (self.outfeatures,) quant_linear_fn = ( QuantLinearFunction if self.trainable else QuantLinearInferenceOnlyFunction ) out = quant_linear_fn.apply( x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.g_idx, self.quant_bits, self.maxq, ) out *= self.scale out = out.half().reshape(out_shape) out = out + self.bias if self.bias is not None else out return out @classmethod def warmup(cls, model, transpose=False, seqlen=2048): """ Pre-tunes the quantized kernel """ from tqdm import tqdm kn_values = {} for _, m in model.named_modules(): if not isinstance(m, cls): continue k = m.infeatures n = m.outfeatures if (k, n) not in kn_values: kn_values[(k, n)] = ( m.qweight, m.scales, m.qzeros, m.g_idx, m.bits, m.maxq, ) # logger.info(f"Found {len(kn_values)} unique KN Linear values.") # logger.info("Warming up autotune cache ...") with torch.no_grad(): for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)): m = 2**m for (k, n), ( qweight, scales, qzeros, g_idx, bits, maxq, ) in kn_values.items(): if transpose: a = torch.randn(m, k, dtype=torch.float16, device=model.device) quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq) a = torch.randn(m, n, dtype=torch.float16, device=model.device) transpose_quant_matmul_248( a, qweight, scales, qzeros, g_idx, bits, maxq ) else: a = torch.randn(m, k, dtype=torch.float16, device=model.device) quant_matmul_inference_only_248( a, qweight, scales, qzeros, g_idx, bits, maxq ) del kn_values