Skip to content

Instantly share code, notes, and snippets.

@mreid
Last active August 29, 2015 14:07
Show Gist options
  • Save mreid/eacfa4dc78c6b142c78f to your computer and use it in GitHub Desktop.
Save mreid/eacfa4dc78c6b142c78f to your computer and use it in GitHub Desktop.
Toy implementation of arithmetic coding as described in Cover & Thomas §13.3
# Example implementation of simple arithmetic coding in Python (2.7+).
#
# USAGE
#
# python -i arithmetic.py
# >>> m = {'a': 1, 'b': 1, 'c': 1}
# >>> model = dirichlet(m)
# >>> encode(model, "aabbaacc")
# '00011110011110010'
#
# NOTES
#
# This implementation has many shortcomings, e.g.,
# - There are several inefficient tests, loops, and conversions
# - There are a few places where code is uncessarily duplicated
# - It does not output the coded message as a stream
# - It can only code short messages due to machine precision
# - The is no defensive coding against errors (e.g., out-of-model symbols)
# - I've not implemented a decoder!
#
# The aim was to make the implementation here as close as possible to
# the algorithm described in lectures while giving some extra detail about
# routines such as finding extensions to binary intervals.
#
# For a more sophisticated implementation, please refer to:
#
# "Arithmetic Coding for Data Compression"
# I. H. Witten, R. M. Neal, and J. G. Cleary
# Communications of the ACM, Col. 30 (6), 1987
#
# AUTHOR: Mark Reid
# CREATED: 2014-09-30
def encode(G, stream):
'''
Arithmetically encodes the given stream using the guesser function G
which returns probabilities over symbols P(x|xs) given a sequence xs.
'''
u, v = 0.0, 1.0 # The interval [u, v) for the message
xs, bs = "", "" # The message xs, and binary code bs
p = G(xs) # Compute the initial distribution over symbols
# Iterate through stream, repeatedly finding the longest binary code
# that surrounds the interval for the message so far
for x in stream:
# Record the new symbol
xs += x
# Find the interval for the message so far
F_lo, F_hi = cdf_interval(p, x)
u, v = u + (v-u)*F_lo, u + (v-u)*F_hi
# Find a binary code whose interval surrounds [u,v)
bs = extend_around(bs, u, v)
# Update the symbol probabilities
p = G(xs)
# Stream finished so find shortest extension of the code that sits inside
# the top half of [u, v)
bs = extend_inside(bs, u + (v-u)/2, v)
return bs
##############################################################################
# Models
def dirichlet(m):
'''
Returns a Dirichlet model (as a function) for probabilities with
prior counts given by the symbol to count dictionary m.
Probabilities returned by the returned functions are (symbol, prob)
dictionaries.
'''
# Build a function that returns P(x|xs) based on the priors in m
# and the counts of the symbols in xs
def p(xs):
counts = m.copy()
for x in xs:
counts[x] += 1
total = sum(counts.values())
return { a: float(c)/total for a, c in counts.items() }
# Return the constructed function
return p
##############################################################################
# Interval methods
def cdf_interval(p, a):
'''
Compute the cumulative distribution interval [F(a'), F(a)) for the
probabilities p (represented as a (symbol,prob) dict) where
F(a) = P(x <= a) and a' is the symbol preceeding a.
'''
F_lo, F_hi = 0, 0
A = sorted(p)
for x in A:
F_lo, F_hi = F_hi, F_hi + p[x]
if x == a:
break
return F_lo, F_hi
def binary_interval(bs):
'''
Returns an interval [n, m) for n and m integers, and denominator d
representing the interval [n/d, m/d) for the binary string bs.
'''
n, d = to_rational(bs)
return n, n + 1, d
def to_rational(bs):
'''Return numerator and denominator for ratio of 0.bs.'''
n = 0
for b in bs:
n *= 2
n += int(b)
return n, 2**len(bs)
def around(bs, u, v):
'''Tests whether [0.bs, 0.bs111...) contains [u, v).'''
n, m, d = binary_interval(bs)
return (n <= u*d) and (v*d <= m)
def extend_around(bs, u, v):
'''Find the longest extension of the given binary string so its interval
wraps around the interval [u, v).'''
contained = True
while contained:
if around(bs + "0", u, v):
bs += "0"
elif around(bs + "1", u, v):
bs += "1"
else:
contained = False
return bs
def inside(bs, u, v):
'''Tests whether [0.bs, 0.bs111...) is contained by [u, v).'''
n, m, d = binary_interval(bs)
return (u*d <= n) and (m <= v*d)
def extend_inside(bs, u, v):
'''Find the shortest extension of the given binary string so its interval
sits inside the interval [u, v).'''
while not inside(bs, u, v):
# Test whether gap between binary interval and [u,v) is bigger at the
# bottom than at the top
n, m, d = binary_interval(bs)
if u*d - n > m - v*d:
bs += "1" # If so, move bottom up by halving
else:
bs += "0" # If not, move top down by halving
return bs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment