Skip to content

Instantly share code, notes, and snippets.

@mmalex
Created July 8, 2022 21:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mmalex/fb28b439e12841d9b1563a15adfb8242 to your computer and use it in GitHub Desktop.
Save mmalex/fb28b439e12841d9b1563a15adfb8242 to your computer and use it in GitHub Desktop.
attempt to understand range coder by implementing the algorithm on the wikipedia page, of all places...
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <assert.h>
#include <vector>
// arithmetic coder/decoder by @mmalex based on the wikipedia page https://en.wikipedia.org/wiki/Range_coding as of July 2022
struct encdec_state {
typedef uint16_t artype;
const static int topshift = sizeof(artype)*8-8;
artype low=0,range=~artype(0);
inline void shift() { low<<=8; range<<=8; }
// the heart of the algorithm is here:
template <typename T> static void normalise(T&self, uint32_t start, uint32_t end, uint32_t total) { // static so we can force inline a polymorphic self.byte() without virtual
self.low+=start*self.range;
self.range*=end-start;
while ((self.low>>topshift)==((self.low+self.range)>>topshift)) { self.byte(); self.shift(); }
if (self.range<total) { //printf("underflow range %llu\n", (uint64_t)self.range);
while (self.range<total) { self.byte(); self.shift(); }
self.range=~self.low;
}
}
};
struct encoder : public encdec_state {
std::vector<uint8_t> out;
inline void byte() { out.push_back(low>>topshift); } // push a byte to the output
inline void encode(uint32_t cdf_start, uint32_t cdf_end, uint32_t cdf_total) { // encode a symbol with the given range of cdf
range/=cdf_total;
normalise(*this,cdf_start,cdf_end,cdf_total);
}
inline void flush_encode() { // at the end of the stream, call this
artype ofs = artype(1)<<topshift;
while (range<ofs) { byte(); shift(); }
low+=ofs;
byte(); shift();
}
};
struct decoder : public encdec_state {
artype state=0;
const uint8_t *in=nullptr, *end=nullptr;
void byte() { state=(state<<8)+ ((in<end)?*in++ : 0); } // pull a byte from the input
decoder(const uint8_t *din, const uint8_t *dend) : encdec_state(), in(din), end(dend) {
for (int i=0;i<sizeof(artype);++i) byte();
range=~artype(0);
}
inline uint32_t decode(uint32_t cdf_total) { // call normalise after. returns a value from 0 to total that needs looking up in an inverse cdf
range /= cdf_total;
return (state - low) / range;
}
};
int main(int argc, char **argv) {
uint32_t cdf[257]={};
std::vector<uint8_t> plaintext;
for (int i=0;i<100;++i) {
plaintext.push_back('A'+(rand()%54));
cdf[1+plaintext.back()]++;
}
for (int i=0;i<256;++i) cdf[i+1]+=cdf[i];
uint32_t total=cdf[256];
std::vector<uint8_t> icdf(total);
for (int i=0;i<256;++i)
for (int j=cdf[i];j<cdf[i+1];++j) icdf[j]=i;
encoder enc;
for (auto sym : plaintext) {
printf("%c", sym);
enc.encode(cdf[sym], cdf[sym+1], total);
}
enc.flush_encode();
printf("\n%d bytes\n",(int)enc.out.size());
// decode it and make sure it round trips
decoder dec(enc.out.data(), enc.out.data()+enc.out.size());
for (auto orig_sym : plaintext) {
uint32_t k = dec.decode(total);
uint32_t sym=icdf[k];
encdec_state::normalise(dec,cdf[sym], cdf[sym+1], total);
printf("%c", sym);
if (sym!=orig_sym) {
printf("decode error, %d vs %d, range %d\n", sym, orig_sym, (int)dec.range);
assert(false);
}
}
printf("\n");
}
@mmalex
Copy link
Author

mmalex commented Jul 8, 2022

change artype to uint64_t in practice; its uint16_t here because that stimulates the underflow case more thoroughly :) but wont work well if the cdf total is too high.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment