Skip to content

Instantly share code, notes, and snippets.

@DavidBuchanan314
Last active May 17, 2022 02:50
Show Gist options
  • Star 14 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save DavidBuchanan314/fe7d87548332a34991f7b258962a845d to your computer and use it in GitHub Desktop.
Save DavidBuchanan314/fe7d87548332a34991f7b258962a845d to your computer and use it in GitHub Desktop.
import zlib
import io
import sys
PNG_MAGIC = b"\x89PNG\r\n\x1a\n"
def parse_png_chunk(stream):
size = int.from_bytes(stream.read(4), "big")
ctype = stream.read(4)
body = stream.read(size)
csum = int.from_bytes(stream.read(4), "big")
assert(zlib.crc32(ctype + body) == csum)
return ctype, body
def parse_png(stream):
magic = stream.read(len(PNG_MAGIC))
assert(magic == PNG_MAGIC)
idat = b""
while True:
ctype, body = parse_png_chunk(stream)
if ctype == b"IEND":
break
if ctype == b"IDAT":
idat += body
if ctype == b"IHDR":
ihdr = body
return ihdr, idat[2:-4] # strip zlib
def decompress(raw):
d = zlib.decompressobj(wbits=-15)
return d.decompress(raw) + d.flush(zlib.Z_FINISH)
# TODO: implement rabin-karp algorithm
# current implementation is very slow!!! (and bad!!!)
class BackrefFinder():
def __init__(self, window_size=2**15):
self.window_size = window_size
self.buf = b""
def feed(self, data):
self.buf += data
def find(self, lookahead):
window = self.buf[-self.window_size:]
if not window:
return 0, None
if len(lookahead) < 3:
return 0, None
x = -1
longest = 0
longest_dist = None
try:
while True:
x = (window+lookahead[:2]).index(lookahead[:3], x+1)
i = None
for i in range(x+3, x+258+1):
if (i - x) >= len(lookahead): # can't look ahead any further
break
if i < len(window): # look back within window
if window[i] != lookahead[i - x]:
break
else: # look back wihthin lookahead (e.g. for RLE)
if lookahead[i-len(window)] != lookahead[i - x]:
break
if i - x >= longest:
longest = i - x
longest_dist = len(window) - x
except ValueError: # I wish there was a better way to handle .index() failing...
return longest, longest_dist
LENGTHS = [
(0, 3), # 257
(0, 4),
(0, 5),
(0, 6),
(0, 7),
(0, 8),
(0, 9),
(0, 10),
(1, 11),
(1, 13), # 266
(1, 15), # 267
(1, 17),
(2, 19),
(2, 23),
(2, 27),
(2, 31),
(3, 35),
(3, 43),
(3, 51),
(3, 59), # 276
(4, 67), # 277
(4, 83),
(4, 99),
(4, 115),
(5, 131),
(5, 163),
(5, 195),
(5, 227),
(0, 258), # 285
(None, 259), # does not exist, here to make some code neater...
]
DISTANCES = [
(0, 1),
(0, 2),
(0, 3),
(0, 4),
(1, 5),
(1, 7),
(2, 9),
(2, 13),
(3, 17),
(3, 25),
(4, 33),
(4, 49),
(5, 65),
(5, 97),
(6, 129),
(6, 193),
(7, 257),
(7, 385),
(8, 513),
(8, 769),
(9, 1025),
(9, 1537),
(10, 2049),
(10, 3073),
(11, 4097),
(11, 6145),
(12, 8193),
(12, 12289),
(13, 16385),
(13, 24577),
(None, 32769), # does not exist, here to make some code neater...
]
class Decompressor():
def __init__(self, stream, original):
self.stream = stream
self.orig = original # original decompressed data
self.byte = None
self.prevbit = 7
self.buf = b""
self.bf = BackrefFinder()
self.steg_bytes = b""
self.steg_bit = 0
self.steg_byte = 0
def next_bit(self):
if self.prevbit == 7:
self.prevbit = 0
self.byte = self.stream.read(1)[0]
else:
self.prevbit += 1
bit = (self.byte >> self.prevbit) & 1
return bit
def write_steg_bit(self, bit):
self.steg_byte |= bit << self.steg_bit
self.steg_bit += 1
if self.steg_bit == 8:
self.steg_bit = 0
self.steg_bytes += bytes([self.steg_byte])
self.steg_byte = 0
def read_data_element(self, nbits):
value = 0
for i in range(nbits):
value |= self.next_bit() << i
return value
def read_huffman_bits(self, nbits, prefix=0):
value = prefix
for _ in range(nbits):
value = (value << 1) | self.next_bit()
return value
def read_huffman_symbol(self):
preview = self.read_huffman_bits(5)
if 0b00110 <= preview <= 0b10111:
return self.read_huffman_bits(3, preview) - 0b0011_0000 + 0
elif 0b11001 <= preview <= 0b11111:
return self.read_huffman_bits(4, preview) - 0b1_1001_0000 + 144
elif 0b00000 <= preview <= 0b00101:
return self.read_huffman_bits(2, preview) - 0b000_0000 + 256
else:
return self.read_huffman_bits(3, preview) - 0b1100_0000 + 280
def read_fixed_huffman_block(self):
symbols = []
while True:
symbol = self.read_huffman_symbol()
if symbol < 0x100:
symbols.append(("lit", symbol))
elif symbol == 0x100:
break
else:
ebits, length = LENGTHS[symbol-257]
length += self.read_data_element(ebits)
ebits, distance = DISTANCES[self.read_huffman_bits(5)]
distance += self.read_data_element(ebits)
symbols.append(("ref", length, distance))
return symbols
def try_recover_bits(self, actual_len, actual_dist, optimal_len, optimal_dist):
if optimal_len == 0:
return # nothing
if optimal_len < 6:
self.write_steg_bit((optimal_len != actual_len) & 1)
else:
delta = optimal_len - actual_len
self.write_steg_bit(delta >> 1)
self.write_steg_bit(delta & 1)
def process_symbols(self, symbols):
for symbol in symbols:
longest, longest_dist = self.bf.find(self.orig[len(self.buf):])
if symbol[0] == "lit":
new = bytes([symbol[1]])
self.bf.feed(new)
self.buf += new
self.try_recover_bits(0, None, longest, longest_dist)
elif symbol[0] == "ref":
length, distance = symbol[1:3]
self.try_recover_bits(length, distance, longest, longest_dist)
news = b""
for _ in range(length):
new = bytes([self.buf[-distance]])
news += new
self.buf += new
self.bf.feed(news)
else:
raise Exception("unexpected symbol")
def read_block(self):
bfinal = self.next_bit()
btype = self.read_data_element(2)
if btype == 0b00:
while self.prevbit != 7:
assert(self.next_bit() == 0)
size = self.read_data_element(16)
notsize = self.read_data_element(16)
assert(size == notsize ^ 0xffff)
self.buf += self.stream.read(size)
elif btype == 0b01:
self.process_symbols(self.read_fixed_huffman_block())
elif btype == 0b10:
raise Exception("not implemented")
#self.process_symbols(self.read_dynamic_huffman_block())
else:
raise Exception("not implemented")
return bfinal
def decompress(self):
while self.read_block() != 1:
pass
return self.buf
def steg_unpack(image):
ihdr, idat = parse_png(image)
d = Decompressor(io.BytesIO(idat), decompress(idat))
d.decompress()
return d.steg_bytes
if len(sys.argv) != 3:
print(f"USAGE: {sys.argv[0]} input.png output.whatever")
exit()
open(sys.argv[2], "wb").write(steg_unpack(open(sys.argv[1], "rb")))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment