Instantly share code, notes, and snippets.

# dsevero/rans.py

Last active February 22, 2023 12:34
Show Gist options
• Save dsevero/7e02d96e079ce44b89ff33d7a1ce1738 to your computer and use it in GitHub Desktop.
Asymmetric Numeral Systems (ANS) codec in pure Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 def push(state, symbol, cdf_func, prec): cdf_low, cdf_high = cdf_func(symbol) freq = cdf_high - cdf_low return prec*(state // freq) + (state % freq) + cdf_low def pop(state, icdf_func, cdf_func, prec): cdf_value = state % prec symbol, cdf_low, cdf_high = icdf_func(cdf_value) freq = cdf_high - cdf_low return symbol, freq*(state // prec) + cdf_value - cdf_low
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 ''' Heavily inspired by https://github.com/j-towns/ans-notes ''' from math import log2 from functools import reduce from rans import push, pop initial_state = 0 precision = 8 alphabet = [0, 1, 2] pmf = [1/2, 1/4, 1/4] entropy = sum(p*log2(1/p) for p in pmf) # For pmf=[1/2, 1/4, 1/4] at precision=8, the quantized cdf=[0, 4, 6, 8] cdf = reduce(lambda acc,el: acc + [acc[-1] + round(el*precision)], pmf, [0]) # ANS requires these 2 functions. def cdf_func(symbol): ''' Function signature is symbol -> (cdf_low, cdf_high). This can be substituted for a more complex model like a neural network''' return cdf[symbol], cdf[symbol+1] def icdf_func(cdf_value): ''' Function signature is cdf_value -> (symbol, cdf_low, cdf_high). Finds the symbol where cdf_func(symbol) <= cdf_value < cdf_func(symbol+1) This can be substituted for a more complex model like a neural network''' for symbol in alphabet: cdf_low, cdf_high = cdf_func(symbol) if cdf_low <= cdf_value < cdf_high: return symbol, cdf_low, cdf_high # Some symbols to compress sequence = 100*[2, 0, 0, 1] # Encode state = initial_state for symbol in reversed(sequence): state = push(state, symbol, cdf_func, precision) rate = state.bit_length()/len(sequence) # Decode decoded_sequence = len(sequence)*[None] for i in range(len(sequence)): decoded_sequence[i], state = pop(state, icdf_func, cdf_func, precision) # Sanity checks assert decoded_sequence == sequence assert (rate - entropy) < 0.01 print(f''' - Encoded {len(sequence)} symbols - Rate: {rate} bits/symbol - Entropy: {entropy} bits ''') # - Encoded 400 symbols # - Rate: 1.5025 bits/symbol # - Entropy: 1.5 bits