Last active
January 29, 2023 15:59
-
-
Save joonas-fi/c48c556b77eab28f9fed374928266c43 to your computer and use it in GitHub Desktop.
A stab at faster MD4 for Go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"encoding/binary" | |
"hash" | |
) | |
// The size of an MD4 checksum in bytes. | |
const Size = 16 | |
// The blocksize of MD4 in bytes. | |
const BlockSize = 64 | |
var ( | |
le = binary.LittleEndian // shorthand | |
) | |
type md4Hash struct { | |
a, b, c, d uint32 // digest | |
inputBytesHashed int64 // how many "payload" bytes were sent to Write() | |
// not all Write()s align with BlockSize (64 bytes), so in those cases we've to copy bytes | |
// here until we receive more writes to complete a full block. | |
queued [BlockSize]byte // reused to reduce allocations | |
queuedLength int // queued bytes = queued[:queuedLength] | |
} | |
func New() *md4Hash { | |
return &md4Hash{ | |
a: 0x67452301, | |
b: 0xefcdab89, | |
c: 0x98badcfe, | |
d: 0x10325476, | |
} | |
} | |
var _ hash.Hash = (*md4Hash)(nil) | |
func (m *md4Hash) Write(p []byte) (int, error) { | |
m.inputBytesHashed += int64(len(p)) | |
leftTowrite := p // pointer to subset of p which is advanced as we process blocks | |
for len(leftTowrite) > 0 { | |
if m.queuedLength > 0 { // have queued writes | |
if m.queuedLength+len(leftTowrite) >= BlockSize { // manage to now get a full block | |
n := copy(m.queued[m.queuedLength:], leftTowrite) | |
m.runBlock(m.queued[:]) | |
leftTowrite = leftTowrite[n:] // advance | |
m.queuedLength = 0 // reset | |
} else { // still not enough input to get a full block | |
// store in queued to wait for the next Write() | |
n := copy(m.queued[m.queuedLength:], leftTowrite) // by definition will fit | |
leftTowrite = leftTowrite[n:] // advance | |
m.queuedLength += n | |
} | |
continue | |
} | |
if len(leftTowrite) >= BlockSize { // can write full block from input | |
m.runBlock(leftTowrite) | |
leftTowrite = leftTowrite[BlockSize:] // advance | |
} else { // only partial block. need to store it so we can pass full blocks to runBlock() | |
// queued was empty. leftTowrite necessarily fits in full | |
n := copy(m.queued[:], leftTowrite) | |
m.queuedLength += n | |
break // *leftTowrite* is necessarily empty now | |
} | |
} | |
return len(p), nil | |
} | |
func (m *md4Hash) Sum(b []byte) []byte { | |
// need to copy to honor the contract that Sum() must not change internal state. | |
final := New() | |
*final = *m // copy contents | |
return final.SumUnsafe(b) | |
} | |
// faster version of Sum() that is allowed to change internal state | |
func (m *md4Hash) SumUnsafe(b []byte) []byte { | |
// capture original input's size because our following end marker, padding etc. Write() | |
// would change it | |
inputBytesHashed := m.inputBytesHashed | |
m.Write([]byte{0x80}) | |
// need 8 bytes to write length of digested stream | |
if m.queuedLength+8 > BlockSize { // doesn't fit in current queued block? pad and flush it out to start new block | |
paddingZeroes := make([]byte, BlockSize-m.queuedLength) | |
m.Write(paddingZeroes) | |
// now m.queued is quaranteed to have space for 8 bytes | |
} | |
for i := m.queuedLength; i < BlockSize; i++ { // zero rest of queued buffer | |
m.queued[i] = 0x00 | |
} | |
le.PutUint64(m.queued[BlockSize-8:], uint64(inputBytesHashed*8)) // in bits | |
m.runBlock(m.queued[:]) | |
sum := [Size]byte{} | |
le.PutUint32(sum[0*4:], m.a) | |
le.PutUint32(sum[1*4:], m.b) | |
le.PutUint32(sum[2*4:], m.c) | |
le.PutUint32(sum[3*4:], m.d) | |
return append(b, sum[:]...) | |
} | |
func (m *md4Hash) Reset() { | |
c := New() | |
*m = *c | |
} | |
func (m *md4Hash) Size() int { | |
return Size | |
} | |
func (m *md4Hash) BlockSize() int { | |
return BlockSize | |
} | |
// len(block) >= 64 is guaranteed here | |
func (m *md4Hash) runBlock(block []byte) { | |
a, b, c, d := m.runBlock2(block) | |
m.a += a | |
m.b += b | |
m.c += c | |
m.d += d | |
} | |
// returns new (a, b, c, d) which are derived from current digest, and which are supposed to be added to digest | |
func (m *md4Hash) runBlock2(block []byte) (uint32, uint32, uint32, uint32) { | |
a, b, c, d := m.a, m.b, m.c, m.d // shorthands | |
words := [16]uint32{ | |
le.Uint32(block[0*4:]), | |
le.Uint32(block[1*4:]), | |
le.Uint32(block[2*4:]), | |
le.Uint32(block[3*4:]), | |
le.Uint32(block[4*4:]), | |
le.Uint32(block[5*4:]), | |
le.Uint32(block[6*4:]), | |
le.Uint32(block[7*4:]), | |
le.Uint32(block[8*4:]), | |
le.Uint32(block[9*4:]), | |
le.Uint32(block[10*4:]), | |
le.Uint32(block[11*4:]), | |
le.Uint32(block[12*4:]), | |
le.Uint32(block[13*4:]), | |
le.Uint32(block[14*4:]), | |
le.Uint32(block[15*4:]), | |
} | |
for _, i := range []int{0, 4, 8, 12} { | |
a = rotl(a+f(b, c, d)+words[i+0], 3) | |
d = rotl(d+f(a, b, c)+words[i+1], 7) | |
c = rotl(c+f(d, a, b)+words[i+2], 11) | |
b = rotl(b+f(c, d, a)+words[i+3], 19) | |
} | |
for _, i := range []int{0, 1, 2, 3} { | |
a = rotl(a+g(b, c, d)+words[i+0]+0x5a827999, 3) | |
d = rotl(d+g(a, b, c)+words[i+4]+0x5a827999, 5) | |
c = rotl(c+g(d, a, b)+words[i+8]+0x5a827999, 9) | |
b = rotl(b+g(c, d, a)+words[i+12]+0x5a827999, 13) | |
} | |
for _, i := range []int{0, 2, 1, 3} { | |
a = rotl(a+h(b, c, d)+words[i+0]+0x6ed9eba1, 3) | |
d = rotl(d+h(a, b, c)+words[i+8]+0x6ed9eba1, 9) | |
c = rotl(c+h(d, a, b)+words[i+4]+0x6ed9eba1, 11) | |
b = rotl(b+h(c, d, a)+words[i+12]+0x6ed9eba1, 15) | |
} | |
return a, b, c, d | |
} | |
func rotl(x, n uint32) uint32 { | |
return (x << n) | (x >> (32 - n)) | |
} | |
func f(x, y, z uint32) uint32 { | |
return (x & y) | (^x & z) | |
} | |
func g(x, y, z uint32) uint32 { | |
return (x & y) | (x & z) | (y & z) | |
} | |
func h(x, y, z uint32) uint32 { | |
return x ^ y ^ z // xor | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"testing" | |
"github.com/function61/gokit/testing/assert" | |
gomd4 "golang.org/x/crypto/md4" | |
) | |
// var testMaterial = []byte("Hello world") | |
var testMaterial = gen1MBFile() | |
func BenchmarkGoMD4(b *testing.B) { | |
for n := 0; n < b.N; n++ { | |
h := gomd4.New() | |
h.Write(testMaterial) | |
h.Sum(nil) | |
} | |
} | |
func BenchmarkOurMD4(b *testing.B) { | |
for n := 0; n < b.N; n++ { | |
h := New() | |
h.Write(testMaterial) | |
h.Sum(nil) | |
} | |
} | |
func TestMD4(t *testing.T) { | |
for _, tc := range []struct { | |
input string | |
expectedOutput string | |
}{ | |
{ | |
"The quick brown fox jumps over the lazy dog", | |
"1bee69a46ba811185c194762abaeae90", | |
}, | |
{ | |
"The quick brown fox jumps over the lazy cog", | |
"b86e130ce7028da59e672d56ad0113df", | |
}, | |
{ | |
"", | |
"31d6cfe0d16ae931b73c59d7e0c089c0", | |
}, | |
{ | |
"a", | |
"bde52cb31de33e46245e05fbdbd6fb24", | |
}, | |
{ | |
"abc", | |
"a448017aaf21d8525fc10ae87aa6729d", | |
}, | |
{ | |
"message digest", | |
"d9130a8164549fe818874806e1c7014b", | |
}, | |
{ | |
"abcdefghijklmnopqrstuvwxyz", | |
"d79e1c308aa5bbcdeea8ed63df412da9", | |
}, | |
{ | |
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789", | |
"043f8582f241db351ce627e153e7f0e4", | |
}, | |
{ | |
"12345678901234567890123456789012345678901234567890123456789012345678901234567890", | |
"e33b4ddc9c38f2199c3e7b164fcc0536", | |
}, | |
} { | |
t.Run(tc.input, func(t *testing.T) { | |
h := New() | |
h.Write([]byte(tc.input)) | |
assert.EqualString(t, hex(h.Sum(nil)), tc.expectedOutput) | |
}) | |
} | |
} | |
func hex(input []byte) string { | |
return fmt.Sprintf("%x", input) | |
} | |
func gen1MBFile() []byte { | |
buf := make([]byte, 1024*1024) | |
for i := 0; i < len(buf); i++ { | |
buf[i] = byte(i) | |
} | |
return buf | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment