Skip to content

Instantly share code, notes, and snippets.

@jborean93
Last active February 22, 2024 20:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jborean93/28a3e44e3645d0ba56ad876adf33164a to your computer and use it in GitHub Desktop.
Save jborean93/28a3e44e3645d0ba56ad876adf33164a to your computer and use it in GitHub Desktop.
A test HTTP server with TLS enabled to test out some TLS behaviour for web based commands
#!/usr/bin/env python
"""Test TLS Enabled Web Server
A script that can start a temporary TLS enabled web server. This server
supports a basic GET request and will return metadata on the request from the
client. By default it will create an ephemeral certificate when starting up but
a custom certificate can be provided. Also supports client authentication by
providing a CA bundle to use for verification or using --tls-client-auth to
generate a new set of keys.
"""
from __future__ import annotations
import argparse
import datetime
import http.server
import json
import os
import os.path
import pathlib
import socket
import ssl
import sys
import typing as t
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec, rsa, types
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.hazmat.primitives.serialization import (
BestAvailableEncryption,
Encoding,
NoEncryption,
PrivateFormat,
pkcs12,
)
from cryptography.x509.oid import ExtendedKeyUsageOID
HAS_ARGCOMPLETE = True
try:
import argcomplete
except ImportError:
HAS_ARGCOMPLETE = False
FILE_NAME = pathlib.Path(__file__).stem
class HTTPHandler(http.server.BaseHTTPRequestHandler):
def do_GET(self):
self.send_response(200)
self.send_header("Content-type", "application/json; charset=utf-8")
self.end_headers()
cipher = self.connection.cipher()
tls_info = {
"protocol": cipher[1],
"cipher": cipher[0],
"client_cert": None,
}
b_peer_cert = self.connection.getpeercert(binary_form=True)
if b_peer_cert:
peer_cert = x509.load_der_x509_certificate(b_peer_cert)
tls_info["client_cert"] = peer_cert.subject.rfc4514_string()
print(f"TLS Client {tls_info}")
data = {
"tls": tls_info,
"request_headers": dict(self.headers),
}
self.wfile.write(json.dumps(data).encode("utf-8"))
def parse_args(argv: list[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(
prog="tls_server.py",
description="Test TLS HTTP Server in Python",
)
parse_path = lambda v: pathlib.Path(os.path.expanduser(os.path.expandvars(v)))
parser.add_argument(
"--tls-cert",
action="store",
type=parse_path,
help="Path to PEM encoded certificate with option embedded key, will use self signed certificate if not set.",
)
parser.add_argument(
"--tls-key",
action="store",
type=parse_path,
help="Path to PEM encoded key for the certificate if not present in --tls-cert.",
)
parser.add_argument(
"--tls-key-pass",
action="store",
type=str,
help="The password needed to decrypt the TLS key provided, can be omitted if the key is not encrypted.",
)
client_ca = parser.add_mutually_exclusive_group()
client_ca.add_argument(
"--tls-client-ca",
action="store",
type=parse_path,
help="Path to a TLS CA bundle file or directory to use with identifying the client. This enforces client cert authentication if set.",
)
client_ca.add_argument(
"--tls-client-auth",
action="store_true",
help="Require TLS Client authentication through pre-generated certificates next to this script",
)
parser.add_argument(
"--tls-min-protocol",
action="store",
choices=["default", "tlsv1_2", "tlsv1_3"],
default="default",
type=str.lower,
help="The minimum TLS protocol to allow, the default is the default for Python.",
)
parser.add_argument(
"--tls-max-protocol",
action="store",
choices=["default", "tlsv1_2", "tlsv1_3"],
default="default",
type=str.lower,
help="The maximum TLS protocol to allow, the default is the default for Python.",
)
parser.add_argument(
"--tls-ciphers",
action="store",
default="",
type=str,
help="The TLS cipher suites to allow in the format of the OpenSSL cipher list string, this cannot restrict ciphers in TLS 1.3.",
)
parser.add_argument(
"--port",
action="store",
type=int,
default=0,
help="The port to listen on, defaults to an ephemeral port available on the host",
)
if HAS_ARGCOMPLETE:
argcomplete.autocomplete(parser)
return parser.parse_args(argv)
def generate_cert(
subject: str,
*,
issuer: (
tuple[x509.Certificate, types.CertificateIssuerPrivateKeyTypes] | None
) = None,
key_type: t.Literal["rsa", "ecdsa"] = "rsa",
extensions: list[tuple[x509.ExtensionType, bool]] | None = None,
) -> tuple[x509.Certificate, types.CertificateIssuerPrivateKeyTypes]:
private_key: types.PrivateKeyTypes
if key_type == "rsa":
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend(),
)
else:
private_key = ec.generate_private_key(
curve=ec.SECP384R1(),
)
subject_name = x509.Name(
[
x509.NameAttribute(x509.NameOID.COUNTRY_NAME, "Au"),
x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, "State"),
x509.NameAttribute(x509.NameOID.LOCALITY_NAME, "City"),
x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, "Organization"),
x509.NameAttribute(x509.NameOID.COMMON_NAME, subject),
]
)
issuer_name = subject_name
sign_key: types.CertificateIssuerPrivateKeyTypes = private_key
if issuer:
issuer_name = issuer[0].subject
sign_key = issuer[1]
now = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=1)
builder = x509.CertificateBuilder()
builder = (
x509.CertificateBuilder()
.subject_name(subject_name)
.issuer_name(issuer_name)
.public_key(private_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(now + datetime.timedelta(days=365))
)
if extensions:
for ext, critical in extensions:
builder = builder.add_extension(ext, critical)
return builder.sign(sign_key, SHA256()), private_key
def serialize_cert(
cert: x509.Certificate,
key: types.CertificateIssuerPrivateKeyTypes,
path: pathlib.Path,
*,
key_password: bytes | None = None,
cert_only: bool = False,
generate_pfx: bool = False,
) -> None:
b_pub_key = cert.public_bytes(Encoding.PEM)
b_key = b""
if not cert_only:
encryption_algorithm = (
BestAvailableEncryption(key_password) if key_password else NoEncryption()
)
b_key = key.private_bytes(
encoding=Encoding.PEM,
format=PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=encryption_algorithm,
)
with open(path, mode="wb") as fd:
if b_key:
fd.write(b_key)
fd.write(b_pub_key)
if generate_pfx:
b_pfx = pkcs12.serialize_key_and_certificates(
cert.subject.rfc4514_string().encode(),
key,
cert,
None,
BestAvailableEncryption(key_password or b"password"),
)
pfx_path = path.with_suffix(".pfx")
with open(pfx_path, mode="wb") as fd:
fd.write(b_pfx)
def create_tls_context(
args: argparse.Namespace,
) -> tuple[ssl.SSLContext, list[pathlib.Path]]:
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
if (min_protocol := args.tls_min_protocol) != "default":
for tls_version in ssl.TLSVersion:
if tls_version.name.lower() == min_protocol.lower():
context.minimum_version = tls_version
break
else:
raise ValueError(f"Unknown --tls-min-protocol '{min_protocol}' specified")
if (max_protocol := args.tls_max_protocol) != "default":
for tls_version in ssl.TLSVersion:
if tls_version.name.lower() == max_protocol.lower():
context.maximum_version = tls_version
break
else:
raise ValueError(f"Unknown --tls-max-protocol '{max_protocol}' specified")
if args.tls_ciphers:
context.set_ciphers(args.tls_ciphers)
my_ca: (
tuple[x509.Certificate, types.CertificateIssuerPrivateKeyTypes, pathlib.Path]
| None
) = None
temp_files: list[pathlib.Path] = []
def generate_ca() -> (
tuple[x509.Certificate, types.CertificateIssuerPrivateKeyTypes, pathlib.Path]
):
my_ca = generate_cert(
"TlsWebServerCA",
extensions=[(x509.BasicConstraints(ca=True, path_length=None), True)],
)
ca_path = pathlib.Path(__file__).parent / f"{FILE_NAME}_ca.pem"
serialize_cert(
my_ca[0],
my_ca[1],
ca_path,
cert_only=True,
)
temp_files.append(ca_path)
return my_ca[0], my_ca[1], ca_path
if args.tls_client_ca:
context.verify_mode = ssl.VerifyMode.CERT_REQUIRED
tls_client_ca = t.cast(pathlib.Path, args.tls_client_ca)
if tls_client_ca.is_dir():
context.load_verify_locations(capath=str(tls_client_ca.absolute()))
elif tls_client_ca.exists():
context.load_verify_locations(cafile=str(tls_client_ca.absolute()))
else:
raise ValueError(
f"Certificate CA verify path '{tls_client_ca}' does not exist"
)
elif args.tls_client_auth:
context.verify_mode = ssl.VerifyMode.CERT_REQUIRED
my_ca = generate_ca()
client_cert = generate_cert(
"TlsWebServerClient",
issuer=(my_ca[0], my_ca[1]),
extensions=[
(
x509.KeyUsage(
digital_signature=True,
content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
key_cert_sign=False,
crl_sign=False,
encipher_only=False,
decipher_only=False,
),
True,
),
(x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]), False),
],
)
client_ca_path = pathlib.Path(__file__).parent / f"{FILE_NAME}_client.pem"
serialize_cert(
client_cert[0],
client_cert[1],
client_ca_path,
generate_pfx=True,
)
context.load_verify_locations(cafile=str(my_ca[2].absolute()))
temp_files.append(client_ca_path)
temp_files.append(client_ca_path.with_suffix(".pfx"))
if not args.tls_cert:
if not my_ca:
my_ca = generate_ca()
hostname = socket.gethostname()
san = x509.SubjectAlternativeName(
[
x509.DNSName(hostname),
x509.DNSName("localhost"),
]
)
for key_type in ["rsa", "ecdsa"]:
cert, key = generate_cert(
hostname,
issuer=(my_ca[0], my_ca[1]),
extensions=[(san, False)],
key_type=key_type, # type: ignore[arg-type] # This is the literal string
)
tls_key_pass = os.urandom(32)
temp_cert = (
pathlib.Path(__file__).parent / f"tls_server_temp_cert_{key_type}.pem"
)
try:
serialize_cert(
cert,
key,
temp_cert,
key_password=tls_key_pass,
)
context.load_cert_chain(
certfile=str(temp_cert.absolute()),
password=tls_key_pass,
)
finally:
temp_cert.unlink(missing_ok=True)
else:
context.load_cert_chain(
certfile=str(args.tls_cert.absolute()),
keyfile=str(args.tls_key.absolute()) if args.tls_key else None,
password=args.tls_key_pass,
)
return context, temp_files
def main() -> None:
args = parse_args(sys.argv[1:])
tls_context, temp_files = create_tls_context(args)
try:
httpd = http.server.HTTPServer(("", args.port), HTTPHandler)
httpd.socket = tls_context.wrap_socket(
httpd.socket,
server_side=True,
)
print(f"Listening on {httpd.server_address}")
httpd.serve_forever()
finally:
for file in temp_files:
file.unlink(missing_ok=True)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment