Created
May 22, 2016 17:46
-
-
Save orlp/255a2a84b6023434b6662c5c0d4bc243 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import bisect | |
import io | |
# We use 17-bit internal values with 15-bit probabilities. | |
CODE_VALUE_BITS = 17 | |
FP_MAX = (1 << CODE_VALUE_BITS) - 1 | |
FP_QUARTER = FP_MAX//4 + 1 | |
FP_HALF = 2*FP_QUARTER | |
FP_3QUARTER = 3*FP_QUARTER | |
FREQ_BITS = CODE_VALUE_BITS - 2 | |
MAX_FREQ = (1 << FREQ_BITS) - 1 | |
# Histogram to cumulative probability model. | |
def hist_to_model(hist): | |
total = 0 | |
model = [0] | |
for h in hist: | |
total += h | |
model.append(total) | |
return model | |
class Model: | |
# A model is a cumulative integer frequency table, starting at 0, and ending | |
# under MAX_FREQ. | |
def set_raw_model(self, model): | |
assert model[0] == 0 | |
assert model[-1] <= MAX_FREQ | |
self.model = model | |
def set_model_from_hist(self, hist): | |
self.set_raw_model(hist_to_model(hist)) | |
class Encoder(Model): | |
def __init__(self, io): | |
self.io = io | |
self.lo = 0 | |
self.hi = FP_MAX | |
self.opposite = 0 | |
self.buf = 0 | |
self.bits_in_buf = 0 | |
self.total_bits = 0 | |
def encode(self, symbol): | |
assert symbol < len(self.model) | |
r = self.hi - self.lo + 1 | |
self.hi = self.lo + r*self.model[symbol + 1]//self.model[-1] - 1 | |
self.lo = self.lo + r*self.model[symbol]//self.model[-1] | |
while True: | |
if self.hi < FP_HALF: | |
self.output_and_opposite(0) | |
elif self.lo >= FP_HALF: | |
self.output_and_opposite(1) | |
self.lo -= FP_HALF | |
self.hi -= FP_HALF | |
elif self.lo >= FP_QUARTER and self.hi < FP_3QUARTER: | |
self.opposite += 1 | |
self.lo -= FP_QUARTER | |
self.hi -= FP_QUARTER | |
else: | |
break | |
self.lo = 2*self.lo | |
self.hi = 2*self.hi + 1 | |
def finish(self): | |
self.opposite += 1 | |
self.output_and_opposite(self.lo >= FP_QUARTER) | |
if self.bits_in_buf: | |
self.io.write(bytes([self.buf << (8 - self.bits_in_buf)])) | |
def output_bit(self, bit): | |
self.buf <<= 1 | |
self.buf |= bit | |
self.bits_in_buf += 1 | |
self.total_bits += 1 | |
if self.bits_in_buf == 8: | |
self.io.write(bytes([self.buf])) | |
self.bits_in_buf = 0 | |
self.buf = 0 | |
def output_and_opposite(self, bit): | |
self.output_bit(bit) | |
while self.opposite: | |
self.output_bit(not bit) | |
self.opposite -= 1 | |
class Decoder(Model): | |
def __init__(self, io): | |
self.io = io | |
self.lo = 0 | |
self.hi = FP_MAX | |
self.value = 0 | |
self.buf = 0 | |
self.bits_in_buf = 0 | |
self.garbage_bits = 0 | |
for _ in range(CODE_VALUE_BITS): | |
self.value = 2*self.value + self.input_bit() | |
def decode(self): | |
r = self.hi - self.lo + 1 | |
cumulative = ((self.value - self.lo + 1)*self.model[-1] - 1)//r | |
symbol = bisect.bisect(self.model, cumulative) - 1 | |
self.hi = self.lo + r*self.model[symbol + 1]//self.model[-1] - 1 | |
self.lo = self.lo + r*self.model[symbol]//self.model[-1] | |
while True: | |
if self.hi < FP_HALF: | |
pass | |
elif self.lo >= FP_HALF: | |
self.value -= FP_HALF | |
self.lo -= FP_HALF | |
self.hi -= FP_HALF | |
elif self.lo >= FP_QUARTER and self.hi < FP_3QUARTER: | |
self.value -= FP_QUARTER | |
self.lo -= FP_QUARTER | |
self.hi -= FP_QUARTER | |
else: | |
break | |
self.lo = 2*self.lo | |
self.hi = 2*self.hi + 1 | |
self.value = 2*self.value + self.input_bit() | |
return symbol | |
def input_bit(self): | |
if self.bits_in_buf == 0: | |
byte = self.io.read(1) | |
if not byte: | |
self.garbage_bits += 1 | |
if self.garbage_bits > CODE_VALUE_BITS - 2: | |
raise RuntimeError("more input expected in arithmetic decoder") | |
return 0 | |
self.bits_in_buf = 8 | |
self.buf = byte[0] | |
r = (self.buf >> 7) & 1 | |
self.buf <<= 1 | |
self.bits_in_buf -= 1 | |
return r | |
class AdaptiveModel: | |
def __init__(self, init_hist): | |
self.hist = init_hist[:] | |
self.hist_total = sum(self.hist) | |
self.set_model_from_hist(self.hist) | |
def encounter_symbol(self, symbol): | |
self.hist[symbol] += 1 | |
self.hist_total += 1 | |
if self.hist_total >= MAX_FREQ: | |
self.hist = [(1 + h//2 if h else 0) for h in self.hist] | |
self.hist_total = sum(self.hist) | |
self.set_model_from_hist(self.hist) | |
class SimpleAdaptiveEncoder(Encoder, AdaptiveModel): | |
def __init__(self, io, init_hist=[1]*256 + [8]): | |
Encoder.__init__(self, io) | |
AdaptiveModel.__init__(self, init_hist) | |
def encode(self, symbol): | |
Encoder.encode(self, symbol) | |
self.encounter_symbol(symbol) | |
def finish(self): | |
Encoder.encode(self, len(self.hist) - 1) | |
Encoder.finish(self) | |
class SimpleAdaptiveDecoder(Decoder, AdaptiveModel): | |
def __init__(self, io, init_hist=[1]*256 + [8]): | |
Decoder.__init__(self, io) | |
AdaptiveModel.__init__(self, init_hist) | |
def decode(self): | |
symbol = Decoder.decode(self) | |
self.encounter_symbol(symbol) | |
return symbol | |
# Examples: | |
enc = Encoder(io.BytesIO()) | |
hist = [1, 2, 8, 1] | |
enc.set_model_from_hist(hist) | |
i = [2, 2, 2, 2, 2, 2, 1, 0, 2, 1, 2, 2, 2, 2] | |
for s in i: | |
enc.encode(s) | |
enc.finish() | |
dec = Decoder(io.BytesIO(enc.io.getvalue())) | |
dec.set_model_from_hist(hist) | |
assert i == [dec.decode() for _ in range(len(i))] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment