Skip to content

Instantly share code, notes, and snippets.

@danifus
Created July 27, 2019 03:18
Show Gist options
  • Save danifus/73d258df243bbb386c1dd64c0888cddf to your computer and use it in GitHub Desktop.
Save danifus/73d258df243bbb386c1dd64c0888cddf to your computer and use it in GitHub Desktop.
Implementing Winzip AES encryption / decryption with zipfile refactor
import zipfile
import zipfile_aes
secret_password = b'lost art of keeping a secret'
with zipfile_aes.AESZipFile('new_test.zip',
'w',
compression=zipfile.ZIP_LZMA,
encryption=zipfile_aes.WZ_AES) as zf:
zf.setpassword(secret_password)
zf.writestr('test.txt', "What ever you do, don't tell anyone!")
with zipfile_aes.AESZipFile('new_test.zip') as zf:
zf.setpassword(secret_password)
my_secrets = zf.read('test.txt')
print(my_secrets)
import struct
# requires pip install pycryptodomex
from Cryptodome.Protocol.KDF import PBKDF2
from Cryptodome.Cipher import AES
from Cryptodome.Hash import HMAC
from Cryptodome.Hash.SHA1 import SHA1Hash
from Cryptodome.Util import Counter
from Cryptodome import Random
from zipfile import (
ZIP_BZIP2, BadZipFile, BaseDecrypter, ZipFile, ZipInfo, ZipExtFile,
_ZipWriteFile, crc32, _MASK_ENCRYPTED, CRCZipDecrypter,
)
WZ_AES = 'WZ_AES'
WZ_AES_COMPRESS_TYPE = 99
WZ_AES_V1 = 0x0001
WZ_AES_V2 = 0x0002
WZ_AES_VENDOR_ID = b'AE'
EXTRA_WZ_AES = 0x9901
WZ_SALT_LENGTHS = {
1: 8, # 128 bit
2: 12, # 192 bit
3: 16, # 256 bit
}
WZ_KEY_LENGTHS = {
1: 16, # 128 bit
2: 24, # 192 bit
3: 32, # 256 bit
}
class AESZipDecrypter(BaseDecrypter):
hmac_size = 10
def __init__(self, zinfo, pwd):
self.zinfo = zinfo
self.name = zinfo.filename
if not pwd:
raise RuntimeError("File %r is encrypted, a password is "
"required for extraction" % self.name)
self.pwd = pwd
def start_decrypt(self, fileobj):
key_length = WZ_KEY_LENGTHS[self.zinfo.wz_aes_strength]
salt_length = WZ_SALT_LENGTHS[self.zinfo.wz_aes_strength]
# salt_length + pwd_verify_length
encryption_header_length = salt_length + 2
encryption_header = fileobj.read(encryption_header_length)
salt = struct.unpack(
"<{}s".format(salt_length),
encryption_header[:salt_length]
)[0]
pwd_verify_length = 2
pwd_verify = encryption_header[salt_length:]
dkLen = 2*key_length + pwd_verify_length
keymaterial = PBKDF2(self.pwd, salt, count=1000, dkLen=dkLen)
encpwdverify = keymaterial[2*key_length:]
if encpwdverify != pwd_verify:
raise RuntimeError("Bad password for file %r" % self.name)
enckey = keymaterial[:key_length]
self.decypter = AES.new(
enckey,
AES.MODE_CTR,
counter=Counter.new(nbits=128, little_endian=True)
)
encmac_key = keymaterial[key_length:2*key_length]
self.hmac = HMAC.new(encmac_key, digestmod=SHA1Hash())
return encryption_header_length + self.hmac_size
def decrypt(self, data):
self.hmac.update(data)
return self.decypter.decrypt(data)
def check_hmac(self, hmac_check):
if self.hmac.digest()[:10] != hmac_check:
raise BadZipFile("Bad HMAC check for file %r" % self.name)
class BaseZipEncrypter:
def update_zipinfo(self, zipinfo):
raise NotImplementedError(
'BaseZipEncrypter implementations must implement `update_zipinfo`.'
)
def encrypt(self, data):
raise NotImplementedError(
'BaseZipEncrypter implementations must implement `encrypt`.'
)
def encryption_header(self):
raise NotImplementedError(
'BaseZipEncrypter implementations must implement '
'`encryption_header`.'
)
def flush(self):
return b''
class AESZipEncrypter(BaseZipEncrypter):
hmac_size = 10
def __init__(self, pwd, nbits=256, force_wz_aes_version=None):
if not pwd:
raise RuntimeError(
'%s encryption requires a password.' % WZ_AES
)
if nbits not in (128, 192, 256):
raise RuntimeError(
"`nbits` must be one of 128, 192, 256. Got '%s'" % nbits
)
self.force_wz_aes_version = force_wz_aes_version
salt_lengths = {
128: 8,
192: 12,
256: 16,
}
self.salt_length = salt_lengths[nbits]
key_lengths = {
128: 16,
192: 24,
256: 32,
}
key_length = key_lengths[nbits]
aes_strengths = {
128: 1,
192: 2,
256: 3,
}
self.aes_strength = aes_strengths[nbits]
self.salt = Random.new().read(self.salt_length)
pwd_verify_length = 2
dkLen = 2 * key_length + pwd_verify_length
keymaterial = PBKDF2(pwd, self.salt, count=1000, dkLen=dkLen)
self.encpwdverify = keymaterial[2*key_length:]
enckey = keymaterial[:key_length]
self.encrypter = AES.new(
enckey,
AES.MODE_CTR,
counter=Counter.new(nbits=128, little_endian=True)
)
encmac_key = keymaterial[key_length:2*key_length]
self.hmac = HMAC.new(encmac_key, digestmod=SHA1Hash())
def update_zipinfo(self, zipinfo):
zipinfo.wz_aes_vendor_id = WZ_AES_VENDOR_ID
zipinfo.wz_aes_strength = self.aes_strength
if self.force_wz_aes_version is not None:
zipinfo.wz_aes_version = self.force_wz_aes_version
def encryption_header(self):
return self.salt + self.encpwdverify
def encrypt(self, data):
data = self.encrypter.encrypt(data)
self.hmac.update(data)
return data
def flush(self):
return struct.pack('<%ds' % self.hmac_size, self.hmac.digest()[:10])
class AESZipInfo(ZipInfo):
"""Class with attributes describing each file in the ZIP archive."""
# __slots__ on subclasses only need to contain the additional slots.
__slots__ = (
'wz_aes_version',
'wz_aes_vendor_id',
'wz_aes_strength',
# 'wz_aes_actual_compression_type',
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.wz_aes_version = None
self.wz_aes_vendor_id = None
self.wz_aes_strength = None
def decode_extra_wz_aes(self, ln, extra_payload):
if ln == 7:
counts = struct.unpack("<H2sBH", extra_payload)
else:
raise BadZipFile(
"Corrupt extra field %04x (size=%d)" % (EXTRA_WZ_AES, ln))
self.wz_aes_version = counts[0]
self.wz_aes_vendor_id = counts[1]
# 0x01 128-bit encryption key
# 0x02 192-bit encryption key
# 0x03 256-bit encryption key
self.wz_aes_strength = counts[2]
# the compression method is the one that would otherwise have been
# stored in the local and central headers for the file. For example, if
# the file is imploded, this field will contain the compression code 6.
# This is needed because a compression method of 99 is used to indicate
# the presence of an AES-encrypted file
self.compress_type = counts[3]
# self.wz_aes_actual_compression_type = counts[3]
def get_extra_decoders(self):
extra_decoders = super().get_extra_decoders()
extra_decoders[EXTRA_WZ_AES] = self.decode_extra_wz_aes
return extra_decoders
def encode_extra(self, crc, compress_type):
wz_aes_extra = b''
if self.wz_aes_vendor_id is not None:
compress_type = WZ_AES_COMPRESS_TYPE
aes_version = self.wz_aes_version
if aes_version is None:
if self.file_size < 20 | self.compress_type == ZIP_BZIP2:
# The only difference between version 1 and 2 is the
# handling of the CRC values. For version 2 the CRC value
# is not used and must be set to 0.
# For small files, the CRC files can leak the contents of
# the encrypted data.
# For bzip2, the compression already has integrity checks
# so CRC is not required.
aes_version = WZ_AES_V2
else:
aes_version = WZ_AES_V1
if aes_version == WZ_AES_V2:
crc = 0
wz_aes_extra = struct.pack(
"<3H2sBH",
EXTRA_WZ_AES,
7, # extra block body length: H2sBH
aes_version,
self.wz_aes_vendor_id,
self.wz_aes_strength,
self.compress_type,
)
return wz_aes_extra, crc, compress_type
def get_local_header_params(self, zip64=False):
params = super().get_local_header_params(zip64=zip64)
wz_aes_extra, crc, compress_type = self.encode_extra(
params["crc"], params["compress_type"])
params["extra"] = params["extra"] + wz_aes_extra
params["crc"] = crc
params["compress_type"] = compress_type
return params
def get_central_directory_kwargs(self):
params = super().get_central_directory_kwargs()
wz_aes_extra, crc, compress_type = self.encode_extra(
params["crc"], params["compress_type"])
params["extra"] = params["extra"] + wz_aes_extra
params["crc"] = crc
params["compress_type"] = compress_type
return params
class AESZipExtFile(ZipExtFile):
def check_wz_aes(self):
hmac_check = self._fileobj.read(self._decrypter.hmac_size)
self._decrypter.check_hmac(hmac_check)
def check_integrity(self):
if self._zinfo.wz_aes_version is not None:
self.check_wz_aes()
if self._expected_crc is not None and self._expected_crc != 0:
# Not part of the spec but still check the CRC if it is
# supplied when WZ_AES_V2 is specified (no CRC check and CRC
# should be 0).
self.check_crc()
elif self._zinfo.wz_aes_version != WZ_AES_V2:
# CRC value should be 0 for AES vendor version 2.
self.check_crc()
else:
super().check_integrity()
class AESZipWriteFile(_ZipWriteFile):
def __init__(self, zf, zinfo, zip64, encrypter):
super().__init__(zf, zinfo, zip64)
self.encrypter = encrypter
if self.encrypter:
self.write_encryption_header()
def write_encryption_header(self):
buf = self.encrypter.encryption_header()
self._compress_size += len(buf)
self._fileobj.write(buf)
def write(self, data):
if self.closed:
raise ValueError('I/O operation on closed file.')
nbytes = len(data)
self._file_size += nbytes
self._crc = crc32(data, self._crc)
if self._compressor:
data = self._compressor.compress(data)
if self.encrypter:
data = self.encrypter.encrypt(data)
self._compress_size += len(data)
self._fileobj.write(data)
return nbytes
def flush_data(self):
if self._compressor:
buf = self._compressor.flush()
else:
buf = b""
if self.encrypter:
buf = self.encrypter.encrypt(buf)
buf += self.encrypter.flush()
self._compress_size += len(buf)
self._fileobj.write(buf)
class AESZipFile(ZipFile):
zipinfo_cls = AESZipInfo
zipextfile_cls = AESZipExtFile
zipwritefile_cls = AESZipWriteFile
def __init__(self, *args, **kwargs):
encryption = kwargs.pop('encryption', None)
encryption_kwargs = kwargs.pop('encryption_kwargs', None)
super().__init__(*args, **kwargs)
self.encryption = encryption
self.encryption_kwargs = encryption_kwargs
def get_decrypter(self, zinfo, pwd):
if zinfo.is_encrypted:
if zinfo.wz_aes_version is not None:
return AESZipDecrypter(zinfo, pwd)
return CRCZipDecrypter(zinfo, pwd)
def get_encrypter(self):
if self.encryption == WZ_AES:
if self.encryption_kwargs is None:
encryption_kwargs = {}
else:
encryption_kwargs = self.encryption_kwargs
return AESZipEncrypter(pwd=self.pwd, **encryption_kwargs)
def get_zipwritefile(self, zinfo, zip64, pwd, **kwargs):
encrypter = None
if pwd is not None or self.encryption is not None:
zinfo.flag_bits |= _MASK_ENCRYPTED
encrypter = self.get_encrypter()
encrypter.update_zipinfo(zinfo)
return self.zipwritefile_cls(self, zinfo, zip64, encrypter)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment