File size: 7,168 Bytes
5c0d7ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
from __future__ import annotations
import contextlib
import functools
import os
from typing import Callable, Iterator, Optional, Sequence
import torch
import torch.nn.functional as F
ellipsis = type(...)
def get_mean_nbits_by_codebook(codes: torch.IntTensor, huffman_group_size: int = 2):
"""
Calculates average code length in codebooks.
:param codes: codebook codes
:param huffman_group_size: huffman compresssion dimension count
"""
import huffman
_, codebook_size, num_codebooks = codes.shape
flat_codes_by_codebook = codes.permute(2, 0, 1).flatten(1, 2)
code_counts = torch.zeros(
num_codebooks, codebook_size, device=flat_codes_by_codebook.device, dtype=flat_codes_by_codebook.dtype
).scatter_add(
-1, flat_codes_by_codebook, torch.ones_like(flat_codes_by_codebook)
) # shape: [current beam_size, num_codebooks, codebook_size], initial beam_size = 1
code_probs = code_counts / code_counts.sum(dim=-1, keepdim=True).float()
code_probs = code_probs.cpu().numpy()
assert num_codebooks % huffman_group_size == 0
mean_code_lengths = []
for group_index in range(num_codebooks // huffman_group_size):
group_code_probs = {(): 1}
for codebook_index in range(group_index * huffman_group_size, (group_index + 1) * huffman_group_size):
new_group_code_probs = {}
for group, group_prob in group_code_probs.items():
for code, code_prob in tuple(enumerate(code_probs[codebook_index])):
new_group_code_probs[group + (code,)] = group_prob * code_prob
group_code_probs = new_group_code_probs
huffman_codebook_i = huffman.codebook(list(group_code_probs.items()))
codebook_mean_code_length_i = sum(
len(huffman_codebook_i[code]) * prob for code, prob in group_code_probs.items()
)
mean_code_lengths.append(codebook_mean_code_length_i)
return mean_code_lengths
def get_int_dtype(nbits: int) -> torch.dtype:
if nbits <= 8:
return torch.int8
if nbits <= 16:
return torch.int16
if nbits <= 32:
return torch.int32
if nbits <= 64:
return torch.int64
raise ValueError(f"No dtype available for {nbits}-bit codebooks")
@torch.inference_mode()
def pack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
data[data >= 2 ** (nbits - 1)] -= 2**nbits
return data.to(get_int_dtype(nbits))
@torch.inference_mode()
def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
return data.to(torch.int64) % (2**nbits)
@functools.lru_cache()
def maybe_script(fn: callable) -> callable:
"""Apply torch.jit.script to function unless one is using TPU. TPU does not support torch.jit.script."""
using_tpu = bool(os.environ.get("TPU_NAME"))
# this is a reserved variable that must be set to TPU address (e.g. grpc://11.22.33.44:1337) for TPU to function
should_script = int(os.environ.get("AQ_USE_JIT", not using_tpu))
return torch.jit.script(fn) if should_script else fn
@contextlib.contextmanager
def using_tf32(enabled: bool):
was_cudnn = torch.backends.cudnn.allow_tf32
was_matmul = torch.backends.cuda.matmul.allow_tf32
torch.backends.cudnn.allow_tf32 = enabled
torch.backends.cuda.matmul.allow_tf32 = enabled
yield
torch.backends.cudnn.allow_tf32 = was_cudnn
torch.backends.cuda.matmul.allow_tf32 = was_matmul
def iterate_minibatches(
*tensors: torch.Tensor,
batch_size: int,
allow_incomplete: bool = True,
device: Optional[torch.device] = None,
callback: Callable[[Sequence[torch.Tensor]], Sequence[torch.Tensor]] = lambda x: x,
) -> Iterator[Sequence[torch.Tensor]]:
"""
Samples data points *forever*, in random order, with less overhead than DataLoader;
Adapted from https://github.com/stanis-morozov/unq/blob/master/lib/utils.py
probably implemented over9000 times in transformers, torch, etc
:param tensors: one or more tensors with the same 0-th dimension
:param batch_size: sample this many points with each yield
:param allow_incomplete: if True and if dataset size is not divisible by batch size, the last batch
may have less than :batch_size: samples to cover the entire dataset. If False, the last batch is dropped
:param callback: optional function to be called on each batch of tensors before it is yielded to the user
:returns: generates a tuple of minibatches from each tensor, same length as input *tensors
If a batch contains only one tensor, this function will yield a tensor (and not a tuple/list with one tensor)
"""
num_samples = len(tensors[0])
assert all(len(x) == num_samples for x in tensors)
indices = torch.randperm(num_samples, device=tensors[0].device)
while True:
prev_batch = None
for batch_start in range(0, len(indices), batch_size):
if not allow_incomplete and batch_start + batch_size > len(indices):
break
batch_ix = indices[batch_start : batch_start + batch_size]
batch = callback(tuple(tensor[batch_ix].to(device, non_blocking=True) for tensor in tensors))
if prev_batch is not None:
yield prev_batch
prev_batch = batch if isinstance(batch, (list, tuple)) and len(tensors) > 1 else batch[0]
del batch
yield prev_batch
@maybe_script
def _dequantize_weight(
codes: torch.Tensor, codebooks: torch.Tensor, scales: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Decode float weights from quantization codes. Differentiable.
:param codes: tensor of integer quantization codes, shape [*dims, num_out_groups, num_in_groups, num_codebooks]
:param codebooks: tensor of vectors for each quantization code, [num_codebooks, codebook_size, out_group_size, in_group_size]
:param scales: weight will be multiplied by this factor, must be broadcastble with [*dims, out_groups, num_in_groups, out_group_size, in_group_size]
:return: reconstructed weight tensor of shape [*dims, num_in_groups*group_size]
"""
num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
out_features = num_out_groups * out_group_size
in_features = num_in_groups * in_group_size
codebook_offsets = torch.arange(
0, num_codebooks * codebook_size, codebook_size, device=codes.device
) # shape: [num_codebooks]
reconstructed_weight_flat = F.embedding_bag(
codes.flatten(0, -2) + codebook_offsets, codebooks.flatten(0, 1).flatten(-2, -1), mode="sum"
) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size]
reconstructed_weight_groupwise = reconstructed_weight_flat.view(
list(codes.shape[:-3]) + [num_out_groups, num_in_groups, out_group_size, in_group_size]
)
if scales is not None:
reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales)
return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])
|