Skip to content

Instantly share code, notes, and snippets.

@flisboac
Last active September 3, 2022 05:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save flisboac/48762e176061b520a606868bc4ce089f to your computer and use it in GitHub Desktop.
Save flisboac/48762e176061b520a606868bc4ce089f to your computer and use it in GitHub Desktop.
Utility Python library and CLI capable of downloading a whole certificate chain; barely tested
# Requires at least Python 3.7, and typing_extensions
from __future__ import annotations
import collections.abc
import datetime
import functools
import os
import pathlib
import re
import shutil
import ssl
import sys
import tempfile
import urllib.request
from abc import ABCMeta, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (
Any,
Iterable,
Iterator,
List,
Mapping,
NamedTuple,
Sequence,
Sized,
TextIO,
Tuple,
TypeVar,
)
from typing_extensions import TypedDict
_CERT_ENCODING = "utf-8"
_PEM_CERTIFICATE_RE = re.compile(
r"(?P<content>-----BEGIN CERTIFICATE-----[a-zA-Z0-9+\/=\r\n]+(-----END CERTIFICATE-----)?\r?\n?)",
flags=re.MULTILINE,
)
_WS_RE_S = "\s*"
_ENVNAME_RE_S = r"[a-zA-Z][\-_a-zA-Z0-9]*"
_HOSTNAME_RE_S = r"([^\[\]\\#?&:=,;]+|\[[\d:]\])"
_HOSTPORT_RE_S = r"\d+"
_INPUT_HOST_RE = re.compile(
rf"\A({_WS_RE_S},{_WS_RE_S})?"
rf"((?P<envname>{_ENVNAME_RE_S}){_WS_RE_S}\={_WS_RE_S})?"
rf"(?P<hostname>{_HOSTNAME_RE_S})"
rf"(:(?P<hostport>{_HOSTPORT_RE_S}))?"
)
_T = TypeVar("_T")
class RawCertificateProperty(NamedTuple):
key: str
value: str
RawCertificateDatetime = str
RawCertificateProperties = Sequence[RawCertificateProperty] # Why?
class _RawCertificateInfo_Optional(TypedDict, total=False):
OCSP: Sequence[str]
caIssuers: Sequence[str]
crlDistributionPoints: Sequence[str]
class _RawCertificateInfo_Required(TypedDict, total=True):
subject: Sequence[RawCertificateProperties]
issuer: Sequence[RawCertificateProperties]
version: int
serialNumber: str
notBefore: RawCertificateDatetime
notAfter: RawCertificateDatetime
subjectAltName: Sequence[RawCertificateProperty]
class RawCertificateInfo(
_RawCertificateInfo_Required,
_RawCertificateInfo_Optional,
TypedDict,
):
pass
@dataclass(frozen=True)
class CertificateHostInfo:
name: str
port: int
sni_name: str | None = None
@classmethod
def parse(self, value: str, *, sni_name: str | None = None) -> CertificateHostInfo:
match = _INPUT_HOST_RE.match(value)
if not match:
raise ValueError(f"Invalid certificate host/hostname value: {value}")
return CertificateHostInfo(
name=match.group("hostname"),
port=int(match.group("hostport") or "443"),
sni_name=sni_name,
)
@property
def server_name(self) -> str:
return self.sni_name or self.name
@dataclass(frozen=True)
class CertificatePrincipal(
collections.abc.Mapping,
Mapping[str, str],
):
_properties: Mapping[str, str]
@classmethod
def from_raw(
cls,
value: Sequence[RawCertificateProperties],
) -> CertificatePrincipal:
return cls({p[0][0]: p[0][1] for p in value})
def __len__(self) -> int:
return len(self._properties)
def __iter__(self) -> int:
yield from self._properties
def __getitem__(self, key: str) -> str:
return self._properties[key]
class Certificate(metaclass=ABCMeta):
@property
@abstractmethod
def subject(self) -> CertificatePrincipal:
...
@property
@abstractmethod
def issuer(self) -> CertificatePrincipal:
...
@property
@abstractmethod
def host(self) -> CertificateHostInfo | None:
...
@property
@abstractmethod
def pem_content(self) -> str:
...
@property
@abstractmethod
def der_content(self) -> bytes:
...
@property
@abstractmethod
def not_before(self) -> datetime.datetime:
...
@property
@abstractmethod
def not_after(self) -> datetime.datetime:
...
def is_root(self) -> bool:
...
class SingleCertificate(Certificate):
def __init__(
self,
*,
host: CertificateHostInfo | None = None,
raw_info: RawCertificateInfo | None = None,
pem_content: str | None = None,
der_content: bytes | None = None,
location: str | pathlib.PurePath = None,
not_before: datetime.datetime | None = None,
not_after: datetime.datetime | None = None,
) -> None:
self._host = host
self._raw_info_ = raw_info
self._pem_content = pem_content
self._der_content = der_content
self._not_before = not_before
self._not_after = not_after
self._location = pathlib.Path(location) if location is not None else None
self._subject: CertificatePrincipal | None = None
self._issuer: CertificatePrincipal | None = None
@property
def subject(self) -> CertificatePrincipal:
if self._subject is not None:
return self._subject
self._subject = self._get_subject()
return self._subject
@property
def issuer(self) -> CertificatePrincipal:
if self._issuer is not None:
return self._issuer
self._issuer = self._get_issuer()
return self._issuer
@property
def host(self) -> CertificateHostInfo | None:
return self._host
@property
def location(self) -> CertificatePrincipal:
return self._location
@property
def pem_content(self) -> str:
if self._pem_content is not None:
return self._pem_content
self._pem_content = self._get_pem_content()
return self._pem_content
@property
def der_content(self) -> bytes:
if self._der_content is not None:
return self._der_content
self._der_content = self._get_der_content()
return self._der_content
@property
def not_before(self) -> datetime.datetime:
if self._not_before is not None:
return self._not_before
self._not_before = self._get_not_before_date()
return self._not_before
@property
def not_after(self) -> datetime.datetime:
if self._not_after is not None:
return self._not_after
self._not_after = self._get_not_after_date()
return self._not_after
@property
def _raw_info(self) -> RawCertificateInfo:
if self._raw_info_ is not None:
return self._raw_info_
self._raw_info_ = self._get_raw_info()
return self._raw_info_
def is_root(self) -> bool:
return self.subject == self.issuer
def to_chain(self) -> CertificateChain:
certificates = []
current = self
while current is not None:
certificates.append(current)
current_ca_url = current._get_issuer_url()
if current_ca_url is not None:
der_content = _download(current_ca_url)
current = SingleCertificate(der_content=der_content)
else:
current = None
current = certificates[-1]
if not current.is_root():
while current is not None:
for system_ca_certificate in get_system_ca_certificates():
if current.issuer == system_ca_certificate.subject:
certificates.append(system_ca_certificate)
current = system_ca_certificate
break
else:
current = None
return CertificateChain(certificates)
def __repr__(self) -> str:
props = ", ".join([f"{k}={v!r}" for k, v in vars(self).items()])
return f"{type(self).__name__}({props})"
def _get_der_content(self) -> bytes:
assert self._pem_content is not None, "Missing certificate content."
return ssl.DER_cert_to_PEM_cert(self._pem_content)
def _get_pem_content(self) -> str:
assert self._der_content is not None, "Missing certificate content."
return ssl.DER_cert_to_PEM_cert(self._der_content)
def _get_raw_info(self) -> RawCertificateInfo:
with _open_temp_rw_text_file(suffix=".pem") as (tmp_file, tmp_path):
tmp_file.write(self.pem_content)
tmp_file.flush()
info = ssl._ssl._test_decode_cert(str(tmp_path))
return info
def _get_not_before_date(self) -> datetime.datetime:
raw_date_str = self._raw_info["notBefore"]
return self._parse_datetime(raw_date_str)
def _get_not_after_date(self) -> datetime.datetime:
raw_date_str = self._raw_info["notAfter"]
return self._parse_datetime(raw_date_str)
def _get_subject(self) -> int:
return CertificatePrincipal.from_raw(self._raw_info["subject"])
def _get_issuer(self) -> int:
return CertificatePrincipal.from_raw(self._raw_info["issuer"])
def _parse_datetime(self, value: str) -> datetime.datetime:
timestamp = ssl.cert_time_to_seconds(value)
return datetime.datetime.utcfromtimestamp(timestamp)
def _get_issuer_url(self) -> str | None:
if "caIssuers" in self._raw_info and len(self._raw_info["caIssuers"]) > 0:
return self._raw_info["caIssuers"][0]
return None
class CertificateBundle(
collections.abc.Sequence,
Sequence[SingleCertificate],
):
def __init__(
self,
certificates: Iterable[SingleCertificate] | None = None,
*,
host: CertificateHostInfo | None = None,
pem_content: str | Iterable[str] | None = None,
der_content: bytes | None = None,
location: str | pathlib.PurePath = None,
) -> None:
self._host = host
self._input_pem_content = pem_content
self._input_der_content = der_content
self._location = pathlib.Path(location) if location is not None else None
self._certificates_: Sequence[SingleCertificate] | None = (
tuple(certificates) if certificates is not None else None
)
self._proper_chain: bool | None = None
self._pem_content: str | None = None
self._der_content: bytes | None = None
self._pem_content_list_: Sequence[str] | None = None
@property
def location(self) -> pathlib.PurePath | None:
return self._location
@property
def pem_content(self) -> str:
if self._pem_content is not None:
return self._pem_content
self._pem_content = self._get_pem_content()
return self._pem_content
@property
def der_content(self) -> bytes:
if self._der_content is not None:
return self._der_content
self._der_content = self._get_der_content()
return self._der_content
def is_proper_chain(self) -> bool:
if self._proper_chain is not None:
return self._proper_chain
self._proper_chain = self._is_proper_chain()
return self._proper_chain
def to_chain(self) -> CertificateChain:
assert (
self.is_proper_chain()
), "This certificate bundle is not a proper certificate chain."
self_certificates = self._certificates
root_certificate = self_certificates[-1]
all_certificates = [*self_certificates[:-1], *root_certificate.to_chain()]
return CertificateChain(all_certificates)
@property
def _pem_content_list(self) -> str:
if self._pem_content_list_ is not None:
return self._pem_content_list_
self._pem_content_list_ = self._get_pem_content_list()
return self._pem_content_list_
@property
def _certificates(self) -> Sequence[SingleCertificate]:
if self._certificates_ is not None:
return self._certificates_
self._certificates_ = self._get_certificates()
return self._certificates_
def __len__(self) -> int:
return len(self._certificates)
def __getitem__(self, key: Any) -> SingleCertificate:
return self._certificates[key]
def __iter__(self) -> Iterator[SingleCertificate]:
yield from self._certificates
def __repr__(self) -> str:
props = ", ".join([f"{k}={v!r}" for k, v in vars(self).items()])
return f"{type(self).__name__}({props})"
def _get_pem_content_list(self) -> Sequence[str]:
if self._input_pem_content is not None:
pem_content = tuple(_split_pem_content(_join_pem_content(self._input_pem_content)))
else:
assert (
self._input_der_content is not None
), "Either PEM or DER content must be provided."
pem_content = tuple(
_split_pem_content(ssl.DER_cert_to_PEM_cert(self._input_der_content))
)
return pem_content
def _get_pem_content(self) -> str:
return _join_pem_content(self._pem_content_list)
def _get_der_content(self) -> str:
return ssl.PEM_cert_to_DER_cert(self.pem_content)
def _get_certificates(self) -> Sequence[SingleCertificate]:
certificates: List[SingleCertificate] = []
for i, pem_content in enumerate(self._pem_content_list):
if i == 0:
host = self._host
else:
host = None
certificate = SingleCertificate(
pem_content=pem_content,
host=host,
)
certificates.append(certificate)
return tuple(certificates)
def _is_proper_chain(self) -> bool:
return _is_proper_certificate_chain(self._certificates)
class CertificateChain(
Certificate,
collections.abc.Sequence,
Sequence[SingleCertificate],
):
def __init__(
self,
certificates: Iterable[SingleCertificate] | None = None,
) -> None:
self._certificates = list(certificates or ())
self._pem_content: str | None = None
self._der_content: bytes | None = None
@property
def root(self) -> SingleCertificate:
return self._certificates[-1]
@property
def target(self) -> SingleCertificate:
return self._certificates[0]
@property
# Certificate
def subject(self) -> CertificatePrincipal:
return self.target.subject
@property
# Certificate
def issuer(self) -> CertificatePrincipal:
return self.target.subject
@property
# Certificate
def host(self) -> CertificateHostInfo | None:
return self.target.host
@property
# Certificate
def raw_info(self) -> RawCertificateInfo:
return self.target._raw_info
@property
# Certificate
def pem_content(self) -> str:
if self._pem_content is not None:
return self._pem_content
self._pem_content = self._get_pem_content()
return self._pem_content
@property
# Certificate
def der_content(self) -> bytes:
if self._der_content is not None:
return self._der_content
self._der_content = self._get_der_content()
return self._der_content
@property
# Certificate
def not_before(self) -> datetime.datetime:
return self.target.not_before
@property
# Certificate
def not_after(self) -> datetime.datetime:
return self.target.not_after
def is_root(self) -> bool:
return self.target.is_root()
def with_root(self, *certificates: SingleCertificate) -> CertificateChain:
return CertificateChain((*self._certificates, *certificates))
def __len__(self) -> int:
return len(self._certificates)
def __getitem__(self, key: Any) -> SingleCertificate:
return self._certificates[key]
def __iter__(self) -> Iterator[SingleCertificate]:
yield from self._certificates
def __repr__(self) -> str:
props = ", ".join([f"{k}={v!r}" for k, v in vars(self).items()])
return f"{type(self).__name__}({props})"
def _get_pem_content(self) -> str:
pem_contents = [c.pem_content for c in self._certificates]
return _join_pem_content(pem_contents)
def _get_der_content(self) -> str:
return ssl.PEM_cert_to_DER_cert(self.pem_content)
@functools.lru_cache(maxsize=None)
def get_system_ca_certificates(cls) -> Sequence[SingleCertificate]:
default_context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH)
certificates: List[SingleCertificate] = []
for der_content in default_context.get_ca_certs(binary_form=True):
certificate = SingleCertificate(der_content=der_content)
certificates.append(certificate)
return tuple(certificates)
def get_server_certificate_pem(
host: CertificateHostInfo | str,
*,
ssl_version: int | None = None,
ca_certs: str | pathlib.PurePath | None = None,
timeout: int | None = None,
) -> str:
if isinstance(host, str):
host = CertificateHostInfo.parse(host)
params = {}
if ssl_version is not None:
params["ssl_version"] = ssl_version
if ca_certs is not None:
params["ca_certs"] = str(ssl_version)
if timeout is not None:
params["timeout"] = timeout
hostname = (host.server_name, host.port)
pem_content = ssl.get_server_certificate(hostname, **params)
return pem_content
def get_server_certificate_chain(
host: CertificateHostInfo | str,
*,
ssl_version: int | None = None,
ca_certs: str | pathlib.PurePath | None = None,
timeout: int | None = None,
) -> CertificateChain:
if isinstance(host, str):
host = CertificateHostInfo.parse(host)
pem_content = get_server_certificate_pem(
host,
ssl_version=ssl_version,
ca_certs=ca_certs,
timeout=timeout,
)
certificate_bundle = CertificateBundle(
pem_content=pem_content,
host=host,
)
if certificate_bundle.is_proper_chain():
return certificate_bundle.to_chain()
return certificate_bundle[0].to_chain()
def _is_proper_certificate_chain(chain: Iterable[SingleCertificate]) -> bool:
chain = list(chain)
root_index = len(chain) - 1
if len(chain) > 1:
return all(
i == root_index or chain[i].issuer == chain[i + 1].subject
for i in range(len(chain))
)
if len(chain) == 1:
return chain[0].is_root()
return False
def _split_pem_content(content: str) -> Iterator[str]:
certs = _PEM_CERTIFICATE_RE.finditer(content)
for c in certs:
yield c.group("content")
def _join_pem_content(contents: str | Iterable[str]) -> str:
if isinstance(contents, str):
contents = [contents]
return re.sub(r"(\r?\n)+", "\n", "\n".join(contents))
def _download(url: str) -> bytes:
with urllib.request.urlopen(url) as opened_url:
content = opened_url.read()
return content
@contextmanager
def _open_temp_dir() -> Iterator[pathlib.Path]:
path = pathlib.Path(tempfile.gettempdir()) / f"scc-py-dir-{os.urandom(24).hex()}"
os.makedirs(path, exist_ok=False)
yield path
shutil.rmtree(path)
@contextmanager
def _open_temp_rw_text_file(
*, suffix: str | None = ""
) -> Iterator[Tuple[TextIO, pathlib.Path]]:
path = (
pathlib.Path(tempfile.gettempdir())
/ f"scc-py-file-{os.urandom(24).hex()}{suffix}"
)
with open(path, "w+", encoding="utf-8") as file:
yield file, path
if __name__ == "__main__":
## USAGE
sys.stdout.write(get_server_certificate_chain(sys.argv[1]).pem_content)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment