Skip to content

Instantly share code, notes, and snippets.

@blubberdiblub
Last active June 17, 2019 08:58
Show Gist options
  • Save blubberdiblub/6a56bd4f3c26d807a786187610780233 to your computer and use it in GitHub Desktop.
Save blubberdiblub/6a56bd4f3c26d807a786187610780233 to your computer and use it in GitHub Desktop.
find-certs
#!/usr/bin/env python3
import base64
import pickle
import sys
from binascii import crc32
from datetime import timezone
import pathvalidate
import cryptography.exceptions
import cryptography.x509
from cryptography.x509.oid import NameOID
import cryptography.hazmat.primitives.hashes as _hashes
import cryptography.hazmat.primitives.serialization as _serialization
from cryptography.hazmat.backends import default_backend as _backend
try:
import cryptography.hazmat.primitives.asymmetric.rsa as _rsa
except ImportError:
_rsa = None
try:
import cryptography.hazmat.primitives.asymmetric.ec as _ec
except ImportError:
_ec = None
try:
import cryptography.hazmat.primitives.asymmetric.dsa as _dsa
except ImportError:
_dsa = None
try:
import cryptography.hazmat.primitives.asymmetric.ed25519 as _ed25519
except ImportError:
_ed25519 = None
def crc32_bytes(*args: bytes) -> int:
checksum = 0
for data in args:
checksum = crc32(data, checksum)
return checksum & 0xffffffff
def crc32_values(*args) -> int:
return crc32_bytes(*(pickle.dumps(value) for value in args))
def crc32_pubkey(pubkey) -> int:
if _rsa and isinstance(pubkey, _rsa.RSAPublicKeyWithSerialization):
numbers = pubkey.public_numbers()
return crc32_values(b'RSA',
numbers.n, numbers.e)
if _ec and isinstance(pubkey, _ec.EllipticCurvePublicKeyWithSerialization):
numbers = pubkey.public_numbers()
curve = numbers.curve
return crc32_values(b'EC',
curve.name, curve.key_size, numbers.x, numbers.y)
if _dsa and isinstance(pubkey, _dsa.DSAPublicKeyWithSerialization):
numbers = pubkey.public_numbers()
params = numbers.parameter_numbers
return crc32_values(b'DSA',
numbers.y, params.p, params.q, params.g)
if _ed25519 and isinstance(pubkey, _ed25519.Ed25519PublicKey):
data = pubkey.public_bytes(encoding=_serialization.Encoding.Raw,
format=_serialization.PublicFormat.Raw)
return crc32_values(b'ED25519',
data)
type_name = type(pubkey).__name__
raise NotImplementedError(f"public key type {type_name} not implemented")
class DataType:
EXTENSION = ''
def __init__(self, data: bytes,
backend=_backend,
digest_algo=_hashes.SHA256,
**kwargs) -> None:
super().__init__(**kwargs) # make it play well with mix-ins
self._backend = backend()
self._digest_algo = digest_algo()
self._digest = None
def get_data(self) -> bytes:
raise NotImplementedError
def get_digest(self) -> bytes:
raise NotImplementedError
def get_name(self) -> str:
raise NotImplementedError
class Opaque(DataType):
EXTENSION = ''
def __init__(self, data: bytes, **kwargs):
super().__init__(data, **kwargs)
self._data = bytes(data)
def get_data(self) -> bytes:
return self._data
def get_digest(self) -> bytes:
if self._digest is None:
digest = _hashes.Hash(algorithm=self._digest_algo,
backend=self._backend)
digest.update(self._data)
self._digest = digest.finalize()
return self._digest
def get_name(self) -> str:
checksum = crc32_bytes(self._data)
return f'{checksum:08x}'
class CSR(DataType):
EXTENSION = '.csr'
def __init__(self, data: bytes, **kwargs) -> None:
super().__init__(data, **kwargs)
self._csr = cryptography.x509.load_pem_x509_csr(data,
backend=self._backend)
def get_data(self) -> bytes:
return self._csr.public_bytes(encoding=_serialization.Encoding.PEM)
def get_digest(self) -> bytes:
if self._digest is None:
digest = _hashes.Hash(algorithm=self._digest_algo,
backend=self._backend)
digest.update(self._csr.tbs_certrequest_bytes)
self._digest = digest.finalize()
return self._digest
def get_name(self) -> str:
# checksum = crc32_bytes(self._csr.tbs_certrequest_bytes)
checksum = crc32_pubkey(self._csr.public_key())
attrs = self._csr.subject.get_attributes_for_oid(NameOID.COMMON_NAME)
if not attrs:
return f'{checksum:08x}'
name = pathvalidate.replace_symbol(attrs[-1].value,
replacement_text='_',
is_replace_consecutive_chars=True,
is_strip=True)
return f'{name}-{checksum:08x}'
class Certificate(DataType):
EXTENSION = '.crt'
def __init__(self, data: bytes, **kwargs) -> None:
super().__init__(data, **kwargs)
self._cert = cryptography.x509.load_pem_x509_certificate(
data,
backend=self._backend,
)
def get_data(self) -> bytes:
return self._cert.public_bytes(encoding=_serialization.Encoding.PEM)
def get_digest(self) -> bytes:
if self._digest is None:
self._digest = self._cert.fingerprint(algorithm=self._digest_algo)
return self._digest
def get_name(self) -> str:
try:
checksum = crc32_pubkey(self._cert.public_key())
except cryptography.exceptions.UnsupportedAlgorithm:
checksum = crc32_bytes(self._cert.tbs_certificate_bytes)
before = self._cert.not_valid_before.astimezone(tz=timezone.utc)
after = self._cert.not_valid_after.astimezone(tz=timezone.utc)
attrs = self._cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME)
if not attrs:
return f'{checksum:08x}-{before:%Y%m%d}-{after:%Y%m%d}'
name = pathvalidate.replace_symbol(attrs[-1].value,
replacement_text='_',
is_replace_consecutive_chars=True,
is_strip=True)
return f'{name}-{checksum:08x}-{before:%Y%m%d}-{after:%Y%m%d}'
class PrivateKey(DataType):
EXTENSION = '.key'
def __init__(self, data: bytes, **kwargs) -> None:
super().__init__(data, **kwargs)
self._key = _serialization.load_pem_private_key(data,
password=None,
backend=self._backend)
def get_data(self) -> bytes:
return self._key.private_bytes(
encoding=_serialization.Encoding.PEM,
format=_serialization.PrivateFormat.PKCS8,
encryption_algorithm=_serialization.NoEncryption(),
)
def get_digest(self) -> bytes:
if self._digest is None:
digest = _hashes.Hash(algorithm=self._digest_algo,
backend=self._backend)
digest.update(self._key.private_bytes(
encoding=_serialization.Encoding.DER,
format=_serialization.PrivateFormat.PKCS8,
encryption_algorithm=_serialization.NoEncryption(),
))
self._digest = digest.finalize()
return self._digest
def get_name(self) -> str:
checksum = crc32_pubkey(self._key.public_key())
return f'{checksum:08x}'
class PublicKey(DataType):
EXTENSION = '.pub'
def __init__(self, data: bytes, **kwargs) -> None:
super().__init__(data, **kwargs)
self._pub = _serialization.load_pem_public_key(data,
backend=self._backend)
def get_data(self) -> bytes:
return self._pub.public_bytes(
encoding=_serialization.Encoding.PEM,
format=_serialization.PublicFormat.SubjectPublicKeyInfo,
)
def get_digest(self) -> bytes:
if self._digest is None:
digest = _hashes.Hash(algorithm=self._digest_algo,
backend=self._backend)
digest.update(self._pub.public_bytes(
encoding=_serialization.Encoding.DER,
format=_serialization.PublicFormat.SubjectPublicKeyInfo,
))
self._digest = digest.finalize()
return self._digest
def get_name(self) -> str:
checksum = crc32_pubkey(self._pub)
return f'{checksum:08x}'
TYPE_MAP = {
b'PRIVATE KEY': PrivateKey,
b'PUBLIC KEY': PublicKey,
b'CERTIFICATE REQUEST': CSR,
b'CERTIFICATE': Certificate,
# legacy identifiers
b'RSA PRIVATE KEY': PrivateKey,
b'RSA PUBLIC KEY': PublicKey,
b'DSA PRIVATE KEY': PrivateKey,
b'DSA PUBLIC KEY': PublicKey,
b'NEW CERTIFICATE REQUEST': CSR,
b'X509 CERTIFICATE': Certificate,
}
class WriteError(Exception):
pass
def write_file(data_type: type, content: bytes,
filename_fmt: str = '{name}{extension}',
buffering: int = 16 * 1024 * 1024) -> None:
try:
data = data_type(content)
except ValueError:
raise WriteError("cannot parse data")
filename = filename_fmt.format(name=data.get_name(),
extension=data.EXTENSION)
filename = pathvalidate.sanitize_filename(filename)
try:
with open(filename, mode='rb', buffering=buffering) as f:
content = f.read()
except FileNotFoundError:
pass
except OSError:
raise WriteError("cannot read existing file")
else:
try:
file_data = data_type(content)
except ValueError:
raise WriteError("cannot parse existing file")
if file_data.get_digest() != data.get_digest():
raise WriteError("new file differs from existing file")
return
try:
content = data.get_data()
except ValueError:
raise WriteError("cannot generate file content")
try:
with open(filename, mode='wb', buffering=buffering) as f:
f.write(content)
except OSError:
raise WriteError("cannot write file")
def find_in_stream(stream, buffer_size: int = 16 * 1024 * 1024) -> None:
MARKER = b'-----'
len_marker = len(MARKER)
(
STATE_BEFORE_MARKER,
STATE_PREFIX,
STATE_BODY,
STATE_SUFFIX,
) = range(4)
buffer_size = max(buffer_size, 8192, len_marker * 2)
buffer = bytearray(buffer_size)
buffer_view = memoryview(buffer)
buffer_head = 0
buffer_tail = 0
start_tail = len_marker
state = STATE_BEFORE_MARKER
type_id = None
content = None
while True:
while buffer_tail < start_tail:
if buffer_tail >= buffer_size:
if buffer_head > 0:
buffer_view = None
del buffer[:buffer_head]
buffer += bytearray(buffer_head)
buffer_view = memoryview(buffer)
assert len(buffer_view) == buffer_size
buffer_tail -= buffer_head
start_tail -= buffer_head
buffer_head = 0
else:
# buffer size exceeded, so flush and reset state
buffer_head = 0
buffer_tail = 0
start_tail = len_marker
state = STATE_BEFORE_MARKER
num_bytes = stream.readinto1(buffer_view[buffer_tail:])
if num_bytes == 0:
return
buffer_tail += num_bytes
assert buffer_tail <= buffer_size
pos_head = buffer.find(MARKER, start_tail - len_marker, buffer_tail)
if pos_head < 0:
start_tail = buffer_tail + 1
continue
if state == STATE_BEFORE_MARKER:
chunk = None
content = None
type_id = None
else:
chunk = buffer_view[buffer_head:pos_head].tobytes()
buffer_head = pos_head + len_marker
start_tail = buffer_head + len_marker
if state == STATE_BEFORE_MARKER:
state = STATE_PREFIX
continue
if (state == STATE_SUFFIX
and chunk == b'END ' + type_id):
content = base64.b64encode(content)
content = [content[i:i+64] for i in range(0, len(content), 64)]
content = (
MARKER + b'BEGIN ' + type_id + MARKER + b'\n'
+ b'\n'.join(content) + b'\n'
+ MARKER + b'END ' + type_id + MARKER + b'\n'
)
try:
write_file(TYPE_MAP[type_id], content, buffering=buffer_size)
except WriteError as e:
content = content.rstrip().decode('ascii',
errors='backslashreplace')
print(f"\n{e}:\n{content}", file=sys.stderr, flush=True)
state = STATE_BEFORE_MARKER
continue
if (state == STATE_BODY
and (chunk.startswith(b'\n') or chunk.startswith(b'\r\n'))
and (chunk.endswith(b'\n'))):
try:
content = base64.b64decode(chunk.strip())
except ValueError:
pass
else:
state = STATE_SUFFIX
continue
chunk = chunk.lstrip(b'-')
if not chunk.startswith(b'BEGIN '):
state = STATE_PREFIX
continue
type_id = chunk[6:]
if type_id not in TYPE_MAP:
state = STATE_PREFIX
continue
state = STATE_BODY
def main(*args):
for filename in args:
with open(filename, mode='rb') as f:
find_in_stream(f)
if __name__ == '__main__':
main(*sys.argv[1:])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment