Skip to content

Instantly share code, notes, and snippets.

@orlp
Created May 22, 2016 17:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save orlp/255a2a84b6023434b6662c5c0d4bc243 to your computer and use it in GitHub Desktop.
Save orlp/255a2a84b6023434b6662c5c0d4bc243 to your computer and use it in GitHub Desktop.
import bisect
import io
# We use 17-bit internal values with 15-bit probabilities.
CODE_VALUE_BITS = 17
FP_MAX = (1 << CODE_VALUE_BITS) - 1
FP_QUARTER = FP_MAX//4 + 1
FP_HALF = 2*FP_QUARTER
FP_3QUARTER = 3*FP_QUARTER
FREQ_BITS = CODE_VALUE_BITS - 2
MAX_FREQ = (1 << FREQ_BITS) - 1
# Histogram to cumulative probability model.
def hist_to_model(hist):
total = 0
model = [0]
for h in hist:
total += h
model.append(total)
return model
class Model:
# A model is a cumulative integer frequency table, starting at 0, and ending
# under MAX_FREQ.
def set_raw_model(self, model):
assert model[0] == 0
assert model[-1] <= MAX_FREQ
self.model = model
def set_model_from_hist(self, hist):
self.set_raw_model(hist_to_model(hist))
class Encoder(Model):
def __init__(self, io):
self.io = io
self.lo = 0
self.hi = FP_MAX
self.opposite = 0
self.buf = 0
self.bits_in_buf = 0
self.total_bits = 0
def encode(self, symbol):
assert symbol < len(self.model)
r = self.hi - self.lo + 1
self.hi = self.lo + r*self.model[symbol + 1]//self.model[-1] - 1
self.lo = self.lo + r*self.model[symbol]//self.model[-1]
while True:
if self.hi < FP_HALF:
self.output_and_opposite(0)
elif self.lo >= FP_HALF:
self.output_and_opposite(1)
self.lo -= FP_HALF
self.hi -= FP_HALF
elif self.lo >= FP_QUARTER and self.hi < FP_3QUARTER:
self.opposite += 1
self.lo -= FP_QUARTER
self.hi -= FP_QUARTER
else:
break
self.lo = 2*self.lo
self.hi = 2*self.hi + 1
def finish(self):
self.opposite += 1
self.output_and_opposite(self.lo >= FP_QUARTER)
if self.bits_in_buf:
self.io.write(bytes([self.buf << (8 - self.bits_in_buf)]))
def output_bit(self, bit):
self.buf <<= 1
self.buf |= bit
self.bits_in_buf += 1
self.total_bits += 1
if self.bits_in_buf == 8:
self.io.write(bytes([self.buf]))
self.bits_in_buf = 0
self.buf = 0
def output_and_opposite(self, bit):
self.output_bit(bit)
while self.opposite:
self.output_bit(not bit)
self.opposite -= 1
class Decoder(Model):
def __init__(self, io):
self.io = io
self.lo = 0
self.hi = FP_MAX
self.value = 0
self.buf = 0
self.bits_in_buf = 0
self.garbage_bits = 0
for _ in range(CODE_VALUE_BITS):
self.value = 2*self.value + self.input_bit()
def decode(self):
r = self.hi - self.lo + 1
cumulative = ((self.value - self.lo + 1)*self.model[-1] - 1)//r
symbol = bisect.bisect(self.model, cumulative) - 1
self.hi = self.lo + r*self.model[symbol + 1]//self.model[-1] - 1
self.lo = self.lo + r*self.model[symbol]//self.model[-1]
while True:
if self.hi < FP_HALF:
pass
elif self.lo >= FP_HALF:
self.value -= FP_HALF
self.lo -= FP_HALF
self.hi -= FP_HALF
elif self.lo >= FP_QUARTER and self.hi < FP_3QUARTER:
self.value -= FP_QUARTER
self.lo -= FP_QUARTER
self.hi -= FP_QUARTER
else:
break
self.lo = 2*self.lo
self.hi = 2*self.hi + 1
self.value = 2*self.value + self.input_bit()
return symbol
def input_bit(self):
if self.bits_in_buf == 0:
byte = self.io.read(1)
if not byte:
self.garbage_bits += 1
if self.garbage_bits > CODE_VALUE_BITS - 2:
raise RuntimeError("more input expected in arithmetic decoder")
return 0
self.bits_in_buf = 8
self.buf = byte[0]
r = (self.buf >> 7) & 1
self.buf <<= 1
self.bits_in_buf -= 1
return r
class AdaptiveModel:
def __init__(self, init_hist):
self.hist = init_hist[:]
self.hist_total = sum(self.hist)
self.set_model_from_hist(self.hist)
def encounter_symbol(self, symbol):
self.hist[symbol] += 1
self.hist_total += 1
if self.hist_total >= MAX_FREQ:
self.hist = [(1 + h//2 if h else 0) for h in self.hist]
self.hist_total = sum(self.hist)
self.set_model_from_hist(self.hist)
class SimpleAdaptiveEncoder(Encoder, AdaptiveModel):
def __init__(self, io, init_hist=[1]*256 + [8]):
Encoder.__init__(self, io)
AdaptiveModel.__init__(self, init_hist)
def encode(self, symbol):
Encoder.encode(self, symbol)
self.encounter_symbol(symbol)
def finish(self):
Encoder.encode(self, len(self.hist) - 1)
Encoder.finish(self)
class SimpleAdaptiveDecoder(Decoder, AdaptiveModel):
def __init__(self, io, init_hist=[1]*256 + [8]):
Decoder.__init__(self, io)
AdaptiveModel.__init__(self, init_hist)
def decode(self):
symbol = Decoder.decode(self)
self.encounter_symbol(symbol)
return symbol
# Examples:
enc = Encoder(io.BytesIO())
hist = [1, 2, 8, 1]
enc.set_model_from_hist(hist)
i = [2, 2, 2, 2, 2, 2, 1, 0, 2, 1, 2, 2, 2, 2]
for s in i:
enc.encode(s)
enc.finish()
dec = Decoder(io.BytesIO(enc.io.getvalue()))
dec.set_model_from_hist(hist)
assert i == [dec.decode() for _ in range(len(i))]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment