|
""" Core mathematics for Additive Quantization (AQ): initialization, reconstruction and beam search""" |
|
import functools |
|
import os |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import triton |
|
import triton.language as tl |
|
|
|
|
|
class FinalizedQuantizedLinear(nn.Module): |
|
def __init__( |
|
self, |
|
in_features: int, |
|
out_features: int, |
|
in_group_size: int, |
|
out_group_size: int, |
|
num_codebooks: int, |
|
nbits_per_codebook: int, |
|
bias=True, |
|
device=None, |
|
dtype=None, |
|
): |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super().__init__() |
|
self.in_features = in_features |
|
self.out_features = out_features |
|
|
|
assert self.in_features % in_group_size == 0 |
|
assert self.out_features % out_group_size == 0 |
|
num_out_groups = out_features // out_group_size |
|
num_in_groups = in_features // in_group_size |
|
self.out_group_size, self.in_group_size = out_group_size, in_group_size |
|
self.num_codebooks = num_codebooks |
|
self.nbits_per_codebook = nbits_per_codebook |
|
self.codebook_size = 2**nbits_per_codebook |
|
|
|
|
|
self.codebooks = nn.Parameter( |
|
torch.empty( |
|
(num_codebooks, self.codebook_size, out_group_size, in_group_size), |
|
**factory_kwargs, |
|
), |
|
requires_grad=True, |
|
) |
|
self.codes = nn.Parameter( |
|
torch.empty( |
|
(num_out_groups, num_in_groups, num_codebooks), |
|
device=device, |
|
dtype=get_int_dtype(nbits_per_codebook), |
|
), |
|
requires_grad=False, |
|
) |
|
|
|
|
|
self.scales = nn.Parameter( |
|
torch.empty((num_out_groups, 1, 1, 1), **factory_kwargs), requires_grad=True |
|
) |
|
|
|
|
|
if bias: |
|
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) |
|
else: |
|
self.register_parameter("bias", None) |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
return forward_pass_quantized_linear( |
|
input, self.codes, self.codebooks, self.scales, self.bias |
|
) |
|
|
|
|
|
def get_int_dtype(nbits: int) -> torch.dtype: |
|
if nbits <= 8: |
|
return torch.int8 |
|
if nbits <= 16: |
|
return torch.int16 |
|
if nbits <= 32: |
|
return torch.int32 |
|
if nbits <= 64: |
|
return torch.int64 |
|
raise ValueError(f"No dtype available for {nbits}-bit codebooks") |
|
|
|
|
|
@torch.inference_mode() |
|
def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor: |
|
return data.to(torch.int64) % (2**nbits) |
|
|
|
|
|
@functools.lru_cache() |
|
def maybe_script(fn: callable) -> callable: |
|
"""Apply torch.jit.script to function unless one is using TPU. TPU does not support torch.jit.script.""" |
|
using_tpu = bool(os.environ.get("TPU_NAME")) |
|
|
|
should_script = int(os.environ.get("AQ_USE_JIT", not using_tpu)) |
|
return torch.jit.script(fn) if should_script else fn |
|
|
|
|
|
@maybe_script |
|
def _dequantize_weight( |
|
codes: torch.Tensor, codebooks: torch.Tensor, scales: Optional[torch.Tensor] = None |
|
) -> torch.Tensor: |
|
""" |
|
Decode float weights from quantization codes. Differentiable. |
|
:param codes: tensor of integer quantization codes, shape [*dims, num_out_groups, num_in_groups, num_codebooks] |
|
:param codebooks: tensor of vectors for each quantization code, [num_codebooks, codebook_size, out_group_size, in_group_size] |
|
:param scales: weight will be multiplied by this factor, must be broadcastble with [*dims, out_groups, num_in_groups, out_group_size, in_group_size] |
|
:return: reconstructed weight tensor of shape [*dims, num_in_groups*group_size] |
|
""" |
|
num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:] |
|
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape |
|
out_features = num_out_groups * out_group_size |
|
in_features = num_in_groups * in_group_size |
|
codebook_offsets = torch.arange( |
|
0, num_codebooks * codebook_size, codebook_size, device=codes.device |
|
) |
|
reconstructed_weight_flat = F.embedding_bag( |
|
codes.flatten(0, -2) + codebook_offsets, |
|
codebooks.flatten(0, 1).flatten(-2, -1), |
|
mode="sum", |
|
) |
|
|
|
reconstructed_weight_groupwise = reconstructed_weight_flat.view( |
|
list(codes.shape[:-3]) |
|
+ [num_out_groups, num_in_groups, out_group_size, in_group_size] |
|
) |
|
if scales is not None: |
|
reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales) |
|
return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape( |
|
list(codes.shape[:-3]) + [out_features, in_features] |
|
) |
|
|
|
|
|
def forward_pass_quantized_linear( |
|
input: torch.Tensor, |
|
codes: torch.IntTensor, |
|
codebooks: torch.Tensor, |
|
scales: torch.Tensor, |
|
bias: Optional[torch.Tensor], |
|
) -> torch.Tensor: |
|
if input.is_cuda: |
|
matmul_result = aqlm_gemm_stupid(input, codes, codebooks, scales) |
|
if bias is not None: |
|
matmul_result += bias |
|
return matmul_result |
|
else: |
|
dequantized_weight = _dequantize_weight( |
|
unpack_int_data(codes, codebooks.shape[0].bit_length() - 1), |
|
codebooks, |
|
scales, |
|
) |
|
return F.linear(input, dequantized_weight, bias) |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({"UNUSED": 1}, num_stages=num_stages, num_warps=num_warps) |
|
for num_stages in (1, 2, 3, 4, 5) |
|
for num_warps in (1, 2, 4, 8) |
|
], |
|
key=[ |
|
"in_features", |
|
"out_features", |
|
"num_codebooks", |
|
"codebook_size", |
|
"out_group_size", |
|
"in_group_size", |
|
"num_input_groups", |
|
"num_input_groups_next_power_of_2", |
|
"compute_in_fp32", |
|
], |
|
) |
|
@triton.jit |
|
def _aqlm_gemv_simple( |
|
input_vec_ptr, |
|
output_vec_ptr, |
|
codes_i16_ptr, |
|
codebooks_ptr, |
|
scales_ptr, |
|
in_features: tl.constexpr, |
|
out_features: tl.constexpr, |
|
num_codebooks: tl.constexpr, |
|
codebook_size: tl.constexpr, |
|
out_group_size: tl.constexpr, |
|
in_group_size: tl.constexpr, |
|
num_input_groups: tl.constexpr, |
|
num_input_groups_next_power_of_2: tl.constexpr, |
|
compute_in_fp32: tl.constexpr, |
|
UNUSED: tl.constexpr, |
|
): |
|
|
|
pid = tl.program_id(axis=0) |
|
|
|
|
|
input_vec = tl.load( |
|
input_vec_ptr |
|
+ tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] * in_group_size |
|
+ tl.arange(0, in_group_size)[None, None, :], |
|
mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] |
|
< num_input_groups, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codes_i_ptrs = ( |
|
codes_i16_ptr |
|
+ pid * num_input_groups * num_codebooks |
|
+ tl.arange(0, num_input_groups_next_power_of_2)[:, None] * num_codebooks |
|
+ tl.arange(0, num_codebooks)[None, :] |
|
) |
|
codes_i_mask_1d = tl.arange(0, num_input_groups_next_power_of_2) < num_input_groups |
|
|
|
codes_i = tl.load( |
|
codes_i_ptrs, mask=codes_i_mask_1d[:, None] |
|
) |
|
if codes_i.dtype == tl.int16: |
|
codes_i = codes_i.to(tl.int32) |
|
codes_i = (codes_i) + ( |
|
codes_i < 0 |
|
) * codebook_size |
|
|
|
|
|
|
|
|
|
else: |
|
codes_i = codes_i.to(tl.int32) |
|
|
|
|
|
codes_i += ( |
|
tl.arange(0, num_codebooks)[None, :] * codebook_size |
|
) |
|
|
|
|
|
|
|
|
|
out_group_ix = tl.arange(0, out_group_size)[None, None, :, None] |
|
in_group_ix = tl.arange(0, in_group_size)[None, None, None, :] |
|
weight_i_ptrs = ( |
|
codebooks_ptr |
|
+ codes_i[:, :, None, None] * out_group_size * in_group_size |
|
+ out_group_ix * in_group_size |
|
+ in_group_ix |
|
) |
|
|
|
|
|
weights_i = tl.load( |
|
weight_i_ptrs, mask=codes_i_mask_1d[:, None, None, None], other=0 |
|
) |
|
if compute_in_fp32: |
|
weights_i = weights_i.to(tl.float32) |
|
input_vec = input_vec.to(tl.float32) |
|
|
|
weights_i = tl.sum(weights_i, axis=1) |
|
|
|
|
|
if out_group_size == 1: |
|
scale = tl.load(scales_ptr + pid).to(weights_i.dtype) |
|
output_i = tl.sum(weights_i * input_vec) * scale |
|
tl.store(output_vec_ptr + pid, output_i.to(input_vec.dtype)) |
|
else: |
|
output_i = tl.sum( |
|
tl.sum(weights_i * input_vec, axis=2), axis=0 |
|
) |
|
output_i *= tl.load(scales_ptr + pid).to(weights_i.dtype) |
|
tl.store( |
|
output_vec_ptr + pid * out_group_size + tl.arange(0, out_group_size), |
|
output_i.to(input_vec.dtype), |
|
) |
|
|
|
|
|
def next_power_of_2(x): |
|
return 1 if x == 0 else 2 ** (x - 1).bit_length() |
|
|
|
|
|
def aqlm_gemv_simple( |
|
input_vec: torch.Tensor, |
|
codes_i16: torch.ShortTensor, |
|
codebooks: torch.Tensor, |
|
scales: torch.Tensor, |
|
compute_in_fp32: bool = True, |
|
): |
|
|
|
device, dtype = codebooks.device, codebooks.dtype |
|
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape |
|
in_features = input_vec.shape[1] |
|
out_features = codes_i16.shape[0] * out_group_size |
|
num_input_groups = codes_i16.shape[1] |
|
assert input_vec.ndim == 2 and input_vec.shape[0] == 1, "do reshape; now!" |
|
assert scales.shape == (out_features // out_group_size, 1, 1, 1) |
|
assert in_features % in_group_size == 0 |
|
assert codebooks.shape[1] == 2**16 |
|
|
|
output_vec = torch.empty(1, out_features, device=device, dtype=dtype) |
|
|
|
grid = lambda META: (out_features // out_group_size,) |
|
_aqlm_gemv_simple[grid]( |
|
input_vec, |
|
output_vec, |
|
codes_i16, |
|
codebooks, |
|
scales, |
|
in_features, |
|
out_features, |
|
num_codebooks, |
|
codebook_size, |
|
out_group_size, |
|
in_group_size, |
|
num_input_groups, |
|
next_power_of_2(num_input_groups), |
|
compute_in_fp32, |
|
) |
|
|
|
return output_vec |
|
|
|
|
|
def aqlm_gemm_stupid( |
|
input: torch.Tensor, |
|
codes_i16: torch.ShortTensor, |
|
codebooks: torch.Tensor, |
|
scales: torch.Tensor, |
|
compute_in_fp32: bool = True, |
|
): |
|
original_shape = input.shape |
|
input = input.reshape(-1, original_shape[-1]) |
|
return torch.cat( |
|
[ |
|
aqlm_gemv_simple( |
|
input_vec.unsqueeze(0), codes_i16, codebooks, scales, compute_in_fp32 |
|
) |
|
for input_vec in input |
|
] |
|
).reshape(original_shape[:-1] + (-1,)) |
|
|