Skip to content

Instantly share code, notes, and snippets.

@cameronhotchkies
Created August 17, 2013 02:14
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save cameronhotchkies/6254921 to your computer and use it in GitHub Desktop.
An implementation of MD4 in pure Scala. More info at: http://www.semisafe.com/2013/08/md4-implementation-in-pure-scala/
package com.semisafe.cryptopals
/*
* Implementation of MD4 in pure Scala
*
* derived from the RSA Data Security, Inc. MD4
* Message-Digest Algorithm (http://tools.ietf.org/html/rfc1320)
*
* This was implemented as part of the Matasano Crypto Challenges
* http://www.matasano.com/articles/crypto-challenges/
*
* Having a native implementation of MD4 is required for the
* fourth set of challenges, building your own isn't, so feel
* free to make use of this, but expect to get your hands dirty
* later.
*
* !!! WARNING !!! -- This is a toy implementation, of an
* obsolete algorithm please for the love of god and all that
* is holy do not deploy this to any production environment.
* If you do, dogs and cats will start living together, making
* pancakes and wearing pyjamas, then everyone will die, and
* it will be your fault.
*
* *** CAVEAT *** -- I'm at this point a very novice Scala
* developer. If you see things that don't make sense, or
* could be clearer, please fork this and fix it, or just
* let me know. I'd love to know more about how to improve this.
*
* Author: Cameron Hotchkies <handsomecam@semisafe.com>
*
*/
import java.nio.ByteBuffer
import scala.annotation.tailrec
object Md4Digest {
def hashMessage(message: Array[Byte]): Array[Byte] = {
val preprocessed = preprocessMessage(message)
// Convert numbers to big endian
val initialSeed = Array(0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476)
val digestHash = hashLoop(preprocessed, initialSeed)
// Produce the final hash value
val hashBuffer = ByteBuffer.allocate(16)
digestHash.foreach(hashBuffer.putInt(_))
val hashBytes = hashBuffer.array.grouped(4)
// reverse the inserted 4 byte chunks to fix the
// LE assumptions of the algorithm
hashBytes.map(_.reverse).flatten.toArray
}
// All three rounds can be a curried function
@tailrec
private[this] def roundX(
shifts: Array[Int],
transform: (Int, Int, Int, Int, Int, Int) => Int,
iter: Int => Int)(
X: Array[Int],
state: List[Int],
iteration: Int): List[Int] = {
if (iteration < 16) {
state match {
case a :: b :: c :: d :: Nil => {
val s = shifts(iteration % 4)
val a_ = transform(a, b, c, d, X(iter(iteration)), s)
roundX(shifts, transform, iter)(X, List(d, a_, b, c), iteration + 1)
}
case _ => ??? //pass
}
} else { state }
}
def r2Iter(iteration: Int) = iteration / 4 + Array(0, 4, 8, 12)(iteration % 4)
def r3Iter(iteration: Int) = Array(0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15)(iteration)
def roundOne = roundX(Array(3, 7, 11, 19), FF, i => i)_
def roundTwo = roundX(Array(3, 5, 9, 13), GG, r2Iter)_
def roundThree = roundX(Array(3, 9, 11, 15), HH, r3Iter)_
private[this] def hashLoop(message: Array[Byte], currentHash: Array[Int]): Array[Int] = {
val chunks = message.grouped(16 * 4).toArray
val runningHash = new Array[Int](4)
Array.copy(currentHash, 0, runningHash, 0, 4)
for (i <- 0 until chunks.length) {
val chunk = chunks(i).grouped(4)
// convert chunks to little-endian ints
val X = chunk.map(x => {
ByteBuffer.wrap(x.reverse).getInt
}).toArray
val roundOneResult = roundOne(X, runningHash.toList, 0)
val roundTwoResult = roundTwo(X, roundOneResult, 0)
val roundThreeResult = roundThree(X, roundTwoResult, 0)
val nextHash = (runningHash, roundThreeResult).zipped map (_ + _)
Array.copy(nextHash, 0, runningHash, 0, 4)
}
runningHash
}
private[this] def preprocessMessage(message: Array[Byte]): Array[Byte] = {
val result = message ++ generatePadding(message)
val resultLen = result.length
result
}
private[this] def generatePadding(message: Array[Byte]): Array[Byte] = {
val initialBitLength = message.length * 8
// Leaving this in as an exercise for the reader
val bitLength = initialBitLength
// Need to append one 1-bit then 7 0-bits
// before even bothering to check the modulus
val append: Byte = 128.toByte
val appendedLength = (message.length + 1) * 8.toLong
val appendedBitLengthMod = (appendedLength % 512).toInt
val addedBitlength = 448 - appendedBitLengthMod
val finalPad = List.fill(addedBitlength / 8)(0.toByte)
val bb = ByteBuffer.allocate(8)
bb.putLong(bitLength)
// little endian
val b = bb.array().reverse
val result = (append :: finalPad ++ b).toArray
result
}
private[this] def F(x: Int, y: Int, z: Int) = ((x & y) | (~x & z))
private[this] def G(x: Int, y: Int, z: Int) = ((x & y) | (x & z) | (y & z))
private[this] def H(x: Int, y: Int, z: Int) = (x ^ y ^ z)
private[this] def FF(a: Int, b: Int, c: Int, d: Int, x: Int, s: Int): Int = {
val a1 = a + F(b, c, d) + x
Integer.rotateLeft(a1, s)
}
private[this] def GG(a: Int, b: Int, c: Int, d: Int, x: Int, s: Int): Int = {
val a1 = a + G(b, c, d) + x + 0x5a827999
Integer.rotateLeft(a1, s)
}
private[this] def HH(a: Int, b: Int, c: Int, d: Int, x: Int, s: Int): Int = {
val a1 = a + H(b, c, d) + x + 0x6ed9eba1
Integer.rotateLeft(a1, s)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment