Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@lifthrasiir
Last active February 6, 2022 07:41
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lifthrasiir/5c24058f21ce6fba231cf1bfba45bf28 to your computer and use it in GitHub Desktop.
Save lifthrasiir/5c24058f21ce6fba231cf1bfba45bf28 to your computer and use it in GitHub Desktop.
Super-experimental PNG recompressor with JPEG XL and reconstruction support
#!/usr/bin/env python3
# jxl-preflate.py - Experimental reconstructable PNG recompressor to JXL
# Kang Seonghoon, 2021-07-18, Public Domain.
import sys
import os.path
import tempfile
import subprocess
import struct
import zlib
import hashlib
import time
import contextlib
DEBUG_LEVEL = 0
PNG_SIGNATURE = b'\x89PNG\x0d\x0a\x1a\x0a'
POSSIBLE_CT_BPP = [
(0, 1), (0, 2), (0, 4), (0, 8), (0, 16),
(2, 8), (2, 16),
(3, 1), (3, 2), (3, 4), (3, 8),
(4, 8), (4, 16),
(6, 8), (6, 16),
]
PASSES = [
(0, 0, 8, 8),
(0, 4, 8, 8),
(4, 0, 8, 4),
(0, 2, 4, 4),
(2, 0, 4, 2),
(0, 1, 2, 2),
(1, 0, 2, 1),
]
CJXL_PATH = 'cjxl'
DJXL_PATH = 'djxl'
PREFLATE_PATH = 'preflate_demo'
BROTLI_PATH = 'brotli'
@contextlib.contextmanager
def timing(operation):
print('(%s starting...)' % operation, file=sys.stderr)
start = time.monotonic()
try:
yield
finally:
elapsed = time.monotonic() - start
print('(%s finished in %.1f secs.)' % (operation, elapsed), file=sys.stderr)
def cjxl(data, opts=[]):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
f.write(data)
f.flush()
outname = f.name + '.jxl'
try:
subprocess.run([CJXL_PATH, '-q', '100', *opts, f.name, outname], check=True)
with open(outname, 'rb') as fout: return fout.read()
finally:
try: os.unlink(outname)
except Exception: pass
def djxl(data, opts=[]):
with tempfile.NamedTemporaryFile(suffix='.jxl') as f:
f.write(data)
f.flush()
outname = f.name + '.png'
try:
subprocess.run([DJXL_PATH, *opts, f.name, outname], check=True)
with open(outname, 'rb') as fout: return fout.read()
finally:
try: os.unlink(outname)
except Exception: pass
def preflate_split(data):
with tempfile.NamedTemporaryFile() as f:
f.write(data)
f.flush()
uname = f.name + '.u'
rname = f.name + '.r'
try:
subprocess.run([PREFLATE_PATH, '-s', f.name], check=True)
with open(uname, 'rb') as fout: udata = fout.read()
with open(rname, 'rb') as fout: rdata = fout.read()
return udata, rdata
finally:
try: os.unlink(uname)
except Exception: pass
try: os.unlink(rname)
except Exception: pass
def preflate_join(udata, rdata):
with tempfile.NamedTemporaryFile(suffix='.u') as f:
f.write(udata)
f.flush()
rname = f.name[:-2] + '.r'
xname = f.name[:-2] + '.x'
with open(rname, 'wb') as fin: fin.write(rdata)
try:
subprocess.run([PREFLATE_PATH, '-x', f.name[:-2]], check=True)
with open(xname, 'rb') as fout: return fout.read()
finally:
try: os.unlink(rname)
except Exception: pass
try: os.unlink(xname)
except Exception: pass
def hash(data): # only for debugging, so anything works
return hashlib.sha256(data).hexdigest()
def brotli(data, opts=[]):
with subprocess.Popen([BROTLI_PATH, '-c', *opts, '-'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) as p:
stdout, _ = p.communicate(data)
assert p.returncode == 0, 'brotli -c returned error'
return stdout
def unbrotli(data):
with subprocess.Popen([BROTLI_PATH, '-d', '-'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) as p:
stdout, _ = p.communicate(data)
assert p.returncode == 0, 'brotli -d returned error'
return stdout
def bits_per_sample(ct, bpp):
return bpp * ((3 if (ct & 3) == 2 else 1) + (1 if ct & 4 else 0))
# returns a list of (scanline length in bytes, padding mask, pass info)
# scanline length might be 0, in which case previous scanline has to be discarded
# pass info is (row, column start in pixels, column increment)
def scanlines(width, height, ct, bpp, im):
bps = bits_per_sample(ct, bpp)
if im == 0:
scanbits = width * bps
scanbytes = (scanbits + 7) // 8
padmask = (1 << (8 - scanbits % 8)) - 1 if scanbits % 8 > 0 else 0
return [(scanbytes, padmask, (r, 0, 1)) for r in range(height)]
elif im == 1:
scanlines = []
for rstart, cstart, rinc, cinc in PASSES:
passwidth = (width - cstart + cinc - 1) // cinc
passheight = (height - rstart + rinc - 1) // rinc
if passwidth == 0 or passheight == 0: continue # empty scanline is not recorded
if scanlines: scanlines.append((0, 0, None))
scanbits = passwidth * bps
scanbytes = (scanbits + 7) // 8
padmask = (1 << (8 - scanbits % 8)) - 1 if scanbits % 8 > 0 else 0
scanlines.extend((scanbytes, padmask, (r, cstart, cinc)) for r in range(rstart, height, rinc))
return scanlines
else:
assert False, 'bad im'
def paeth(a, b, c):
p = a + b - c
pa = abs(p - a)
pb = abs(p - b)
pc = abs(p - c)
if pa <= pb and pa <= pc: return a
elif pb <= pc: return b
else: return c
def refilter(stride, ty, last, this):
last = last or bytearray([0]) * len(this)
assert len(last) == len(this)
if ty == 0: # none
ret = bytearray(this)
elif ty == 1: # sub
ret = bytearray(this[:stride])
ret.extend((x - a) & 0xff for x, a in zip(this[stride:], this))
elif ty == 2: # up
ret = bytearray((x - b) & 0xff for x, b in zip(this, last))
elif ty == 3: # average
ret = bytearray((x - (b >> 1)) & 0xff for x, b in zip(this[:stride], last[:stride]))
ret.extend((x - ((a + b) >> 1)) & 0xff for x, a, b in zip(this[stride:], this, last[stride:]))
elif ty == 4: # paeth
ret = bytearray((x - paeth(0, b, 0)) & 0xff for x, b in zip(this[:stride], last[:stride]))
ret.extend((x - paeth(a, b, c)) & 0xff for x, a, b, c in zip(this[stride:], this, last[stride:], last))
else:
assert None, 'unknown filter type'
return ret
def unfilter(stride, ty, last, this):
last = last or bytearray([0]) * len(this)
assert len(last) == len(this)
if ty == 0: # none
ret = bytearray(this)
elif ty == 1: # sub
ret = bytearray(this[:stride])
for x in this[stride:]:
a = ret[-stride]
ret.append((x + a) & 0xff)
elif ty == 2: # up
ret = bytearray((x + b) & 0xff for x, b in zip(this, last))
elif ty == 3: # average
ret = bytearray((x + (b >> 1)) & 0xff for x, b in zip(this[:stride], last[:stride]))
for x, b in zip(this[stride:], last[stride:]):
a = ret[-stride]
ret.append((x + ((a + b) >> 1)) & 0xff)
elif ty == 4: # paeth
ret = bytearray((x + paeth(0, b, 0)) & 0xff for x, b in zip(this[:stride], last[:stride]))
for x, b, c in zip(this[stride:], last[stride:], last):
a = ret[-stride]
ret.append((x + paeth(a, b, c)) & 0xff)
else:
assert None, 'unknown filter type'
return ret
def split(png, *, cjxlopts, brotliopts):
with timing('jxl encoder'):
jxl = cjxl(png, opts=cjxlopts)
assert png.startswith(PNG_SIGNATURE), 'invalid PNG signature'
recons = bytearray()
idatseen = False
i = 8
while i < len(png):
chunklen, = struct.unpack('!I', png[i:i+4])
assert chunklen < 0x80000000, 'too long chunk'
chunkdata = png[i+4:i+8+chunklen]
assert len(chunkdata) == chunklen + 4, 'premature end of file'
chunkcrc, = struct.unpack('!I', png[i+8+chunklen:i+8+chunklen+4])
assert zlib.crc32(chunkdata) == chunkcrc, 'mismatching CRC32'
assert (i == 8) == chunkdata.startswith(b'IHDR'), 'PNG does not start with IHDR'
i += 12 + chunklen
assert (i == len(png)) == chunkdata.startswith(b'IEND'), 'PNG does not end with IEND'
if chunkdata.startswith(b'IHDR'):
assert chunklen == 13, 'unexpected IHDR length'
width, height, bpp, ct, cm, fm, im = struct.unpack('!IIBBBBB', chunkdata[4:])
assert 0 < width < 0x80000000, 'width out of range'
assert 0 < height < 0x80000000, 'height out of range'
assert (ct, bpp) in POSSIBLE_CT_BPP, 'unexpected colour type & bpp combinations'
assert cm == 0, 'unknown compression method'
assert fm == 0, 'unknown filter method'
assert 0 <= im <= 1, 'unknown interlace method'
# reconstruction data starts with one byte:
# bits 0..3 - ct/bpp combination (0..14), mapped by POSSIBLE_CT_BPP
# bit 4 - interlace method (0..1)
# bits 5..7 - reserved
recons.append(POSSIBLE_CT_BPP.index((ct, bpp)) | (im << 4))
if ct & 1: print('Warning: indexed color might be not reconstructed at the moment', file=sys.stderr)
elif chunkdata.startswith(b'IDAT'):
# multiple IDAT chunks should be consecutive, read them all
assert not idatseen, 'multiple IDAT chunks should be consecutive'
idatseen = True
idat = bytearray(chunkdata[4:])
idatlens = [chunklen]
while png[i+4:i+8] == b'IDAT':
chunklen, = struct.unpack('!I', png[i:i+4])
assert chunklen < 0x80000000, 'too long chunk'
assert i + 12 + chunklen <= len(png), 'premature end of file'
chunkcrc, = struct.unpack('!I', png[i+8+chunklen:i+8+chunklen+4])
assert zlib.crc32(png[i+4:i+8+chunklen]) == chunkcrc, 'mismatching CRC32'
idat.extend(png[i+8:i+8+chunklen])
idatlens.append(chunklen)
i += 12 + chunklen
idatlens.pop() # last size can be reconstructed trivially
cmf = idat[0]
flg = idat[1]
assert (cmf & 0xf) == 8, 'zlib stream with unsupported compression method'
assert (cmf >> 4) < 8, 'zlib stream with too large window size'
assert (flg & 0x20) == 0, 'zlib stream with unexpected preset dictionary'
assert ((cmf << 8) | flg) % 31 == 0, 'zlib stream with bad header checksum'
# FLEVEL is technically possible to have two possible values in some cases,
# but thankfully unique for CMF/FLG ranges we concern...
origdata = zlib.decompress(idat) # implicitly checks ADLER-32 checksum
# reconstruction data for IDAT consists of:
# chunklen (u32) - length for reconstruction data except for "IDAT"
# "IDAT" (u32)
# chunkcount-1 (u32)
# chunk sizes in the original (u32[chunkcount-1]) - last size implied
# cmflg (u8) - meaningful bits from zlib CMF & FLG
# bits 0..2 - CMF bit 4..6 (CM, always less than 8)
# bits 3..4 - FLG bit 6..7 (FLEVEL)
# bits 5..7 - reserved
# for each scanline:
# padding (u8) - padded lowermost bits, if any; other bits reset to 0
# filter type (u8)
# preflate data (u8[*])
reconschunk = bytearray()
reconschunk.extend(struct.pack('!I', len(idatlens)))
for idatlen in idatlens:
reconschunk.extend(struct.pack('!I', idatlen))
reconschunk.append((cmf >> 4) | ((flg & 0xc0) >> 3))
k = 0
last = None
stride = max(1, bits_per_sample(ct, bpp) // 8)
for scanlen, padmask, _ in scanlines(width, height, ct, bpp, im):
if not scanlen:
last = None
continue
assert origdata[k] < 5, 'unknown filter type'
reconschunk.append(origdata[k])
if padmask: # PNG does not constrain padding bits
unfiltered = unfilter(stride, origdata[k], last, origdata[k+1:k+scanlen+1])
last = unfiltered
reconschunk.append(unfiltered[-1] & padmask)
k += 1 + scanlen
assert k == len(origdata), 'zlib stream size does not agree with scanline data'
with timing('preflate'):
zorigdata, zreconsdata = preflate_split(idat[2:-4])
assert zorigdata == origdata, 'preflate result mismatches with zlib'
if DEBUG_LEVEL > 0:
print('original data: size %d, hash %s' % (len(origdata), hash(origdata)), file=sys.stderr)
if DEBUG_LEVEL > 1:
with open('origdata-in', 'w') as f:
for p in range(0, len(zorigdata), width*stride+1): f.write(zorigdata[p:p+width*stride+1].hex() + '\n')
reconschunk.extend(zreconsdata)
recons.extend(struct.pack('!I4s', len(reconschunk), b'IDAT'))
recons.extend(reconschunk)
elif chunkdata.startswith(b'IEND'):
assert chunklen == 0, 'IEND chunk has unexpected data'
break
else:
# reconstruction data for other chunks consists of:
# chunklen (u32)
# chunkname (u32)
# chunkdata (u8[chunklen])
recons.extend(struct.pack('!I', chunklen))
recons.extend(chunkdata)
else:
assert False, 'premature end of PNG file'
with timing('brotli compression'):
recons = brotli(recons, opts=brotliopts)
return jxl, recons
def join(jxl, recons):
with timing('jxl decoder'):
png = djxl(jxl)
assert png.startswith(PNG_SIGNATURE), 'djxl returned non-PNG'
# assume that djxl returns a valid PNG, so can skip most checks
i = 8
while i < len(png):
chunklen, = struct.unpack('!I', png[i:i+4])
chunkname = png[i+4:i+8]
if chunkname == b'IHDR':
width, height, pngbpp, pngct, pngcm, pngfm, pngim = struct.unpack('!IIBBBBB', png[i+8:i+8+13])
assert pngct in (0, 2, 4, 6), 'djxl returned indexed color'
assert pngbpp in (8, 16), 'djxl returned non-8n-bit color'
assert pngim == 0, 'djxl returned interlaced PNG'
elif chunkname == b'IDAT':
origdata = zlib.decompress(png[i+8:i+8+chunklen])
pngdata = bytearray()
k = 0
last = None
stride = max(1, bits_per_sample(pngct, pngbpp) // 8)
for scanlen, padmask, _ in scanlines(width, height, pngct, pngbpp, pngim):
assert scanlen
unfiltered = unfilter(stride, origdata[k], last, origdata[k+1:k+1+scanlen])
last = unfiltered
if padmask: assert (unfiltered[-1] & padmask) == 0
pngdata.extend(unfiltered)
k += 1 + scanlen
break
i += 12 + chunklen
else:
assert False, 'djxl returned PNG with no IDAT'
with timing('brotli decompressor'):
recons = unbrotli(recons)
ct, bpp = POSSIBLE_CT_BPP[recons[0] & 0xf]
im = (recons[0] >> 4) & 1
assert (ct & 1) == 0, 'indexed color reconstruction not yet supported'
pngchunks = [
struct.pack('!4sIIBBBBB', b'IHDR', width, height, bpp, ct, 0, 0, im),
]
i = 1
while i < len(recons):
chunklen, = struct.unpack('!I', recons[i:i+4])
chunkname = recons[i+4:i+8]
if chunkname == b'IDAT':
i += 8
end = i + chunklen
idatlencount, = struct.unpack('!I', recons[i:i+4])
idatlens = list(struct.unpack('!%dI' % idatlencount, recons[i+4:i+4+idatlencount*4])) if idatlencount else []
i += 4 + idatlencount * 4
cmf = 0x08 | ((recons[i] & 7) << 4)
flg = (recons[i] & 0x18) << 3
flg |= -((cmf << 8) | flg) % 31
i += 1
last = None
stride = max(1, bits_per_sample(ct, bpp) // 8)
zorigdata = bytearray()
for scanlen, padmask, passinfo in scanlines(width, height, ct, bpp, im):
if not scanlen:
last = None
continue
r, cstart, cinc = passinfo
this = bytearray()
rbase = r * width * stride
for k in range(rbase + cstart * stride, rbase + width * stride, cinc * stride):
this.extend(pngdata[k:k+stride])
filtered = refilter(stride, recons[i], last, this)
last = this
zorigdata.append(recons[i])
i += 1
if padmask:
assert (filtered[-1] & padmask) == 0, 'djxl padded byte boundary with non-zeroes'
filtered[-1] |= recons[i] & padmask
i += 1
zorigdata.extend(filtered)
zreconsdata = recons[i:end]
if DEBUG_LEVEL > 0:
print('reconstructed original data: size %d, hash %s' % (len(zorigdata), hash(zorigdata)), file=sys.stderr)
if DEBUG_LEVEL > 1:
with open('origdata-out', 'w') as f:
for p in range(0, len(zorigdata), width*stride+1): f.write(zorigdata[p:p+width*stride+1].hex() + '\n')
with timing('preflate reconstruction'):
idat = struct.pack('!BB', cmf, flg) + preflate_join(zorigdata, zreconsdata) + struct.pack('!I', zlib.adler32(zorigdata))
idatlens.append(len(idat) - sum(idatlens))
assert idatlens[-1] >= 0, 'multiple IDAT chunks are longer than the reconstructed data'
chunkoff = 0
for idatlen in idatlens:
pngchunks.append(b'IDAT' + idat[chunkoff:chunkoff+idatlen])
chunkoff += idatlen
else:
pngchunks.append(recons[i+4:i+8+chunklen])
i += 8 + chunklen
pngchunks.append(b'IEND') # implied
return PNG_SIGNATURE + b''.join(
struct.pack('!I', len(chunk) - 4) +
chunk +
struct.pack('!I', zlib.crc32(chunk))
for chunk in pngchunks
)
USAGE = '''\
Usage: %(prog)s [-c] <opts> PNG [JXL] [RECONS] for compression
or %(prog)s [-r] <opts> JXL [PNG] [RECONS] for reconstruction
Recompresses PNG into JPEG XL and small data that can be reconstructed.
If the first argument ends with `.png` the compression (-c) is implied.
If it ends with `.jxl` the reconstruction (-r) is implied.
Otherwise you need to give -c or -r for the operation.
For compression JXL defaults to PNG.jxl.
For reconstruction PNG defaults to JXL.png to avoid accidental overwrite.
For both usages RECONS defaults to JXL.recons.
This tool depends on these tools in $PATH by default:
cjxl, djxl from libjxl https://github.com/libjxl/libjxl
preflate_demo from preflate https://github.com/deus-libri/preflate
brotli from brotli https://github.com/google/brotli
Options:
-h, --help This.
-c, --compress Compresses PNG into JXL and reconstruction data.
-r, --reconstruct Reconstructs PNG from JXL and reconstruction data.
-s EFFORT, --speed=EFFORT [Default: 4 through 7 depending on file size]
JPEG XL effort/speed setting from 1 to 9.
Note: -s 1/2 might not be supported by old cjxl.
-p, --progressive Make JPEG XL progressive. [Default: off]
--jxl=OPTIONS Additional options to JPEG XL encoder (cjxl).
Warning: this can ruin the reconstruction!
--brotli=OPTIONS Additional options to Brotli compressor.
--cjxl-path=PATH Path to cjxl if it is not in $PATH.
--djxl-path=PATH Path to djxl if it is not in $PATH.
--brotli-path=PATH Path to brotli if it is not in $PATH.
--preflate-path=PATH Path to preflate_demo if it is not in $PATH.
'''
def main(argv):
usage = USAGE % dict(prog=os.path.basename(argv[0]))
import getopt
try:
opts, args = getopt.gnu_getopt(
sys.argv[1:],
'h:crs:',
['help', 'compress', 'reconstruct', 'speed', 'jxl', 'brotli',
'cjxl-path', 'djxl-path', 'brotli-path', 'preflate-path'])
except getopt.GetoptError as e:
print(e, file=sys.stderr)
print(usage, file=sys.stderr)
raise SystemExit(2)
compress = False
reconstruct = False
effort = 0
progressive = False
cjxlopts = []
brotliopts = []
for o, a in opts:
if o in ('-h', '--help'):
print(usage, file=sys.stderr)
return
elif o in ('-c', '--compress'):
compress = True
elif o in ('-r', '--reconstruct'):
reconstruct = True
elif o in ('-s', '--speed'):
effort = int(a)
assert 1 <= effort <= 9, 'only -s 1 through -s 9 are supported'
elif o in ('-p', '--progressive'):
progressive = True
elif o == '--jxl':
cjxlopts.append(a)
elif o == '--brotli':
brotliopts.append(a)
elif o == '--cjxl-path':
CJXL_PATH = a
elif o == '--djxl-path':
DJXL_PATH = a
elif o == '--brotli-path':
BROTLI_PATH = a
elif o == '--preflate-path':
PREFLATE_PATH = a
else:
assert False, 'unknown option'
if compress and reconstruct:
assert False, '-c and -r cannot be used together'
elif not compress and not reconstruct:
if not args:
print(usage, file=sys.stderr)
return
elif args[0].endswith('.png'):
compress = True
elif args[0].endswith('.jxl'):
reconstruct = True
else:
assert False, 'one of -c or -r should be given with unknown extensions'
if not args:
assert False, 'file name(s) should be given'
if compress:
pngpath = args[0]
if len(args) > 1:
jxlpath = args[1]
else:
jxlpath = pngpath + '.jxl'
if len(args) > 2:
reconspath = args[2]
else:
reconspath = jxlpath + '.recons'
with open(pngpath, 'rb') as f: png = f.read()
if not effort: # heuristic effort determination
if len(png) >= 10000000:
effort = 3
else:
effort = 7
if progressive:
cjxlopts[0:0] = ['-p']
cjxlopts[0:0] = ['-s', str(effort)]
jxl, recons = split(png, cjxlopts=cjxlopts, brotliopts=brotliopts)
with open(jxlpath, 'wb') as f: f.write(jxl)
with open(reconspath, 'wb') as f: f.write(recons)
print(
'original %d -> jxl %d + recons %d = compressed %d (%.1f%% smaller)' % (
len(png), len(jxl), len(recons), len(jxl) + len(recons),
100.0 * (1 - (len(jxl) + len(recons)) / len(png))
), file=sys.stderr)
else:
jxlpath = args[0]
if len(args) > 1:
pngpath = args[1]
else:
pngpath = jxlpath + '.png'
if len(args) > 2:
reconspath = args[2]
else:
reconspath = jxlpath + '.recons'
with open(jxlpath, 'rb') as f: jxl = f.read()
with open(reconspath, 'rb') as f: recons = f.read()
png = join(jxl, recons)
with open(pngpath, 'wb') as f: f.write(png)
if __name__ == '__main__':
main(sys.argv)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment