Skip to content

Instantly share code, notes, and snippets.

@bwplotka
Last active May 14, 2020 17:17
Show Gist options
  • Save bwplotka/55a383a218f40f0a8f59da7851018c68 to your computer and use it in GitHub Desktop.
Save bwplotka/55a383a218f40f0a8f59da7851018c68 to your computer and use it in GitHub Desktop.
Go Replayable io.Reader: Useful when you want to share slice of bytes across many io.Reader sequential consumers). E.g the same reqest.Body in HTTP server used by multiple RoundTrippers (!).
package replayable
import (
"bytes"
"io"
)
type Reader struct {
wrapped io.Reader
buf []byte
offset int
}
// Rewind allows replayable Reader to be read again.
func (b *Reader) Rewind() {
if b == nil {
return
}
b.offset = 0
}
func (b *Reader) Read(p []byte) (n int, err error) {
if b == nil {
return 0, io.EOF
}
if len(b.buf)-b.offset > 0 {
n, err = bytes.NewReader(b.buf[b.offset:]).Read(p)
b.offset += n
}
if err == nil && n < len(p) {
var n64 int64
// Try to buffer rest (if needed) from wrapped io.Reader.
tmp := bytes.NewBuffer(b.buf)
n64, err = tmp.ReadFrom(io.LimitReader(b.wrapped, int64(len(p)-n)))
b.buf = tmp.Bytes()
if n64 > 0 {
copy(p[n:], b.buf[b.offset:])
n += int(n64)
b.offset += int(n64)
}
}
// Buffer.ReadFrom masks io.EOF so we assume EOF once n == 0 and no error.
if err == nil && n == 0 && len(p) > 0 {
return 0, io.EOF
}
return n, err
}
// NewReader returns replayable.Reader.
// The content read from the source is buffered in a lazy fashion to keep storage requirements
// limited to a minimum while still allowing for the reader to be rewinded and previously read
// content to be replayed.
func NewReader(src io.Reader) *Reader {
return &Reader{wrapped: src}
}
package replayable
import (
"bytes"
"io"
"testing"
"github.com/stretchr/testify/require"
)
func TestReplayableReader(t *testing.T) {
for _, tcase := range []struct {
name string
src io.Reader
sequentialReadBytes []int
rewindBeforeRead []bool
expectedBytes [][]byte
expectedErrs []error
}{
{
name: "WrappedNil_Read_ShouldReturnEOF",
src: nil,
sequentialReadBytes: []int{10},
rewindBeforeRead: []bool{false},
expectedBytes: [][]byte{{}},
expectedErrs: []error{io.EOF},
},
{
name: "WrappedNil_RewindRead_ShouldReturnEOF",
src: nil,
sequentialReadBytes: []int{10},
rewindBeforeRead: []bool{true},
expectedBytes: [][]byte{{}},
expectedErrs: []error{io.EOF},
},
{
name: "SmallBigBigReads_FinishedWithEOF",
src: bytes.NewReader([]byte{1, 2, 3, 4}),
sequentialReadBytes: []int{1, 8192, 8192},
rewindBeforeRead: []bool{false, false, false},
expectedBytes: [][]byte{{1}, {2, 3, 4}, {}},
expectedErrs: []error{nil, nil, io.EOF},
},
{
name: "SmallReads_FinishedWithEOF",
src: bytes.NewReader([]byte{1, 2, 3, 4}),
sequentialReadBytes: []int{1, 2, 4, 1},
rewindBeforeRead: []bool{false, false, false, false},
expectedBytes: [][]byte{{1}, {2, 3}, {4}, {}},
expectedErrs: []error{nil, nil, nil, io.EOF},
},
{
name: "SmallReadsTakingExactBytes",
src: bytes.NewReader([]byte{1, 2, 3, 4, 5}),
sequentialReadBytes: []int{1, 2, 2},
rewindBeforeRead: []bool{false, false, false},
expectedBytes: [][]byte{{1}, {2, 3}, {4, 5}},
expectedErrs: []error{nil, nil, nil},
},
{
name: "SmallReadsRewindSmallRead",
src: bytes.NewReader([]byte{1, 2, 3, 4, 5}),
sequentialReadBytes: []int{1, 2, 4, 2},
rewindBeforeRead: []bool{false, false, true, false},
expectedBytes: [][]byte{{1}, {2, 3}, {1, 2, 3, 4}, {5}},
expectedErrs: []error{nil, nil, nil, nil},
},
{
name: "BigReadRewindSmallReads",
src: bytes.NewReader([]byte{1, 2, 3, 4}),
sequentialReadBytes: []int{8192, 2, 3},
rewindBeforeRead: []bool{false, true, false},
expectedBytes: [][]byte{{1, 2, 3, 4}, {1, 2}, {3, 4}},
expectedErrs: []error{nil, nil, nil},
},
{
name: "BigReadRewindBigReadSmall_FinishedWithEOF",
src: bytes.NewReader([]byte{1, 2, 3, 4}),
sequentialReadBytes: []int{8192, 8192, 3},
rewindBeforeRead: []bool{false, true, false},
expectedBytes: [][]byte{{1, 2, 3, 4}, {1, 2, 3, 4}, {}},
expectedErrs: []error{nil, nil, io.EOF},
},
} {
if ok := t.Run(tcase.name, func(t *testing.T) {
b := NewReader(tcase.src)
for i, read := range tcase.sequentialReadBytes {
if tcase.rewindBeforeRead[i] {
b.Rewind()
}
toRead := make([]byte, read)
n, err := b.Read(toRead)
require.Equal(t, tcase.expectedErrs[i], err, "read %d", i+1)
require.Len(t, tcase.expectedBytes[i], n, "read %d", i+1)
require.Equal(t, tcase.expectedBytes[i], toRead[:len(tcase.expectedBytes[i])], "read %d", i+1)
}
}); !ok {
return
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment