|
import math |
|
|
|
import numpy as np |
|
import torch |
|
import transformers |
|
from auto_gptq.nn_modules.triton_utils.mixin import TritonModuleMixin |
|
from torch import nn |
|
|
|
|
|
def weight_quant(weight, num_bits=1): |
|
dtype = weight.dtype |
|
weight = weight.float() |
|
s = 1 / weight.abs().mean().clamp(min=1e-5) |
|
result = (weight * s).round().clamp(-1, 1) / s |
|
return result.type(dtype) |
|
|
|
|
|
def weight_quant2(weight, num_bits=1): |
|
dtype = weight.dtype |
|
weight = weight.float() |
|
s = 1 / weight.abs().mean().clamp(min=1e-5) |
|
result = (weight * s).round().clamp(-1, 1) |
|
return result.type(dtype), 1 / s |
|
|
|
|
|
def activation_quant(x, num_bits=8): |
|
dtype = x.dtype |
|
x = x.float() |
|
Qn = -(2 ** (num_bits - 1)) |
|
Qp = 2 ** (num_bits - 1) - 1 |
|
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) |
|
result = (x * s).round().clamp(Qn, Qp) / s |
|
return result.type(dtype) |
|
|
|
|
|
def optimized_linear(quant_input, weight, scale): |
|
|
|
pos_mask = weight == 1 |
|
neg_mask = weight == -1 |
|
|
|
|
|
pos_sum = torch.matmul(quant_input, pos_mask.to(quant_input.dtype)) |
|
neg_sum = torch.matmul(quant_input, neg_mask.to(quant_input.dtype)) |
|
result = pos_sum - neg_sum |
|
|
|
|
|
result *= scale |
|
|
|
return result |
|
|
|
|
|
class BitLinear(nn.Linear): |
|
def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs): |
|
super(BitLinear, self).__init__(*kargs, **kwargs) |
|
""" |
|
RMSNorm is placed outside BitLinear |
|
""" |
|
self.weight_bits = weight_bits |
|
self.input_bits = input_bits |
|
self.quant_initialized = False |
|
self.quant_scale = None |
|
|
|
def quantize(self): |
|
if not self.quant_initialized: |
|
quant_weight, quant_scale = weight_quant2(self.weight, self.weight_bits) |
|
quant_weight = self.weight + (quant_weight - self.weight).detach() |
|
|
|
|
|
self.weight.data = quant_weight |
|
self.quant_scale = quant_scale |
|
self.quant_initialized = True |
|
|
|
def forward(self, input): |
|
if not self.quant_initialized: |
|
self.quantize() |
|
|
|
quant_input = ( |
|
input + (activation_quant(input, self.input_bits) - input).detach() |
|
) |
|
|
|
out = nn.functional.linear(quant_input, self.weight) * self.quant_scale |
|
if self.bias is not None: |
|
out += self.bias.view(1, -1).expand_as(out) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
from auto_gptq.nn_modules.triton_utils.kernels import ( |
|
QuantLinearFunction, |
|
QuantLinearInferenceOnlyFunction, |
|
quant_matmul_248, |
|
quant_matmul_inference_only_248, |
|
transpose_quant_matmul_248, |
|
) |
|
except ImportError as e: |
|
triton_import_exception = e |
|
|
|
def error_raiser_triton(*args, **kwargs): |
|
raise ValueError( |
|
f"Trying to use the triton backend, but could not import triton dependencies with the following error: {triton_import_exception}" |
|
) |
|
|
|
class FakeTriton: |
|
def __getattr__(self, name): |
|
raise ImportError( |
|
f"Trying to use the triton backend, but could not import triton dependencies with the following error: {triton_import_exception}" |
|
) |
|
|
|
quant_matmul_248 = error_raiser_triton |
|
transpose_quant_matmul_248 = error_raiser_triton |
|
quant_matmul_inference_only_248 = error_raiser_triton |
|
QuantLinearFunction = FakeTriton |
|
QuantLinearInferenceOnlyFunction = FakeTriton |
|
|
|
|
|
class QuantizedBitLinear(nn.Module, TritonModuleMixin): |
|
QUANT_TYPE = "triton" |
|
|
|
def __init__( |
|
self, |
|
infeatures, |
|
outfeatures, |
|
bias, |
|
weight_bits=1, |
|
input_bits=8, |
|
quant_bits=2, |
|
group_size=128, |
|
trainable=False, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
if quant_bits not in [2, 4, 8]: |
|
raise NotImplementedError("Only 2,4,8 bits are supported.") |
|
if infeatures % 32 != 0 or outfeatures % 32 != 0: |
|
raise NotImplementedError( |
|
"in_feature and out_feature must be divisible by 32." |
|
) |
|
self.infeatures = infeatures |
|
self.outfeatures = outfeatures |
|
self.weight_bits = weight_bits |
|
self.input_bits = input_bits |
|
self.quant_bits = quant_bits |
|
self.group_size = group_size if group_size != -1 else infeatures |
|
self.maxq = 2**self.quant_bits - 1 |
|
|
|
self.register_buffer( |
|
"qweight", |
|
torch.zeros( |
|
(infeatures // 32 * self.quant_bits, outfeatures), dtype=torch.int32 |
|
), |
|
) |
|
self.register_buffer( |
|
"qzeros", |
|
torch.zeros( |
|
( |
|
math.ceil(infeatures / self.group_size), |
|
outfeatures // 32 * self.quant_bits, |
|
), |
|
dtype=torch.int32, |
|
), |
|
) |
|
self.register_buffer( |
|
"scales", |
|
torch.zeros( |
|
(math.ceil(infeatures / self.group_size), outfeatures), |
|
dtype=torch.float16, |
|
), |
|
) |
|
self.register_buffer( |
|
"g_idx", |
|
torch.tensor( |
|
[i // self.group_size for i in range(infeatures)], dtype=torch.int32 |
|
), |
|
) |
|
if bias: |
|
self.register_buffer( |
|
"bias", torch.zeros((outfeatures), dtype=torch.float16) |
|
) |
|
else: |
|
self.bias = None |
|
|
|
self.register_buffer("scale", torch.tensor(1.0, dtype=torch.float16)) |
|
|
|
self.trainable = trainable |
|
|
|
def post_init(self): |
|
pass |
|
|
|
def pack(self, bitlinear: BitLinear): |
|
device = bitlinear.weight.device |
|
bitlinear = bitlinear.cpu() |
|
|
|
W = bitlinear.weight.data.clone() |
|
if isinstance(bitlinear, nn.Conv2d): |
|
W = W.flatten(1) |
|
if isinstance(bitlinear, transformers.pytorch_utils.Conv1D): |
|
W = W.t() |
|
|
|
self.scale = torch.tensor(bitlinear.quant_scale, dtype=torch.float16) |
|
|
|
|
|
|
|
scales = torch.ones( |
|
self.outfeatures, |
|
math.ceil(self.infeatures / self.group_size), |
|
) |
|
zero = 1 |
|
zeros = torch.zeros( |
|
self.outfeatures, |
|
math.ceil(self.infeatures / self.group_size), |
|
) |
|
zeros.fill_(zero) |
|
|
|
scales = scales.t().contiguous() |
|
zeros = zeros.t().contiguous() |
|
scale_zeros = zeros * scales |
|
self.scales = scales.clone().half() |
|
if bitlinear.bias is not None: |
|
self.bias = bitlinear.bias.clone().half() |
|
|
|
intweight = [] |
|
for idx in range(self.infeatures): |
|
intweight.append( |
|
torch.round( |
|
(W[:, idx] + scale_zeros[self.g_idx[idx]]) |
|
/ self.scales[self.g_idx[idx]] |
|
).to(torch.int)[:, None] |
|
) |
|
|
|
intweight = torch.cat(intweight, dim=1) |
|
intweight = intweight.t().contiguous() |
|
intweight = intweight.numpy().astype(np.uint32) |
|
|
|
print(f"Int weight: {intweight}") |
|
|
|
i = 0 |
|
row = 0 |
|
qweight = np.zeros( |
|
(intweight.shape[0] // 32 * self.quant_bits, intweight.shape[1]), |
|
dtype=np.uint32, |
|
) |
|
while row < qweight.shape[0]: |
|
if self.quant_bits in [2, 4, 8]: |
|
for j in range(i, i + (32 // self.quant_bits)): |
|
qweight[row] |= intweight[j] << (self.quant_bits * (j - i)) |
|
i += 32 // self.quant_bits |
|
row += 1 |
|
else: |
|
raise NotImplementedError("Only 2,4,8 bits are supported.") |
|
|
|
qweight = qweight.astype(np.int32) |
|
self.qweight = torch.from_numpy(qweight) |
|
|
|
print(f"Quantized weight: {self.qweight}") |
|
|
|
zeros -= 1 |
|
zeros = zeros.numpy().astype(np.uint32) |
|
qzeros = np.zeros( |
|
(zeros.shape[0], zeros.shape[1] // 32 * self.quant_bits), dtype=np.uint32 |
|
) |
|
i = 0 |
|
col = 0 |
|
while col < qzeros.shape[1]: |
|
if self.quant_bits in [2, 4, 8]: |
|
for j in range(i, i + (32 // self.quant_bits)): |
|
qzeros[:, col] |= zeros[:, j] << (self.quant_bits * (j - i)) |
|
i += 32 // self.quant_bits |
|
col += 1 |
|
else: |
|
raise NotImplementedError("Only 2,4,8 bits are supported.") |
|
|
|
qzeros = qzeros.astype(np.int32) |
|
self.qzeros = torch.from_numpy(qzeros) |
|
|
|
self.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = x + (activation_quant(x, self.input_bits) - x).detach() |
|
out_shape = x.shape[:-1] + (self.outfeatures,) |
|
|
|
quant_linear_fn = ( |
|
QuantLinearFunction if self.trainable else QuantLinearInferenceOnlyFunction |
|
) |
|
out = quant_linear_fn.apply( |
|
x.reshape(-1, x.shape[-1]), |
|
self.qweight, |
|
self.scales, |
|
self.qzeros, |
|
self.g_idx, |
|
self.quant_bits, |
|
self.maxq, |
|
) |
|
|
|
out *= self.scale |
|
|
|
out = out.half().reshape(out_shape) |
|
out = out + self.bias if self.bias is not None else out |
|
|
|
return out |
|
|
|
@classmethod |
|
def warmup(cls, model, transpose=False, seqlen=2048): |
|
""" |
|
Pre-tunes the quantized kernel |
|
""" |
|
from tqdm import tqdm |
|
|
|
kn_values = {} |
|
|
|
for _, m in model.named_modules(): |
|
if not isinstance(m, cls): |
|
continue |
|
|
|
k = m.infeatures |
|
n = m.outfeatures |
|
|
|
if (k, n) not in kn_values: |
|
kn_values[(k, n)] = ( |
|
m.qweight, |
|
m.scales, |
|
m.qzeros, |
|
m.g_idx, |
|
m.bits, |
|
m.maxq, |
|
) |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)): |
|
m = 2**m |
|
for (k, n), ( |
|
qweight, |
|
scales, |
|
qzeros, |
|
g_idx, |
|
bits, |
|
maxq, |
|
) in kn_values.items(): |
|
if transpose: |
|
a = torch.randn(m, k, dtype=torch.float16, device=model.device) |
|
quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq) |
|
a = torch.randn(m, n, dtype=torch.float16, device=model.device) |
|
transpose_quant_matmul_248( |
|
a, qweight, scales, qzeros, g_idx, bits, maxq |
|
) |
|
else: |
|
a = torch.randn(m, k, dtype=torch.float16, device=model.device) |
|
quant_matmul_inference_only_248( |
|
a, qweight, scales, qzeros, g_idx, bits, maxq |
|
) |
|
del kn_values |
|
|