Skip to content

Instantly share code, notes, and snippets.

@juliusgeo
Created July 10, 2023 23:45
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save juliusgeo/d4965b16a3c4478bb4eca2fe210559eb to your computer and use it in GitHub Desktop.
Save juliusgeo/d4965b16a3c4478bb4eca2fe210559eb to your computer and use it in GitHub Desktop.
Bits and Pieces: zlib compliant DEFLATE from scratch

Bits and Pieces: zlib compliant DEFLATE from scratch

zlib underlies most zip file decompressors, and DEFLATE is one the binary formats used to store compressed data in a bitstream. The goal of this article is to walk through how my Python DEFLATE compressor implementation works. There are many guides on the internet that describe how to implement each step of DEFLATE, but very few end up producing a bitstream that can actually be parsed by a library like zlib. This article assumes that you roughly know how each step of the DEFLATE algorithm is implemented, but are having trouble with some of the finer points that are often glossed over.

The code can be found in "deflate.py".

LZ77 Compression

This is probably the easiest part because there are many examples of "correct" LZ77 compression. Make sure, however, that the default parameters will align with the parameters that are given to zlib at the end. A good implementation that you can use to check yours is here.

Huffman Coding

After LZ77 compression, you will have to convert those tokens into a bitstream. In my implementation, to keep things simple, I used Fixed Huffman Coding, which means that the trees are defined in RFC 1951.

Huffman coding uses two different alphabets--one for both literals and length codes, and one for distance codes. Depending on the value of the literal/length, it will be encoded with a set number of bits. For literals, that is it. For lengths, one must also add on extra bits if necessary. This is because both the length and distance codes have a set number of initial bits (7 and 5, respectively), but also an additional set of bits that determine the offset from the base length/distance. The number of extra bits is determined by the size of the prefix.

def tokens_to_stream(compressed):
    it = "110"
    for tok in compressed:
        if tok.length <= 1:
            # Write a literal
            code, shift = huff_codes(tok.indicator)
            it += f"{code:0{shift}b}"
        else:
            # Length/distance pair, write `nbits` bits with value `ebits` in reverse order
            code, ebits, nbits = length_code(tok.length)
            it += f"{code:07b}"[-7:]
            if nbits >= 1:
                it += f"{ebits:0{nbits}b}"[-nbits:][
                    ::-1
                ]

            code, ebits, nbits = distance_code(tok.offset)
            it += f"{code:05b}"[-5:]
            if nbits >= 1:
                it += f"{ebits:0{nbits}b}"[-nbits:][
                    ::-1
                ]
    # Pad to byte boundary, add terminating byte
    return b"".join(
        [
            int(it[i:i+8][::-1], 2).to_bytes(1, byteorder="big", signed=False)
            for i in range(0, len(it), 8)
        ]
    )+b"\x00"

Some important details that are often glossed over:

  • You must always write out nbits bits, even if the value of those bits is 0.
  • Any extra bits are written in reverse order, and should be truncated to the minimum number of possible bits.
  • Double check that your distance and length codes produce the correct number of bits, etc, for all possible inputs. In my case, I just reversed the tables shown in the RFC, and used that.
  • The distance alphabet is a distinct alphabet and has very little relation to the length alphabet except in its construction (base + extra bits).
  • Make sure that the output is byte-aligned and ends with a null byte.

Making the zlib Stream

Everything covered so far results in a valid DEFLATE stream, but to make it understandable to zlib in a mode besides raw, you need to add a header and trailer. The header is two bytes that you can essentially just look up, but the trailer is an adler32 checksum.

def adler32(data):
    a1,a2=1,0
    for b in data:
        a1=(a1+b) % 65521
        a2=(a2+a1) % 65521
    return a2.to_bytes(2, byteorder="big")+a1.to_bytes(2, byteorder="big")


bitstream_bytes = b"\x78\x01" + bitstream_bytes + adler32(strk.encode("ascii"))

The header is a magic number, and the adler32 function is a simple checksum calculation, but make sure that it is calculated based on the uncompressed data. Furthermore, it must be in big-endian byte-order. Another source of some confusion is the parameters that must be given to zlib. To test your raw deflate stream, attempt to decompress with wbits=-15. Once adding the zlib header and footer, you must switch to wbits=15.

Conclusion

All together, it forms a valid DEFLATE stream that can be decoded by zlib:

strk = """
"Did you win your sword fight?"
"Of course I won the fucking sword fight," Hiro says. "I'm the greatest sword fighter in the world."
"And you wrote the software."
"Yeah. That, too," Hiro says."
"""
compressed = compress(strk)
bitstream_bytes = tokens_to_stream(compressed)

# Add `zlib` header and footer. Checksum is calculated on the uncompressed original
bitstream_bytes = b"\x78\x01" + bitstream_bytes + adler32(strk.encode("ascii"))
import zlib
decompressed_data = zlib.decompress(bitstream_bytes, wbits=15)
assert decompressed_data == strk.encode("ascii")
print(str(decompressed_data, "ascii"))
"Did you win your sword fight?"
"Of course I won the fucking sword fight," Hiro says. "I'm the greatest sword fighter in the world."
"And you wrote the software."
"Yeah. That, too," Hiro says."

Further reading:

from collections import namedtuple
Token = namedtuple("Token", ["offset", "length", "indicator"])
def compress(input_string, max_offset=2047, max_length=31):
input_array = str(input_string[:])
window, output= "", []
while input_array != "":
length, offset = blo(window,input_array,max_length,max_offset)
output.append(Token(offset, length, input_array[0]))
window += input_array[:length]
input_array = input_array[length:]
return output
def blo(window,input_string,max_length=15,max_offset=4095):
if input_string is None or input_string == "":
return 0, 0
cut_window = window[-max_offset:] if max_offset < len(window) else window
if input_string[0] not in cut_window:
best_length = rl_fs(input_string[0],input_string[1:])
return (min((1 + best_length), max_length), 0)
length, offset,max_length = 0, 0, min(max_length, len(input_string))
for index in range(1, (len(cut_window) + 1)):
if cut_window[-index] == input_string[0] and \
(found_length := rl_fs(cut_window[-index:], input_string)) > length:
length, offset = found_length,index
return (min(length, max_length), offset) if found_length > 2 else (1, 0)
def rl_fs(window,input_string):
return 1 + rl_fs(window[1:] + input_string[0], input_string[1:]) if window and input_string and window[0] == input_string[0] else 0
def decompress(compressed):
output = ""
for value in compressed:
offset, length, char = value
if length == 0:
if char is not None:
output += char
else:
if offset == 0:
if char is not None:
output += char
length -= 1
offset = 1
start_index = len(output) - offset
for i in range(length):
output += output[start_index + i]
return output
#print(decompress(compressed))
def huff_codes(val):
if isinstance(val, str): # Literal byte
val = ord(val)
if val < 144:
return int(bin(val + 0b00110000), 2), 8
elif val < 257:
return int(bin(val - 144 + 0b110010000), 2), 9
elif val < 280:
return int(bin(val - 257 + 0b0010100), 2), 7
elif val < 288:
return int(bin(val - 280 + 0b11000000), 2), 8
else:
raise ValueError("Value out of range")
def length_code(n):
if n <= 2:
return n, 0, 0
if n <= 10:
return 254 + n, 0, 0
elif n <= 18:
return 265 + (n - 11) // 2, (n - 11) % 2, 1
elif n <= 34:
return 269 + (n - 19) // 4, (n - 19) % 4, 2
elif n <= 66:
return 273 + (n - 35) // 8, (n - 35) % 8, 3
elif n <= 130:
return 277 + (n - 67) // 16, (n - 67) % 16, 4
elif n < 258:
return 281 + (n - 131) // 32, (n - 131) % 32, 5
elif n == 258:
return 285, 0, 0
else:
raise ValueError("Invalid length")
length_test_cases = [
((3, 3), 257, 0),((4, 4), 258, 0),((5, 5), 259, 0),((6, 6), 260, 0),((7, 7), 261, 0),((8, 8), 262, 0),((9, 9), 263, 0),((10, 10), 264, 0),((11, 12), 265, 1),((13, 14), 266, 1),((15, 16), 267, 1),((17, 18), 268, 1),((19, 22), 269, 2),((23, 26), 270, 2),((27, 30), 271, 2),((31, 34), 272, 2),((35, 42), 273, 3),((43, 50), 274, 3),((51, 58), 275, 3),((59, 66), 276, 3),((67, 82), 277, 4),((83, 98), 278, 4),((99, 114), 279, 4),((115, 130), 280, 4),((131, 162), 281, 5),((163, 194), 282, 5),((195, 226), 283, 5),((227, 257), 284, 5),((258, 258), 285, 0),
]
for val in sorted(length_test_cases):
for i in range(val[0][0], val[0][1] + 1):
a, _, b = length_code(i)
_, c, d = val
assert (a, b) == (c, d), f"Failed {i} for {a},{b} != {(c, d)}"
def distance_code(n):
if n <= 4:
return n - 1, 0, 0
elif n <= 8:
return (n - 5) // 2 + 4, (n - 5), 1
elif n <= 16:
return (n - 9) // 4 + 6, (n - 9), 2
elif n <= 32:
return (n - 17) // 8 + 8, (n - 17), 3
elif n <= 64:
return (n - 33) // 16 + 10, (n - 33), 4
elif n <= 128:
return (n - 65) // 32 + 12, (n - 65), 5
elif n <= 256:
return (n - 129) // 64 + 14, (n - 129), 6
elif n <= 512:
return (n - 257) // 128 + 16, (n - 257), 7
elif n <= 1024:
return (n - 513) // 256 + 18, (n - 513), 8
elif n <= 2048:
return (n - 1025) // 512 + 20, (n - 1025), 9
elif n <= 4096:
return (n - 2049) // 1024 + 22, (n - 2049), 10
elif n <= 8192:
return (n - 4097) // 2048 + 24, (n - 4097), 11
elif n <= 16384:
return (n - 8193) // 4096 + 26, (n - 8193), 12
elif n <= 32768:
return (n - 16385) // 8192 + 28, (n - 16385), 13
else:
raise ValueError("Invalid distance")
test_dist = {
((1, 1),0,0,
),((2, 2),1,0,
),((3, 3),2,0,
),((4, 4),3,0,
),((5, 6),4,1,
),((7, 8),5,1,
),((9, 12),6,2,
),((13, 16),7,2,
),((17, 24),8,3,
),((25, 32),9,3,
),((33, 48),10,4,
),((49, 64),11,4,
),((65, 96),12,5,
),((97, 128),13,5,
),((129, 192),14,6,
),((193, 256),15,6,
),((257, 384),16,7,
),((385, 512),17,7,
),((513, 768),18,8,
),((769, 1024),19,8,
),((1025, 1536),20,9,
),((1537, 2048),21,9,
),((2049, 3072),22,10,
),((3073, 4096),23,10,
),((4097, 6144),24,11,
),((6145, 8192),25,11,
),((8193, 12288),26,12,
),((12289, 16384),27,12,),
}
for val in sorted(test_dist):
for i in range(val[0][0], val[0][1] + 1):
a, extra, b = distance_code(i)
_, c, d = val
assert (a, b) == (c, d), f"Failed {i} for {a},{b} != {(c, d)}"
def tokens_to_stream(compressed):
it = "110"
for tok in compressed:
if tok.length <= 1:
# Write a literal
code, shift = huff_codes(tok.indicator)
it += f"{code:0{shift}b}"
else:
# Length/distance pair, write `nbits` bits with value `ebits` in reverse order
code, ebits, nbits = length_code(tok.length)
it += f"{code:07b}"[-7:]
if nbits >= 1:
it += f"{ebits:0{nbits}b}"[-nbits:][
::-1
]
code, ebits, nbits = distance_code(tok.offset)
it += f"{code:05b}"[-5:]
if nbits >= 1:
it += f"{ebits:0{nbits}b}"[-nbits:][
::-1
]
# Pad to byte boundary, add terminating byte
return b"".join(
[
int(it[i:i+8][::-1], 2).to_bytes(1, byteorder="big", signed=False)
for i in range(0, len(it), 8)
]
)+b"\x00"
def adler32(data):
a1,a2=1,0
for b in data:
a1=(a1+b) % 65521
a2=(a2+a1) % 65521
return a2.to_bytes(2, byteorder="big")+a1.to_bytes(2, byteorder="big")
strk = """
"Did you win your sword fight?"
"Of course I won the fucking sword fight," Hiro says. "I'm the greatest sword fighter in the world."
"And you wrote the software."
"Yeah. That, too," Hiro says."
"""
compressed = compress(strk)
bitstream_bytes = tokens_to_stream(compressed)
# Add `zlib` header and footer. Checksum is calculated on the uncompressed original
bitstream_bytes = b"\x78\x01" + bitstream_bytes + adler32(strk.encode("ascii"))
import zlib
decompressed_data = zlib.decompress(bitstream_bytes, wbits=15)
assert decompressed_data == strk.encode("ascii")
print(str(decompressed_data, "ascii"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment