Skip to content

Instantly share code, notes, and snippets.

@autumnjolitz
Last active October 24, 2023 06:32
Show Gist options
  • Save autumnjolitz/50f663dc51f7f636f54422edcf917e5c to your computer and use it in GitHub Desktop.
Save autumnjolitz/50f663dc51f7f636f54422edcf917e5c to your computer and use it in GitHub Desktop.
implementation of scram-sha512 using wikipedia as the only source of information. Does not care about GS2 header nonsense. `python -m pip install argon2-cffi pytest` please. Done now in a functional style!
import base64
import hashlib
import hmac
import uuid
from typing import NamedTuple, Union, Tuple, Optional
import argon2 # type: ignore
import pytest
class ClientSession(NamedTuple):
username: str
def login(self, password: Union[str, bytes]) -> "ClientRequestChallenge":
if not isinstance(password, bytes):
password = password.encode("utf8")
return ClientRequestChallenge.new(self, password)
class ClientRequestChallenge(NamedTuple):
client: ClientSession
password: bytes
nonce: bytes
@classmethod
def new(cls, client: ClientSession, password: bytes) -> "ClientRequestChallenge":
return cls(client, password, uuid.uuid4().bytes)
@property
def username(self) -> str:
return self.client.username
def create_request(self) -> bytes:
"""
Payload to send to server
"""
return self.username.encode() + b"," + self.nonce
def parse_response(self, challenge: bytes) -> "HandleServerChallenge":
return HandleServerChallenge.new(self, ClientServerChallenge.new(challenge, self.nonce))
class ClientServerChallenge(NamedTuple):
salt: bytes
client_nonce: bytes
server_nonce: bytes
@classmethod
def new(cls, challenge: bytes, client_nonce: bytes) -> "ClientServerChallenge":
c_nonce_and_s_nonce, salt = challenge.rsplit(b",", 1)
if not c_nonce_and_s_nonce.startswith(client_nonce):
raise ValueError(
f"Expected server challenge to start with client nonce {client_nonce!r} but "
f"instead got {c_nonce_and_s_nonce[:len(client_nonce)]!r}"
)
server_nonce = c_nonce_and_s_nonce[len(client_nonce) :]
return cls(salt, client_nonce, server_nonce)
@property
def challenge(self) -> bytes:
return self.client_nonce + self.server_nonce + b"," + self.salt
class HandleServerChallenge(NamedTuple):
initial_request: ClientRequestChallenge
server_challenge: ClientServerChallenge
@classmethod
def new(cls, initial_request: ClientRequestChallenge, server_challenge: ClientServerChallenge):
return cls(initial_request, server_challenge)
@property
def encrypted_password(self) -> bytes:
return hash_password(self.initial_request.password, salt=self.server_challenge.salt)
@property
def client_key(self) -> bytes:
return hmac.new(self.encrypted_password, b"Client Key").digest()
@property
def hashed_client_key(self) -> bytes:
return hashlib.sha512(self.client_key).digest()
@property
def authentication_message(self):
return b",".join(
(
self.initial_request.create_request(),
self.server_challenge.challenge,
b"".join((self.server_challenge.client_nonce, self.server_challenge.server_nonce)),
)
)
@property
def client_proof(self):
hmac_ckey = hmac.new(self.hashed_client_key, self.authentication_message).digest()
return bytes(bytearray(x ^ y for x, y in zip(self.client_key, hmac_ckey)))
def create_request(self) -> bytes:
return b"".join(
(
self.server_challenge.client_nonce,
self.server_challenge.server_nonce,
self.client_proof,
)
)
@property
def server_key(self) -> bytes:
return hmac.new(self.encrypted_password, b"Server Key").digest()
@property
def server_certificate(self) -> bytes:
return hmac.new(self.server_key, self.authentication_message).digest()
def parse_response(self, response: bytes) -> "VerifyClientServerResponse":
assert isinstance(response, bytes)
return VerifyClientServerResponse(response, self).validate()
class VerifyClientServerResponse(NamedTuple):
server_certificate: bytes
client_challenge: HandleServerChallenge
def validate(self):
if not hmac.compare_digest(
self.server_certificate, self.client_challenge.server_certificate
):
raise ValueError(f"server gave a bad cert!")
return self
HASHER = argon2.PasswordHasher()
def hash_password(password: bytes, salt: Optional[bytes] = None) -> bytes:
assert salt or salt is None, "salt must either be a non empty string or None"
return HASHER.hash(password, salt=salt).encode("utf-8")
class User(NamedTuple):
username: str
encrypted_password: bytes
@property
def salt(self) -> bytes:
_, salt_b64, _ = self.encrypted_password.rsplit(b"$", 2)
step, rem = divmod(len(salt_b64), 4)
needed_padding = 4 - rem
if needed_padding:
salt_b64 += b"=" * needed_padding
return base64.b64decode(salt_b64)
@classmethod
def new(cls, username: str, password: Union[str, bytes]) -> "User":
if not isinstance(password, bytes):
password = password.encode("utf8")
encrypted_password = hash_password(password)
return cls(username, encrypted_password)
class Server(NamedTuple):
users: Tuple[User, ...]
@classmethod
def create(cls, users=()) -> "Server":
return cls(users)
def handle_request(self, client_request: bytes) -> "ServerHandleInitialRequest":
b_username, client_nonce = client_request.split(b",", 1)
server_nonce = uuid.uuid4().bytes
username: str = b_username.decode()
for user in self.users:
if user.username == username:
return ServerHandleInitialRequest(user, client_nonce, server_nonce)
raise PermissionError(f"Invalid user/password")
def add_user(self, username, password: str) -> "Server":
user = User.new(username, password)
return self._replace(users=self.users + (user,))
class ServerHandleInitialRequest(NamedTuple):
user: User
client_nonce: bytes
server_nonce: bytes
@property
def initial_request(self) -> bytes:
return self.user.username.encode() + b"," + self.client_nonce
def create_response(self) -> bytes:
return self.client_nonce + self.server_nonce + b"," + self.user.salt
@property
def server_key(self) -> bytes:
return hmac.new(self.user.encrypted_password, b"Server Key").digest()
@property
def hashed_client_key(self) -> bytes:
return hashlib.sha512(
hmac.new(self.user.encrypted_password, b"Client Key").digest()
).digest()
@property
def authentication_message(self) -> bytes:
return b",".join(
(
self.initial_request,
self.create_response(),
b"".join((self.client_nonce, self.server_nonce)),
)
)
@property
def hmac_client_key(self) -> bytes:
return hmac.new(self.hashed_client_key, self.authentication_message).digest()
@property
def server_signature(self) -> bytes:
return hmac.new(self.server_key, self.authentication_message).digest()
def handle_request(self, client_request: bytes) -> "ServerChallengeResponse":
expected_start = self.client_nonce + self.server_nonce
client_start = client_request[: len(expected_start)]
if not hmac.compare_digest(expected_start, client_start):
raise ValueError("Client request does not match expected client request")
client_proof = client_request[len(expected_start) :]
return ServerChallengeResponse.new(self, client_proof)
class ServerChallengeResponse(NamedTuple):
initial_request_handler: ServerHandleInitialRequest
client_proof: bytes
def create_response(self) -> bytes:
return self.initial_request_handler.server_signature
@property
def client_key(self) -> bytes:
return bytes(
bytearray(
x ^ y
for x, y in zip(self.initial_request_handler.hmac_client_key, self.client_proof)
)
)
@property
def hashed_client_key(self) -> bytes:
return hashlib.sha512(self.client_key).digest()
@classmethod
def new(
cls, initial_response: ServerHandleInitialRequest, client_proof: bytes
) -> "ServerChallengeResponse":
return cls(initial_response, client_proof).validate()
def validate(self):
if not hmac.compare_digest(
self.hashed_client_key, self.initial_request_handler.hashed_client_key
):
raise ValueError("client keys do not match!")
return self
@pytest.fixture
def random_password():
return str(uuid.uuid4())
IClientRequest = Union[
ClientSession, ClientRequestChallenge, HandleServerChallenge, VerifyClientServerResponse
]
IServerRequest = Union[ServerHandleInitialRequest, ServerChallengeResponse]
@pytest.fixture
def init_server(random_password: str) -> Server:
server = Server.create().add_user("autumn", random_password).add_user("joe", uuid.uuid4().hex)
return server
class Client:
session: IClientRequest
@classmethod
def new(cls, session: ClientSession):
self = cls()
self.session = session
return self
class AuthServer:
session: Union[IServerRequest, Server]
@classmethod
def new(cls, init_state: Server):
self = cls()
self.session = init_state
return self
def test_successful_login(init_server: Server, random_password: str):
"""
Showcase the two message exchanges bidirectionally between client and server.
"""
client = Client.new(ClientSession("autumn"))
server = AuthServer.new(init_server)
client_init_login: ClientRequestChallenge = client.session.login(random_password)
client.session = client_init_login
del client_init_login
request: bytes = client.session.create_request()
# Server receives the request:
create_scram_challenge: ServerHandleInitialRequest = server.session.handle_request(request)
server.session = create_scram_challenge
del create_scram_challenge
# Server responds with:
response: bytes = server.session.create_response()
# Client interprets the Server's response as the
challenge_response: HandleServerChallenge = client.session.parse_response(response)
client.session = challenge_response
del request, response, challenge_response
# So the client resolves to tell the server it's answer to the server's 'challenge'
request = client.session.create_request()
# Server receives the clients 'challenge response' and responds with the certificate
# that was used this session:
verify_scram_challenge: ServerChallengeResponse = server.session.handle_request(request)
response = verify_scram_challenge.create_response()
del verify_scram_challenge
# Client receives the server certificate used for this entire session and ensures
# it is the same as the beginning
verified_scram_session: VerifyClientServerResponse = client.session.parse_response(response)
client.session = verified_scram_session
def test_invalid_password(init_server: Server):
"""
It can take two exchanges to determine it's a bunk password...
"""
client_session: IClientRequest
server_session: IServerRequest
client_session = ClientSession("joe").login("1234")
server_session = init_server.handle_request(client_session.create_request())
client_session = client_session.parse_response(server_session.create_response())
# Kaboom!
with pytest.raises(ValueError):
server_session = server_session.handle_request(client_session.create_request())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment