Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import numpy as np | |
from torch.autograd.grad_mode import F | |
from torch.utils.cpp_extension import load | |
PRECISION = 16 # DO NOT EDIT! | |
# Load on-the-fly with ninja. | |
torchac_dir = os.path.dirname(os.path.realpath(__file__)) | |
backend_dir = os.path.join(torchac_dir, 'backend') | |
numpyAc_backend = load( | |
name="numpyAc_backend", | |
sources=[os.path.join(backend_dir, "numpyAc_backend.cpp")], | |
verbose=False) | |
def _encode_float_cdf(cdf_float, | |
sym, | |
needs_normalization=True, | |
check_input_bounds=False): | |
"""Encode symbols `sym` with potentially unnormalized floating point CDF. | |
Check the README for more details. | |
:param cdf_float: CDF tensor, float32, on CPU. Shape (N1, ..., Nm, Lp). | |
:param sym: The symbols to encode, int16, on CPU. Shape (N1, ..., Nm). | |
:param needs_normalization: if True, assume `cdf_float` is un-normalized and | |
needs normalization. Otherwise only convert it, without normalizing. | |
:param check_input_bounds: if True, ensure inputs have valid values. | |
Important: may take significant time. Only enable to check. | |
:return: byte-string, encoding `sym`. | |
""" | |
if check_input_bounds: | |
if cdf_float.min() < 0: | |
raise ValueError(f'cdf_float.min() == {cdf_float.min()}, should be >=0.!') | |
if cdf_float.max() > 1: | |
raise ValueError(f'cdf_float.max() == {cdf_float.max()}, should be <=1.!') | |
Lp = cdf_float.shape[-1] | |
if sym.max() >= Lp - 1: | |
raise ValueError(f'sym.max() == {sym.max()}, should be <=Lp - 1.!') | |
cdf_int = _convert_to_int_and_normalize(cdf_float, needs_normalization) | |
return _encode_int16_normalized_cdf(cdf_int, sym) | |
def _encode_int16_normalized_cdf(cdf_int, sym): | |
"""Encode symbols `sym` with a normalized integer cdf `cdf_int`. | |
Check the README for more details. | |
:param cdf_int: CDF tensor, int16, on CPU. Shape (N1, ..., Nm, Lp). | |
:param sym: The symbols to encode, int16, on CPU. Shape (N1, ..., Nm). | |
:return: byte-string, encoding `sym` | |
""" | |
cdf_int, sym = _check_and_reshape_inputs(cdf_int, sym) | |
return numpyAc_backend.encode_cdf( torch.ShortTensor(cdf_int), torch.ShortTensor(sym)) | |
def _check_and_reshape_inputs(cdf, sym=None): | |
"""Check device, dtype, and shapes.""" | |
if sym is not None and sym.dtype != np.int16: | |
raise ValueError('Symbols must be int16!') | |
if sym is not None: | |
if len(cdf.shape) != len(sym.shape) + 1 or cdf.shape[:-1] != sym.shape: | |
raise ValueError(f'Invalid shapes of cdf={cdf.shape}, sym={sym.shape}! ' | |
'The first m elements of cdf.shape must be equal to ' | |
'sym.shape, and cdf should only have one more dimension.') | |
Lp = cdf.shape[-1] | |
cdf = cdf.reshape(-1, Lp) | |
if sym is None: | |
return cdf | |
sym = sym.reshape(-1) | |
return cdf, sym | |
# def _reshape_output(cdf_shape, sym): | |
# """Reshape single dimension `sym` back to the correct spatial dimensions.""" | |
# spatial_dimensions = cdf_shape[:-1] | |
# if len(sym) != np.prod(spatial_dimensions): | |
# raise ValueError() | |
# return sym.reshape(*spatial_dimensions) | |
def _convert_to_int_and_normalize(cdf_float, needs_normalization): | |
"""Convert floatingpoint CDF to integers. See README for more info. | |
The idea is the following: | |
When we get the cdf here, it is (assumed to be) between 0 and 1, i.e, | |
cdf \in [0, 1) | |
(note that 1 should not be included.) | |
We now want to convert this to int16 but make sure we do not get | |
the same value twice, as this would break the arithmetic coder | |
(you need a strictly monotonically increasing function). | |
So, if needs_normalization==True, we multiply the input CDF | |
with 2**16 - (Lp - 1). This means that now, | |
cdf \in [0, 2**16 - (Lp - 1)]. | |
Then, in a final step, we add an arange(Lp), which is just a line with | |
slope one. This ensure that for sure, we will get unique, strictly | |
monotonically increasing CDFs, which are \in [0, 2**16) | |
""" | |
Lp = cdf_float.shape[-1] | |
factor = 2**PRECISION | |
new_max_value = factor | |
if needs_normalization: | |
new_max_value = new_max_value - (Lp - 1) | |
cdf_float = cdf_float*(new_max_value) | |
cdf_float = np.round(cdf_float) | |
cdf = cdf_float.astype(np.int16) | |
if needs_normalization: | |
r = np.arange(Lp) | |
cdf+=r | |
return cdf | |
def pdf_convert_to_cdf_and_normalize(pdf): | |
assert pdf.ndim==2 | |
cdfF = np.cumsum( pdf, axis=1) | |
cdfF = cdfF/cdfF[:,-1:] | |
cdfF = np.hstack((np.zeros((pdf.shape[0],1)),cdfF)) | |
return cdfF | |
class arithmeticCoding(): | |
def __init__(self) -> None: | |
self.binfile = None | |
self.sysNum = None | |
self.byte_stream = None | |
def encode(self,pdf,sym,binfile=None): | |
assert pdf.shape[0]==sym.shape[0] | |
assert pdf.ndim==2 and sym.ndim==1 | |
self.sysNum = sym.shape[0] | |
cdfF = pdf_convert_to_cdf_and_normalize(pdf) | |
# pdf = np.diff(cdfF) | |
# print( -np.log2(pdf[range(0,self.sysNum),sym]).sum()) | |
self.byte_stream = _encode_float_cdf(cdfF, sym, check_input_bounds=True) | |
real_bits = len(self.byte_stream) * 8 | |
# # Write to a file. | |
if binfile is not None: | |
with open(binfile, 'wb') as fout: | |
fout.write(self.byte_stream) | |
return self.byte_stream,real_bits | |
class arithmeticDeCoding(): | |
""" | |
Decoding class | |
byte_stream: the bin file stream. | |
sysNum: the Number of symbols that you are going to decode. This value should be | |
saved in other ways. | |
sysDim: the Number of the possible symbols. | |
binfile: bin file path, if it is Not None, 'byte_stream' will read from this file | |
and copy to Cpp backend Class 'InCacheString' | |
""" | |
def __init__(self,byte_stream,sysNum,symDim,binfile=None) -> None: | |
if binfile is not None: | |
with open(binfile, 'rb') as fin: | |
byte_stream = fin.read() | |
self.byte_stream = byte_stream | |
self.decoder = numpyAc_backend.decode(self.byte_stream,sysNum,symDim+1) | |
def decode(self,pdf): | |
cdfF = pdf_convert_to_cdf_and_normalize(pdf) | |
pro = _convert_to_int_and_normalize(cdfF,needs_normalization=True) | |
pro = pro.squeeze(0).astype(np.uint16).tolist() | |
sym_out = self.decoder.decodeAsym(pro) | |
return sym_out | |