File size: 6,141 Bytes
e389f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b3e38f
e389f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os
import torch
import numpy as np
from torch.autograd.grad_mode import F
from torch.utils.cpp_extension import load


PRECISION = 16 # DO NOT EDIT!


# Load on-the-fly with ninja.
torchac_dir = os.path.dirname(os.path.realpath(__file__))
backend_dir = os.path.join(torchac_dir, 'backend')
numpyAc_backend = load(
  name="numpyAc_backend",
  sources=[os.path.join(backend_dir, "numpyAc_backend.cpp")],
  verbose=False)

def _encode_float_cdf(cdf_float,
                     sym,
                     needs_normalization=True,
                     check_input_bounds=False):
  """Encode symbols `sym` with potentially unnormalized floating point CDF.

  Check the README for more details.

  :param cdf_float: CDF tensor, float32, on CPU. Shape (N1, ..., Nm, Lp).
  :param sym: The symbols to encode, int16, on CPU. Shape (N1, ..., Nm).
  :param needs_normalization: if True, assume `cdf_float` is un-normalized and
    needs normalization. Otherwise only convert it, without normalizing.
  :param check_input_bounds: if True, ensure inputs have valid values.
    Important: may take significant time. Only enable to check.

  :return: byte-string, encoding `sym`.
  """
  if check_input_bounds:
    if cdf_float.min() < 0:
      raise ValueError(f'cdf_float.min() == {cdf_float.min()}, should be >=0.!')
    if cdf_float.max() > 1:
      raise ValueError(f'cdf_float.max() == {cdf_float.max()}, should be <=1.!')
    Lp = cdf_float.shape[-1]
    if sym.max() >= Lp - 1:
      raise ValueError(f'sym.max() == {sym.max()}, should be <=Lp - 1.!')
  cdf_int = _convert_to_int_and_normalize(cdf_float, needs_normalization)
  return _encode_int16_normalized_cdf(cdf_int, sym)


def _encode_int16_normalized_cdf(cdf_int, sym):
  """Encode symbols `sym` with a normalized integer cdf `cdf_int`.

  Check the README for more details.

  :param cdf_int: CDF tensor, int16, on CPU. Shape (N1, ..., Nm, Lp).
  :param sym: The symbols to encode, int16, on CPU. Shape (N1, ..., Nm).

  :return: byte-string, encoding `sym`
  """
  cdf_int, sym = _check_and_reshape_inputs(cdf_int, sym)
  return numpyAc_backend.encode_cdf( torch.ShortTensor(cdf_int), torch.ShortTensor(sym))


def _check_and_reshape_inputs(cdf, sym=None):
  """Check device, dtype, and shapes."""
  if sym is not None and sym.dtype != np.int16:
    raise ValueError('Symbols must be int16!')
  if sym is not None:
    if len(cdf.shape) != len(sym.shape) + 1 or cdf.shape[:-1] != sym.shape:
      raise ValueError(f'Invalid shapes of cdf={cdf.shape}, sym={sym.shape}! '
                       'The first m elements of cdf.shape must be equal to '
                       'sym.shape, and cdf should only have one more dimension.')
  Lp = cdf.shape[-1]
  cdf = cdf.reshape(-1, Lp)
  if sym is None:
    return cdf
  sym = sym.reshape(-1)
  return cdf, sym


# def _reshape_output(cdf_shape, sym):
#   """Reshape single dimension `sym` back to the correct spatial dimensions."""
#   spatial_dimensions = cdf_shape[:-1]
#   if len(sym) != np.prod(spatial_dimensions):
#     raise ValueError()
#   return sym.reshape(*spatial_dimensions)


def _convert_to_int_and_normalize(cdf_float, needs_normalization):
  """Convert floatingpoint CDF to integers. See README for more info.

  The idea is the following:
  When we get the cdf here, it is (assumed to be) between 0 and 1, i.e,
    cdf \in [0, 1)
  (note that 1 should not be included.)
  We now want to convert this to int16 but make sure we do not get
  the same value twice, as this would break the arithmetic coder
  (you need a strictly monotonically increasing function).
  So, if needs_normalization==True, we multiply the input CDF
  with 2**16 - (Lp - 1). This means that now,
    cdf \in [0, 2**16 - (Lp - 1)].
  Then, in a final step, we add an arange(Lp), which is just a line with
  slope one. This ensure that for sure, we will get unique, strictly
  monotonically increasing CDFs, which are \in [0, 2**16)
  """
  Lp = cdf_float.shape[-1]
  factor = 2**PRECISION
  new_max_value = factor
  if needs_normalization:
    new_max_value = new_max_value - (Lp - 1)
  cdf_float = cdf_float*(new_max_value)
  cdf_float = np.round(cdf_float)
  cdf = cdf_float.astype(np.int16)
  if needs_normalization:
    r = np.arange(Lp) 
    cdf+=r
  return cdf

def pdf_convert_to_cdf_and_normalize(pdf):
    assert pdf.ndim==2
    cdfF = np.cumsum( pdf, axis=1)
    cdfF = cdfF/cdfF[:,-1:]
    cdfF = np.hstack((np.zeros((pdf.shape[0],1)),cdfF))
    return cdfF

class arithmeticCoding():
  def __init__(self) -> None:
    self.binfile = None
    self.sysNum = None
    self.byte_stream = None
    

  def encode(self,pdf,sym,binfile=None):
    assert pdf.shape[0]==sym.shape[0]
    assert pdf.ndim==2 and sym.ndim==1

    self.sysNum = sym.shape[0]

    cdfF = pdf_convert_to_cdf_and_normalize(pdf)
    
    # pdf = np.diff(cdfF)
    # print( -np.log2(pdf[range(0,self.sysNum),sym]).sum())

    self.byte_stream = _encode_float_cdf(cdfF, sym, check_input_bounds=True)
    real_bits = len(self.byte_stream) * 8
    # # Write to a file.
    if binfile is not None:
      with open(binfile, 'wb') as fout:
          fout.write(self.byte_stream)
    return self.byte_stream,real_bits

class arithmeticDeCoding():
  """
    Decoding class
    byte_stream: the bin file stream.
    sysNum: the Number of symbols that you are going to decode. This value should be 
            saved in other ways.
    sysDim: the Number of the possible symbols.
    binfile: bin file path, if it is Not None, 'byte_stream' will read from this file
            and copy to Cpp backend Class 'InCacheString'
  """
  def __init__(self,byte_stream,sysNum,symDim,binfile=None) -> None:
      if binfile is not None:
        with open(binfile, 'rb') as fin:
          byte_stream = fin.read()  
      self.byte_stream = byte_stream
      self.decoder = numpyAc_backend.decode(self.byte_stream,sysNum,symDim+1)

  def decode(self,pdf):
    cdfF = pdf_convert_to_cdf_and_normalize(pdf)
    pro = _convert_to_int_and_normalize(cdfF,needs_normalization=True)
    pro = pro.squeeze(0).astype(np.uint16).tolist()
    sym_out = self.decoder.decodeAsym(pro)
    return sym_out