Created
May 30, 2024 08:58
-
-
Save dmage/77ebddf9d6320ad02e256f25121f71b7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import os | |
import socket | |
import sys | |
from hashlib import sha1, sha256 | |
from Crypto.Cipher import AES # provided by the package pycryptodome | |
from pyasn1.codec.der import decoder | |
from pyasn1_modules import rfc2437, rfc2459 | |
CLIENT_RANDOM = os.urandom(32) | |
SERVER_HELLO = { | |
# random | |
} | |
PUBLIC_KEY = { | |
# modulus | |
# publicExponent | |
} | |
CIPHER = { | |
"seq_num": 0, | |
# client_write_MAC_key | |
# server_write_MAC_key | |
# client_write_key | |
# server_write_key | |
} | |
# https://datatracker.ietf.org/doc/html/rfc5246#appendix-C | |
MAC_LENGTH = 20 # mac_length for MAC SHA | |
MAC_KEY_LENGTH = 20 # mac_key_length for MAC SHA | |
ENC_KEY_LENGTH = 16 # Key Material for Cipher AES_128_CBC | |
FIXED_IV_LENGTH = 16 # IV Size for Cipher AES_128_CBC | |
BLOCK_SIZE = 16 # Block Size for Cipher AES_128_CBC | |
def non_zero_random_bytes(n): | |
b = list(os.urandom(n)) | |
for i in range(0, len(b)): | |
while b[i] == 0: | |
b[i] = int.from_bytes(os.urandom(1)) | |
return bytes(b) | |
def rsa_encrypt(plaintext): | |
k = (PUBLIC_KEY["modulus"].bit_length() + 7) // 8 | |
PS = non_zero_random_bytes(k - 3 - len(plaintext)) | |
EB = b"\x00\x02" + PS + b"\x00" + plaintext | |
assert len(EB) == k | |
encrypted = pow(int.from_bytes(EB, byteorder="big"), PUBLIC_KEY["publicExponent"], PUBLIC_KEY["modulus"]) | |
return encrypted.to_bytes(k, byteorder="big") | |
def HMAC_SHA1(key, text): | |
if len(key) > 64: | |
key = sha1(key).digest() | |
if len(key) < 64: | |
key += b"\x00" * (64 - len(key)) | |
key_ipad = bytes(key[i] ^ 0x36 for i in range(64)) | |
key_opad = bytes(key[i] ^ 0x5C for i in range(64)) | |
return sha1(key_opad + sha1(key_ipad + text).digest()).digest() | |
def HMAC_SHA256(key, text): | |
if len(key) > 64: | |
key = sha256(key).digest() | |
if len(key) < 64: | |
key += b"\x00" * (64 - len(key)) | |
ipad_key = bytes(key[i] ^ 0x36 for i in range(64)) | |
opad_key = bytes(key[i] ^ 0x5C for i in range(64)) | |
return sha256(opad_key + sha256(ipad_key + text).digest()).digest() | |
def PRF(secret: bytes, label: bytes, seed: bytes, size: int) -> bytes: | |
a = label + seed | |
out = b"" | |
while len(out) < size: | |
a = HMAC_SHA256(secret, a) | |
out += HMAC_SHA256(secret, a + label + seed) | |
return out[:size] | |
def MAC(key, seq_num, typ, length, data): | |
version = b"\x03\x03" | |
return HMAC_SHA1(key, seq_num.to_bytes(8, "big") + typ.to_bytes(1, "big") + version + length.to_bytes(2, "big") + data) | |
def construct_tls_plaintext_record(content_type, body): | |
return bytes([content_type]) + b"\x03\x03" + len(body).to_bytes(2, byteorder="big") + body | |
def has_complete_tls_record(buf): | |
if len(buf) < 5: | |
return False | |
length = (buf[3] << 8) + buf[4] | |
return len(buf) >= 5 + length | |
def get_tls_plaintext_record(buf): | |
length = (buf[3] << 8) + buf[4] | |
assert len(buf) >= 5 + length | |
return buf[0], buf[5 : 5 + length], buf[5 + length :] | |
def construct_tls_ciphertext_record(content_type, content): | |
mac = MAC(CIPHER["client_write_MAC_key"], CIPHER["seq_num"], content_type, len(content), content) | |
assert len(mac) == MAC_LENGTH | |
padding_length = BLOCK_SIZE - (len(content) + MAC_LENGTH + 1) % BLOCK_SIZE | |
padding = bytes([padding_length] * padding_length) | |
block = content + mac + bytes([padding_length]) + padding | |
iv = os.urandom(FIXED_IV_LENGTH) | |
aes = AES.new(CIPHER["client_write_key"], AES.MODE_CBC, iv) | |
encrypted_block = aes.encrypt(block) | |
fragment = iv + encrypted_block | |
CIPHER["seq_num"] += 1 | |
return construct_tls_plaintext_record(content_type, fragment) | |
def get_tls_ciphertext_record(buf): | |
content_type, fragment, buf = get_tls_plaintext_record(buf) | |
iv, encrypted_block = fragment[:FIXED_IV_LENGTH], fragment[FIXED_IV_LENGTH:] | |
aes = AES.new(CIPHER["server_write_key"], AES.MODE_CBC, iv) | |
block = aes.decrypt(encrypted_block) | |
padding_length = block[-1] | |
return content_type, block[: -padding_length - MAC_LENGTH - 1], buf | |
def construct_handshake(msg_type, body): | |
return bytes([msg_type]) + len(body).to_bytes(3, byteorder="big") + body | |
def get_handshake_body(buf): | |
length = (buf[1] << 16) + (buf[2] << 8) + buf[3] | |
assert len(buf) == 4 + length | |
return buf[0], buf[4 : 4 + length], buf[4 + length :] | |
def construct_client_hello_handshake(): | |
print("Client random:", CLIENT_RANDOM.hex(), file=sys.stderr) | |
return construct_handshake(1, b"\x03\x03" + CLIENT_RANDOM + b"\x00\x00\x02\x00\x2f\x01\x00") | |
def handle_handshake_server_hello(body): | |
assert body[0] == 3 and body[1] == 3, "Only TLS 1.2 is supported" | |
SERVER_HELLO["random"] = body[2 : 2 + 32] | |
print("Server random:", SERVER_HELLO["random"].hex(), file=sys.stderr) | |
def handle_handshake_certificate(body): | |
certs_length = (body[0] << 16) + (body[1] << 8) + body[2] | |
assert len(body) == 3 + certs_length | |
body = body[3:] # body is a list of certificates, each prefixed with a 3-byte length | |
# Extract the first (server) certificate | |
server_cert_length = (body[0] << 16) + (body[1] << 8) + body[2] | |
assert len(body) >= 3 + server_cert_length | |
raw_server_cert = body[3 : 3 + server_cert_length] | |
# Extract the public key from the server certificate | |
server_cert, _ = decoder.decode(raw_server_cert, asn1Spec=rfc2459.Certificate()) | |
subjectPublicKey = ( | |
server_cert.getComponentByName("tbsCertificate").getComponentByName("subjectPublicKeyInfo").getComponentByName("subjectPublicKey").asOctets() | |
) | |
public_key, _ = decoder.decode(subjectPublicKey, asn1Spec=rfc2437.RSAPublicKey()) | |
PUBLIC_KEY["modulus"] = int(public_key.getComponentByName("modulus")) | |
PUBLIC_KEY["publicExponent"] = int(public_key.getComponentByName("publicExponent")) | |
print("Public key:", file=sys.stderr) | |
print(" Modulus: {:x}".format(PUBLIC_KEY["modulus"]), file=sys.stderr) | |
print(" Public exponent: {:x}".format(PUBLIC_KEY["publicExponent"]), file=sys.stderr) | |
def handle_handshake_server_hello_done(body): | |
assert len(body) == 0, "ServerHelloDone should not have body" | |
def construct_client_key_exchange_handshake(encrypted_pre_master_secret): | |
body = len(encrypted_pre_master_secret).to_bytes(2, byteorder="big") + encrypted_pre_master_secret | |
return construct_handshake(16, body) | |
def construct_finished_handshake(verify_data): | |
return construct_handshake(20, verify_data) | |
def main(): | |
hostname = "example.com" | |
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
s.connect((hostname, 443)) | |
handshake_messages = b"" | |
# Send ClientHello | |
client_hello_handshake = construct_client_hello_handshake() | |
handshake_messages += client_hello_handshake | |
s.send(construct_tls_plaintext_record(22, client_hello_handshake)) | |
# Receive ServerHello, Certificate, ServerHelloDone | |
buf = s.recv(2**14) | |
while buf: | |
content_type, fragment, buf = get_tls_plaintext_record(buf) | |
assert content_type == 22, "Got {}, want 22 (handshake)".format(content_type) | |
handshake_messages += fragment | |
handshake_type, body, _ = get_handshake_body(fragment) | |
if handshake_type == 2: | |
handle_handshake_server_hello(body) | |
elif handshake_type == 11: | |
handle_handshake_certificate(body) | |
elif handshake_type == 14: | |
handle_handshake_server_hello_done(body) | |
assert not buf, "ServerHelloDone should be the last message" | |
else: | |
raise NotImplementedError("Unsupported handshake: msg_type={}, length={}, body={}".format(handshake_type, len(body), body.hex())) | |
# Send ClientKeyExchange | |
pre_master_secret = b"\x03\x03" + os.urandom(46) | |
print("Pre master secret:", pre_master_secret.hex(), file=sys.stderr) | |
client_key_exchange_handshake = construct_client_key_exchange_handshake(rsa_encrypt(pre_master_secret)) | |
handshake_messages += client_key_exchange_handshake | |
s.send(construct_tls_plaintext_record(22, client_key_exchange_handshake)) | |
master_secret = PRF(pre_master_secret, b"master secret", CLIENT_RANDOM + SERVER_HELLO["random"], 48) | |
print("Master secret:", master_secret.hex(), file=sys.stderr) | |
key_block = PRF(master_secret, b"key expansion", SERVER_HELLO["random"] + CLIENT_RANDOM, 2 * MAC_KEY_LENGTH + 2 * ENC_KEY_LENGTH) | |
CIPHER["client_write_MAC_key"], key_block = key_block[:MAC_KEY_LENGTH], key_block[MAC_KEY_LENGTH:] | |
CIPHER["server_write_MAC_key"], key_block = key_block[:MAC_KEY_LENGTH], key_block[MAC_KEY_LENGTH:] | |
CIPHER["client_write_key"], key_block = key_block[:ENC_KEY_LENGTH], key_block[ENC_KEY_LENGTH:] | |
CIPHER["server_write_key"], key_block = key_block[:ENC_KEY_LENGTH], key_block[ENC_KEY_LENGTH:] | |
assert len(key_block) == 0, "Key block has trailing bytes: {!r}".format(key_block) | |
print("Client write MAC key:", CIPHER["client_write_MAC_key"].hex(), file=sys.stderr) | |
print("Server write MAC key:", CIPHER["server_write_MAC_key"].hex(), file=sys.stderr) | |
print("Client write key:", CIPHER["client_write_key"].hex(), file=sys.stderr) | |
print("Server write key:", CIPHER["server_write_key"].hex(), file=sys.stderr) | |
# Send ChangeCipherSpec | |
s.send(construct_tls_plaintext_record(20, b"\x01")) | |
# Send Finished | |
verify_data = PRF(master_secret, b"client finished", sha256(handshake_messages).digest(), 12) | |
finished_handshake = construct_finished_handshake(verify_data) | |
handshake_messages += finished_handshake | |
s.send(construct_tls_ciphertext_record(22, finished_handshake)) | |
# Receive ChangeCipherSpec and Finished | |
expected_verify_data = PRF(master_secret, b"server finished", sha256(handshake_messages).digest(), 12) | |
buf = s.recv(2**14) | |
content_type, fragment, buf = get_tls_plaintext_record(buf) | |
assert content_type == 20 and fragment == b"\x01", "Got content_type={}, fragment={!r}, want ChangeCipherSpec".format(content_type, fragment) | |
while buf: | |
content_type, content, buf = get_tls_ciphertext_record(buf) | |
assert content_type == 22, "Got {}, want 22 (handshake)".format(content_type) | |
handshake_type, body, _ = get_handshake_body(content) | |
if handshake_type == 20: | |
assert body == expected_verify_data, "Finished verification failed: {!r} != {!r}".format(body, expected_verify_data) | |
print("Handshake successful", file=sys.stderr) | |
assert not buf, "Finished should be the last message" | |
else: | |
raise NotImplementedError("Unsupported handshake: msg_type={}, length={}, body={}".format(handshake_type, len(body), body.hex())) | |
# Send HTTP request | |
http_request = b"GET / HTTP/1.1\r\nHost: " + hostname.encode("utf-8") + b"\r\nConnection: close\r\n\r\n" | |
s.send(construct_tls_ciphertext_record(23, http_request)) | |
# Receive HTTP response | |
while True: | |
data = s.recv(2**14) | |
if not data: | |
assert len(buf) == 0, "Connection closed with trailing data: {!r}".format(buf) | |
break | |
buf += data | |
while has_complete_tls_record(buf): | |
content_type, content, buf = get_tls_ciphertext_record(buf) | |
if content_type == 21 and content == b"\x01\x00": | |
assert len(buf) == 0, "Trailing data after close_notify: {!r}".format(buf) | |
break | |
assert content_type == 23, "Got {}, want 23 (application_data)".format(content_type) | |
print(content.decode("latin1"), end="") | |
s.close() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment