Skip to content

Instantly share code, notes, and snippets.

@jaburns
Created June 11, 2021 16:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jaburns/4c4a597021ac884dab9f7ec65d06a270 to your computer and use it in GitHub Desktop.
Save jaburns/4c4a597021ac884dab9f7ec65d06a270 to your computer and use it in GitHub Desktop.
Rust implementation of an arithmetic coder with some basic models
// Reference: https://github.com/rygorous/gaffer_net/blob/master/main.cpp
const PROBABILITY_BITS: u32 = 15;
pub const PROBABILITY_MAX: u32 = 1 << PROBABILITY_BITS;
struct BinaryArithCoder {
lo: u32,
hi: u32,
bytes: Vec<u8>,
}
impl BinaryArithCoder {
pub fn new() -> Self {
Self {
lo: 0,
hi: !0,
bytes: Vec::new(),
}
}
pub fn encode(&mut self, bit: bool, prob: u32) {
let lo64 = self.lo as u64;
let hi64 = self.hi as u64;
let prob64 = prob as u64;
let prob_bits64 = PROBABILITY_BITS as u64;
let x = self.lo + (((hi64 - lo64) * prob64) >> prob_bits64) as u32;
if bit {
self.hi = x;
} else {
self.lo = x + 1;
}
while (self.lo ^ self.hi) < (1u32 << 24) {
self.bytes.push((self.lo >> 24) as u8);
self.lo <<= 8;
self.hi = (self.hi << 8) | 0xff;
}
}
pub fn finalize(mut self) -> Vec<u8> {
let mut round_up: u32 = 0xffffffu32;
while round_up > 0 {
if (self.lo | round_up) != !0u32 {
let rounded: u32 = (self.lo + round_up) & !round_up;
if rounded <= self.hi {
self.lo = rounded;
break;
}
}
round_up >>= 8;
}
while self.lo > 0 {
self.bytes.push((self.lo >> 24) as u8);
self.lo <<= 8;
}
self.bytes
}
}
struct BinaryArithDecoder<'a> {
code: u32,
lo: u32,
hi: u32,
bytes: &'a [u8],
read_pos: usize,
}
impl<'a> BinaryArithDecoder<'a> {
pub fn new(bytes: &'a [u8]) -> Self {
let mut ret = Self {
lo: 0,
hi: !0,
code: 0,
bytes,
read_pos: 0,
};
for _ in 0..4 {
ret.code = (ret.code << 8) | ret.get_byte() as u32;
}
ret
}
fn get_byte(&mut self) -> u8 {
if self.read_pos < self.bytes.len() {
let i = self.read_pos;
self.read_pos += 1;
self.bytes[i]
} else {
0
}
}
pub fn decode(&mut self, prob: u32) -> bool {
let lo64 = self.lo as u64;
let hi64 = self.hi as u64;
let prob64 = prob as u64;
let prob_bits64 = PROBABILITY_BITS as u64;
let x = self.lo + (((hi64 - lo64) * prob64) >> prob_bits64) as u32;
let bit;
if self.code <= x {
self.hi = x;
bit = true;
} else {
self.lo = x + 1;
bit = false;
}
while (self.lo ^ self.hi) < (1u32 << 24) {
self.code = (self.code << 8) | self.get_byte() as u32;
self.lo <<= 8;
self.hi = (self.hi << 8) | 0xff;
}
bit
}
}
trait EncDec {
fn encode(&mut self, coder: &mut BinaryArithCoder, bit: bool);
fn decode(&mut self, coder: &mut BinaryArithDecoder) -> bool;
}
#[derive(Clone)]
struct TwoBinShiftModel<const INERTIA_0: u32, const INERTIA_1: u32> {
prob0: u32,
prob1: u32,
}
impl<const INERTIA_0: u32, const INERTIA_1: u32> Default
for TwoBinShiftModel<INERTIA_0, INERTIA_1>
{
fn default() -> Self {
Self {
prob0: PROBABILITY_MAX / 4,
prob1: PROBABILITY_MAX / 4,
}
}
}
impl<const INERTIA_0: u32, const INERTIA_1: u32> TwoBinShiftModel<INERTIA_0, INERTIA_1> {
pub fn new() -> Self {
Self {
prob0: PROBABILITY_MAX / 4,
prob1: PROBABILITY_MAX / 4,
}
}
fn adapt(&mut self, bit: bool) {
if bit {
self.prob0 += (PROBABILITY_MAX / 2 - self.prob0) >> INERTIA_0;
self.prob1 += (PROBABILITY_MAX / 2 - self.prob1) >> INERTIA_1;
} else {
self.prob0 -= self.prob0 >> INERTIA_0;
self.prob1 -= self.prob1 >> INERTIA_1;
}
}
}
impl<const INERTIA_0: u32, const INERTIA_1: u32> EncDec for TwoBinShiftModel<INERTIA_0, INERTIA_1> {
fn encode(&mut self, coder: &mut BinaryArithCoder, bit: bool) {
coder.encode(bit, self.prob0 + self.prob1);
self.adapt(bit);
}
fn decode(&mut self, coder: &mut BinaryArithDecoder) -> bool {
let bit = coder.decode(self.prob0 + self.prob1);
self.adapt(bit);
bit
}
}
struct BitTreeModel<MODEL, const NUM_BITS: usize> {
model: Vec<MODEL>,
}
impl<MODEL: Clone + Default + EncDec, const NUM_BITS: usize> BitTreeModel<MODEL, NUM_BITS> {
const NUM_SYMS: usize = 1 << NUM_BITS;
const MSB: usize = 1 << (NUM_BITS - 1);
pub fn new() -> Self {
Self {
model: vec![MODEL::default(); Self::NUM_SYMS - 1],
}
}
pub fn encode(&mut self, coder: &mut BinaryArithCoder, mut value: usize) {
std::assert!(value < Self::NUM_SYMS);
let mut ctx: usize = 1;
while ctx < Self::NUM_SYMS {
let bit = (value & Self::MSB) != 0;
value += value;
self.model[ctx - 1].encode(coder, bit);
ctx = ctx + ctx + bit as usize;
}
}
pub fn decode(&mut self, coder: &mut BinaryArithDecoder) -> usize {
let mut ctx: usize = 1;
while ctx < Self::NUM_SYMS {
ctx = ctx + ctx + self.model[ctx - 1].decode(coder) as usize;
}
ctx - Self::NUM_SYMS
}
}
fn test_encode() {
let bytes = std::fs::read("test.txt").unwrap();
let mut coder = BinaryArithCoder::new();
let mut model = BitTreeModel::<TwoBinShiftModel<3, 7>, 8>::new();
for byte in &bytes {
model.encode(&mut coder, *byte as usize);
}
let mut out = coder.finalize();
let size_bytes = unsafe { std::mem::transmute::<u32, [u8; 4]>(bytes.len() as u32) };
out.splice(0..0, size_bytes.iter().cloned());
std::fs::write("test.out", out).unwrap();
}
fn test_decode() {
let bytes = std::fs::read("test.out").unwrap();
let mut bit_count = 8 * unsafe {
std::mem::transmute::<[u8; 4], u32>([bytes[0], bytes[1], bytes[2], bytes[3]])
};
let mut coder = BinaryArithDecoder::new(&bytes[4..]);
let mut model = BitTreeModel::<TwoBinShiftModel<3, 7>, 8>::new();
let mut out_bytes = Vec::<u8>::new();
while bit_count > 0 {
out_bytes.push(model.decode(&mut coder) as u8);
bit_count -= 8;
}
std::fs::write("test.out.txt", out_bytes).unwrap();
}
fn main() {
test_encode();
test_decode();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment