Skip to content

Instantly share code, notes, and snippets.

@jkatz
Last active October 15, 2019 19:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jkatz/7444eda78a6fff18ab5d74c024e3761d to your computer and use it in GitHub Desktop.
Save jkatz/7444eda78a6fff18ab5d74c024e3761d to your computer and use it in GitHub Desktop.
SCRAM-SHA-256 POC for PostgreSQL; client does not support channel binding; Py3+
import base64
import hashlib
import hmac
import re
import secrets
import socket
class SCRAMAuthentication(object):
AUTHENTICATION_METHOD = b"SCRAM-SHA-256"
DIGEST = hashlib.sha256
REQUIREMENTS_CLIENT_FINAL_MESSAGE = ['client_channel_binding', 'client_proof', 'server_nonce']
REQUIREMENTS_CLIENT_PROOF = ['password_iterations', 'password_salt', 'server_first_message', 'server_nonce']
SOCKET_BYTE_CHUNK = 65536
def __init__(self, host='localhost', port=5432, dbname="postgres"):
# keep this for the socket
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# connection specific settings
self.host = host
self.port = port
self.dbname = dbname
# SCRAM variables
self.client_channel_binding = b"n,,"
self.client_first_message_bare = None
self.client_nonce = None
self.client_proof = None
self.password_salt = None
self.password_iterations = None
self.server_first_message = None
self.server_nonce = None
def authenticate(self, username, password):
# open the socket
self.socket.connect((self.host, self.port))
# create a startup message, of course this has nothing to do with SCRAM implementation,
# but have it here for testing
msg = self._create_startup_message(username, self.dbname)
print(msg)
# send startup message
self.socket.send(msg)
# happy path...
self.socket.recv(self.SOCKET_BYTE_CHUNK)
# TODO: here is where we determine whether or not to use channel binding
# create the first client message
msg = self._create_client_first_message(username)
print(msg)
self.socket.send(msg)
# this time we need to parse the response
server_first_response = self.socket.recv(self.SOCKET_BYTE_CHUNK)
server_nonce, password_salt, password_iterations, = self._parse_server_first_response(server_first_response)
print(server_first_response)
# now we need to generate the client proof
self.client_proof = self._generate_client_proof(password=password)
# create the final client message
msg = self._create_client_final_message()
# send the message
self.socket.send(msg)
# get response
r = self.socket.recv(65536)
print(r)
def _create_client_first_message(self, username):
"""Create the initial client message for SCRAM"""
self.client_nonce = self._generate_client_nonce()
# set the client first message bare here, as it's used in a later step
self.client_first_message_bare = b"n=" + username.encode("utf-8") + b",r=" + self.client_nonce
# put together the full message here
msg = bytes()
msg += self.AUTHENTICATION_METHOD + b"\0"
client_first_message = self.client_channel_binding + self.client_first_message_bare
msg += (len(client_first_message)).to_bytes(4, byteorder='big') + client_first_message
# calculate full length of message
msg = (len(msg) + 4).to_bytes(4, byteorder='big') + msg
return b"p" + msg
def _create_client_final_message(self):
if any([getattr(self, val) is None for val in self.REQUIREMENTS_CLIENT_FINAL_MESSAGE]):
raise Exception("you need values from server in order to generate a client proof")
msg = bytes()
msg += b"c=" + base64.b64encode(self.client_channel_binding) + b",r=" + self.server_nonce + \
b",p=" + base64.b64encode(self.client_proof)
# calculate length of message
msg = (len(msg) + 4).to_bytes(4, byteorder='big') + msg
# provider final header
msg = b"p" + msg
return msg
def _create_startup_message(self, username, dbname):
msg = bytes()
# protocol
msg += (196608).to_bytes(4, byteorder='big')
msg += b"user\0"
msg += username.encode("utf-8") + b"\0"
msg += b"database\0"
msg += dbname.encode("utf-8") + b"\0\0"
return (len(msg) + 4).to_bytes(4, byteorder='big') + msg
def _generate_client_nonce(self, bytes=24):
return base64.b64encode(secrets.token_bytes(bytes))
def _generate_client_proof(self, password):
"""need to ensure a server response exists, i.e. """
if any([getattr(self, val) is None for val in self.REQUIREMENTS_CLIENT_PROOF]):
raise Exception("you need values from server in order to generate a client proof")
# generate a salt password
salted_password = self._generate_salted_password(password, self.password_salt, self.password_iterations)
# client key is derived from the salted password
client_key = hmac.new(salted_password, b"Client Key", self.DIGEST)
# this allows us to compute the stored key that is residing on the server
stored_key = self.DIGEST(client_key.digest())
# build the authorization message that will be used in the client signature
# the "c=" portion is for the channel binding, but this is not presently set
authorization_message = self.client_first_message_bare + b"," + self.server_first_message + \
b",c=" + base64.b64encode(self.client_channel_binding) + b",r=" + self.server_nonce
# sign!
client_signature = hmac.new(stored_key.digest(), authorization_message, self.DIGEST)
# and the proof
return self._bytes_xor(client_key.digest(), client_signature.digest())
def _generate_salted_password(self, password, salt, iterations):
"""This follows the "Hi" algorithm specified in RFC5802"""
# convert the password to a binary string - UTF8 is safe for SASL (though there are SASLPrep rules)
p = password.encode("utf8")
# the salt needs to be base64 decoded -- full binary must be used
s = base64.b64decode(salt)
# the initial signature is the salt with a terminator of a 32-bit string ending in 1
ui = hmac.new(p, s + b'\x00\x00\x00\x01', self.DIGEST)
# grab the initial digest
u = ui.digest()
# for X number of iterations, recompute the HMAC signature against the password
# and the latest iteration of the hash, and XOR it with the previous version
for x in range(iterations - 1):
ui = hmac.new(p, ui.digest(), hashlib.sha256)
# this is a fancy way of XORing two byte strings together
u = self._bytes_xor(u, ui.digest())
return u
def _parse_server_first_response(self, response):
""""""
# first, protocol stuff
if not response[0].to_bytes(1, 'big') == b'R':
raise Exception("failed")
index = 1
msg_length = int.from_bytes(response[index:index + 4], byteorder='big')
msg_type = int.from_bytes(response[index + 4:index + 8], byteorder='big')
self.server_first_message = response[index + 8:]
# ok, get out the stuff...happy path mostly
try:
self.server_nonce = re.search(b'r=([^,]+),', self.server_first_message)[1]
except IndexError:
raise Exception("could not get nonce")
if not self.server_nonce.startswith(self.client_nonce):
raise Exception("invalid nonce")
try:
self.password_salt = re.search(b's=([^,]+),', self.server_first_message)[1]
except IndexError:
raise Exception("could not get salt")
try:
self.password_iterations = int(re.search(b'i=(\d+),?', self.server_first_message)[1])
except (IndexError, TypeError, ValueError):
raise Exception("could not get iterations")
return (self.server_nonce, self.password_salt, self.password_iterations)
def _bytes_xor(self, a, b):
"""fancy way to XOR two bytestrings together"""
return bytes(a_i ^ b_i for a_i, b_i in zip(a, b))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment