Created
April 2, 2017 18:13
-
-
Save cstockton/81525f8882fe56bd604f04e60e87173a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Package internal implements testing instrumentation for the wait package and | |
// serves as a staging area for experimental Waiter implementations. | |
package internal | |
import ( | |
"fmt" | |
"log" | |
"path/filepath" | |
"runtime" | |
"strings" | |
"sync" | |
) | |
// Ok just returns a empty func() error that never fails. | |
func Ok() func() error { | |
return func() error { | |
return nil | |
} | |
} | |
// Err returns a func() error that always returns the given err. | |
func Err(err error) func() error { | |
return func() error { | |
return err | |
} | |
} | |
// BlockIn will block the function from executing until func() is called. | |
func BlockIn(fn func() error) (func() error, func()) { | |
startCh, empty := make(chan struct{}, 1), struct{}{} | |
unblock := func() { | |
select { | |
case startCh <- empty: | |
default: | |
} | |
} | |
blocker := func() error { | |
<-startCh | |
return fn() | |
} | |
return blocker, unblock | |
} | |
// BlockOut will block the function from returning until func() is called. | |
func BlockOut(fn func() error) (func() error, func()) { | |
startCh, empty := make(chan struct{}, 1), struct{}{} | |
unblock := func() { | |
select { | |
case startCh <- empty: | |
default: | |
} | |
} | |
blocker := func() error { | |
err := fn() | |
<-startCh | |
return err | |
} | |
return blocker, unblock | |
} | |
// MutexWaiter wraps the given func() error in a mutex. | |
type MutexWaiter struct { | |
sync.Mutex | |
Fn func() error | |
} | |
func Mutex(fn func() error) *MutexWaiter { | |
return &MutexWaiter{Fn: fn} | |
} | |
func (w *MutexWaiter) Wait() error { | |
w.Lock() | |
defer w.Unlock() | |
err := w.Fn() | |
return err | |
} | |
// SentinelError is used for tracking error propagation. | |
type SentinelError struct { | |
Label string | |
Idx int | |
} | |
func (w *SentinelError) Next() *SentinelError { | |
w.Idx++ | |
return &SentinelError{`sentinel`, w.Idx} | |
} | |
func (w *SentinelError) Error() string { | |
return fmt.Sprintf(`Error(%v #%v)`, w.Label, w.Idx) | |
} | |
func Error() *SentinelError { | |
se := &SentinelError{} | |
return se.Next() | |
} | |
// SentinelWaiter returns unique errors to validate error propagation. | |
type SentinelWaiter struct { | |
Err error | |
} | |
func Sentinel() *SentinelWaiter { | |
return &SentinelWaiter{Err: Error()} | |
} | |
func (w *SentinelWaiter) Check(err error) error { | |
if w.Err == err { | |
return nil | |
} | |
return fmt.Errorf(`exp %v; got %v`, w.Err, err) | |
} | |
func (w *SentinelWaiter) Wait() error { | |
return w.Err | |
} | |
func (w *SentinelWaiter) String() string { | |
return fmt.Sprintf(`Sentinel(%v)`, w.Err) | |
} | |
// PropagationWaiter helps test error propagation | |
type PropagationWaiter struct { | |
*SentinelWaiter | |
Fn func() error | |
} | |
func Propagation(s *SentinelWaiter, fn func() error) *PropagationWaiter { | |
return &PropagationWaiter{SentinelWaiter: s, Fn: fn} | |
} | |
func (w *PropagationWaiter) Wait() error { | |
err := w.Fn() | |
if err = w.SentinelWaiter.Check(err); err != nil { | |
return fmt.Errorf( | |
`propagation error, expected %v to propagate`, w.SentinelWaiter) | |
} | |
return nil | |
} | |
func (w *PropagationWaiter) String() string { | |
return fmt.Sprintf(`Propagation(%v)`, w.SentinelWaiter) | |
} | |
// CountWaiter will count each call to Fn. | |
type CountWaiter struct { | |
Calls int | |
Fn func() error | |
} | |
func Count(fn func() error) *CountWaiter { | |
return &CountWaiter{Fn: fn} | |
} | |
func (w *CountWaiter) Called() bool { | |
return w.Calls > 0 | |
} | |
func (w *CountWaiter) Check(n int) error { | |
if w.Calls != n { | |
return fmt.Errorf(`expected %d calls to CountWaiter; got %d`, n, w.Calls) | |
} | |
return nil | |
} | |
func (w *CountWaiter) Wait() error { | |
w.Calls++ | |
return w.Fn() | |
} | |
func (w *CountWaiter) String() string { | |
return fmt.Sprintf(`Count(%v)`, w.Calls) | |
} | |
// TokenWaiter adds a token for tracking to CountWaiter. | |
type TokenWaiter struct { | |
*CountWaiter | |
Token int | |
} | |
func Token(token int, fn func() error) *TokenWaiter { | |
if fn == nil { | |
fn = Ok() | |
} | |
return &TokenWaiter{CountWaiter: Count(fn), Token: token} | |
} | |
func (w *TokenWaiter) String() string { | |
return fmt.Sprintf(`Token(%v: %v)`, w.CountWaiter.Calls, w.Token) | |
} | |
// Tracker issues funcs to be Tracked as a collective group. Check will verify | |
// that each tracked function has been called. | |
type Tracker struct { | |
sync.Mutex | |
Idx int | |
Buf []byte | |
Ptrs []uintptr | |
Label string | |
Created []*TokenWaiter | |
Tracked []int | |
} | |
func Track(label string) *Tracker { | |
return &Tracker{Label: label} | |
} | |
// Track the given func. | |
func (t *Tracker) Track(fn func() error) func() error { | |
t.Lock() | |
defer t.Unlock() | |
t.Idx++ | |
idx := t.Idx | |
tw := Token(idx, func() error { | |
t.Lock() | |
t.Tracked = append(t.Tracked, idx) | |
t.Unlock() | |
return fn() | |
}) | |
t.Created = append(t.Created, tw) | |
mFn := Mutex(func() error { | |
return tw.Wait() | |
}) | |
return func() error { | |
return mFn.Wait() | |
} | |
} | |
// Next creates a new tracked func | |
func (t *Tracker) Next() func() error { | |
return t.Track(Ok()) | |
} | |
// Check makes sure each created func was called once | |
func (t *Tracker) Check() error { | |
if err := t.Calls(1); err != nil { | |
return err | |
} | |
return t.Stopped() | |
} | |
// Calls makes sure each created func was called count times | |
func (t *Tracker) Calls(count int) error { | |
t.Lock() | |
defer t.Unlock() | |
if exp, got := len(t.Created), len(t.Tracked); exp != got { | |
return fmt.Errorf("exp %v tracked funcs; got %v", exp, got) | |
} | |
saw := make(map[int]struct{}) | |
for _, idx := range t.Tracked { | |
saw[idx] = struct{}{} | |
} | |
for _, tw := range t.Created { | |
if _, ok := saw[tw.Token]; !ok { | |
return fmt.Errorf(`did not track created %v`, tw) | |
} | |
if err := tw.Check(count); err != nil { | |
return err | |
} | |
} | |
return nil | |
} | |
// Ordered makes sure each func was called in order it was issued | |
func (t *Tracker) Ordered() error { | |
t.Lock() | |
defer t.Unlock() | |
if exp, got := len(t.Created), len(t.Tracked); exp != got { | |
return fmt.Errorf("exp %v tracked funcs; got %v", exp, got) | |
} | |
for idx, tw := range t.Created { | |
saw := t.Tracked[idx] | |
if saw != tw.Token { | |
return fmt.Errorf(`token out of order, exp %v; got %v`, tw.Token, saw) | |
} | |
} | |
return nil | |
} | |
// Running returns the number of goroutines started with this tracker by | |
// inspecting the current stack frames. | |
func (t *Tracker) Running() int { | |
t.Lock() | |
defer t.Unlock() | |
return 0 | |
} | |
// Stopped makes sure all goroutines that where tracked have exited. | |
func (t *Tracker) Stopped() error { | |
t.Lock() | |
defer t.Unlock() | |
return nil | |
} | |
// Caller returns a string dump of the runtime.Stack(). | |
func (t *Tracker) Caller(skip int) error { | |
skip++ | |
ptr, file, line, ok := runtime.Caller(skip) | |
if !ok { | |
return fmt.Errorf(`unable to get caller information from runtime`) | |
} | |
log.Println(`caller`) | |
log.Println(ptr, file, line, ok) | |
// runtime | |
fn := runtime.FuncForPC(ptr) | |
log.Println(fn, fn.Name()) | |
log.Println(`FileLine(ptr)`) | |
log.Println(fn.FileLine(ptr)) | |
return nil | |
} | |
// Stack returns a string dump of the runtime.Stack(). | |
func (t *Tracker) Stack() string { | |
t.Lock() | |
defer t.Unlock() | |
if t.Buf == nil { | |
t.Buf = make([]byte, 1e4) | |
} | |
var str string | |
for { | |
bs := len(t.Buf) | |
ws := runtime.Stack(t.Buf, true) | |
if bs > ws || ws == 0 { | |
str = string(t.Buf[:ws]) | |
break | |
} | |
} | |
return str | |
} | |
func (t *Tracker) Callers(skip int) error { | |
t.Lock() | |
defer t.Unlock() | |
// if t.Ptrs == nil { | |
// t.Ptrs = make([]uintptr, 100) | |
// } | |
// | |
// var pcs []uintptr | |
// for { | |
// bs := len(t.Ptrs) | |
// ws := runtime.Callers(0, t.Ptrs) | |
// if bs > ws || ws == 0 { | |
// pcs = t.Ptrs[:ws] | |
// break | |
// } | |
// } | |
// | |
// frames := runtime.CallersFrames(pcs) | |
// for { | |
// frame, more := frames.Next() | |
// if frame.Function != "" { | |
// | |
// } | |
// if !more { | |
// break | |
// } | |
// } | |
return nil | |
} | |
type TrackerCallers struct { | |
Name, Dir, File string | |
Line int | |
Err error | |
} | |
func (cs TrackerCallers) String() string { | |
return fmt.Sprintf(`CallSite(%v:%v %v())`, cs.File, cs.Line, cs.Name) | |
} | |
type TrackerCallSite struct { | |
Name, Dir, File string | |
Line int | |
Err error | |
} | |
func (cs TrackerCallSite) String() string { | |
return fmt.Sprintf(`CallSite(%v:%v %v())`, cs.File, cs.Line, cs.Name) | |
} | |
func CallSite(skip int) *TrackerCallSite { | |
skip++ | |
ptr, pth, line, ok := runtime.Caller(skip) | |
if !ok { | |
return &TrackerCallSite{Err: fmt.Errorf(`unable to get caller information from runtime`)} | |
} | |
fn := runtime.FuncForPC(ptr) | |
name := fn.Name() | |
fnPth, fnLine := fn.FileLine(ptr) | |
if fnPth != pth { | |
log.Println(`pthdiff`, fnPth, pth) | |
} | |
if fnLine != line { | |
log.Println(`klinediff`, fnLine, line) | |
} | |
dir, file := filepath.Split(pth) | |
return &TrackerCallSite{ | |
Name: name[strings.LastIndex(name, "/")+1:], // no func ends with / | |
File: file, | |
Dir: dir, | |
Line: line, | |
} | |
} | |
func (t *Tracker) String() string { | |
return fmt.Sprintf(`Tracker(%v)`, t.Label) | |
} | |
// TrackerGroup is a logical grouping of Tracker that calls each trackers | |
// method for checks. | |
type TrackerGroup []*Tracker | |
func Trackers(tracker ...*Tracker) TrackerGroup { | |
return TrackerGroup(tracker) | |
} | |
func (g TrackerGroup) Check() error { | |
for _, tw := range g { | |
if err := tw.Check(); err != nil { | |
return err | |
} | |
} | |
return nil | |
} | |
func (g TrackerGroup) Calls(count int) error { | |
for _, tw := range g { | |
if err := tw.Calls(count); err != nil { | |
return err | |
} | |
} | |
return nil | |
} | |
func (g TrackerGroup) Ordered() error { | |
for _, tw := range g { | |
if err := tw.Ordered(); err != nil { | |
return err | |
} | |
} | |
return nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment