zb12138
Update numpyAc.py
6b3e38f unverified
raw
history blame
6.14 kB
import os
import torch
import numpy as np
from torch.autograd.grad_mode import F
from torch.utils.cpp_extension import load
PRECISION = 16 # DO NOT EDIT!
# Load on-the-fly with ninja.
torchac_dir = os.path.dirname(os.path.realpath(__file__))
backend_dir = os.path.join(torchac_dir, 'backend')
numpyAc_backend = load(
name="numpyAc_backend",
sources=[os.path.join(backend_dir, "numpyAc_backend.cpp")],
verbose=False)
def _encode_float_cdf(cdf_float,
sym,
needs_normalization=True,
check_input_bounds=False):
"""Encode symbols `sym` with potentially unnormalized floating point CDF.
Check the README for more details.
:param cdf_float: CDF tensor, float32, on CPU. Shape (N1, ..., Nm, Lp).
:param sym: The symbols to encode, int16, on CPU. Shape (N1, ..., Nm).
:param needs_normalization: if True, assume `cdf_float` is un-normalized and
needs normalization. Otherwise only convert it, without normalizing.
:param check_input_bounds: if True, ensure inputs have valid values.
Important: may take significant time. Only enable to check.
:return: byte-string, encoding `sym`.
"""
if check_input_bounds:
if cdf_float.min() < 0:
raise ValueError(f'cdf_float.min() == {cdf_float.min()}, should be >=0.!')
if cdf_float.max() > 1:
raise ValueError(f'cdf_float.max() == {cdf_float.max()}, should be <=1.!')
Lp = cdf_float.shape[-1]
if sym.max() >= Lp - 1:
raise ValueError(f'sym.max() == {sym.max()}, should be <=Lp - 1.!')
cdf_int = _convert_to_int_and_normalize(cdf_float, needs_normalization)
return _encode_int16_normalized_cdf(cdf_int, sym)
def _encode_int16_normalized_cdf(cdf_int, sym):
"""Encode symbols `sym` with a normalized integer cdf `cdf_int`.
Check the README for more details.
:param cdf_int: CDF tensor, int16, on CPU. Shape (N1, ..., Nm, Lp).
:param sym: The symbols to encode, int16, on CPU. Shape (N1, ..., Nm).
:return: byte-string, encoding `sym`
"""
cdf_int, sym = _check_and_reshape_inputs(cdf_int, sym)
return numpyAc_backend.encode_cdf( torch.ShortTensor(cdf_int), torch.ShortTensor(sym))
def _check_and_reshape_inputs(cdf, sym=None):
"""Check device, dtype, and shapes."""
if sym is not None and sym.dtype != np.int16:
raise ValueError('Symbols must be int16!')
if sym is not None:
if len(cdf.shape) != len(sym.shape) + 1 or cdf.shape[:-1] != sym.shape:
raise ValueError(f'Invalid shapes of cdf={cdf.shape}, sym={sym.shape}! '
'The first m elements of cdf.shape must be equal to '
'sym.shape, and cdf should only have one more dimension.')
Lp = cdf.shape[-1]
cdf = cdf.reshape(-1, Lp)
if sym is None:
return cdf
sym = sym.reshape(-1)
return cdf, sym
# def _reshape_output(cdf_shape, sym):
# """Reshape single dimension `sym` back to the correct spatial dimensions."""
# spatial_dimensions = cdf_shape[:-1]
# if len(sym) != np.prod(spatial_dimensions):
# raise ValueError()
# return sym.reshape(*spatial_dimensions)
def _convert_to_int_and_normalize(cdf_float, needs_normalization):
"""Convert floatingpoint CDF to integers. See README for more info.
The idea is the following:
When we get the cdf here, it is (assumed to be) between 0 and 1, i.e,
cdf \in [0, 1)
(note that 1 should not be included.)
We now want to convert this to int16 but make sure we do not get
the same value twice, as this would break the arithmetic coder
(you need a strictly monotonically increasing function).
So, if needs_normalization==True, we multiply the input CDF
with 2**16 - (Lp - 1). This means that now,
cdf \in [0, 2**16 - (Lp - 1)].
Then, in a final step, we add an arange(Lp), which is just a line with
slope one. This ensure that for sure, we will get unique, strictly
monotonically increasing CDFs, which are \in [0, 2**16)
"""
Lp = cdf_float.shape[-1]
factor = 2**PRECISION
new_max_value = factor
if needs_normalization:
new_max_value = new_max_value - (Lp - 1)
cdf_float = cdf_float*(new_max_value)
cdf_float = np.round(cdf_float)
cdf = cdf_float.astype(np.int16)
if needs_normalization:
r = np.arange(Lp)
cdf+=r
return cdf
def pdf_convert_to_cdf_and_normalize(pdf):
assert pdf.ndim==2
cdfF = np.cumsum( pdf, axis=1)
cdfF = cdfF/cdfF[:,-1:]
cdfF = np.hstack((np.zeros((pdf.shape[0],1)),cdfF))
return cdfF
class arithmeticCoding():
def __init__(self) -> None:
self.binfile = None
self.sysNum = None
self.byte_stream = None
def encode(self,pdf,sym,binfile=None):
assert pdf.shape[0]==sym.shape[0]
assert pdf.ndim==2 and sym.ndim==1
self.sysNum = sym.shape[0]
cdfF = pdf_convert_to_cdf_and_normalize(pdf)
# pdf = np.diff(cdfF)
# print( -np.log2(pdf[range(0,self.sysNum),sym]).sum())
self.byte_stream = _encode_float_cdf(cdfF, sym, check_input_bounds=True)
real_bits = len(self.byte_stream) * 8
# # Write to a file.
if binfile is not None:
with open(binfile, 'wb') as fout:
fout.write(self.byte_stream)
return self.byte_stream,real_bits
class arithmeticDeCoding():
"""
Decoding class
byte_stream: the bin file stream.
sysNum: the Number of symbols that you are going to decode. This value should be
saved in other ways.
sysDim: the Number of the possible symbols.
binfile: bin file path, if it is Not None, 'byte_stream' will read from this file
and copy to Cpp backend Class 'InCacheString'
"""
def __init__(self,byte_stream,sysNum,symDim,binfile=None) -> None:
if binfile is not None:
with open(binfile, 'rb') as fin:
byte_stream = fin.read()
self.byte_stream = byte_stream
self.decoder = numpyAc_backend.decode(self.byte_stream,sysNum,symDim+1)
def decode(self,pdf):
cdfF = pdf_convert_to_cdf_and_normalize(pdf)
pro = _convert_to_int_and_normalize(cdfF,needs_normalization=True)
pro = pro.squeeze(0).astype(np.uint16).tolist()
sym_out = self.decoder.decodeAsym(pro)
return sym_out