Skip to content

Instantly share code, notes, and snippets.

@ghetzel
Last active July 11, 2018 18:07
Show Gist options
  • Save ghetzel/d780121a7e5a887f4c0a8ce15e9db1e5 to your computer and use it in GitHub Desktop.
Save ghetzel/d780121a7e5a887f4c0a8ce15e9db1e5 to your computer and use it in GitHub Desktop.
package main
import (
"bufio"
"bytes"
)
type SubsequenceHandlerFunc func(seq []byte)
type ScanInterceptor struct {
accumulator *bytes.Buffer
subsequences map[string]SubsequenceHandlerFunc
longestSubsequence int
totalWritten int64
highWaterMark map[string]int64
passthrough bufio.SplitFunc
}
// A ScanInterceptor is used as a SplitFunc on a bufio.Scanner. It will look at the stream of bytes being scanned for
// specific substrings. The registered handler function associated with a substring will be called whenever it is seen
// in the stream. The passthrough SplitFunc is called as normal. This allows for a stream to be
// split and processed while also being inspected for specific content, allowing the user to react to that content
// as it comes by.
func NewScanInterceptor(passthrough bufio.SplitFunc, intercepts ...map[string]SubsequenceHandlerFunc) *ScanInterceptor {
var intercept map[string]SubsequenceHandlerFunc
if len(intercepts) == 0 {
intercept = make(map[string]SubsequenceHandlerFunc)
} else {
intercept = intercepts[0]
}
// return a new, empty interceptor
return &ScanInterceptor{
passthrough: passthrough,
accumulator: bytes.NewBuffer(nil),
subsequences: intercept,
highWaterMark: make(map[string]int64),
}
}
// Add an intercept sequence and handler. If the sequence is already registered, its handler
// function will be replaced with this one.
func (self *ScanInterceptor) Intercept(sequence string, handler SubsequenceHandlerFunc) {
self.subsequences[sequence] = handler
for k, _ := range self.subsequences {
if len(k) > self.longestSubsequence {
self.longestSubsequence = len(k)
}
}
}
// Implements the bufio.SplitFunc interface.
func (self *ScanInterceptor) Scan(data []byte, atEOF bool) (advance int, token []byte, err error) {
if _, err := self.accumulator.Write(data); err != nil {
return 0, nil, err
}
// if we've accumulated *at least* as many bytes as our longest subsequence, then
// we go to work...
if processedLen := self.accumulator.Len(); processedLen >= self.longestSubsequence {
// get the bytes we've accumulated since start or the last time we reset
soFar := self.accumulator.Bytes()
// for each registered subsequence...
for k, handler := range self.subsequences {
subseq := []byte(k)
// skip zero-length matches
if len(subseq) == 0 {
continue
}
// the High Water Mark (HWM) represents the furthest we've ever gotten in the stream.
// we make sure that our current HWM is *before* the end of the stream, so that if this
// SplitFunc is called repeatedy for the same data (which can happen), we're not firing off
// multiple handler calls for the same position(s).
//
if self.highWaterMark[k] > self.totalWritten {
continue
}
// find the index in the stream of our match (if any)
if indexOf := bytes.Index(soFar, subseq); indexOf >= 0 {
// mark the end of the stream (so we ensure we dont fire events for anything before this point)
endIndex := indexOf + len(subseq)
// fire the handler
handler(soFar[indexOf:endIndex])
// advance the HWM for this interceptor past this result
self.highWaterMark[k] = self.totalWritten + int64(endIndex)
}
}
// reset the accumulator, we go again!
self.accumulator = bytes.NewBuffer(nil)
}
// call the SplitFunc we were given
advance, token, err = self.passthrough(data, atEOF)
// however far we just advanced (if at all), keep track of that
self.totalWritten += int64(advance)
// return the results of the SplitFunc we were given
return advance, token, err
}
// Return the total number of bytes this scanner has scanned.
func (self *ScanInterceptor) BytesScanned() int64 {
return self.totalWritten
}
package main
import (
"bufio"
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
func TestScanInterceptorNothing(t *testing.T) {
assert := require.New(t)
var lines []string
splitter := NewScanInterceptor(bufio.ScanLines)
data := bytes.NewBuffer([]byte("first\nsecond\nthird\n"))
scanner := bufio.NewScanner(data)
scanner.Split(splitter.Scan)
for scanner.Scan() {
lines = append(lines, scanner.Text())
}
assert.NoError(scanner.Err())
assert.Equal([]string{
`first`,
`second`,
`third`,
}, lines)
}
// test single subsequence
// ---------------------------------------------------------------------------------------------
func TestScanInterceptorSingle(t *testing.T) {
assert := require.New(t)
errors := 0
prompts := 0
var lines []string
splitter := NewScanInterceptor(bufio.ScanLines, map[string]SubsequenceHandlerFunc{
`[error] `: func(seq []byte) {
errors += 1
},
` password: `: func(seq []byte) {
prompts += 1
},
`Password: `: func(seq []byte) {
prompts += 1
},
})
data := bytes.NewBuffer([]byte(
"Warning: Permanently added '[127.0.0.1]:2200' (ECDSA) to the list of known hosts.\n" +
"test@127.0.0.1's password: ",
))
scanner := bufio.NewScanner(data)
scanner.Split(splitter.Scan)
for scanner.Scan() {
lines = append(lines, scanner.Text())
}
assert.NoError(scanner.Err())
assert.Equal(0, errors)
assert.Equal(1, prompts)
assert.Equal([]string{
`Warning: Permanently added '[127.0.0.1]:2200' (ECDSA) to the list of known hosts.`,
`test@127.0.0.1's password: `,
}, lines)
}
// test multiple subsequences
// ---------------------------------------------------------------------------------------------
func TestScanInterceptorMultiple(t *testing.T) {
assert := require.New(t)
errors := 0
prompts := 0
var lines []string
splitter := NewScanInterceptor(bufio.ScanLines, map[string]SubsequenceHandlerFunc{
`[error] `: func(seq []byte) {
errors += 1
},
` password: `: func(seq []byte) {
prompts += 1
},
`Password: `: func(seq []byte) {
prompts += 1
},
})
data := bytes.NewBuffer([]byte(
"Password: [error] something cool went wrong\n" +
"test@127.0.0.1's password: ",
))
scanner := bufio.NewScanner(data)
scanner.Split(splitter.Scan)
for scanner.Scan() {
lines = append(lines, scanner.Text())
}
assert.NoError(scanner.Err())
assert.Equal(1, errors)
assert.Equal(2, prompts)
assert.Equal([]string{
`Password: [error] something cool went wrong`,
`test@127.0.0.1's password: `,
}, lines)
}
// test add intercept after the fact
// ---------------------------------------------------------------------------------------------
func TestScanInterceptorAddIntercept(t *testing.T) {
assert := require.New(t)
errors := 0
warnings := 0
var lines []string
splitter := NewScanInterceptor(bufio.ScanLines, map[string]SubsequenceHandlerFunc{
`[error] `: func(seq []byte) {
errors += 1
},
})
data := bytes.NewBuffer([]byte(
"Warning: Permanently added '[127.0.0.1]:2200' (ECDSA) to the list of known hosts.\n" +
"[error] something cool went wrong\n",
))
scanner := bufio.NewScanner(data)
scanner.Split(splitter.Scan)
for scanner.Scan() {
lines = append(lines, scanner.Text())
}
assert.NoError(scanner.Err())
assert.Equal(1, errors)
assert.Equal(0, warnings)
assert.Equal([]string{
`Warning: Permanently added '[127.0.0.1]:2200' (ECDSA) to the list of known hosts.`,
`[error] something cool went wrong`,
}, lines)
// new scanner, same interceptor, add new data
scanner = bufio.NewScanner(data)
scanner.Split(splitter.Scan)
splitter.Intercept(`Warning:`, func(seq []byte) {
warnings += 1
})
lines = nil
data.WriteString("some cool stuff going on OH NOOOO Warning: NOOOOOOO\n")
for scanner.Scan() {
lines = append(lines, scanner.Text())
}
assert.NoError(scanner.Err())
assert.Equal(1, warnings)
assert.Equal([]string{
`some cool stuff going on OH NOOOO Warning: NOOOOOOO`,
}, lines)
}
func TestScanInterceptorBinarySubsequence(t *testing.T) {
assert := require.New(t)
terminators := 0
splitter := NewScanInterceptor(bufio.ScanBytes)
data := bytes.NewBuffer([]byte{
0x71, 0x00, 0x5d, 0x13, 0xfe, 0x05, 0xff, 0xff,
0xe7, 0xfe, 0x00, 0x16, 0x20, 0x02, 0x07, 0x5d,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0xaa, 0x55,
})
splitter.Intercept(string([]byte{0xAA, 0x55}), func(seq []byte) {
terminators += 1
})
scanner := bufio.NewScanner(data)
scanner.Split(splitter.Scan)
for scanner.Scan() {
continue
}
assert.NoError(scanner.Err())
assert.Equal(1, terminators)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment