Spaces:
Sleeping
Sleeping
Decoding doesn't require number of symbols anymore
Browse files- numpyAc/backend/numpyAc_backend.cpp +42 -53
- numpyAc/numpyAc.py +2 -5
- test.py +1 -1
numpyAc/backend/numpyAc_backend.cpp
CHANGED
@@ -138,7 +138,6 @@ private:
|
|
138 |
public:
|
139 |
int dataID=0;
|
140 |
const int Lp;// To calculate offset
|
141 |
-
const int N_sym;// To know the # of syms to decode. Is encoded in the stream!
|
142 |
const int max_symbol;
|
143 |
uint32_t low = 0;
|
144 |
uint32_t high = 0xFFFFFFFFU;
|
@@ -147,71 +146,61 @@ public:
|
|
147 |
cdf_t sym_i = 0;
|
148 |
uint32_t value = 0;
|
149 |
InCacheString in_cache;
|
150 |
-
decode(const std::string &in, const int&
|
151 |
in_cache.initialize(value);
|
152 |
|
153 |
};
|
154 |
|
155 |
int16_t decodeAsym(py::list cdf) {
|
|
|
|
|
|
|
|
|
156 |
|
|
|
157 |
|
158 |
-
|
159 |
-
|
160 |
-
const uint64_t span = static_cast<uint64_t>(high) - static_cast<uint64_t>(low) + 1;
|
161 |
-
// always < 0x10000 ???
|
162 |
-
const uint16_t count = ((static_cast<uint64_t>(value) - static_cast<uint64_t>(low) + 1) * c_count - 1) / span;
|
163 |
|
164 |
-
|
|
|
165 |
|
166 |
-
|
|
|
167 |
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
-
|
170 |
-
|
171 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
|
|
173 |
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
high = (low - 1) + ((span * static_cast<uint64_t>(c_high)) >> precision);
|
178 |
-
low = (low) + ((span * static_cast<uint64_t>(c_low)) >> precision);
|
179 |
-
|
180 |
-
while (true) {
|
181 |
-
if (low >= 0x80000000U || high < 0x80000000U) {
|
182 |
-
low <<= 1;
|
183 |
-
high <<= 1;
|
184 |
-
high |= 1;
|
185 |
-
|
186 |
-
in_cache.get(value);
|
187 |
-
|
188 |
-
} else if (low >= 0x40000000U && high < 0xC0000000U) {
|
189 |
-
/**
|
190 |
-
* 0100 0000 ... <= value < 1100 0000 ...
|
191 |
-
* <=>
|
192 |
-
* 0100 0000 ... <= value <= 1011 1111 ...
|
193 |
-
* <=>
|
194 |
-
* value starts with 01 or 10.
|
195 |
-
* 01 - 01 == 00 | 10 - 01 == 01
|
196 |
-
* i.e., with shifts
|
197 |
-
* 01A -> 0A or 10A -> 1A, i.e., discard 2SB as it's all the same while we are in
|
198 |
-
* near convergence
|
199 |
-
*/
|
200 |
-
low <<= 1;
|
201 |
-
low &= 0x7FFFFFFFU; // make MSB 0
|
202 |
-
high <<= 1;
|
203 |
-
high |= 0x80000001U; // add 1 at the end, retain MSB = 1
|
204 |
-
value -= 0x40000000U;
|
205 |
-
|
206 |
-
in_cache.get(value);
|
207 |
-
|
208 |
-
} else {
|
209 |
-
break;
|
210 |
-
}
|
211 |
}
|
212 |
-
|
213 |
-
return (int16_t)sym_i;
|
214 |
}
|
|
|
|
|
215 |
}
|
216 |
|
217 |
};
|
@@ -340,8 +329,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
340 |
m.def("encode_cdf", &encode_cdf, "Encode from CDF");
|
341 |
|
342 |
py::class_<decode>(m, "decode")
|
343 |
-
.def(py::init([] (const std::string in, const int&
|
344 |
-
return new decode(in,
|
345 |
}))
|
346 |
.def("decodeAsym", &decode::decodeAsym);
|
347 |
}
|
|
|
138 |
public:
|
139 |
int dataID=0;
|
140 |
const int Lp;// To calculate offset
|
|
|
141 |
const int max_symbol;
|
142 |
uint32_t low = 0;
|
143 |
uint32_t high = 0xFFFFFFFFU;
|
|
|
146 |
cdf_t sym_i = 0;
|
147 |
uint32_t value = 0;
|
148 |
InCacheString in_cache;
|
149 |
+
decode(const std::string &in, const int&sysNumDim_):in_cache(in),Lp(sysNumDim_),max_symbol(sysNumDim_-2){
|
150 |
in_cache.initialize(value);
|
151 |
|
152 |
};
|
153 |
|
154 |
int16_t decodeAsym(py::list cdf) {
|
155 |
+
|
156 |
+
const uint64_t span = static_cast<uint64_t>(high) - static_cast<uint64_t>(low) + 1;
|
157 |
+
// always < 0x10000 ???
|
158 |
+
const uint16_t count = ((static_cast<uint64_t>(value) - static_cast<uint64_t>(low) + 1) * c_count - 1) / span;
|
159 |
|
160 |
+
int offset = 0;
|
161 |
|
162 |
+
sym_i = binsearch(cdf, count, (cdf_t)max_symbol, offset);
|
|
|
|
|
|
|
|
|
163 |
|
164 |
+
const uint32_t c_low = cdf[offset + sym_i].cast<cdf_t>();
|
165 |
+
const uint32_t c_high = sym_i == max_symbol ? 0x10000U : cdf[offset + sym_i + 1].cast<cdf_t>();
|
166 |
|
167 |
+
high = (low - 1) + ((span * static_cast<uint64_t>(c_high)) >> precision);
|
168 |
+
low = (low) + ((span * static_cast<uint64_t>(c_low)) >> precision);
|
169 |
|
170 |
+
while (true) {
|
171 |
+
if (low >= 0x80000000U || high < 0x80000000U) {
|
172 |
+
low <<= 1;
|
173 |
+
high <<= 1;
|
174 |
+
high |= 1;
|
175 |
|
176 |
+
in_cache.get(value);
|
177 |
+
|
178 |
+
} else if (low >= 0x40000000U && high < 0xC0000000U) {
|
179 |
+
/**
|
180 |
+
* 0100 0000 ... <= value < 1100 0000 ...
|
181 |
+
* <=>
|
182 |
+
* 0100 0000 ... <= value <= 1011 1111 ...
|
183 |
+
* <=>
|
184 |
+
* value starts with 01 or 10.
|
185 |
+
* 01 - 01 == 00 | 10 - 01 == 01
|
186 |
+
* i.e., with shifts
|
187 |
+
* 01A -> 0A or 10A -> 1A, i.e., discard 2SB as it's all the same while we are in
|
188 |
+
* near convergence
|
189 |
+
*/
|
190 |
+
low <<= 1;
|
191 |
+
low &= 0x7FFFFFFFU; // make MSB 0
|
192 |
+
high <<= 1;
|
193 |
+
high |= 0x80000001U; // add 1 at the end, retain MSB = 1
|
194 |
+
value -= 0x40000000U;
|
195 |
|
196 |
+
in_cache.get(value);
|
197 |
|
198 |
+
} else {
|
199 |
+
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
}
|
|
|
|
|
201 |
}
|
202 |
+
|
203 |
+
return (int16_t)sym_i;
|
204 |
}
|
205 |
|
206 |
};
|
|
|
329 |
m.def("encode_cdf", &encode_cdf, "Encode from CDF");
|
330 |
|
331 |
py::class_<decode>(m, "decode")
|
332 |
+
.def(py::init([] (const std::string in, const int&sysNumDim_) {
|
333 |
+
return new decode(in,sysNumDim_);
|
334 |
}))
|
335 |
.def("decodeAsym", &decode::decodeAsym);
|
336 |
}
|
numpyAc/numpyAc.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import os
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
-
from torch.autograd.grad_mode import F
|
5 |
from torch.utils.cpp_extension import load
|
6 |
|
7 |
|
@@ -151,18 +150,16 @@ class arithmeticDeCoding():
|
|
151 |
"""
|
152 |
Decoding class
|
153 |
byte_stream: the bin file stream.
|
154 |
-
sysNum: the Number of symbols that you are going to decode. This value should be
|
155 |
-
saved in other ways.
|
156 |
sysDim: the Number of the possible symbols.
|
157 |
binfile: bin file path, if it is Not None, 'byte_stream' will read from this file
|
158 |
and copy to Cpp backend Class 'InCacheString'
|
159 |
"""
|
160 |
-
def __init__(self,byte_stream,
|
161 |
if binfile is not None:
|
162 |
with open(binfile, 'rb') as fin:
|
163 |
byte_stream = fin.read()
|
164 |
self.byte_stream = byte_stream
|
165 |
-
self.decoder = numpyAc_backend.decode(self.byte_stream,
|
166 |
|
167 |
def decode(self,pdf):
|
168 |
cdfF = pdf_convert_to_cdf_and_normalize(pdf)
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
import numpy as np
|
|
|
4 |
from torch.utils.cpp_extension import load
|
5 |
|
6 |
|
|
|
150 |
"""
|
151 |
Decoding class
|
152 |
byte_stream: the bin file stream.
|
|
|
|
|
153 |
sysDim: the Number of the possible symbols.
|
154 |
binfile: bin file path, if it is Not None, 'byte_stream' will read from this file
|
155 |
and copy to Cpp backend Class 'InCacheString'
|
156 |
"""
|
157 |
+
def __init__(self,byte_stream,symDim,binfile=None) -> None:
|
158 |
if binfile is not None:
|
159 |
with open(binfile, 'rb') as fin:
|
160 |
byte_stream = fin.read()
|
161 |
self.byte_stream = byte_stream
|
162 |
+
self.decoder = numpyAc_backend.decode(self.byte_stream,symDim+1)
|
163 |
|
164 |
def decode(self,pdf):
|
165 |
cdfF = pdf_convert_to_cdf_and_normalize(pdf)
|
test.py
CHANGED
@@ -20,7 +20,7 @@ print('real_bits',real_bits)
|
|
20 |
print('shannon entropy',-int(np.log2(pdf[range(0,symsNum),sym]).sum()))
|
21 |
|
22 |
# Decode from bytestream.
|
23 |
-
decodec = numpyAc.arithmeticDeCoding(None,
|
24 |
|
25 |
# Autoregressive decoding and output will be equal to the input.
|
26 |
for i,s in enumerate(sym):
|
|
|
20 |
print('shannon entropy',-int(np.log2(pdf[range(0,symsNum),sym]).sum()))
|
21 |
|
22 |
# Decode from bytestream.
|
23 |
+
decodec = numpyAc.arithmeticDeCoding(None,dim,'out.b')
|
24 |
|
25 |
# Autoregressive decoding and output will be equal to the input.
|
26 |
for i,s in enumerate(sym):
|