tomaszki commited on
Commit
485fe95
1 Parent(s): 0b6ba53

Decoding doesn't require number of symbols anymore

Browse files
numpyAc/backend/numpyAc_backend.cpp CHANGED
@@ -138,7 +138,6 @@ private:
138
  public:
139
  int dataID=0;
140
  const int Lp;// To calculate offset
141
- const int N_sym;// To know the # of syms to decode. Is encoded in the stream!
142
  const int max_symbol;
143
  uint32_t low = 0;
144
  uint32_t high = 0xFFFFFFFFU;
@@ -147,71 +146,61 @@ public:
147
  cdf_t sym_i = 0;
148
  uint32_t value = 0;
149
  InCacheString in_cache;
150
- decode(const std::string &in, const int&sysNum_,const int&sysNumDim_):in_cache(in),N_sym(sysNum_),Lp(sysNumDim_),max_symbol(sysNumDim_-2){
151
  in_cache.initialize(value);
152
 
153
  };
154
 
155
  int16_t decodeAsym(py::list cdf) {
 
 
 
 
156
 
 
157
 
158
- for (; dataID < N_sym; ++dataID) {
159
-
160
- const uint64_t span = static_cast<uint64_t>(high) - static_cast<uint64_t>(low) + 1;
161
- // always < 0x10000 ???
162
- const uint16_t count = ((static_cast<uint64_t>(value) - static_cast<uint64_t>(low) + 1) * c_count - 1) / span;
163
 
164
- int offset = 0;
 
165
 
166
- sym_i = binsearch(cdf, count, (cdf_t)max_symbol, offset);
 
167
 
 
 
 
 
 
168
 
169
- if (dataID == N_sym-1) {
170
- break;
171
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
 
173
 
174
- const uint32_t c_low = cdf[offset + sym_i].cast<cdf_t>();
175
- const uint32_t c_high = sym_i == max_symbol ? 0x10000U : cdf[offset + sym_i + 1].cast<cdf_t>();
176
-
177
- high = (low - 1) + ((span * static_cast<uint64_t>(c_high)) >> precision);
178
- low = (low) + ((span * static_cast<uint64_t>(c_low)) >> precision);
179
-
180
- while (true) {
181
- if (low >= 0x80000000U || high < 0x80000000U) {
182
- low <<= 1;
183
- high <<= 1;
184
- high |= 1;
185
-
186
- in_cache.get(value);
187
-
188
- } else if (low >= 0x40000000U && high < 0xC0000000U) {
189
- /**
190
- * 0100 0000 ... <= value < 1100 0000 ...
191
- * <=>
192
- * 0100 0000 ... <= value <= 1011 1111 ...
193
- * <=>
194
- * value starts with 01 or 10.
195
- * 01 - 01 == 00 | 10 - 01 == 01
196
- * i.e., with shifts
197
- * 01A -> 0A or 10A -> 1A, i.e., discard 2SB as it's all the same while we are in
198
- * near convergence
199
- */
200
- low <<= 1;
201
- low &= 0x7FFFFFFFU; // make MSB 0
202
- high <<= 1;
203
- high |= 0x80000001U; // add 1 at the end, retain MSB = 1
204
- value -= 0x40000000U;
205
-
206
- in_cache.get(value);
207
-
208
- } else {
209
- break;
210
- }
211
  }
212
-
213
- return (int16_t)sym_i;
214
  }
 
 
215
  }
216
 
217
  };
@@ -340,8 +329,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
340
  m.def("encode_cdf", &encode_cdf, "Encode from CDF");
341
 
342
  py::class_<decode>(m, "decode")
343
- .def(py::init([] (const std::string in, const int&sysNum_,const int&sysNumDim_) {
344
- return new decode(in,sysNum_,sysNumDim_);
345
  }))
346
  .def("decodeAsym", &decode::decodeAsym);
347
  }
 
138
  public:
139
  int dataID=0;
140
  const int Lp;// To calculate offset
 
141
  const int max_symbol;
142
  uint32_t low = 0;
143
  uint32_t high = 0xFFFFFFFFU;
 
146
  cdf_t sym_i = 0;
147
  uint32_t value = 0;
148
  InCacheString in_cache;
149
+ decode(const std::string &in, const int&sysNumDim_):in_cache(in),Lp(sysNumDim_),max_symbol(sysNumDim_-2){
150
  in_cache.initialize(value);
151
 
152
  };
153
 
154
  int16_t decodeAsym(py::list cdf) {
155
+
156
+ const uint64_t span = static_cast<uint64_t>(high) - static_cast<uint64_t>(low) + 1;
157
+ // always < 0x10000 ???
158
+ const uint16_t count = ((static_cast<uint64_t>(value) - static_cast<uint64_t>(low) + 1) * c_count - 1) / span;
159
 
160
+ int offset = 0;
161
 
162
+ sym_i = binsearch(cdf, count, (cdf_t)max_symbol, offset);
 
 
 
 
163
 
164
+ const uint32_t c_low = cdf[offset + sym_i].cast<cdf_t>();
165
+ const uint32_t c_high = sym_i == max_symbol ? 0x10000U : cdf[offset + sym_i + 1].cast<cdf_t>();
166
 
167
+ high = (low - 1) + ((span * static_cast<uint64_t>(c_high)) >> precision);
168
+ low = (low) + ((span * static_cast<uint64_t>(c_low)) >> precision);
169
 
170
+ while (true) {
171
+ if (low >= 0x80000000U || high < 0x80000000U) {
172
+ low <<= 1;
173
+ high <<= 1;
174
+ high |= 1;
175
 
176
+ in_cache.get(value);
177
+
178
+ } else if (low >= 0x40000000U && high < 0xC0000000U) {
179
+ /**
180
+ * 0100 0000 ... <= value < 1100 0000 ...
181
+ * <=>
182
+ * 0100 0000 ... <= value <= 1011 1111 ...
183
+ * <=>
184
+ * value starts with 01 or 10.
185
+ * 01 - 01 == 00 | 10 - 01 == 01
186
+ * i.e., with shifts
187
+ * 01A -> 0A or 10A -> 1A, i.e., discard 2SB as it's all the same while we are in
188
+ * near convergence
189
+ */
190
+ low <<= 1;
191
+ low &= 0x7FFFFFFFU; // make MSB 0
192
+ high <<= 1;
193
+ high |= 0x80000001U; // add 1 at the end, retain MSB = 1
194
+ value -= 0x40000000U;
195
 
196
+ in_cache.get(value);
197
 
198
+ } else {
199
+ break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  }
 
 
201
  }
202
+
203
+ return (int16_t)sym_i;
204
  }
205
 
206
  };
 
329
  m.def("encode_cdf", &encode_cdf, "Encode from CDF");
330
 
331
  py::class_<decode>(m, "decode")
332
+ .def(py::init([] (const std::string in, const int&sysNumDim_) {
333
+ return new decode(in,sysNumDim_);
334
  }))
335
  .def("decodeAsym", &decode::decodeAsym);
336
  }
numpyAc/numpyAc.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
  import torch
3
  import numpy as np
4
- from torch.autograd.grad_mode import F
5
  from torch.utils.cpp_extension import load
6
 
7
 
@@ -151,18 +150,16 @@ class arithmeticDeCoding():
151
  """
152
  Decoding class
153
  byte_stream: the bin file stream.
154
- sysNum: the Number of symbols that you are going to decode. This value should be
155
- saved in other ways.
156
  sysDim: the Number of the possible symbols.
157
  binfile: bin file path, if it is Not None, 'byte_stream' will read from this file
158
  and copy to Cpp backend Class 'InCacheString'
159
  """
160
- def __init__(self,byte_stream,sysNum,symDim,binfile=None) -> None:
161
  if binfile is not None:
162
  with open(binfile, 'rb') as fin:
163
  byte_stream = fin.read()
164
  self.byte_stream = byte_stream
165
- self.decoder = numpyAc_backend.decode(self.byte_stream,sysNum,symDim+1)
166
 
167
  def decode(self,pdf):
168
  cdfF = pdf_convert_to_cdf_and_normalize(pdf)
 
1
  import os
2
  import torch
3
  import numpy as np
 
4
  from torch.utils.cpp_extension import load
5
 
6
 
 
150
  """
151
  Decoding class
152
  byte_stream: the bin file stream.
 
 
153
  sysDim: the Number of the possible symbols.
154
  binfile: bin file path, if it is Not None, 'byte_stream' will read from this file
155
  and copy to Cpp backend Class 'InCacheString'
156
  """
157
+ def __init__(self,byte_stream,symDim,binfile=None) -> None:
158
  if binfile is not None:
159
  with open(binfile, 'rb') as fin:
160
  byte_stream = fin.read()
161
  self.byte_stream = byte_stream
162
+ self.decoder = numpyAc_backend.decode(self.byte_stream,symDim+1)
163
 
164
  def decode(self,pdf):
165
  cdfF = pdf_convert_to_cdf_and_normalize(pdf)
test.py CHANGED
@@ -20,7 +20,7 @@ print('real_bits',real_bits)
20
  print('shannon entropy',-int(np.log2(pdf[range(0,symsNum),sym]).sum()))
21
 
22
  # Decode from bytestream.
23
- decodec = numpyAc.arithmeticDeCoding(None,symsNum,dim,'out.b')
24
 
25
  # Autoregressive decoding and output will be equal to the input.
26
  for i,s in enumerate(sym):
 
20
  print('shannon entropy',-int(np.log2(pdf[range(0,symsNum),sym]).sum()))
21
 
22
  # Decode from bytestream.
23
+ decodec = numpyAc.arithmeticDeCoding(None,dim,'out.b')
24
 
25
  # Autoregressive decoding and output will be equal to the input.
26
  for i,s in enumerate(sym):