Skip to content

Instantly share code, notes, and snippets.

@cstockton
Created April 2, 2017 18:13
Show Gist options
  • Save cstockton/81525f8882fe56bd604f04e60e87173a to your computer and use it in GitHub Desktop.
Save cstockton/81525f8882fe56bd604f04e60e87173a to your computer and use it in GitHub Desktop.
// Package internal implements testing instrumentation for the wait package and
// serves as a staging area for experimental Waiter implementations.
package internal
import (
"fmt"
"log"
"path/filepath"
"runtime"
"strings"
"sync"
)
// Ok just returns a empty func() error that never fails.
func Ok() func() error {
return func() error {
return nil
}
}
// Err returns a func() error that always returns the given err.
func Err(err error) func() error {
return func() error {
return err
}
}
// BlockIn will block the function from executing until func() is called.
func BlockIn(fn func() error) (func() error, func()) {
startCh, empty := make(chan struct{}, 1), struct{}{}
unblock := func() {
select {
case startCh <- empty:
default:
}
}
blocker := func() error {
<-startCh
return fn()
}
return blocker, unblock
}
// BlockOut will block the function from returning until func() is called.
func BlockOut(fn func() error) (func() error, func()) {
startCh, empty := make(chan struct{}, 1), struct{}{}
unblock := func() {
select {
case startCh <- empty:
default:
}
}
blocker := func() error {
err := fn()
<-startCh
return err
}
return blocker, unblock
}
// MutexWaiter wraps the given func() error in a mutex.
type MutexWaiter struct {
sync.Mutex
Fn func() error
}
func Mutex(fn func() error) *MutexWaiter {
return &MutexWaiter{Fn: fn}
}
func (w *MutexWaiter) Wait() error {
w.Lock()
defer w.Unlock()
err := w.Fn()
return err
}
// SentinelError is used for tracking error propagation.
type SentinelError struct {
Label string
Idx int
}
func (w *SentinelError) Next() *SentinelError {
w.Idx++
return &SentinelError{`sentinel`, w.Idx}
}
func (w *SentinelError) Error() string {
return fmt.Sprintf(`Error(%v #%v)`, w.Label, w.Idx)
}
func Error() *SentinelError {
se := &SentinelError{}
return se.Next()
}
// SentinelWaiter returns unique errors to validate error propagation.
type SentinelWaiter struct {
Err error
}
func Sentinel() *SentinelWaiter {
return &SentinelWaiter{Err: Error()}
}
func (w *SentinelWaiter) Check(err error) error {
if w.Err == err {
return nil
}
return fmt.Errorf(`exp %v; got %v`, w.Err, err)
}
func (w *SentinelWaiter) Wait() error {
return w.Err
}
func (w *SentinelWaiter) String() string {
return fmt.Sprintf(`Sentinel(%v)`, w.Err)
}
// PropagationWaiter helps test error propagation
type PropagationWaiter struct {
*SentinelWaiter
Fn func() error
}
func Propagation(s *SentinelWaiter, fn func() error) *PropagationWaiter {
return &PropagationWaiter{SentinelWaiter: s, Fn: fn}
}
func (w *PropagationWaiter) Wait() error {
err := w.Fn()
if err = w.SentinelWaiter.Check(err); err != nil {
return fmt.Errorf(
`propagation error, expected %v to propagate`, w.SentinelWaiter)
}
return nil
}
func (w *PropagationWaiter) String() string {
return fmt.Sprintf(`Propagation(%v)`, w.SentinelWaiter)
}
// CountWaiter will count each call to Fn.
type CountWaiter struct {
Calls int
Fn func() error
}
func Count(fn func() error) *CountWaiter {
return &CountWaiter{Fn: fn}
}
func (w *CountWaiter) Called() bool {
return w.Calls > 0
}
func (w *CountWaiter) Check(n int) error {
if w.Calls != n {
return fmt.Errorf(`expected %d calls to CountWaiter; got %d`, n, w.Calls)
}
return nil
}
func (w *CountWaiter) Wait() error {
w.Calls++
return w.Fn()
}
func (w *CountWaiter) String() string {
return fmt.Sprintf(`Count(%v)`, w.Calls)
}
// TokenWaiter adds a token for tracking to CountWaiter.
type TokenWaiter struct {
*CountWaiter
Token int
}
func Token(token int, fn func() error) *TokenWaiter {
if fn == nil {
fn = Ok()
}
return &TokenWaiter{CountWaiter: Count(fn), Token: token}
}
func (w *TokenWaiter) String() string {
return fmt.Sprintf(`Token(%v: %v)`, w.CountWaiter.Calls, w.Token)
}
// Tracker issues funcs to be Tracked as a collective group. Check will verify
// that each tracked function has been called.
type Tracker struct {
sync.Mutex
Idx int
Buf []byte
Ptrs []uintptr
Label string
Created []*TokenWaiter
Tracked []int
}
func Track(label string) *Tracker {
return &Tracker{Label: label}
}
// Track the given func.
func (t *Tracker) Track(fn func() error) func() error {
t.Lock()
defer t.Unlock()
t.Idx++
idx := t.Idx
tw := Token(idx, func() error {
t.Lock()
t.Tracked = append(t.Tracked, idx)
t.Unlock()
return fn()
})
t.Created = append(t.Created, tw)
mFn := Mutex(func() error {
return tw.Wait()
})
return func() error {
return mFn.Wait()
}
}
// Next creates a new tracked func
func (t *Tracker) Next() func() error {
return t.Track(Ok())
}
// Check makes sure each created func was called once
func (t *Tracker) Check() error {
if err := t.Calls(1); err != nil {
return err
}
return t.Stopped()
}
// Calls makes sure each created func was called count times
func (t *Tracker) Calls(count int) error {
t.Lock()
defer t.Unlock()
if exp, got := len(t.Created), len(t.Tracked); exp != got {
return fmt.Errorf("exp %v tracked funcs; got %v", exp, got)
}
saw := make(map[int]struct{})
for _, idx := range t.Tracked {
saw[idx] = struct{}{}
}
for _, tw := range t.Created {
if _, ok := saw[tw.Token]; !ok {
return fmt.Errorf(`did not track created %v`, tw)
}
if err := tw.Check(count); err != nil {
return err
}
}
return nil
}
// Ordered makes sure each func was called in order it was issued
func (t *Tracker) Ordered() error {
t.Lock()
defer t.Unlock()
if exp, got := len(t.Created), len(t.Tracked); exp != got {
return fmt.Errorf("exp %v tracked funcs; got %v", exp, got)
}
for idx, tw := range t.Created {
saw := t.Tracked[idx]
if saw != tw.Token {
return fmt.Errorf(`token out of order, exp %v; got %v`, tw.Token, saw)
}
}
return nil
}
// Running returns the number of goroutines started with this tracker by
// inspecting the current stack frames.
func (t *Tracker) Running() int {
t.Lock()
defer t.Unlock()
return 0
}
// Stopped makes sure all goroutines that where tracked have exited.
func (t *Tracker) Stopped() error {
t.Lock()
defer t.Unlock()
return nil
}
// Caller returns a string dump of the runtime.Stack().
func (t *Tracker) Caller(skip int) error {
skip++
ptr, file, line, ok := runtime.Caller(skip)
if !ok {
return fmt.Errorf(`unable to get caller information from runtime`)
}
log.Println(`caller`)
log.Println(ptr, file, line, ok)
// runtime
fn := runtime.FuncForPC(ptr)
log.Println(fn, fn.Name())
log.Println(`FileLine(ptr)`)
log.Println(fn.FileLine(ptr))
return nil
}
// Stack returns a string dump of the runtime.Stack().
func (t *Tracker) Stack() string {
t.Lock()
defer t.Unlock()
if t.Buf == nil {
t.Buf = make([]byte, 1e4)
}
var str string
for {
bs := len(t.Buf)
ws := runtime.Stack(t.Buf, true)
if bs > ws || ws == 0 {
str = string(t.Buf[:ws])
break
}
}
return str
}
func (t *Tracker) Callers(skip int) error {
t.Lock()
defer t.Unlock()
// if t.Ptrs == nil {
// t.Ptrs = make([]uintptr, 100)
// }
//
// var pcs []uintptr
// for {
// bs := len(t.Ptrs)
// ws := runtime.Callers(0, t.Ptrs)
// if bs > ws || ws == 0 {
// pcs = t.Ptrs[:ws]
// break
// }
// }
//
// frames := runtime.CallersFrames(pcs)
// for {
// frame, more := frames.Next()
// if frame.Function != "" {
//
// }
// if !more {
// break
// }
// }
return nil
}
type TrackerCallers struct {
Name, Dir, File string
Line int
Err error
}
func (cs TrackerCallers) String() string {
return fmt.Sprintf(`CallSite(%v:%v %v())`, cs.File, cs.Line, cs.Name)
}
type TrackerCallSite struct {
Name, Dir, File string
Line int
Err error
}
func (cs TrackerCallSite) String() string {
return fmt.Sprintf(`CallSite(%v:%v %v())`, cs.File, cs.Line, cs.Name)
}
func CallSite(skip int) *TrackerCallSite {
skip++
ptr, pth, line, ok := runtime.Caller(skip)
if !ok {
return &TrackerCallSite{Err: fmt.Errorf(`unable to get caller information from runtime`)}
}
fn := runtime.FuncForPC(ptr)
name := fn.Name()
fnPth, fnLine := fn.FileLine(ptr)
if fnPth != pth {
log.Println(`pthdiff`, fnPth, pth)
}
if fnLine != line {
log.Println(`klinediff`, fnLine, line)
}
dir, file := filepath.Split(pth)
return &TrackerCallSite{
Name: name[strings.LastIndex(name, "/")+1:], // no func ends with /
File: file,
Dir: dir,
Line: line,
}
}
func (t *Tracker) String() string {
return fmt.Sprintf(`Tracker(%v)`, t.Label)
}
// TrackerGroup is a logical grouping of Tracker that calls each trackers
// method for checks.
type TrackerGroup []*Tracker
func Trackers(tracker ...*Tracker) TrackerGroup {
return TrackerGroup(tracker)
}
func (g TrackerGroup) Check() error {
for _, tw := range g {
if err := tw.Check(); err != nil {
return err
}
}
return nil
}
func (g TrackerGroup) Calls(count int) error {
for _, tw := range g {
if err := tw.Calls(count); err != nil {
return err
}
}
return nil
}
func (g TrackerGroup) Ordered() error {
for _, tw := range g {
if err := tw.Ordered(); err != nil {
return err
}
}
return nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment