Last active
October 15, 2019 19:16
-
-
Save jkatz/7444eda78a6fff18ab5d74c024e3761d to your computer and use it in GitHub Desktop.
SCRAM-SHA-256 POC for PostgreSQL; client does not support channel binding; Py3+
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import base64 | |
import hashlib | |
import hmac | |
import re | |
import secrets | |
import socket | |
class SCRAMAuthentication(object): | |
AUTHENTICATION_METHOD = b"SCRAM-SHA-256" | |
DIGEST = hashlib.sha256 | |
REQUIREMENTS_CLIENT_FINAL_MESSAGE = ['client_channel_binding', 'client_proof', 'server_nonce'] | |
REQUIREMENTS_CLIENT_PROOF = ['password_iterations', 'password_salt', 'server_first_message', 'server_nonce'] | |
SOCKET_BYTE_CHUNK = 65536 | |
def __init__(self, host='localhost', port=5432, dbname="postgres"): | |
# keep this for the socket | |
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
# connection specific settings | |
self.host = host | |
self.port = port | |
self.dbname = dbname | |
# SCRAM variables | |
self.client_channel_binding = b"n,," | |
self.client_first_message_bare = None | |
self.client_nonce = None | |
self.client_proof = None | |
self.password_salt = None | |
self.password_iterations = None | |
self.server_first_message = None | |
self.server_nonce = None | |
def authenticate(self, username, password): | |
# open the socket | |
self.socket.connect((self.host, self.port)) | |
# create a startup message, of course this has nothing to do with SCRAM implementation, | |
# but have it here for testing | |
msg = self._create_startup_message(username, self.dbname) | |
print(msg) | |
# send startup message | |
self.socket.send(msg) | |
# happy path... | |
self.socket.recv(self.SOCKET_BYTE_CHUNK) | |
# TODO: here is where we determine whether or not to use channel binding | |
# create the first client message | |
msg = self._create_client_first_message(username) | |
print(msg) | |
self.socket.send(msg) | |
# this time we need to parse the response | |
server_first_response = self.socket.recv(self.SOCKET_BYTE_CHUNK) | |
server_nonce, password_salt, password_iterations, = self._parse_server_first_response(server_first_response) | |
print(server_first_response) | |
# now we need to generate the client proof | |
self.client_proof = self._generate_client_proof(password=password) | |
# create the final client message | |
msg = self._create_client_final_message() | |
# send the message | |
self.socket.send(msg) | |
# get response | |
r = self.socket.recv(65536) | |
print(r) | |
def _create_client_first_message(self, username): | |
"""Create the initial client message for SCRAM""" | |
self.client_nonce = self._generate_client_nonce() | |
# set the client first message bare here, as it's used in a later step | |
self.client_first_message_bare = b"n=" + username.encode("utf-8") + b",r=" + self.client_nonce | |
# put together the full message here | |
msg = bytes() | |
msg += self.AUTHENTICATION_METHOD + b"\0" | |
client_first_message = self.client_channel_binding + self.client_first_message_bare | |
msg += (len(client_first_message)).to_bytes(4, byteorder='big') + client_first_message | |
# calculate full length of message | |
msg = (len(msg) + 4).to_bytes(4, byteorder='big') + msg | |
return b"p" + msg | |
def _create_client_final_message(self): | |
if any([getattr(self, val) is None for val in self.REQUIREMENTS_CLIENT_FINAL_MESSAGE]): | |
raise Exception("you need values from server in order to generate a client proof") | |
msg = bytes() | |
msg += b"c=" + base64.b64encode(self.client_channel_binding) + b",r=" + self.server_nonce + \ | |
b",p=" + base64.b64encode(self.client_proof) | |
# calculate length of message | |
msg = (len(msg) + 4).to_bytes(4, byteorder='big') + msg | |
# provider final header | |
msg = b"p" + msg | |
return msg | |
def _create_startup_message(self, username, dbname): | |
msg = bytes() | |
# protocol | |
msg += (196608).to_bytes(4, byteorder='big') | |
msg += b"user\0" | |
msg += username.encode("utf-8") + b"\0" | |
msg += b"database\0" | |
msg += dbname.encode("utf-8") + b"\0\0" | |
return (len(msg) + 4).to_bytes(4, byteorder='big') + msg | |
def _generate_client_nonce(self, bytes=24): | |
return base64.b64encode(secrets.token_bytes(bytes)) | |
def _generate_client_proof(self, password): | |
"""need to ensure a server response exists, i.e. """ | |
if any([getattr(self, val) is None for val in self.REQUIREMENTS_CLIENT_PROOF]): | |
raise Exception("you need values from server in order to generate a client proof") | |
# generate a salt password | |
salted_password = self._generate_salted_password(password, self.password_salt, self.password_iterations) | |
# client key is derived from the salted password | |
client_key = hmac.new(salted_password, b"Client Key", self.DIGEST) | |
# this allows us to compute the stored key that is residing on the server | |
stored_key = self.DIGEST(client_key.digest()) | |
# build the authorization message that will be used in the client signature | |
# the "c=" portion is for the channel binding, but this is not presently set | |
authorization_message = self.client_first_message_bare + b"," + self.server_first_message + \ | |
b",c=" + base64.b64encode(self.client_channel_binding) + b",r=" + self.server_nonce | |
# sign! | |
client_signature = hmac.new(stored_key.digest(), authorization_message, self.DIGEST) | |
# and the proof | |
return self._bytes_xor(client_key.digest(), client_signature.digest()) | |
def _generate_salted_password(self, password, salt, iterations): | |
"""This follows the "Hi" algorithm specified in RFC5802""" | |
# convert the password to a binary string - UTF8 is safe for SASL (though there are SASLPrep rules) | |
p = password.encode("utf8") | |
# the salt needs to be base64 decoded -- full binary must be used | |
s = base64.b64decode(salt) | |
# the initial signature is the salt with a terminator of a 32-bit string ending in 1 | |
ui = hmac.new(p, s + b'\x00\x00\x00\x01', self.DIGEST) | |
# grab the initial digest | |
u = ui.digest() | |
# for X number of iterations, recompute the HMAC signature against the password | |
# and the latest iteration of the hash, and XOR it with the previous version | |
for x in range(iterations - 1): | |
ui = hmac.new(p, ui.digest(), hashlib.sha256) | |
# this is a fancy way of XORing two byte strings together | |
u = self._bytes_xor(u, ui.digest()) | |
return u | |
def _parse_server_first_response(self, response): | |
"""""" | |
# first, protocol stuff | |
if not response[0].to_bytes(1, 'big') == b'R': | |
raise Exception("failed") | |
index = 1 | |
msg_length = int.from_bytes(response[index:index + 4], byteorder='big') | |
msg_type = int.from_bytes(response[index + 4:index + 8], byteorder='big') | |
self.server_first_message = response[index + 8:] | |
# ok, get out the stuff...happy path mostly | |
try: | |
self.server_nonce = re.search(b'r=([^,]+),', self.server_first_message)[1] | |
except IndexError: | |
raise Exception("could not get nonce") | |
if not self.server_nonce.startswith(self.client_nonce): | |
raise Exception("invalid nonce") | |
try: | |
self.password_salt = re.search(b's=([^,]+),', self.server_first_message)[1] | |
except IndexError: | |
raise Exception("could not get salt") | |
try: | |
self.password_iterations = int(re.search(b'i=(\d+),?', self.server_first_message)[1]) | |
except (IndexError, TypeError, ValueError): | |
raise Exception("could not get iterations") | |
return (self.server_nonce, self.password_salt, self.password_iterations) | |
def _bytes_xor(self, a, b): | |
"""fancy way to XOR two bytestrings together""" | |
return bytes(a_i ^ b_i for a_i, b_i in zip(a, b)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment