Skip to content

Instantly share code, notes, and snippets.

@dmage
Created May 30, 2024 08:58
Show Gist options
  • Save dmage/77ebddf9d6320ad02e256f25121f71b7 to your computer and use it in GitHub Desktop.
Save dmage/77ebddf9d6320ad02e256f25121f71b7 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
import os
import socket
import sys
from hashlib import sha1, sha256
from Crypto.Cipher import AES # provided by the package pycryptodome
from pyasn1.codec.der import decoder
from pyasn1_modules import rfc2437, rfc2459
CLIENT_RANDOM = os.urandom(32)
SERVER_HELLO = {
# random
}
PUBLIC_KEY = {
# modulus
# publicExponent
}
CIPHER = {
"seq_num": 0,
# client_write_MAC_key
# server_write_MAC_key
# client_write_key
# server_write_key
}
# https://datatracker.ietf.org/doc/html/rfc5246#appendix-C
MAC_LENGTH = 20 # mac_length for MAC SHA
MAC_KEY_LENGTH = 20 # mac_key_length for MAC SHA
ENC_KEY_LENGTH = 16 # Key Material for Cipher AES_128_CBC
FIXED_IV_LENGTH = 16 # IV Size for Cipher AES_128_CBC
BLOCK_SIZE = 16 # Block Size for Cipher AES_128_CBC
def non_zero_random_bytes(n):
b = list(os.urandom(n))
for i in range(0, len(b)):
while b[i] == 0:
b[i] = int.from_bytes(os.urandom(1))
return bytes(b)
def rsa_encrypt(plaintext):
k = (PUBLIC_KEY["modulus"].bit_length() + 7) // 8
PS = non_zero_random_bytes(k - 3 - len(plaintext))
EB = b"\x00\x02" + PS + b"\x00" + plaintext
assert len(EB) == k
encrypted = pow(int.from_bytes(EB, byteorder="big"), PUBLIC_KEY["publicExponent"], PUBLIC_KEY["modulus"])
return encrypted.to_bytes(k, byteorder="big")
def HMAC_SHA1(key, text):
if len(key) > 64:
key = sha1(key).digest()
if len(key) < 64:
key += b"\x00" * (64 - len(key))
key_ipad = bytes(key[i] ^ 0x36 for i in range(64))
key_opad = bytes(key[i] ^ 0x5C for i in range(64))
return sha1(key_opad + sha1(key_ipad + text).digest()).digest()
def HMAC_SHA256(key, text):
if len(key) > 64:
key = sha256(key).digest()
if len(key) < 64:
key += b"\x00" * (64 - len(key))
ipad_key = bytes(key[i] ^ 0x36 for i in range(64))
opad_key = bytes(key[i] ^ 0x5C for i in range(64))
return sha256(opad_key + sha256(ipad_key + text).digest()).digest()
def PRF(secret: bytes, label: bytes, seed: bytes, size: int) -> bytes:
a = label + seed
out = b""
while len(out) < size:
a = HMAC_SHA256(secret, a)
out += HMAC_SHA256(secret, a + label + seed)
return out[:size]
def MAC(key, seq_num, typ, length, data):
version = b"\x03\x03"
return HMAC_SHA1(key, seq_num.to_bytes(8, "big") + typ.to_bytes(1, "big") + version + length.to_bytes(2, "big") + data)
def construct_tls_plaintext_record(content_type, body):
return bytes([content_type]) + b"\x03\x03" + len(body).to_bytes(2, byteorder="big") + body
def has_complete_tls_record(buf):
if len(buf) < 5:
return False
length = (buf[3] << 8) + buf[4]
return len(buf) >= 5 + length
def get_tls_plaintext_record(buf):
length = (buf[3] << 8) + buf[4]
assert len(buf) >= 5 + length
return buf[0], buf[5 : 5 + length], buf[5 + length :]
def construct_tls_ciphertext_record(content_type, content):
mac = MAC(CIPHER["client_write_MAC_key"], CIPHER["seq_num"], content_type, len(content), content)
assert len(mac) == MAC_LENGTH
padding_length = BLOCK_SIZE - (len(content) + MAC_LENGTH + 1) % BLOCK_SIZE
padding = bytes([padding_length] * padding_length)
block = content + mac + bytes([padding_length]) + padding
iv = os.urandom(FIXED_IV_LENGTH)
aes = AES.new(CIPHER["client_write_key"], AES.MODE_CBC, iv)
encrypted_block = aes.encrypt(block)
fragment = iv + encrypted_block
CIPHER["seq_num"] += 1
return construct_tls_plaintext_record(content_type, fragment)
def get_tls_ciphertext_record(buf):
content_type, fragment, buf = get_tls_plaintext_record(buf)
iv, encrypted_block = fragment[:FIXED_IV_LENGTH], fragment[FIXED_IV_LENGTH:]
aes = AES.new(CIPHER["server_write_key"], AES.MODE_CBC, iv)
block = aes.decrypt(encrypted_block)
padding_length = block[-1]
return content_type, block[: -padding_length - MAC_LENGTH - 1], buf
def construct_handshake(msg_type, body):
return bytes([msg_type]) + len(body).to_bytes(3, byteorder="big") + body
def get_handshake_body(buf):
length = (buf[1] << 16) + (buf[2] << 8) + buf[3]
assert len(buf) == 4 + length
return buf[0], buf[4 : 4 + length], buf[4 + length :]
def construct_client_hello_handshake():
print("Client random:", CLIENT_RANDOM.hex(), file=sys.stderr)
return construct_handshake(1, b"\x03\x03" + CLIENT_RANDOM + b"\x00\x00\x02\x00\x2f\x01\x00")
def handle_handshake_server_hello(body):
assert body[0] == 3 and body[1] == 3, "Only TLS 1.2 is supported"
SERVER_HELLO["random"] = body[2 : 2 + 32]
print("Server random:", SERVER_HELLO["random"].hex(), file=sys.stderr)
def handle_handshake_certificate(body):
certs_length = (body[0] << 16) + (body[1] << 8) + body[2]
assert len(body) == 3 + certs_length
body = body[3:] # body is a list of certificates, each prefixed with a 3-byte length
# Extract the first (server) certificate
server_cert_length = (body[0] << 16) + (body[1] << 8) + body[2]
assert len(body) >= 3 + server_cert_length
raw_server_cert = body[3 : 3 + server_cert_length]
# Extract the public key from the server certificate
server_cert, _ = decoder.decode(raw_server_cert, asn1Spec=rfc2459.Certificate())
subjectPublicKey = (
server_cert.getComponentByName("tbsCertificate").getComponentByName("subjectPublicKeyInfo").getComponentByName("subjectPublicKey").asOctets()
)
public_key, _ = decoder.decode(subjectPublicKey, asn1Spec=rfc2437.RSAPublicKey())
PUBLIC_KEY["modulus"] = int(public_key.getComponentByName("modulus"))
PUBLIC_KEY["publicExponent"] = int(public_key.getComponentByName("publicExponent"))
print("Public key:", file=sys.stderr)
print(" Modulus: {:x}".format(PUBLIC_KEY["modulus"]), file=sys.stderr)
print(" Public exponent: {:x}".format(PUBLIC_KEY["publicExponent"]), file=sys.stderr)
def handle_handshake_server_hello_done(body):
assert len(body) == 0, "ServerHelloDone should not have body"
def construct_client_key_exchange_handshake(encrypted_pre_master_secret):
body = len(encrypted_pre_master_secret).to_bytes(2, byteorder="big") + encrypted_pre_master_secret
return construct_handshake(16, body)
def construct_finished_handshake(verify_data):
return construct_handshake(20, verify_data)
def main():
hostname = "example.com"
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((hostname, 443))
handshake_messages = b""
# Send ClientHello
client_hello_handshake = construct_client_hello_handshake()
handshake_messages += client_hello_handshake
s.send(construct_tls_plaintext_record(22, client_hello_handshake))
# Receive ServerHello, Certificate, ServerHelloDone
buf = s.recv(2**14)
while buf:
content_type, fragment, buf = get_tls_plaintext_record(buf)
assert content_type == 22, "Got {}, want 22 (handshake)".format(content_type)
handshake_messages += fragment
handshake_type, body, _ = get_handshake_body(fragment)
if handshake_type == 2:
handle_handshake_server_hello(body)
elif handshake_type == 11:
handle_handshake_certificate(body)
elif handshake_type == 14:
handle_handshake_server_hello_done(body)
assert not buf, "ServerHelloDone should be the last message"
else:
raise NotImplementedError("Unsupported handshake: msg_type={}, length={}, body={}".format(handshake_type, len(body), body.hex()))
# Send ClientKeyExchange
pre_master_secret = b"\x03\x03" + os.urandom(46)
print("Pre master secret:", pre_master_secret.hex(), file=sys.stderr)
client_key_exchange_handshake = construct_client_key_exchange_handshake(rsa_encrypt(pre_master_secret))
handshake_messages += client_key_exchange_handshake
s.send(construct_tls_plaintext_record(22, client_key_exchange_handshake))
master_secret = PRF(pre_master_secret, b"master secret", CLIENT_RANDOM + SERVER_HELLO["random"], 48)
print("Master secret:", master_secret.hex(), file=sys.stderr)
key_block = PRF(master_secret, b"key expansion", SERVER_HELLO["random"] + CLIENT_RANDOM, 2 * MAC_KEY_LENGTH + 2 * ENC_KEY_LENGTH)
CIPHER["client_write_MAC_key"], key_block = key_block[:MAC_KEY_LENGTH], key_block[MAC_KEY_LENGTH:]
CIPHER["server_write_MAC_key"], key_block = key_block[:MAC_KEY_LENGTH], key_block[MAC_KEY_LENGTH:]
CIPHER["client_write_key"], key_block = key_block[:ENC_KEY_LENGTH], key_block[ENC_KEY_LENGTH:]
CIPHER["server_write_key"], key_block = key_block[:ENC_KEY_LENGTH], key_block[ENC_KEY_LENGTH:]
assert len(key_block) == 0, "Key block has trailing bytes: {!r}".format(key_block)
print("Client write MAC key:", CIPHER["client_write_MAC_key"].hex(), file=sys.stderr)
print("Server write MAC key:", CIPHER["server_write_MAC_key"].hex(), file=sys.stderr)
print("Client write key:", CIPHER["client_write_key"].hex(), file=sys.stderr)
print("Server write key:", CIPHER["server_write_key"].hex(), file=sys.stderr)
# Send ChangeCipherSpec
s.send(construct_tls_plaintext_record(20, b"\x01"))
# Send Finished
verify_data = PRF(master_secret, b"client finished", sha256(handshake_messages).digest(), 12)
finished_handshake = construct_finished_handshake(verify_data)
handshake_messages += finished_handshake
s.send(construct_tls_ciphertext_record(22, finished_handshake))
# Receive ChangeCipherSpec and Finished
expected_verify_data = PRF(master_secret, b"server finished", sha256(handshake_messages).digest(), 12)
buf = s.recv(2**14)
content_type, fragment, buf = get_tls_plaintext_record(buf)
assert content_type == 20 and fragment == b"\x01", "Got content_type={}, fragment={!r}, want ChangeCipherSpec".format(content_type, fragment)
while buf:
content_type, content, buf = get_tls_ciphertext_record(buf)
assert content_type == 22, "Got {}, want 22 (handshake)".format(content_type)
handshake_type, body, _ = get_handshake_body(content)
if handshake_type == 20:
assert body == expected_verify_data, "Finished verification failed: {!r} != {!r}".format(body, expected_verify_data)
print("Handshake successful", file=sys.stderr)
assert not buf, "Finished should be the last message"
else:
raise NotImplementedError("Unsupported handshake: msg_type={}, length={}, body={}".format(handshake_type, len(body), body.hex()))
# Send HTTP request
http_request = b"GET / HTTP/1.1\r\nHost: " + hostname.encode("utf-8") + b"\r\nConnection: close\r\n\r\n"
s.send(construct_tls_ciphertext_record(23, http_request))
# Receive HTTP response
while True:
data = s.recv(2**14)
if not data:
assert len(buf) == 0, "Connection closed with trailing data: {!r}".format(buf)
break
buf += data
while has_complete_tls_record(buf):
content_type, content, buf = get_tls_ciphertext_record(buf)
if content_type == 21 and content == b"\x01\x00":
assert len(buf) == 0, "Trailing data after close_notify: {!r}".format(buf)
break
assert content_type == 23, "Got {}, want 23 (application_data)".format(content_type)
print(content.decode("latin1"), end="")
s.close()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment