Andrei Panferov commited on
Commit
0110580
·
1 Parent(s): 5edaefc

deleted leftovers

Browse files
inference_kernels/router.py DELETED
@@ -1,29 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
-
7
- from src.inference_kernels.triton_kernel import aqlm_gemm_stupid as triton_gemm
8
- from src.utils import _dequantize_weight, unpack_int_data
9
-
10
-
11
- def forward_pass_quantized_linear(
12
- input: torch.Tensor,
13
- codes: torch.IntTensor,
14
- codebooks: torch.Tensor,
15
- scales: torch.Tensor,
16
- bias: Optional[torch.Tensor],
17
- ) -> torch.Tensor:
18
- if input.is_cuda:
19
- matmul_result = triton_gemm(input, codes, codebooks, scales)
20
- if bias is not None:
21
- matmul_result += bias
22
- return matmul_result
23
- else:
24
- dequantized_weight = _dequantize_weight(
25
- unpack_int_data(codes, codebooks.shape[0].bit_length() - 1),
26
- codebooks,
27
- scales,
28
- )
29
- return F.linear(input, dequantized_weight, bias)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference_kernels/triton_kernel.py DELETED
@@ -1,170 +0,0 @@
1
- import torch
2
- import triton
3
- import triton.language as tl
4
- from torch.autograd import Function
5
-
6
-
7
- @triton.autotune(
8
- configs=[
9
- triton.Config({"UNUSED": 1}, num_stages=num_stages, num_warps=num_warps)
10
- for num_stages in (1, 2, 3, 4, 5)
11
- for num_warps in (1, 2, 4, 8)
12
- ],
13
- key=[
14
- "in_features",
15
- "out_features",
16
- "num_codebooks",
17
- "codebook_size",
18
- "out_group_size",
19
- "in_group_size",
20
- "num_input_groups",
21
- "num_input_groups_next_power_of_2",
22
- "compute_in_fp32",
23
- ],
24
- )
25
- @triton.jit
26
- def _aqlm_gemv_simple(
27
- input_vec_ptr,
28
- output_vec_ptr,
29
- codes_i16_ptr,
30
- codebooks_ptr,
31
- scales_ptr,
32
- in_features: tl.constexpr,
33
- out_features: tl.constexpr,
34
- num_codebooks: tl.constexpr,
35
- codebook_size: tl.constexpr,
36
- out_group_size: tl.constexpr,
37
- in_group_size: tl.constexpr,
38
- num_input_groups: tl.constexpr,
39
- num_input_groups_next_power_of_2: tl.constexpr,
40
- compute_in_fp32: tl.constexpr,
41
- UNUSED: tl.constexpr,
42
- ):
43
- # variables ending with "_i" mean "for i-th output unit"
44
- pid = tl.program_id(axis=0) # [0, 1, ... {out_features-1}]
45
-
46
- # Stage 1: load input data
47
- input_vec = tl.load(
48
- input_vec_ptr
49
- + tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] * in_group_size
50
- + tl.arange(0, in_group_size)[None, None, :],
51
- mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] < num_input_groups,
52
- )
53
- # [in_features//in_group_size, 1, group_size]
54
- # Note: we could simply load input_vec then reshape
55
- # input_vec = tl.load(input_vec_ptr + tl.arange(0, in_features)) # [in_features]
56
- # input_vec = tl.view(input_vec, [num_input_groups, 1, in_group_size])
57
- # , but this does not work because tl.view may reorder elements arbitrarily; see its docstring
58
-
59
- # Stage 2: load integer codes for the active row
60
- # [in_features // in_group_size, num_codebooks]
61
- codes_i_ptrs = (
62
- codes_i16_ptr
63
- + pid * num_input_groups * num_codebooks
64
- + tl.arange(0, num_input_groups_next_power_of_2)[:, None] * num_codebooks
65
- + tl.arange(0, num_codebooks)[None, :]
66
- )
67
- codes_i_mask_1d = tl.arange(0, num_input_groups_next_power_of_2) < num_input_groups
68
-
69
- codes_i = tl.load(codes_i_ptrs, mask=codes_i_mask_1d[:, None]) # [in_features//in_group_size, num_codebooks]
70
- if codes_i.dtype == tl.int16:
71
- codes_i = codes_i.to(tl.int32)
72
- codes_i = (codes_i) + (codes_i < 0) * codebook_size # aka 2 ** nbits_per_codebook
73
- # ^-- (because codes are int16 tensors that contain uint data)
74
-
75
- # The following alternative does not work:
76
- # codes_i = codes_i.to(tl.int32) % codebook_size # aka 2 ** nbits_per_codebook
77
- else:
78
- codes_i = codes_i.to(tl.int32)
79
-
80
- # shift codes_i so that codebooks after 0th point to correct indices in codebooks_ptr
81
- codes_i += tl.arange(0, num_codebooks)[None, :] * codebook_size # aka 2 ** nbits_per_codebook
82
- # ^-- [in_group_size, num_codebooks]
83
-
84
- # Stage 3: convert codes to pointers to every individual (activated) weight in codebooks
85
- # [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]
86
- out_group_ix = tl.arange(0, out_group_size)[None, None, :, None]
87
- in_group_ix = tl.arange(0, in_group_size)[None, None, None, :]
88
- weight_i_ptrs = (
89
- codebooks_ptr
90
- + codes_i[:, :, None, None] * out_group_size * in_group_size
91
- + out_group_ix * in_group_size
92
- + in_group_ix
93
- )
94
-
95
- # Stage 4: reconstruct weights, multiply by inputs and write out
96
- weights_i = tl.load(weight_i_ptrs, mask=codes_i_mask_1d[:, None, None, None], other=0)
97
- if compute_in_fp32:
98
- weights_i = weights_i.to(tl.float32)
99
- input_vec = input_vec.to(tl.float32)
100
- # ^-- [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]
101
- weights_i = tl.sum(weights_i, axis=1) # sum codebooks as per additive quantization
102
- # ^-- [in_features // in_group_size, out_group_size, in_group_size]
103
-
104
- if out_group_size == 1:
105
- scale = tl.load(scales_ptr + pid).to(weights_i.dtype) # scalar
106
- output_i = tl.sum(weights_i * input_vec) * scale
107
- tl.store(output_vec_ptr + pid, output_i.to(input_vec.dtype))
108
- else:
109
- output_i = tl.sum(tl.sum(weights_i * input_vec, axis=2), axis=0) # [out_group_size]
110
- output_i *= tl.load(scales_ptr + pid).to(weights_i.dtype)
111
- tl.store(output_vec_ptr + pid * out_group_size + tl.arange(0, out_group_size), output_i.to(input_vec.dtype))
112
-
113
-
114
- def next_power_of_2(x):
115
- return 1 if x == 0 else 2 ** (x - 1).bit_length()
116
-
117
-
118
- def aqlm_gemv_simple(
119
- input_vec: torch.Tensor,
120
- codes_i16: torch.ShortTensor,
121
- codebooks: torch.Tensor,
122
- scales: torch.Tensor,
123
- compute_in_fp32: bool = True,
124
- ):
125
-
126
- device, dtype = codebooks.device, codebooks.dtype
127
- num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
128
- in_features = input_vec.shape[1]
129
- out_features = codes_i16.shape[0] * out_group_size
130
- num_input_groups = codes_i16.shape[1]
131
- assert input_vec.ndim == 2 and input_vec.shape[0] == 1, "do reshape; now!"
132
- assert scales.shape == (out_features // out_group_size, 1, 1, 1)
133
- assert in_features % in_group_size == 0
134
- assert codebooks.shape[1] == 2**16
135
-
136
- output_vec = torch.empty(1, out_features, device=device, dtype=dtype)
137
- # 1D launch kernel where each block computes output unit
138
- grid = lambda META: (out_features // out_group_size,)
139
- _aqlm_gemv_simple[grid](
140
- input_vec,
141
- output_vec,
142
- codes_i16,
143
- codebooks,
144
- scales,
145
- in_features,
146
- out_features,
147
- num_codebooks,
148
- codebook_size,
149
- out_group_size,
150
- in_group_size,
151
- num_input_groups,
152
- next_power_of_2(num_input_groups),
153
- compute_in_fp32,
154
- )
155
-
156
- return output_vec
157
-
158
-
159
- def aqlm_gemm_stupid(
160
- input: torch.Tensor,
161
- codes_i16: torch.ShortTensor,
162
- codebooks: torch.Tensor,
163
- scales: torch.Tensor,
164
- compute_in_fp32: bool = True,
165
- ):
166
- original_shape = input.shape
167
- input = input.reshape(-1, original_shape[-1])
168
- return torch.cat(
169
- [aqlm_gemv_simple(input_vec.unsqueeze(0), codes_i16, codebooks, scales, compute_in_fp32) for input_vec in input]
170
- ).reshape(original_shape[:-1] + (-1,))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py DELETED
@@ -1,159 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import contextlib
4
- import functools
5
- import os
6
- from typing import Callable, Iterator, Optional, Sequence
7
-
8
- import torch
9
- import torch.nn.functional as F
10
-
11
- ellipsis = type(...)
12
-
13
-
14
- def get_mean_nbits_by_codebook(codes: torch.IntTensor, huffman_group_size: int = 2):
15
-
16
- """
17
- Calculates average code length in codebooks.
18
- :param codes: codebook codes
19
- :param huffman_group_size: huffman compresssion dimension count
20
- """
21
- import huffman
22
-
23
- _, codebook_size, num_codebooks = codes.shape
24
- flat_codes_by_codebook = codes.permute(2, 0, 1).flatten(1, 2)
25
- code_counts = torch.zeros(
26
- num_codebooks, codebook_size, device=flat_codes_by_codebook.device, dtype=flat_codes_by_codebook.dtype
27
- ).scatter_add(
28
- -1, flat_codes_by_codebook, torch.ones_like(flat_codes_by_codebook)
29
- ) # shape: [current beam_size, num_codebooks, codebook_size], initial beam_size = 1
30
- code_probs = code_counts / code_counts.sum(dim=-1, keepdim=True).float()
31
- code_probs = code_probs.cpu().numpy()
32
- assert num_codebooks % huffman_group_size == 0
33
-
34
- mean_code_lengths = []
35
- for group_index in range(num_codebooks // huffman_group_size):
36
- group_code_probs = {(): 1}
37
-
38
- for codebook_index in range(group_index * huffman_group_size, (group_index + 1) * huffman_group_size):
39
- new_group_code_probs = {}
40
- for group, group_prob in group_code_probs.items():
41
- for code, code_prob in tuple(enumerate(code_probs[codebook_index])):
42
- new_group_code_probs[group + (code,)] = group_prob * code_prob
43
- group_code_probs = new_group_code_probs
44
-
45
- huffman_codebook_i = huffman.codebook(list(group_code_probs.items()))
46
- codebook_mean_code_length_i = sum(
47
- len(huffman_codebook_i[code]) * prob for code, prob in group_code_probs.items()
48
- )
49
- mean_code_lengths.append(codebook_mean_code_length_i)
50
- return mean_code_lengths
51
-
52
-
53
- def get_int_dtype(nbits: int) -> torch.dtype:
54
- if nbits <= 8:
55
- return torch.int8
56
- if nbits <= 16:
57
- return torch.int16
58
- if nbits <= 32:
59
- return torch.int32
60
- if nbits <= 64:
61
- return torch.int64
62
- raise ValueError(f"No dtype available for {nbits}-bit codebooks")
63
-
64
-
65
- @torch.inference_mode()
66
- def pack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
67
- data[data >= 2 ** (nbits - 1)] -= 2**nbits
68
- return data.to(get_int_dtype(nbits))
69
-
70
-
71
- @torch.inference_mode()
72
- def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
73
- return data.to(torch.int64) % (2**nbits)
74
-
75
-
76
- @functools.lru_cache()
77
- def maybe_script(fn: callable) -> callable:
78
- """Apply torch.jit.script to function unless one is using TPU. TPU does not support torch.jit.script."""
79
- using_tpu = bool(os.environ.get("TPU_NAME"))
80
- # this is a reserved variable that must be set to TPU address (e.g. grpc://11.22.33.44:1337) for TPU to function
81
- should_script = int(os.environ.get("AQ_USE_JIT", not using_tpu))
82
- return torch.jit.script(fn) if should_script else fn
83
-
84
-
85
- @contextlib.contextmanager
86
- def using_tf32(enabled: bool):
87
- was_cudnn = torch.backends.cudnn.allow_tf32
88
- was_matmul = torch.backends.cuda.matmul.allow_tf32
89
- torch.backends.cudnn.allow_tf32 = enabled
90
- torch.backends.cuda.matmul.allow_tf32 = enabled
91
- yield
92
- torch.backends.cudnn.allow_tf32 = was_cudnn
93
- torch.backends.cuda.matmul.allow_tf32 = was_matmul
94
-
95
-
96
- def iterate_minibatches(
97
- *tensors: torch.Tensor,
98
- batch_size: int,
99
- allow_incomplete: bool = True,
100
- device: Optional[torch.device] = None,
101
- callback: Callable[[Sequence[torch.Tensor]], Sequence[torch.Tensor]] = lambda x: x,
102
- ) -> Iterator[Sequence[torch.Tensor]]:
103
- """
104
- Samples data points *forever*, in random order, with less overhead than DataLoader;
105
- Adapted from https://github.com/stanis-morozov/unq/blob/master/lib/utils.py
106
- probably implemented over9000 times in transformers, torch, etc
107
- :param tensors: one or more tensors with the same 0-th dimension
108
- :param batch_size: sample this many points with each yield
109
- :param allow_incomplete: if True and if dataset size is not divisible by batch size, the last batch
110
- may have less than :batch_size: samples to cover the entire dataset. If False, the last batch is dropped
111
- :param callback: optional function to be called on each batch of tensors before it is yielded to the user
112
- :returns: generates a tuple of minibatches from each tensor, same length as input *tensors
113
- If a batch contains only one tensor, this function will yield a tensor (and not a tuple/list with one tensor)
114
- """
115
- num_samples = len(tensors[0])
116
- assert all(len(x) == num_samples for x in tensors)
117
- indices = torch.randperm(num_samples, device=tensors[0].device)
118
- while True:
119
- prev_batch = None
120
- for batch_start in range(0, len(indices), batch_size):
121
- if not allow_incomplete and batch_start + batch_size > len(indices):
122
- break
123
- batch_ix = indices[batch_start : batch_start + batch_size]
124
- batch = callback(tuple(tensor[batch_ix].to(device, non_blocking=True) for tensor in tensors))
125
- if prev_batch is not None:
126
- yield prev_batch
127
- prev_batch = batch if isinstance(batch, (list, tuple)) and len(tensors) > 1 else batch[0]
128
- del batch
129
- yield prev_batch
130
-
131
-
132
- @maybe_script
133
- def _dequantize_weight(
134
- codes: torch.Tensor, codebooks: torch.Tensor, scales: Optional[torch.Tensor] = None
135
- ) -> torch.Tensor:
136
- """
137
- Decode float weights from quantization codes. Differentiable.
138
- :param codes: tensor of integer quantization codes, shape [*dims, num_out_groups, num_in_groups, num_codebooks]
139
- :param codebooks: tensor of vectors for each quantization code, [num_codebooks, codebook_size, out_group_size, in_group_size]
140
- :param scales: weight will be multiplied by this factor, must be broadcastble with [*dims, out_groups, num_in_groups, out_group_size, in_group_size]
141
- :return: reconstructed weight tensor of shape [*dims, num_in_groups*group_size]
142
- """
143
- num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
144
- num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
145
- out_features = num_out_groups * out_group_size
146
- in_features = num_in_groups * in_group_size
147
- codebook_offsets = torch.arange(
148
- 0, num_codebooks * codebook_size, codebook_size, device=codes.device
149
- ) # shape: [num_codebooks]
150
- reconstructed_weight_flat = F.embedding_bag(
151
- codes.flatten(0, -2) + codebook_offsets, codebooks.flatten(0, 1).flatten(-2, -1), mode="sum"
152
- ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size]
153
-
154
- reconstructed_weight_groupwise = reconstructed_weight_flat.view(
155
- list(codes.shape[:-3]) + [num_out_groups, num_in_groups, out_group_size, in_group_size]
156
- )
157
- if scales is not None:
158
- reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales)
159
- return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])