Skip to content

Instantly share code, notes, and snippets.

@autumnjolitz
Last active October 23, 2023 21:15
Show Gist options
  • Save autumnjolitz/56e3dd55266207e21575330372d5f32f to your computer and use it in GitHub Desktop.
Save autumnjolitz/56e3dd55266207e21575330372d5f32f to your computer and use it in GitHub Desktop.
Monkey patch a specific module in a way that is isolated to the owning scope
"""
Allow patching a requests.Session(...) at a specific module path with a certificate store,
handling the issues of OpenSSL requiring certificates to be at a subject name hash as well
as multiple certificates inside a provided X509 certificate file (like a chain of certificates).
This allows one to target a specific place, like ``snowflake.connector.network`` and replace it's
``requests``-module with our proxied one without affecting other users of ``requests.Session``.
"""
import collections
import hashlib
import importlib
import logging
import os
import os.path
import shutil
import types
import weakref
from contextlib import suppress
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Mapping, Any, NamedTuple, IO, Iterable
from cryptography import x509
from cryptography.hazmat.primitives.serialization import Encoding
from cryptography.hazmat.primitives import hashes
from OpenSSL.crypto import X509
try:
import certifi
except ImportError:
certifi = None
CA_CERTIFICATE_HOME: None | str | TemporaryDirectory = None
SESSIONS = weakref.WeakSet()
logger = logging.getLogger(__name__ if __name__ != "__main__" else __file__)
class CustomCASessionMixin:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
SESSIONS.add(self)
self._update_cacert_home()
def _update_cacert_home(self):
ca_cert_dir = CA_CERTIFICATE_HOME
if hasattr(ca_cert_dir, "name"):
ca_cert_dir = ca_cert_dir.name
self.verify = ca_cert_dir
def proxy_module(
in_module: types.ModuleType, **intercept_by_names: Mapping[str, Any]
) -> types.ModuleType:
if isinstance(in_module, ProxyModule):
in_module.__proxy_value_for_key__.update(intercept_by_names)
return in_module
return ProxyModule(in_module, **intercept_by_names)
class ProxyModule:
__slots__ = "__wrapped__", "__proxy_value_for_key__"
def __init__(
self, module: types.ModuleType, **proxy_value_for_key: Mapping[str, Any]
):
self.__wrapped__ = module
self.__proxy_value_for_key__ = proxy_value_for_key
def __dir__(self):
return dir(self.__wrapped__)
def __getattr__(self, name):
if name in ("__wrapped__", "__proxy_value_for_key__"):
return object.__getattribute__(self, name)
value = getattr(self.__wrapped__, name)
with suppress(KeyError):
return self.__proxy_value_for_key__[name]
return value
def setup_custom_certificate(
*,
cafile: Path | str | None = None,
cadata: bytes | None = None,
capath: Path | str | None = None,
init_default_certificates: bool = True,
ensure_openssl_certificate_names: bool = True,
):
"""
pick one of ``cafile``, ``cadata``, ``capath``.
Can be called multiple times.
Each call will ensure any open requests.Session are pointed
to CA_CERTIFICATE_HOME upon function call (i.e. we can add new certificates to live
Session objects for new connections).
init_default_certificates - copy from cacerts the default certificates into our storage
ensure_openssl_certificate_names - requests.Session (urllib3) will not use our certificates
unless they are named an openssl specific filename (in format of
$hash.0, $hash.1, ..., $hash.N) for N files that share the same Subject SHA1 hash.
This will copy any CA cert pem files into appropriately named ones.
"""
global CA_CERTIFICATE_HOME
name = None
if cafile:
if not isinstance(cafile, Path):
cafile = Path(cafile).resolve(True)
if not cafile.is_file():
raise FileNotFoundError(str(cafile))
elif capath:
if not isinstance(capath, Path):
capath = Path(capath).resolve(True)
if not capath.is_dir():
raise NotADirectoryError(str(capath))
else:
if not cadata:
raise ValueError("cadata or capath or cafile must be specified!")
if hasattr(cadata, "read"):
cadata = cadata.read()
if hasattr(cadata, "name"):
name = os.path.basename(cadata.name)
if not isinstance(cadata, bytes):
raise TypeError(f"cadata must be PEM-encoded bytes, not {type(cadata)!r}")
if CA_CERTIFICATE_HOME is None:
CA_CERTIFICATE_HOME = TemporaryDirectory()
certificate_home = CA_CERTIFICATE_HOME
if hasattr(certificate_home, "name"):
certificate_home = certificate_home.name
if not isinstance(certificate_home, Path):
certificate_home = Path(certificate_home)
if not certificate_home.exists():
certificate_home.makedirs()
certificate_home = certificate_home.resolve(True)
if init_default_certificates:
if certifi is not None:
default_ca_certs = certificate_home / os.path.basename(certifi.where())
if not default_ca_certs.exists():
with open(default_ca_certs, "wb") as fh:
fh.write(certifi.contents().encode())
else:
logger.error("certifi is not installed, not seeding CA certificates")
if cafile is not None:
if cafile.parent != certificate_home:
with open(certificate_home / cafile.name, "wb") as fh:
shutil.copyfileobj(cafile.open("rb"), fh)
elif capath is not None:
for file in capath.iterdir():
if file.parent != certificate_home:
shutil.copyfile(file, certificate_home / file.name)
elif cadata:
name = name or f"{hashlib.sha1(cadata).hexdigest()}.pem"
with open(certificate_home / name, "wb") as fh:
fh.write(cadata)
cert_hashes = collections.defaultdict(set)
fingerprints = {}
for file in certificate_home.iterdir():
for certificate in load_certificates_from(file):
cert_hashes[certificate.openssl_filename_prefix].add(file)
if file.name.startswith(certificate.openssl_filename_prefix):
fingerprints[certificate.fingerprint] = file
continue
if certificate.fingerprint not in fingerprints:
num_distinct_certificates = len(
cert_hashes[certificate.openssl_filename_prefix]
)
current_count = num_distinct_certificates - 1
new_name = f"{certificate.openssl_filename_prefix}.{current_count}"
openssl_filename = file.with_name(new_name)
with open(openssl_filename, "wb") as fh:
fh.write(certificate.content)
fingerprints[certificate.fingerprint] = openssl_filename
for instance in SESSIONS:
instance._update_cacert_home()
class CertStoreHash(NamedTuple):
openssl_subject_name_hash: str
sha1_fingerprint: str
class Certificate(NamedTuple):
file: Path | None
cert_index: int
certificate: x509.Certificate
content: bytes
hashes: CertStoreHash
multiple_certificates: bool
@property
def fingerprint(self) -> str:
return self.hashes.sha1_fingerprint
@property
def openssl_filename_prefix(self) -> str:
return self.hashes.openssl_subject_name_hash
def load_certificates_from(file: bytes | Path | IO[bytes]) -> Iterable[Certificate]:
"""
A file may have multiple certificates. Unforuntately most implementations
will choose the very first one, which makes this hard to debug when
you provide N certificates in a blob of file.
This splits it out.
"""
if hasattr(file, "read"):
content = file.read()
file = Path(file.name)
elif hasattr(file, "read_bytes"):
content = file.read_bytes()
else:
content = file
file = None
if hasattr(content, "encode"):
content = content.encode()
certs = x509.load_pem_x509_certificates(content)
multicert = len(certs) > 1
for index, cert in enumerate(certs):
fingerprint = cert.fingerprint(hashes.SHA1())
openssl_subject_hash: int = X509.from_cryptography(cert).subject_name_hash()
pem = cert.public_bytes(Encoding.PEM)
yield Certificate(
file,
index,
cert,
pem,
CertStoreHash(f"{hex(openssl_subject_hash)[2:]}", fingerprint),
multicert,
)
def patch_requests_in(module: types.ModuleType, name: str = "requests"):
"""
Given a specific module that ``import requests``, replace it with a proxied module.
"""
try:
target = getattr(module, name)
except AttributeError:
raise AttributeError(f"Unable to access {module.__name__}.{name}") from None
setattr(
module, name, proxy_module(target, Session=make_session_cls(target.Session))
)
def patch_requests_session_in(module: types.ModuleType, name: str = "Session"):
try:
session_cls = getattr(module, name)
except AttributeError:
raise AttributeError(f"Session not found in {module!r}") from None
if issubclass(session_cls, CustomCASessionMixin):
return session_cls
setattr(module, "Session", make_session_cls(session_cls))
def make_session_cls(session_cls):
CustomCASession = type(
"CustomCASession",
(
CustomCASessionMixin,
session_cls,
),
{},
)
return CustomCASession
def patch(path: types.ModuleType | str):
"""
Patch a module at path.
Examples:
``patch("snowflake.connector.network")``
``patch("snowflake.connector.network.requests")``
``patch(snowflake.connector.network)``
"""
if isinstance(path, str):
if path.endswith(".requests"):
path, name = path.rsplit(".", 1)
return patch_requests_in(importlib.import_module(path), name)
elif path.endswith(".Session"):
path, name = path.rsplit(".", 1)
module = importlib.import_module(path)
return patch_requests_session_in(module)
module = importlib.import_module(path)
else:
module = path
if hasattr(module, "Session"):
if not hasattr(module.Session, "request"):
raise TypeError("Not a requests.Session class")
patch_requests_session_in(module)
if hasattr(module, "requests"):
patch_requests_in(module)
__all__ = "patch", "setup_custom_certificate"
if __name__ == "__main__":
import argparse
# import snowflake.connector.network
# patch(snowflake.connector.network)
patch("snowflake.connector.network")
parser = argparse.ArgumentParser()
parser.add_argument("capath", type=Path)
parser.add_argument("url", type=str)
args = parser.parse_args()
kwargs = {}
if args.capath.is_dir():
kwargs["capath"] = args.capath
elif args.capath.is_file():
kwargs["cafile"] = args.capath
from snowflake.connector import network
setup_custom_certificate(**kwargs, init_default_certificates=False)
with network.requests.Session() as session:
print("I am using these Certs", os.listdir(session.verify))
print(session.get(args.url))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment