Last active
October 23, 2023 21:15
-
-
Save autumnjolitz/56e3dd55266207e21575330372d5f32f to your computer and use it in GitHub Desktop.
Monkey patch a specific module in a way that is isolated to the owning scope
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Allow patching a requests.Session(...) at a specific module path with a certificate store, | |
handling the issues of OpenSSL requiring certificates to be at a subject name hash as well | |
as multiple certificates inside a provided X509 certificate file (like a chain of certificates). | |
This allows one to target a specific place, like ``snowflake.connector.network`` and replace it's | |
``requests``-module with our proxied one without affecting other users of ``requests.Session``. | |
""" | |
import collections | |
import hashlib | |
import importlib | |
import logging | |
import os | |
import os.path | |
import shutil | |
import types | |
import weakref | |
from contextlib import suppress | |
from pathlib import Path | |
from tempfile import TemporaryDirectory | |
from typing import Mapping, Any, NamedTuple, IO, Iterable | |
from cryptography import x509 | |
from cryptography.hazmat.primitives.serialization import Encoding | |
from cryptography.hazmat.primitives import hashes | |
from OpenSSL.crypto import X509 | |
try: | |
import certifi | |
except ImportError: | |
certifi = None | |
CA_CERTIFICATE_HOME: None | str | TemporaryDirectory = None | |
SESSIONS = weakref.WeakSet() | |
logger = logging.getLogger(__name__ if __name__ != "__main__" else __file__) | |
class CustomCASessionMixin: | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
SESSIONS.add(self) | |
self._update_cacert_home() | |
def _update_cacert_home(self): | |
ca_cert_dir = CA_CERTIFICATE_HOME | |
if hasattr(ca_cert_dir, "name"): | |
ca_cert_dir = ca_cert_dir.name | |
self.verify = ca_cert_dir | |
def proxy_module( | |
in_module: types.ModuleType, **intercept_by_names: Mapping[str, Any] | |
) -> types.ModuleType: | |
if isinstance(in_module, ProxyModule): | |
in_module.__proxy_value_for_key__.update(intercept_by_names) | |
return in_module | |
return ProxyModule(in_module, **intercept_by_names) | |
class ProxyModule: | |
__slots__ = "__wrapped__", "__proxy_value_for_key__" | |
def __init__( | |
self, module: types.ModuleType, **proxy_value_for_key: Mapping[str, Any] | |
): | |
self.__wrapped__ = module | |
self.__proxy_value_for_key__ = proxy_value_for_key | |
def __dir__(self): | |
return dir(self.__wrapped__) | |
def __getattr__(self, name): | |
if name in ("__wrapped__", "__proxy_value_for_key__"): | |
return object.__getattribute__(self, name) | |
value = getattr(self.__wrapped__, name) | |
with suppress(KeyError): | |
return self.__proxy_value_for_key__[name] | |
return value | |
def setup_custom_certificate( | |
*, | |
cafile: Path | str | None = None, | |
cadata: bytes | None = None, | |
capath: Path | str | None = None, | |
init_default_certificates: bool = True, | |
ensure_openssl_certificate_names: bool = True, | |
): | |
""" | |
pick one of ``cafile``, ``cadata``, ``capath``. | |
Can be called multiple times. | |
Each call will ensure any open requests.Session are pointed | |
to CA_CERTIFICATE_HOME upon function call (i.e. we can add new certificates to live | |
Session objects for new connections). | |
init_default_certificates - copy from cacerts the default certificates into our storage | |
ensure_openssl_certificate_names - requests.Session (urllib3) will not use our certificates | |
unless they are named an openssl specific filename (in format of | |
$hash.0, $hash.1, ..., $hash.N) for N files that share the same Subject SHA1 hash. | |
This will copy any CA cert pem files into appropriately named ones. | |
""" | |
global CA_CERTIFICATE_HOME | |
name = None | |
if cafile: | |
if not isinstance(cafile, Path): | |
cafile = Path(cafile).resolve(True) | |
if not cafile.is_file(): | |
raise FileNotFoundError(str(cafile)) | |
elif capath: | |
if not isinstance(capath, Path): | |
capath = Path(capath).resolve(True) | |
if not capath.is_dir(): | |
raise NotADirectoryError(str(capath)) | |
else: | |
if not cadata: | |
raise ValueError("cadata or capath or cafile must be specified!") | |
if hasattr(cadata, "read"): | |
cadata = cadata.read() | |
if hasattr(cadata, "name"): | |
name = os.path.basename(cadata.name) | |
if not isinstance(cadata, bytes): | |
raise TypeError(f"cadata must be PEM-encoded bytes, not {type(cadata)!r}") | |
if CA_CERTIFICATE_HOME is None: | |
CA_CERTIFICATE_HOME = TemporaryDirectory() | |
certificate_home = CA_CERTIFICATE_HOME | |
if hasattr(certificate_home, "name"): | |
certificate_home = certificate_home.name | |
if not isinstance(certificate_home, Path): | |
certificate_home = Path(certificate_home) | |
if not certificate_home.exists(): | |
certificate_home.makedirs() | |
certificate_home = certificate_home.resolve(True) | |
if init_default_certificates: | |
if certifi is not None: | |
default_ca_certs = certificate_home / os.path.basename(certifi.where()) | |
if not default_ca_certs.exists(): | |
with open(default_ca_certs, "wb") as fh: | |
fh.write(certifi.contents().encode()) | |
else: | |
logger.error("certifi is not installed, not seeding CA certificates") | |
if cafile is not None: | |
if cafile.parent != certificate_home: | |
with open(certificate_home / cafile.name, "wb") as fh: | |
shutil.copyfileobj(cafile.open("rb"), fh) | |
elif capath is not None: | |
for file in capath.iterdir(): | |
if file.parent != certificate_home: | |
shutil.copyfile(file, certificate_home / file.name) | |
elif cadata: | |
name = name or f"{hashlib.sha1(cadata).hexdigest()}.pem" | |
with open(certificate_home / name, "wb") as fh: | |
fh.write(cadata) | |
cert_hashes = collections.defaultdict(set) | |
fingerprints = {} | |
for file in certificate_home.iterdir(): | |
for certificate in load_certificates_from(file): | |
cert_hashes[certificate.openssl_filename_prefix].add(file) | |
if file.name.startswith(certificate.openssl_filename_prefix): | |
fingerprints[certificate.fingerprint] = file | |
continue | |
if certificate.fingerprint not in fingerprints: | |
num_distinct_certificates = len( | |
cert_hashes[certificate.openssl_filename_prefix] | |
) | |
current_count = num_distinct_certificates - 1 | |
new_name = f"{certificate.openssl_filename_prefix}.{current_count}" | |
openssl_filename = file.with_name(new_name) | |
with open(openssl_filename, "wb") as fh: | |
fh.write(certificate.content) | |
fingerprints[certificate.fingerprint] = openssl_filename | |
for instance in SESSIONS: | |
instance._update_cacert_home() | |
class CertStoreHash(NamedTuple): | |
openssl_subject_name_hash: str | |
sha1_fingerprint: str | |
class Certificate(NamedTuple): | |
file: Path | None | |
cert_index: int | |
certificate: x509.Certificate | |
content: bytes | |
hashes: CertStoreHash | |
multiple_certificates: bool | |
@property | |
def fingerprint(self) -> str: | |
return self.hashes.sha1_fingerprint | |
@property | |
def openssl_filename_prefix(self) -> str: | |
return self.hashes.openssl_subject_name_hash | |
def load_certificates_from(file: bytes | Path | IO[bytes]) -> Iterable[Certificate]: | |
""" | |
A file may have multiple certificates. Unforuntately most implementations | |
will choose the very first one, which makes this hard to debug when | |
you provide N certificates in a blob of file. | |
This splits it out. | |
""" | |
if hasattr(file, "read"): | |
content = file.read() | |
file = Path(file.name) | |
elif hasattr(file, "read_bytes"): | |
content = file.read_bytes() | |
else: | |
content = file | |
file = None | |
if hasattr(content, "encode"): | |
content = content.encode() | |
certs = x509.load_pem_x509_certificates(content) | |
multicert = len(certs) > 1 | |
for index, cert in enumerate(certs): | |
fingerprint = cert.fingerprint(hashes.SHA1()) | |
openssl_subject_hash: int = X509.from_cryptography(cert).subject_name_hash() | |
pem = cert.public_bytes(Encoding.PEM) | |
yield Certificate( | |
file, | |
index, | |
cert, | |
pem, | |
CertStoreHash(f"{hex(openssl_subject_hash)[2:]}", fingerprint), | |
multicert, | |
) | |
def patch_requests_in(module: types.ModuleType, name: str = "requests"): | |
""" | |
Given a specific module that ``import requests``, replace it with a proxied module. | |
""" | |
try: | |
target = getattr(module, name) | |
except AttributeError: | |
raise AttributeError(f"Unable to access {module.__name__}.{name}") from None | |
setattr( | |
module, name, proxy_module(target, Session=make_session_cls(target.Session)) | |
) | |
def patch_requests_session_in(module: types.ModuleType, name: str = "Session"): | |
try: | |
session_cls = getattr(module, name) | |
except AttributeError: | |
raise AttributeError(f"Session not found in {module!r}") from None | |
if issubclass(session_cls, CustomCASessionMixin): | |
return session_cls | |
setattr(module, "Session", make_session_cls(session_cls)) | |
def make_session_cls(session_cls): | |
CustomCASession = type( | |
"CustomCASession", | |
( | |
CustomCASessionMixin, | |
session_cls, | |
), | |
{}, | |
) | |
return CustomCASession | |
def patch(path: types.ModuleType | str): | |
""" | |
Patch a module at path. | |
Examples: | |
``patch("snowflake.connector.network")`` | |
``patch("snowflake.connector.network.requests")`` | |
``patch(snowflake.connector.network)`` | |
""" | |
if isinstance(path, str): | |
if path.endswith(".requests"): | |
path, name = path.rsplit(".", 1) | |
return patch_requests_in(importlib.import_module(path), name) | |
elif path.endswith(".Session"): | |
path, name = path.rsplit(".", 1) | |
module = importlib.import_module(path) | |
return patch_requests_session_in(module) | |
module = importlib.import_module(path) | |
else: | |
module = path | |
if hasattr(module, "Session"): | |
if not hasattr(module.Session, "request"): | |
raise TypeError("Not a requests.Session class") | |
patch_requests_session_in(module) | |
if hasattr(module, "requests"): | |
patch_requests_in(module) | |
__all__ = "patch", "setup_custom_certificate" | |
if __name__ == "__main__": | |
import argparse | |
# import snowflake.connector.network | |
# patch(snowflake.connector.network) | |
patch("snowflake.connector.network") | |
parser = argparse.ArgumentParser() | |
parser.add_argument("capath", type=Path) | |
parser.add_argument("url", type=str) | |
args = parser.parse_args() | |
kwargs = {} | |
if args.capath.is_dir(): | |
kwargs["capath"] = args.capath | |
elif args.capath.is_file(): | |
kwargs["cafile"] = args.capath | |
from snowflake.connector import network | |
setup_custom_certificate(**kwargs, init_default_certificates=False) | |
with network.requests.Session() as session: | |
print("I am using these Certs", os.listdir(session.verify)) | |
print(session.get(args.url)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment