Skip to content

Instantly share code, notes, and snippets.

@seece
Last active April 19, 2020 18:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save seece/5d1e704f66cb03e4d720e901be998d8d to your computer and use it in GitHub Desktop.
Save seece/5d1e704f66cb03e4d720e901be998d8d to your computer and use it in GitHub Desktop.
An arithmetic coder
"""
A pretty terrible arithmetic coder with a 0th order model.
Based on Matt Mahoney's fpaq0 implementation available at
http://mattmahoney.net/dc/#fpaq0
"""
import sys
import os
import struct
EOF = b''
INTMASK = 0xffffffff
def getc(f):
c = f.read(1)
if c == EOF:
return EOF
return ord(c)
def putc(c, f):
f.write(struct.pack('B', c))
class Model():
def __init__(self):
self.history = 0
self.counts = {}
def predict(self):
"Returns a probability of getting 1 next"
if self.history not in self.counts:
self.counts[self.history] = [0, 0]
onecount = 1 + self.counts[self.history][1]
zerocount = 1 + self.counts[self.history][0]
return int(4096 * onecount / (onecount + zerocount))
def update(self, y):
count = self.counts[self.history]
count[y] += 1
if count[y] > 65534:
count[0] /= 2
count[1] /= 2
self.history = ((self.history << 1) | y) & 0xff
class Coder:
def __init__(self, fp):
self.x1 = 0
self.x2 = 0xffffffff
self.fp = fp
self.model = Model()
class Encoder(Coder):
def __init__(self, fp):
super().__init__(fp)
def encode(self, y):
x1 = self.x1
x2 = self.x2
p = self.model.predict() # prob for getting 1
mid = x1 + (p * ((x2 - x1) >> 12)) & INTMASK
mid = mid & INTMASK
assert(mid >= x1 and mid < x2)
if y is 1:
x2 = mid
else:
x1 = mid + 1 # why +1?
self.model.update(y)
while ((x2 ^ x1) & 0xff000000) == 0:
putc(x2 >> 24, self.fp)
x1 = (x1 << 8) & INTMASK
x2 = ((x2 << 8) | 0xff) & INTMASK
self.x1 = x1
self.x2 = x2
def flush(self):
x1 = self.x1
x2 = self.x2
while ((x2 ^ x1) & 0xff000000) == 0:
putc(x2 >> 24, self.fp)
x1 = x1 << 8
x2 = (x2 << 8) | 0xff
# why do we write out the last unequal byte?
# is x2 now large enough to encode the full range?
putc((x2 >> 24) & INTMASK, self.fp)
self.x1 = x1
self.x2 = x2
class Decoder(Coder):
def __init__(self, fp):
super().__init__(fp)
self.x = 0
for i in range(4):
c = getc(self.fp)
if (c == EOF):
c = 0
self.x = (self.x << 8) | (c & 0xff)
assert (self.x & INTMASK) == self.x
def decode(self):
x1 = self.x1
x2 = self.x2
x = self.x
p = self.model.predict()
mid = x1 + (p * ((x2 - x1) >> 12)) & INTMASK
mid = mid & INTMASK
assert(mid >= x1 and mid < x2)
if x < mid:
y = 1
x2 = mid
elif x >= mid:
y = 0
x1 = mid + 1
self.model.update(y)
while ((x2 ^ x1) & 0xff000000) == 0:
x1 = (x1 << 8) & INTMASK
x2 = ((x2 << 8) | 0xff) & INTMASK
c = getc(self.fp)
if (c == EOF):
c = 0 # x shouldn't be read after EOF
x = ((x << 8) | c) & INTMASK
self.x1 = x1
self.x2 = x2
self.x = x
return y
if __name__ == "__main__":
cmd = sys.argv[1]
path = sys.argv[2]
outpath = sys.argv[3]
print(cmd, path, outpath)
if cmd == 'c':
size = os.path.getsize(path)
fin = open(path, 'rb')
fout = open(outpath, 'wb')
print(f"Writing file size {size}")
fout.write(struct.pack("I", size))
e = Encoder(fout)
c = getc(fin)
while c != EOF:
for i in reversed(range(8)):
e.encode((c >> i) & 0x1)
c = getc(fin)
e.flush()
fout.close()
fin.close()
elif cmd == 'd':
fin = open(path, 'rb')
size, = struct.unpack("I", fin.read(4))
print(f"Uncompressed file size {size}")
fout = open(outpath, 'wb')
d = Decoder(fin)
c = 0
for index in range(size):
c = 0
for i in range(8):
c = (c << 1) | d.decode()
c = putc(c, fout)
fout.close()
fin.close()
else:
print(f"Invalid command {cmd}!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment