Skip to content

Instantly share code, notes, and snippets.

@alex
Last active June 12, 2020 19:49
Show Gist options
  • Save alex/2a91b94c799609075b7abedb7aa8d30a to your computer and use it in GitHub Desktop.
Save alex/2a91b94c799609075b7abedb7aa8d30a to your computer and use it in GitHub Desktop.
package validate
import (
"fmt"
"io"
"unicode/utf8"
)
type validatingWriter struct {
w io.Writer
buf []byte
}
func New(w io.Writer) io.WriteCloser {
return &validatingWriter{w, nil}
}
func (w *validatingWriter) Write(p []byte) (int, error) {
buf := append(w.buf, p...)
i := 0
for i < len(buf) {
if buf[i] < utf8.RuneSelf {
i++
continue
}
r, sz := utf8.DecodeRune(buf[i:])
if r == utf8.RuneError {
if len(buf) - i >= utf8.UTFMax {
return 0, fmt.Errorf("Invalid UTF8")
} else {
_, err := w.w.Write(buf[:i])
w.buf = buf[i:]
return len(p), err
}
}
i += sz
}
_, err := w.w.Write(buf)
w.buf = nil
return len(p), err
}
func (w *validatingWriter) Close() error {
if !utf8.Valid(w.buf) {
return fmt.Errorf("Invalid UTF-8 in buffer at Close()")
}
_, err := w.w.Write(w.buf)
if err != nil {
return err
}
if c, ok := w.w.(io.WriteCloser); ok {
return c.Close()
}
return nil
}
package validate
import (
"bytes"
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func TestValidatingWriter(t *testing.T) {
for i, c := range []struct {
bufs [][]byte
expectErr bool
}{
{
bufs: [][]byte{},
},
{
bufs: [][]byte{[]byte{0x61}, []byte{0x62}, []byte{0x63}},
},
{
bufs: [][]byte{[]byte{0xF0, 0x9F}, []byte{0x99, 0x88}},
},
{
bufs: [][]byte{[]byte{0xF0, 0x9F, 0x99, 0x88}},
},
{
bufs: [][]byte{[]byte{0x61, 0x61, 0x61, 0x61, 0xF0, 0x9F}, []byte{0x99, 0x88}},
},
{
bufs: [][]byte{[]byte{0xF0, 0x9F, 0x99, 0x88, 0xF0, 0x9F}, []byte{0x99, 0x88}},
},
{
bufs: [][]byte{[]byte{0xF0}},
expectErr: true,
},
{
bufs: [][]byte{[]byte{0xF0}, []byte{0x9F, 0x99}},
expectErr: true,
},
{
bufs: [][]byte{[]byte{0xF0, 0x9F, 0x99, 0x88, 0xF0, 0x9F, 0x99}},
expectErr: true,
},
} {
t.Run(fmt.Sprintf("Case #%d", i), func(t *testing.T) {
var raw bytes.Buffer
b := new(bytes.Buffer)
writer := New(b)
for _, part := range c.bufs {
_, err := writer.Write(part)
if err != nil {
require.True(t, c.expectErr)
return
}
raw.Write(part)
}
err := writer.Close()
if err != nil {
require.True(t, c.expectErr)
return
}
require.False(t, c.expectErr)
require.Equal(t, b.Bytes(), raw.Bytes())
})
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment