Skip to content

Instantly share code, notes, and snippets.

@dsevero
Last active February 22, 2023 12:34
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dsevero/7e02d96e079ce44b89ff33d7a1ce1738 to your computer and use it in GitHub Desktop.
Save dsevero/7e02d96e079ce44b89ff33d7a1ce1738 to your computer and use it in GitHub Desktop.
Asymmetric Numeral Systems (ANS) codec in pure Python
def push(state, symbol, cdf_func, prec):
cdf_low, cdf_high = cdf_func(symbol)
freq = cdf_high - cdf_low
return prec*(state // freq) + (state % freq) + cdf_low
def pop(state, icdf_func, cdf_func, prec):
cdf_value = state % prec
symbol, cdf_low, cdf_high = icdf_func(cdf_value)
freq = cdf_high - cdf_low
return symbol, freq*(state // prec) + cdf_value - cdf_low
''' Heavily inspired by https://github.com/j-towns/ans-notes
'''
from math import log2
from functools import reduce
from rans import push, pop
initial_state = 0
precision = 8
alphabet = [0, 1, 2]
pmf = [1/2, 1/4, 1/4]
entropy = sum(p*log2(1/p) for p in pmf)
# For pmf=[1/2, 1/4, 1/4] at precision=8, the quantized cdf=[0, 4, 6, 8]
cdf = reduce(lambda acc,el: acc + [acc[-1] + round(el*precision)], pmf, [0])
# ANS requires these 2 functions.
def cdf_func(symbol):
''' Function signature is symbol -> (cdf_low, cdf_high).
This can be substituted for a more complex model like a neural network'''
return cdf[symbol], cdf[symbol+1]
def icdf_func(cdf_value):
''' Function signature is cdf_value -> (symbol, cdf_low, cdf_high).
Finds the symbol where cdf_func(symbol) <= cdf_value < cdf_func(symbol+1)
This can be substituted for a more complex model like a neural network'''
for symbol in alphabet:
cdf_low, cdf_high = cdf_func(symbol)
if cdf_low <= cdf_value < cdf_high:
return symbol, cdf_low, cdf_high
# Some symbols to compress
sequence = 100*[2, 0, 0, 1]
# Encode
state = initial_state
for symbol in reversed(sequence):
state = push(state, symbol, cdf_func, precision)
rate = state.bit_length()/len(sequence)
# Decode
decoded_sequence = len(sequence)*[None]
for i in range(len(sequence)):
decoded_sequence[i], state = pop(state, icdf_func, cdf_func, precision)
# Sanity checks
assert decoded_sequence == sequence
assert (rate - entropy) < 0.01
print(f'''
- Encoded {len(sequence)} symbols
- Rate: {rate} bits/symbol
- Entropy: {entropy} bits
''')
# - Encoded 400 symbols
# - Rate: 1.5025 bits/symbol
# - Entropy: 1.5 bits
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment