Skip to content

Instantly share code, notes, and snippets.

@rfk
Created October 22, 2018 20:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rfk/f10ef148ed47a78881fd0603064c0f03 to your computer and use it in GitHub Desktop.
Save rfk/f10ef148ed47a78881fd0603064c0f03 to your computer and use it in GitHub Desktop.
//
// This is a JavaScript implementation of (an approximate subset of)
// the Noise Protocol Framework [1], using the WebCrypto API.
//
// It's not spec-compliant but it aims to be about as close as we
// can possibly get. Key points of divergence:
//
// * Uses P-256 rather than Curve25519 for ECDH, since Curve25519
// is not available in WebCrypto.
//
// * Restricts nonces to a single byte rather than a 64-bit counter,
// since we don't have 64-bit integers.
//
// It includes spec-compliant AES-GCM and SHA25 operations, so it uses
// an algorithm identifier of ""WebCryptoP256_AESGCM_SHA256".
//
// The implementation favours easy auditability over performance;
// it should read very closely to a transliteration of the protocol
// spec document. But it hasn't actually *been* audited by anyway,
// so ya know, caveat emptor...
//
// [1] http://noiseprotocol.org/noise.html
//
export {
HandshakeState,
InputBuffer,
OutputBuffer,
string_to_buffer,
ROLE_RESPONDER,
ROLE_INITIATOR
}
// Constants which depend on the specific algorithms chosen.
// There are inline assertions to make sure they're accurate.
const CURVE_NAME = "P-256"
const DHLEN = 65
const AEAD_NAME = "AES-GCM"
const AEAD_KEY_LENGTH = 256
const AEAD_TAG_LENGTH = 128
const AEAD_IV_LENGTH = 96
const HASH_NAME = "SHA-256"
const HASHLEN = 32
const ALGORITHM_IDENTIFIER = "WebCryptoP256_AESGCM_SHA256"
// Some static constants needed by the protocol.
const BUFSIZE = 65535
const ZEROLEN = new Uint8Array(0)
const ROLE_RESPONDER = 0
const ROLE_INITIATOR = 1
// We only support a limited number of handshake patterns.
// No static keys (yet).
const HANDSHAKE_PATTERNS = {
"NN": "e|e,ee",
"NNpsk0": "psk,e|e,ee",
"NNpsk2": "e|e,ee,psk"
}
// Various low-level helper functions.
function assert(cond, msg) {
if (! cond) { throw new Error("assert failed: " + msg) }
}
function string_to_buffer(value, length=null) {
if (length === null) {
length = value.length
} else {
assert(value.length <= length, "string must be shorter than requested buffer length")
}
let buf = new Uint8Array(length)
for (let i in value) {
let c = value.charCodeAt(i)
assert(c < 256, "string contains only single-byte characters")
buf[i] = c
i += 1
}
return buf
}
function concat(buf1, buf2) {
let buf = new Uint8Array(buf1.byteLength + buf2.byteLength)
buf.set(buf1, 0)
buf.set(buf2, buf1.byteLength)
return buf
}
function byte(n) {
return Uint8Array.of(n)
}
function nonce_to_iv(n) {
assert(n < 256, "nonce counter is sufficiently small for JS to cope with")
// Nonce encoding per http://noiseprotocol.org/noise.html#the-aesgcm-cipher-functions
length = AEAD_IV_LENGTH / 8
let iv = new Uint8Array(length)
iv[length - 1] = n
return iv
}
// A convenience wrapper for accumulating output in a buffer.
function OutputBuffer() {
this.i = 0
this.buf = new Uint8Array(BUFSIZE)
}
OutputBuffer.prototype.append = function append(data) {
this.buf.set(data, this.i)
this.i += data.byteLength
assert(this.i <= BUFSIZE, "do not write past end of buffer")
}
OutputBuffer.prototype.finalize = function finalize() {
return new Uint8Array(this.buf.buffer, 0, this.i)
}
// A convenience wrapper for reading through an input buffer.
function InputBuffer(message) {
this.i = 0
this.buf = message
}
InputBuffer.prototype.read = function read(length) {
let slice = new Uint8Array(this.buf.buffer, this.buf.byteOffset + this.i, length)
this.i += length
assert(this.i <= this.buf.byteLength, "do not read past end of buffer")
return slice
}
InputBuffer.prototype.readall = function readall() {
let slice = new Uint8Array(this.buf.buffer, this.buf.byteOffset + this.i, this.buf.byteLength - this.i)
this.i += slice.byteLength
assert(this.i == this.buf.byteLength, "read to end of buffer")
return slice
}
// The ECDH primitives required by Noise.
// Ref http://noiseprotocol.org/noise.html#dh-functions
async function GENERATE_KEYPAIR() {
let kp = await crypto.subtle.generateKey({
name: "ECDH",
namedCurve: CURVE_NAME
}, true, ["deriveBits"])
let pubkey = await crypto.subtle.exportKey("raw", kp.publicKey)
assert(pubkey.byteLength === DHLEN, "public key encoding has correct length")
return {
public_key: new Uint8Array(pubkey),
private_key: kp.privateKey
}
}
async function DH(key_pair, public_key) {
let pubkey = await crypto.subtle.importKey("raw", public_key, {
name: "ECDH",
namedCurve: CURVE_NAME
}, false, ["deriveBits"])
let output = await crypto.subtle.deriveBits({
name: "ECDH",
namedCurve: CURVE_NAME,
public: pubkey
}, key_pair.private_key, 256)
// We deviate from Noise spec here, which requires that DH() produce
// an output of DHLEN bytes. Our encoded P-256 keys are 65 bytes long,
// but the P-256 DH operation only produces a 32-byte secret. We could
// try to language-laywer our way out of it, e.g. by HKDF'ing it into
// a longer key or by compressing the point representation, but it seems
// better to just explicitly note a divergence rather than beeing too clever.
// assert(output.byteLength == DHLEN, "generated secret has correct length")
return new Uint8Array(output)
}
// The encryption primitives required by Noise.
// Ref http://noiseprotocol.org/noise.html#cipher-functions
async function ENCRYPT(k, n, ad, plaintext) {
k = await crypto.subtle.importKey("raw", k, {
name: AEAD_NAME,
length: AEAD_KEY_LENGTH
}, false, ["encrypt"])
let ciphertext = await crypto.subtle.encrypt({
name: AEAD_NAME,
tagLength: AEAD_TAG_LENGTH,
iv: nonce_to_iv(n),
additionalData: ad
}, k, plaintext)
assert(ciphertext.byteLength === plaintext.byteLength + 16, "ciphertext has correct encoded length")
return new Uint8Array(ciphertext)
}
async function DECRYPT(k, n, ad, ciphertext) {
k = await crypto.subtle.importKey("raw", k, {
name: AEAD_NAME,
length: AEAD_KEY_LENGTH
}, false, ["decrypt"])
let plaintext = await crypto.subtle.decrypt({
name: AEAD_NAME,
tagLength: AEAD_TAG_LENGTH,
iv: nonce_to_iv(n),
additionalData: ad
}, k, ciphertext)
assert(plaintext.byteLength === ciphertext.byteLength - 16, "plaintext has correct decoded length")
return new Uint8Array(plaintext)
}
// The hashing primitives required by Noise.
// Ref http://noiseprotocol.org/noise.html#hash-functions
async function HASH(data) {
hash = await crypto.subtle.digest({ name: HASH_NAME }, data)
assert(hash.byteLength == HASHLEN, "hash has correct output length")
return new Uint8Array(hash)
}
async function HMAC_HASH(key, data) {
key = await crypto.subtle.importKey("raw", key, {
name: "HMAC",
hash: { name: HASH_NAME }
}, false, ["sign"])
hash = await crypto.subtle.sign("HMAC", key, data)
assert(hash.byteLength == HASHLEN, "hmac-hash has correct output length")
return new Uint8Array(hash)
}
async function HKDF(chaining_key, input_key_material, num_outputs) {
assert([0, 32, DHLEN].indexOf(input_key_material.byteLength) !== -1, "input_key_material has acceptable length")
let temp_key = await HMAC_HASH(chaining_key, input_key_material)
let output1 = await HMAC_HASH(temp_key, byte(0x01))
let output2 = await HMAC_HASH(temp_key, concat(output1, byte(0x02)))
if (num_outputs === 2) {
return [output1, output2]
}
let output3 = await HMAC_HASH(temp_key, concat(output2, byte(0x03)))
return [output1, output2, output3]
}
// The low-level "CipherState" object specced by Noise.
// Ref http://noiseprotocol.org/noise.html#the-cipherstate-object
function CipherState() {
this.k = null
this.n = 0
}
CipherState.prototype.InitializeKey = function InitializeKey(key) {
assert(key === null || key.byteLength === 32, "valid cipherstate key")
this.k = key
this.n = 0
}
CipherState.prototype.HasKey = function HasKey() {
return this.k !== null
}
CipherState.prototype.EncryptWithAD = async function EncryptWithAD(ad, plaintext) {
if (! this.HasKey()) {
return plaintext
}
return await ENCRYPT(this.k, this.n, ad, plaintext)
}
CipherState.prototype.DecryptWithAD = async function DecryptWithAD(ad, ciphertext) {
if (! this.HasKey()) {
return plaintext
}
// We must not increment n if there is a decryption error.
let plaintext = await DECRYPT(this.k, this.n, ad, ciphertext)
this.n += 1
return plaintext
}
// The mid-level "SymmetricState" object specced by Noise.
// Ref http://noiseprotocol.org/noise.html#the-symmetricstate-object
function SymmetricState(protocol_name) {
this.h = null
this.ck = null
this.cipherstate = new CipherState()
}
SymmetricState.prototype.InitializeSymmetric = async function InitializeSymmetric(protocol_name) {
if (protocol_name.length <= HASHLEN) {
this.h = string_to_buffer(protocol_name, HASHLEN)
} else {
this.h = await HASH(protocol_name)
}
this.ck = this.h
this.cipherstate.InitializeKey(null)
}
SymmetricState.prototype.MixKey = async function MixKey(input_key_material) {
let [ck, temp_k] = await HKDF(this.ck, input_key_material, 2)
this.ck = ck
assert(HASHLEN === 32, "no need to truncate temp_k")
await this.cipherstate.InitializeKey(temp_k)
}
SymmetricState.prototype.MixHash = async function MixHash(data) {
this.h = await HASH(concat(this.h, data))
}
SymmetricState.prototype.MixKeyAndHash = async function MixKeyAndHash(input_key_material) {
let [ck, temp_h, temp_k] = await HKDF(this.ck, input_key_material, 3)
this.ck = ck
await this.MixHash(temp_h)
assert(HASHLEN === 32, "no need to truncate temp_k")
await this.cipherstate.InitializeKey(temp_k)
}
SymmetricState.prototype.GetHandshakeHash = async function GetHandshakeHash() {
return this.h
}
SymmetricState.prototype.EncryptAndHash = async function EncryptAndHash(plaintext) {
let ciphertext = await this.cipherstate.EncryptWithAD(this.h, plaintext)
await this.MixHash(ciphertext)
return ciphertext
}
SymmetricState.prototype.DecryptAndHash = async function DecryptAndHash(ciphertext) {
let plaintext = await this.cipherstate.DecryptWithAD(this.h, ciphertext)
await this.MixHash(ciphertext)
return plaintext
}
SymmetricState.prototype.Split = async function Split() {
let [temp_k1, temp_k2] = await HKDF(this.ck, ZEROLEN, 2)
assert(HASHLEN === 32, "no need to truncate temp_k1 or temp_k2")
let c1 = new CipherState()
let c2 = new CipherState()
c1.InitializeKey(temp_k1)
c2.InitializeKey(temp_k2)
return [c1, c2]
}
// The high-level "HandshakeState" object specced by Noise.
// Ref http://noiseprotocol.org/noise.html#the-handshakestate-object
//
// It also include support for pre-shared symmetric keys,
// ref http://noiseprotocol.org/noise.html#pre-shared-symmetric-keys
function HandshakeState() {
this.symmetricstate = new SymmetricState()
this.initiator = null
this.e = null
this.re = null
this.psk = null
this.message_patterns = null
}
HandshakeState.prototype.Initialize = async function Initialize(handshake_pattern, role, prologue=ZEROLEN, keys={}) {
assert(handshake_pattern in HANDSHAKE_PATTERNS, "using a known handshake pattern")
this.role = role
await this.symmetricstate.InitializeSymmetric("Noise_" + handshake_pattern + "_" + ALGORITHM_IDENTIFIER)
await this.symmetricstate.MixHash(prologue)
this.e = keys.e || null
this.re = keys.re || null
this.psk = keys.psk || null
this.message_patterns = HANDSHAKE_PATTERNS[handshake_pattern].split("|").map(msg => msg.split(","))
this.handshake_contains_psk = this.message_patterns.some(tokens => tokens.indexOf("psk") !== -1)
}
HandshakeState.prototype.WriteMessage = async function WriteMessage(payload, message_buffer) {
assert(this.message_patterns.length > 0, "WriteMessage must have a pending message pattern")
assert(message_buffer instanceof OutputBuffer, "caller has provided an OutputBuffer")
tokens = this.message_patterns.shift()
for (let i in tokens) {
switch(tokens[i]) {
case "e":
assert(this.e === null, "e must be empty")
this.e = await GENERATE_KEYPAIR()
message_buffer.append(this.e.public_key)
await this.symmetricstate.MixHash(this.e.public_key)
if (this.handshake_contains_psk) {
await this.symmetricstate.MixKey(this.e.public_key)
}
break
case "ee":
await this.symmetricstate.MixKey(await DH(this.e, this.re))
break
case "psk":
assert(this.psk, "psk was provided")
await this.symmetricstate.MixKeyAndHash(this.psk)
break
default:
throw new Error("unexpected handshake token: " + token)
}
}
message_buffer.append(await this.symmetricstate.EncryptAndHash(payload))
}
HandshakeState.prototype.ReadMessage = async function ReadMessage(message, payload_buffer) {
assert(this.message_patterns.length > 0, "ReadMessage must have a pending message pattern")
assert(message instanceof InputBuffer, "caller has provided an InputBuffer")
assert(payload_buffer instanceof OutputBuffer, "caller has provided an OutputBuffer")
tokens = this.message_patterns.shift()
for (let i in tokens) {
switch(tokens[i]) {
case "e":
assert(this.re === null, "re must be empty")
this.re = message.read(DHLEN)
await this.symmetricstate.MixHash(this.re)
if (this.handshake_contains_psk) {
await this.symmetricstate.MixKey(this.re)
}
break
case "ee":
await this.symmetricstate.MixKey(await DH(this.e, this.re))
break
case "psk":
assert(this.psk, "psk was provided")
await this.symmetricstate.MixKeyAndHash(this.psk)
break
default:
throw new Error("unexpected handshake token: " + token)
}
}
payload_buffer.append(await this.symmetricstate.DecryptAndHash(message.readall()))
}
HandshakeState.prototype.HasCompleted = async function HasCompleted() {
return this.message_patterns.length === 0
}
HandshakeState.prototype.Split = async function Split() {
assert(this.HasCompleted(), "handshake must have completed before splitting")
return this.symmetricstate.Split()
}
HandshakeState.prototype.GetHandshakeHash = async function GetHandshakeHash() {
assert(this.HasCompleted(), "handshake must have completed before getting hash")
return this.symmetricstate.GetHandShakeHash()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment