Last active
November 19, 2018 08:18
-
-
Save nilium/1ccbc05023d4327ae0055957bc676f21 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// This is an example message reader/writer with length prefixes and digests | |
// (in the example's case, using hmac-sha1) to try to confirm that the entire | |
// correct message was read. | |
package main | |
import ( | |
"bytes" | |
"crypto/hmac" | |
"crypto/sha1" | |
"encoding/binary" | |
"fmt" | |
"hash" | |
"io" | |
"math" | |
) | |
func main() { | |
var b bytes.Buffer | |
var hasher HashFunc | |
// CRC64: | |
// isoTable := crc64.MakeTable(crc64.ISO) | |
// hasher = HashFunc(func() hash.Hash { return crc64.New(isoTable) }) | |
// HMAC-SHA1: | |
hasher = HMACFunc("Secret key", sha1.New) | |
// SHA1: | |
// hasher = HashFunc(sha1.New) | |
// No hash: | |
// hasher = HashFunc(NullHash) | |
err := NewMessenger(100, hasher).WriteMsg(&b, []byte{}) | |
if err != nil { | |
panic(err) | |
} | |
// Print written message | |
fmt.Printf("BUF = %x\n", &b) | |
// Create hash mismatch: | |
// hasher = hmacHashFunc("key 2", sha1.New) | |
msg, err := NewMessenger(100, hasher).ReadMsg(&b) | |
if err != nil { | |
panic(err) | |
} | |
// Print read message | |
fmt.Printf("MSG = %q\n", msg) | |
} | |
// HashFunc is a function that returns a new hash.Hash for use in a Messenger. | |
// Returned hashes are not reused. | |
type HashFunc func() hash.Hash | |
// Messenger reads and writes messages with length and digest prefixes. | |
type Messenger struct { | |
maxSize int64 | |
hashfn HashFunc | |
} | |
// NewMessenger allocates a new Messenger for reading and writing length-and-digest prefixed | |
// messages. maxMsgSize specifies the maximum size of a message in bytes, and must be greater than | |
// eight bytes. If any arguments are invalid, NewMessenger panics. | |
// | |
// Returning an error would be more appropriate in real code, but this is a prototype. | |
func NewMessenger(maxMsgSize int64, hashfn HashFunc) *Messenger { | |
if hashfn == nil { | |
hashfn = NullHash | |
} | |
if maxMsgSize <= 1 { | |
panic("max message size must be > 1 bytes") | |
} | |
return &Messenger{ | |
maxSize: maxMsgSize, | |
hashfn: hashfn, | |
} | |
} | |
type msgHeader struct { | |
Size int64 | |
Digest []byte // Includes Size's bytes and the payload in it | |
} | |
// ReadMsg reads length, digest, and payload from r. | |
// If the digest of the payload does not match when filtered through Messenger's hash function, it | |
// returns an error. This is to prevent acceptance of corrupt, spoofed, or otherwise invalid | |
// messages (depending on the digest). | |
// In addition, too-small and too-large messages will also return errors. | |
// All other errors arise from reading from r. | |
func (m *Messenger) ReadMsg(r io.Reader) ([]byte, error) { | |
h := m.hashfn() | |
hashReader := io.TeeReader(r, h) | |
hashSize := h.Size() | |
var header msgHeader | |
// Read size | |
usize, err := binary.ReadUvarint(asByteReader(hashReader)) | |
if err != nil { | |
return nil, err | |
} else if usize > math.MaxInt64 || int64(usize) > m.maxSize { | |
return nil, fmt.Errorf("message of %d bytes exceeds max size of %d bytes", | |
usize, m.maxSize) | |
} | |
header.Size = int64(usize) | |
if header.Size < int64(hashSize) { | |
return nil, fmt.Errorf("message of %d bytes is too short", | |
header.Size) | |
} | |
// Read digest | |
header.Digest = make([]byte, hashSize) | |
if _, err := io.ReadFull(r, header.Digest); err != nil { | |
return nil, err | |
} | |
// Read and hash payload | |
payloadSize := int(header.Size) - hashSize | |
p := make([]byte, payloadSize) | |
_, err = io.ReadFull(io.TeeReader(r, h), p) | |
if err != nil { | |
return nil, err | |
} | |
// Compare received and computed digests | |
if sum := h.Sum(nil); !hmac.Equal(sum, header.Digest) { | |
return nil, fmt.Errorf("message digests don't match: sent(%x) <> received(%x)", | |
sum, header.Digest) | |
} | |
return p, nil | |
} | |
// WriteMsg writes the payload, p, prefixed by length and digest, to the writer w. | |
// If the length of p combined with the header's length would exceed the Messenger's maximum message | |
// size, then the message is not written and an error is returned. Empty messages also return an | |
// error. All other errors arise from writing to w. | |
func (m *Messenger) WriteMsg(w io.Writer, p []byte) error { | |
var intbuf [10]byte | |
h := m.hashfn() | |
hashSize := int64(h.Size()) | |
headerSize := hashSize + int64(len(intbuf)) | |
if int64(len(p)) > m.maxSize-headerSize { | |
return fmt.Errorf("message of %d bytes exceeds max size of %d - header(%d) bytes", | |
len(p), m.maxSize, headerSize) | |
} | |
header := msgHeader{ | |
Size: int64(len(p)) + hashSize, | |
Digest: make([]byte, 0, hashSize), | |
} | |
// Encode message length (excluding length itself) | |
intp := intbuf[:] | |
intlen := binary.PutUvarint(intp, uint64(header.Size)) | |
intp = intp[:intlen] | |
// Compute digest | |
h.Write(intp) // length bytes | |
h.Write(p) // payload | |
header.Digest = h.Sum(header.Digest) | |
// Would love an iovec here | |
err := writeFull(w, intp, nil) // length bytes | |
err = writeFull(w, header.Digest, err) // digest bytes | |
err = writeFull(w, p, err) // payload | |
return nil | |
} | |
func writeFull(w io.Writer, p []byte, err error) error { | |
if err != nil { | |
return err | |
} else if n, err := w.Write(p); err != nil { | |
return err | |
} else if n < len(p) { | |
return io.ErrShortWrite | |
} | |
return nil | |
} | |
// byteReader is a simple combined io.ByteReader/io.Reader just for binary.ReadUvarint to be happy. | |
type byteReader interface { | |
io.Reader | |
io.ByteReader | |
} | |
// simpleByteReader wraps an io.Reader and implements a naive io.ByteReader on top of it. | |
// It doesn't implement special cases because it will only ever be used to wrap a TeeReader and is | |
// only expected to read up to 10 bytes. Ensuring part of the underlying reader is buffered is more | |
// useful than mucking with this. | |
type simpleByteReader struct { | |
io.Reader | |
} | |
func asByteReader(r io.Reader) byteReader { | |
return simpleByteReader{Reader: r} | |
} | |
func (b simpleByteReader) ReadByte() (byte, error) { | |
var p [1]byte | |
_, err := io.ReadFull(b.Reader, p[:]) | |
return p[0], err | |
} | |
// Hash functions for testing: | |
func HMACFunc(key string, hashfn HashFunc) HashFunc { | |
bkey := []byte(key) | |
return func() hash.Hash { | |
return hmac.New(hashfn, bkey) | |
} | |
} | |
func NullHash() hash.Hash { | |
return nullHasher{} | |
} | |
type nullHasher struct{} | |
func (nullHasher) Write(p []byte) (n int, err error) { | |
return len(p), nil | |
} | |
func (nullHasher) Size() int { return 0 } | |
func (nullHasher) Reset() {} | |
func (nullHasher) BlockSize() int { return 1 } | |
func (nullHasher) Sum(h []byte) []byte { return h[:0] } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment