Skip to content

Instantly share code, notes, and snippets.

@robey
Created February 3, 2020 18:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save robey/0b748439a5fa76a9c763bd9f04b15c35 to your computer and use it in GitHub Desktop.
Save robey/0b748439a5fa76a9c763bd9f04b15c35 to your computer and use it in GitHub Desktop.
python snappy test code, to determine how effective different lookback window sizes are
#!/usr/bin/env python3
import argparse
import array
import itertools
import struct
import sys
from typing import Iterator, List
__title__ = "kindasnappy"
__description__ = "generate snappy-compressed stdout from stdin with configurable window size"
__version__ = "1.0"
# tests with various window sizes (8 - 20):
# original 1286792
# test8 1071334
# test9 1042274
# test10 1019489
# test11 998654
# test12 992215
# test13 986921
# test14 980365
# test15 968564
# test16 960698
# test18 962707
# test20 963237
HASH_MULT = 0x1e35a7bd
MAX_TABLE_BITS = 20
def next_power_of_two(n: int) -> int:
n -= 1
n |= (n >> 1)
n |= (n >> 2)
n |= (n >> 4)
n |= (n >> 8)
n |= (n >> 16)
return n + 1
def emit_varint(n: int) -> bytes:
varint = []
while n > 0x7f:
varint.append((n & 0x7f) | 0x80)
n >>= 7
varint.append(n)
return bytes(varint)
def emit_literal(b: bytes) -> List[bytes]:
n = len(b) - 1
tag: bytes
if n < 60:
tag = bytes([ n << 2 ])
elif n < (1 << 8):
tag = bytes([ 60 << 2, n ])
elif n < (1 << 16):
tag = bytes([ 61 << 2, n & 0xff, n >> 8 ])
elif n < (1 << 24):
tag = bytes([ 62 << 2, n & 0xff, (n >> 8) & 0xff, n >> 16 ])
else:
tag = bytes([ 63 << 2, n & 0xff, (n >> 8) & 0xff, (n >> 16) & 0xff, n >> 24 ])
return [ tag, b ]
def emit_copy(offset: int, length: int) -> List[bytes]:
rv = []
while length > 0:
n = min(length, 64)
if n >= 4 and n < 12 and offset < 2048:
rv.append(bytes([ 1 | ((n - 4) << 2) | (offset & 0x700) >> 3, offset & 0xff ]))
elif offset < (1 << 16):
rv.append(bytes([ 2 | ((n - 1) << 2), offset & 0xff, offset >> 8 ]))
else:
# the "official" snappy will never emit this code, because it
# uses a 15-bit lookback. i've done a few quick tests and it's
# never actually worth using: the encoding is too bloaty to do
# much good. it's better to look for closer matches.
rv.append(bytes([ 3 | ((n - 1) << 2), offset & 0xff, (offset >> 8) & 0xff, (offset >> 16) & 0xff, offset >> 24 ]))
length -= n
return rv
def compress(data: bytes, max_lookback_bits: int) -> Iterator[bytes]:
yield emit_varint(len(data))
# edge cases
if len(data) == 0:
return
if len(data) <= 4:
yield from emit_literal(data)
return
# hashtable should be much bigger than the lookup limit.
# table uses 0 to mean "none" and stores absolute offset + 1.
table_bits = min(MAX_TABLE_BITS, max(len(data).bit_length(), max_lookback_bits + 2))
table_size = 1 << table_bits
table_mask = table_size - 1
table_shift = 32 - table_bits
table = array.array("I", itertools.repeat(0, table_size))
max_lookback = 1 << max_lookback_bits
i = 0
end = len(data)
start = 0
# look for 4-byte sequences we've seen before
while i + 3 < end:
chunk = data[i : i + 4]
bucket = ((struct.unpack("I", chunk)[0] * HASH_MULT) >> table_shift) & table_mask
last_offset = table[bucket] - 1
last_chunk = data[last_offset : last_offset + 4] if last_offset >= 0 else b""
table[bucket] = i + 1
if last_offset < 0 or i - last_offset > max_lookback or last_chunk != chunk:
i += 1
continue
# find the extent of the duplicated run
copy_start = i
i += 4
last_offset += 4
while i < end and data[i] == data[last_offset]:
i += 1
last_offset += 1
if i - last_offset > (1 << 16) and i - copy_start < 6:
# long offsets take 5 bytes to encode, so this isn't worth emiting.
i = copy_start + 1
continue
# emit any pending literal
if start != copy_start:
yield from emit_literal(data[start : copy_start])
yield from emit_copy(i - last_offset, i - copy_start)
start = i
yield from emit_literal(data[start:])
def main() -> None:
parser = argparse.ArgumentParser(description = "{} ({}): {}".format(__title__, __version__, __description__))
parser.add_argument("-w", "--window", metavar = "BITS", type = int, help = "set window size, in bits (8 - 16)", default = 20)
args = parser.parse_intermixed_args()
if args.window < 8 or args.window > 16:
sys.stderr.write("window must be 8 to 16 inclusive\n")
sys.exit(1)
data = sys.stdin.buffer.read()
sys.stdout.buffer.write(b"".join(list(compress(data, args.window))))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment