Spaces:
Sleeping
Sleeping
Decoding doesn't require number of symbols anymore
Browse files- numpyAc/backend/numpyAc_backend.cpp +42 -53
- numpyAc/numpyAc.py +2 -5
- test.py +1 -1
numpyAc/backend/numpyAc_backend.cpp
CHANGED
|
@@ -138,7 +138,6 @@ private:
|
|
| 138 |
public:
|
| 139 |
int dataID=0;
|
| 140 |
const int Lp;// To calculate offset
|
| 141 |
-
const int N_sym;// To know the # of syms to decode. Is encoded in the stream!
|
| 142 |
const int max_symbol;
|
| 143 |
uint32_t low = 0;
|
| 144 |
uint32_t high = 0xFFFFFFFFU;
|
|
@@ -147,71 +146,61 @@ public:
|
|
| 147 |
cdf_t sym_i = 0;
|
| 148 |
uint32_t value = 0;
|
| 149 |
InCacheString in_cache;
|
| 150 |
-
decode(const std::string &in, const int&
|
| 151 |
in_cache.initialize(value);
|
| 152 |
|
| 153 |
};
|
| 154 |
|
| 155 |
int16_t decodeAsym(py::list cdf) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
|
|
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
const uint64_t span = static_cast<uint64_t>(high) - static_cast<uint64_t>(low) + 1;
|
| 161 |
-
// always < 0x10000 ???
|
| 162 |
-
const uint16_t count = ((static_cast<uint64_t>(value) - static_cast<uint64_t>(low) + 1) * c_count - 1) / span;
|
| 163 |
|
| 164 |
-
|
|
|
|
| 165 |
|
| 166 |
-
|
|
|
|
| 167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
|
|
|
| 173 |
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
high = (low - 1) + ((span * static_cast<uint64_t>(c_high)) >> precision);
|
| 178 |
-
low = (low) + ((span * static_cast<uint64_t>(c_low)) >> precision);
|
| 179 |
-
|
| 180 |
-
while (true) {
|
| 181 |
-
if (low >= 0x80000000U || high < 0x80000000U) {
|
| 182 |
-
low <<= 1;
|
| 183 |
-
high <<= 1;
|
| 184 |
-
high |= 1;
|
| 185 |
-
|
| 186 |
-
in_cache.get(value);
|
| 187 |
-
|
| 188 |
-
} else if (low >= 0x40000000U && high < 0xC0000000U) {
|
| 189 |
-
/**
|
| 190 |
-
* 0100 0000 ... <= value < 1100 0000 ...
|
| 191 |
-
* <=>
|
| 192 |
-
* 0100 0000 ... <= value <= 1011 1111 ...
|
| 193 |
-
* <=>
|
| 194 |
-
* value starts with 01 or 10.
|
| 195 |
-
* 01 - 01 == 00 | 10 - 01 == 01
|
| 196 |
-
* i.e., with shifts
|
| 197 |
-
* 01A -> 0A or 10A -> 1A, i.e., discard 2SB as it's all the same while we are in
|
| 198 |
-
* near convergence
|
| 199 |
-
*/
|
| 200 |
-
low <<= 1;
|
| 201 |
-
low &= 0x7FFFFFFFU; // make MSB 0
|
| 202 |
-
high <<= 1;
|
| 203 |
-
high |= 0x80000001U; // add 1 at the end, retain MSB = 1
|
| 204 |
-
value -= 0x40000000U;
|
| 205 |
-
|
| 206 |
-
in_cache.get(value);
|
| 207 |
-
|
| 208 |
-
} else {
|
| 209 |
-
break;
|
| 210 |
-
}
|
| 211 |
}
|
| 212 |
-
|
| 213 |
-
return (int16_t)sym_i;
|
| 214 |
}
|
|
|
|
|
|
|
| 215 |
}
|
| 216 |
|
| 217 |
};
|
|
@@ -340,8 +329,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
| 340 |
m.def("encode_cdf", &encode_cdf, "Encode from CDF");
|
| 341 |
|
| 342 |
py::class_<decode>(m, "decode")
|
| 343 |
-
.def(py::init([] (const std::string in, const int&
|
| 344 |
-
return new decode(in,
|
| 345 |
}))
|
| 346 |
.def("decodeAsym", &decode::decodeAsym);
|
| 347 |
}
|
|
|
|
| 138 |
public:
|
| 139 |
int dataID=0;
|
| 140 |
const int Lp;// To calculate offset
|
|
|
|
| 141 |
const int max_symbol;
|
| 142 |
uint32_t low = 0;
|
| 143 |
uint32_t high = 0xFFFFFFFFU;
|
|
|
|
| 146 |
cdf_t sym_i = 0;
|
| 147 |
uint32_t value = 0;
|
| 148 |
InCacheString in_cache;
|
| 149 |
+
decode(const std::string &in, const int&sysNumDim_):in_cache(in),Lp(sysNumDim_),max_symbol(sysNumDim_-2){
|
| 150 |
in_cache.initialize(value);
|
| 151 |
|
| 152 |
};
|
| 153 |
|
| 154 |
int16_t decodeAsym(py::list cdf) {
|
| 155 |
+
|
| 156 |
+
const uint64_t span = static_cast<uint64_t>(high) - static_cast<uint64_t>(low) + 1;
|
| 157 |
+
// always < 0x10000 ???
|
| 158 |
+
const uint16_t count = ((static_cast<uint64_t>(value) - static_cast<uint64_t>(low) + 1) * c_count - 1) / span;
|
| 159 |
|
| 160 |
+
int offset = 0;
|
| 161 |
|
| 162 |
+
sym_i = binsearch(cdf, count, (cdf_t)max_symbol, offset);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
+
const uint32_t c_low = cdf[offset + sym_i].cast<cdf_t>();
|
| 165 |
+
const uint32_t c_high = sym_i == max_symbol ? 0x10000U : cdf[offset + sym_i + 1].cast<cdf_t>();
|
| 166 |
|
| 167 |
+
high = (low - 1) + ((span * static_cast<uint64_t>(c_high)) >> precision);
|
| 168 |
+
low = (low) + ((span * static_cast<uint64_t>(c_low)) >> precision);
|
| 169 |
|
| 170 |
+
while (true) {
|
| 171 |
+
if (low >= 0x80000000U || high < 0x80000000U) {
|
| 172 |
+
low <<= 1;
|
| 173 |
+
high <<= 1;
|
| 174 |
+
high |= 1;
|
| 175 |
|
| 176 |
+
in_cache.get(value);
|
| 177 |
+
|
| 178 |
+
} else if (low >= 0x40000000U && high < 0xC0000000U) {
|
| 179 |
+
/**
|
| 180 |
+
* 0100 0000 ... <= value < 1100 0000 ...
|
| 181 |
+
* <=>
|
| 182 |
+
* 0100 0000 ... <= value <= 1011 1111 ...
|
| 183 |
+
* <=>
|
| 184 |
+
* value starts with 01 or 10.
|
| 185 |
+
* 01 - 01 == 00 | 10 - 01 == 01
|
| 186 |
+
* i.e., with shifts
|
| 187 |
+
* 01A -> 0A or 10A -> 1A, i.e., discard 2SB as it's all the same while we are in
|
| 188 |
+
* near convergence
|
| 189 |
+
*/
|
| 190 |
+
low <<= 1;
|
| 191 |
+
low &= 0x7FFFFFFFU; // make MSB 0
|
| 192 |
+
high <<= 1;
|
| 193 |
+
high |= 0x80000001U; // add 1 at the end, retain MSB = 1
|
| 194 |
+
value -= 0x40000000U;
|
| 195 |
|
| 196 |
+
in_cache.get(value);
|
| 197 |
|
| 198 |
+
} else {
|
| 199 |
+
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
}
|
|
|
|
|
|
|
| 201 |
}
|
| 202 |
+
|
| 203 |
+
return (int16_t)sym_i;
|
| 204 |
}
|
| 205 |
|
| 206 |
};
|
|
|
|
| 329 |
m.def("encode_cdf", &encode_cdf, "Encode from CDF");
|
| 330 |
|
| 331 |
py::class_<decode>(m, "decode")
|
| 332 |
+
.def(py::init([] (const std::string in, const int&sysNumDim_) {
|
| 333 |
+
return new decode(in,sysNumDim_);
|
| 334 |
}))
|
| 335 |
.def("decodeAsym", &decode::decodeAsym);
|
| 336 |
}
|
numpyAc/numpyAc.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
| 4 |
-
from torch.autograd.grad_mode import F
|
| 5 |
from torch.utils.cpp_extension import load
|
| 6 |
|
| 7 |
|
|
@@ -151,18 +150,16 @@ class arithmeticDeCoding():
|
|
| 151 |
"""
|
| 152 |
Decoding class
|
| 153 |
byte_stream: the bin file stream.
|
| 154 |
-
sysNum: the Number of symbols that you are going to decode. This value should be
|
| 155 |
-
saved in other ways.
|
| 156 |
sysDim: the Number of the possible symbols.
|
| 157 |
binfile: bin file path, if it is Not None, 'byte_stream' will read from this file
|
| 158 |
and copy to Cpp backend Class 'InCacheString'
|
| 159 |
"""
|
| 160 |
-
def __init__(self,byte_stream,
|
| 161 |
if binfile is not None:
|
| 162 |
with open(binfile, 'rb') as fin:
|
| 163 |
byte_stream = fin.read()
|
| 164 |
self.byte_stream = byte_stream
|
| 165 |
-
self.decoder = numpyAc_backend.decode(self.byte_stream,
|
| 166 |
|
| 167 |
def decode(self,pdf):
|
| 168 |
cdfF = pdf_convert_to_cdf_and_normalize(pdf)
|
|
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
|
|
|
| 4 |
from torch.utils.cpp_extension import load
|
| 5 |
|
| 6 |
|
|
|
|
| 150 |
"""
|
| 151 |
Decoding class
|
| 152 |
byte_stream: the bin file stream.
|
|
|
|
|
|
|
| 153 |
sysDim: the Number of the possible symbols.
|
| 154 |
binfile: bin file path, if it is Not None, 'byte_stream' will read from this file
|
| 155 |
and copy to Cpp backend Class 'InCacheString'
|
| 156 |
"""
|
| 157 |
+
def __init__(self,byte_stream,symDim,binfile=None) -> None:
|
| 158 |
if binfile is not None:
|
| 159 |
with open(binfile, 'rb') as fin:
|
| 160 |
byte_stream = fin.read()
|
| 161 |
self.byte_stream = byte_stream
|
| 162 |
+
self.decoder = numpyAc_backend.decode(self.byte_stream,symDim+1)
|
| 163 |
|
| 164 |
def decode(self,pdf):
|
| 165 |
cdfF = pdf_convert_to_cdf_and_normalize(pdf)
|
test.py
CHANGED
|
@@ -20,7 +20,7 @@ print('real_bits',real_bits)
|
|
| 20 |
print('shannon entropy',-int(np.log2(pdf[range(0,symsNum),sym]).sum()))
|
| 21 |
|
| 22 |
# Decode from bytestream.
|
| 23 |
-
decodec = numpyAc.arithmeticDeCoding(None,
|
| 24 |
|
| 25 |
# Autoregressive decoding and output will be equal to the input.
|
| 26 |
for i,s in enumerate(sym):
|
|
|
|
| 20 |
print('shannon entropy',-int(np.log2(pdf[range(0,symsNum),sym]).sum()))
|
| 21 |
|
| 22 |
# Decode from bytestream.
|
| 23 |
+
decodec = numpyAc.arithmeticDeCoding(None,dim,'out.b')
|
| 24 |
|
| 25 |
# Autoregressive decoding and output will be equal to the input.
|
| 26 |
for i,s in enumerate(sym):
|