Skip to content

Instantly share code, notes, and snippets.

@nilium
Last active November 19, 2018 08:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nilium/1ccbc05023d4327ae0055957bc676f21 to your computer and use it in GitHub Desktop.
Save nilium/1ccbc05023d4327ae0055957bc676f21 to your computer and use it in GitHub Desktop.
// This is an example message reader/writer with length prefixes and digests
// (in the example's case, using hmac-sha1) to try to confirm that the entire
// correct message was read.
package main
import (
"bytes"
"crypto/hmac"
"crypto/sha1"
"encoding/binary"
"fmt"
"hash"
"io"
"math"
)
func main() {
var b bytes.Buffer
var hasher HashFunc
// CRC64:
// isoTable := crc64.MakeTable(crc64.ISO)
// hasher = HashFunc(func() hash.Hash { return crc64.New(isoTable) })
// HMAC-SHA1:
hasher = HMACFunc("Secret key", sha1.New)
// SHA1:
// hasher = HashFunc(sha1.New)
// No hash:
// hasher = HashFunc(NullHash)
err := NewMessenger(100, hasher).WriteMsg(&b, []byte{})
if err != nil {
panic(err)
}
// Print written message
fmt.Printf("BUF = %x\n", &b)
// Create hash mismatch:
// hasher = hmacHashFunc("key 2", sha1.New)
msg, err := NewMessenger(100, hasher).ReadMsg(&b)
if err != nil {
panic(err)
}
// Print read message
fmt.Printf("MSG = %q\n", msg)
}
// HashFunc is a function that returns a new hash.Hash for use in a Messenger.
// Returned hashes are not reused.
type HashFunc func() hash.Hash
// Messenger reads and writes messages with length and digest prefixes.
type Messenger struct {
maxSize int64
hashfn HashFunc
}
// NewMessenger allocates a new Messenger for reading and writing length-and-digest prefixed
// messages. maxMsgSize specifies the maximum size of a message in bytes, and must be greater than
// eight bytes. If any arguments are invalid, NewMessenger panics.
//
// Returning an error would be more appropriate in real code, but this is a prototype.
func NewMessenger(maxMsgSize int64, hashfn HashFunc) *Messenger {
if hashfn == nil {
hashfn = NullHash
}
if maxMsgSize <= 1 {
panic("max message size must be > 1 bytes")
}
return &Messenger{
maxSize: maxMsgSize,
hashfn: hashfn,
}
}
type msgHeader struct {
Size int64
Digest []byte // Includes Size's bytes and the payload in it
}
// ReadMsg reads length, digest, and payload from r.
// If the digest of the payload does not match when filtered through Messenger's hash function, it
// returns an error. This is to prevent acceptance of corrupt, spoofed, or otherwise invalid
// messages (depending on the digest).
// In addition, too-small and too-large messages will also return errors.
// All other errors arise from reading from r.
func (m *Messenger) ReadMsg(r io.Reader) ([]byte, error) {
h := m.hashfn()
hashReader := io.TeeReader(r, h)
hashSize := h.Size()
var header msgHeader
// Read size
usize, err := binary.ReadUvarint(asByteReader(hashReader))
if err != nil {
return nil, err
} else if usize > math.MaxInt64 || int64(usize) > m.maxSize {
return nil, fmt.Errorf("message of %d bytes exceeds max size of %d bytes",
usize, m.maxSize)
}
header.Size = int64(usize)
if header.Size < int64(hashSize) {
return nil, fmt.Errorf("message of %d bytes is too short",
header.Size)
}
// Read digest
header.Digest = make([]byte, hashSize)
if _, err := io.ReadFull(r, header.Digest); err != nil {
return nil, err
}
// Read and hash payload
payloadSize := int(header.Size) - hashSize
p := make([]byte, payloadSize)
_, err = io.ReadFull(io.TeeReader(r, h), p)
if err != nil {
return nil, err
}
// Compare received and computed digests
if sum := h.Sum(nil); !hmac.Equal(sum, header.Digest) {
return nil, fmt.Errorf("message digests don't match: sent(%x) <> received(%x)",
sum, header.Digest)
}
return p, nil
}
// WriteMsg writes the payload, p, prefixed by length and digest, to the writer w.
// If the length of p combined with the header's length would exceed the Messenger's maximum message
// size, then the message is not written and an error is returned. Empty messages also return an
// error. All other errors arise from writing to w.
func (m *Messenger) WriteMsg(w io.Writer, p []byte) error {
var intbuf [10]byte
h := m.hashfn()
hashSize := int64(h.Size())
headerSize := hashSize + int64(len(intbuf))
if int64(len(p)) > m.maxSize-headerSize {
return fmt.Errorf("message of %d bytes exceeds max size of %d - header(%d) bytes",
len(p), m.maxSize, headerSize)
}
header := msgHeader{
Size: int64(len(p)) + hashSize,
Digest: make([]byte, 0, hashSize),
}
// Encode message length (excluding length itself)
intp := intbuf[:]
intlen := binary.PutUvarint(intp, uint64(header.Size))
intp = intp[:intlen]
// Compute digest
h.Write(intp) // length bytes
h.Write(p) // payload
header.Digest = h.Sum(header.Digest)
// Would love an iovec here
err := writeFull(w, intp, nil) // length bytes
err = writeFull(w, header.Digest, err) // digest bytes
err = writeFull(w, p, err) // payload
return nil
}
func writeFull(w io.Writer, p []byte, err error) error {
if err != nil {
return err
} else if n, err := w.Write(p); err != nil {
return err
} else if n < len(p) {
return io.ErrShortWrite
}
return nil
}
// byteReader is a simple combined io.ByteReader/io.Reader just for binary.ReadUvarint to be happy.
type byteReader interface {
io.Reader
io.ByteReader
}
// simpleByteReader wraps an io.Reader and implements a naive io.ByteReader on top of it.
// It doesn't implement special cases because it will only ever be used to wrap a TeeReader and is
// only expected to read up to 10 bytes. Ensuring part of the underlying reader is buffered is more
// useful than mucking with this.
type simpleByteReader struct {
io.Reader
}
func asByteReader(r io.Reader) byteReader {
return simpleByteReader{Reader: r}
}
func (b simpleByteReader) ReadByte() (byte, error) {
var p [1]byte
_, err := io.ReadFull(b.Reader, p[:])
return p[0], err
}
// Hash functions for testing:
func HMACFunc(key string, hashfn HashFunc) HashFunc {
bkey := []byte(key)
return func() hash.Hash {
return hmac.New(hashfn, bkey)
}
}
func NullHash() hash.Hash {
return nullHasher{}
}
type nullHasher struct{}
func (nullHasher) Write(p []byte) (n int, err error) {
return len(p), nil
}
func (nullHasher) Size() int { return 0 }
func (nullHasher) Reset() {}
func (nullHasher) BlockSize() int { return 1 }
func (nullHasher) Sum(h []byte) []byte { return h[:0] }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment