Skip to content

Instantly share code, notes, and snippets.

@au-phiware
Last active November 26, 2023 22:32
Show Gist options
  • Save au-phiware/1bc0b2dc0e9680251778d289e5493826 to your computer and use it in GitHub Desktop.
Save au-phiware/1bc0b2dc0e9680251778d289e5493826 to your computer and use it in GitHub Desktop.
mock module
go.sum -diff -merge

mock module

Tired of mocking libraries with cumbersome APIs? Frustrated with numerous and complicated options? Looking for a mock that works well with a composite of small interfaces or loves high ordered functions? Introducing mock, the simple mocking support that will enthusiastically accept a function that can be tailored to any bespoke test case. mock is guided by a central principle: test code must have full control of the code that runs in the mocked object. This means mock behaviour has access to anything in the test fixture and the testing.T value. This module provides a number functions that can be used as building blocks for your own mocks.

Installation

To use the mock module, ensure it is installed and imported in your project.

import mock "github.com/Versent/go-mock"

Basic Usage

  1. Define an Interface

Create one or more interfaces that your mock needs to satisfy. For example:

package my

type Getter interface {
	Get(string) (any, bool)
}

type Putter interface {
	Put(string, any) error
}
  1. Create a Mock Implementation

Implement the interface with mock methods. For example:

type mockObject struct {
	_ byte // prevent zero-sized type
}

func (m *mockObject) Get(key string) (any, bool) {
	return mock.Call2[any, bool](m, "Get", key)
}

func (m *mockObject) Put(key string, value any) error {
	return mock.Call1[error](m, "Put", key, value)
}

Note that mock.New will panic if a zero-sized type is constructed more than once.

  1. (Optional) Define Helpers

Implement Expect functions for greater readability. For example:

func ExpectGet(delegate func(testing.TB, string) (any, bool)) func(*mockObject) {
	return mock.Expect[mockObject]("Get", delegate)
}

func ExpectPut(delegate func(testing.TB, string, any) error) func(*mockObject) {
	return mock.Expect[mockObject]("Put", delegate)
}
  1. Using the Mock in Tests

Create a mock instance in your test and use it as needed. For instance:

func TestObject(t *testing.T) {
	m := mock.New(t,
  	mock.ExpectGet(func(t testing.TB, key string) (any, bool) {...}),
  	mock.ExpectPut(func(t testing.TB, key string, value any) error {...}),
  )

	// Use the mock instance in your test

	// Assert that all expected methods were called
  mock.AssertExpectedCalls(t, m)
}

Beyond Basic Usage

Be sure to checkout the Examples in the tests.

Expect Variants

In addition to the mock.Expect function, which corresponds to a single call of a method, there is also mock.ExpectMany, which will consume all remaining calls of a method.

Expect functions accepts a delegate function that matches the signature of the named method. The delegate may also accept a *testingT or testing.TB value as the first argument. This the same testing.T that was used to construct the mock (first argument to mock.New). In addition, ExpectMany optionally accepts the method's call count.

Ordered Calls

The mock.ExpectInOrder will ensure that calls occur in a specified order. For example, this will fail if Put is called before Get:

mock.New(t, mock.ExpectInOrder(mock.Expect("Get", ...), mock.Expect("Put", ...)))
package mock
import (
"errors"
"fmt"
"reflect"
"testing"
)
// Callable defines an interface for delegates to call test functions.
type Callable interface {
Call(testing.TB, CallCount, []reflect.Value) []reflect.Value
}
// MultiCallable defines an interface for Callable objects that can be called
// multiple times.
type MultiCallable interface {
MultiCallable() bool
}
// Callables is a slice of Callable objects.
type Callables []Callable
// Len returns the number of Callables in the slice.
func (c Callables) Len() int {
return len(c)
}
// Cap returns the capacity of the slice of Callables.
func (c Callables) Cap() int {
return cap(c)
}
// Append adds one or more Callables to the slice.
func (c Callables) Append(callable ...Callable) Callables {
return append(c, callable...)
}
// Call invokes the Callable at the given index with the given arguments.
// Panics if the index is out of range and the last Callable is not a
// MultiCallable.
func (c Callables) Call(t testing.TB, index CallCount, in []reflect.Value) []reflect.Value {
if int(index) < len(c) {
return c[index].Call(t, index, in)
}
if c.MultiCallable() {
return c[len(c)-1].Call(t, index, in)
}
panic(fmt.Sprintf("Callables.Call: index out of range [%d] with length %d", index, len(c)))
}
// MultiCallable returns true if the last Callable in the slice is a
// MultiCallable.
func (c Callables) MultiCallable() bool {
if len(c) == 0 {
return false
}
if m, ok := c[len(c)-1].(MultiCallable); ok {
return m.MultiCallable()
}
return false
}
// Value is a Callable that wraps a reflect.Value.
type Value struct {
reflect.Value
ordered
}
// Call invokes the Callable with the given arguments. If the Callable is variadic,
// the last argument must be passed as a slice, otherwise this method panics.
func (v Value) Call(t testing.TB, i CallCount, in []reflect.Value) []reflect.Value {
fn := v.Value
if fn.Kind() != reflect.Func {
panic(fmt.Sprintf("Value.Call: expected func, got %T", v))
}
if fn.Type().NumIn() == len(in)+1 {
in = append([]reflect.Value{reflect.ValueOf(t)}, in...)
}
if fn.Type().IsVariadic() {
return fn.CallSlice(in)
} else {
return fn.Call(in)
}
}
// multi is a Callable that wraps a reflect.Value and implements MultiCallable.
type multi Value
// MultiCallable returns true.
func (v multi) MultiCallable() bool { return true }
// Call invokes the Callable with the given arguments.
func (v multi) Call(t testing.TB, i CallCount, in []reflect.Value) []reflect.Value {
funcType := v.Value.Type()
if funcType.NumIn() > 0 && funcType.In(0) == reflect.TypeOf(i) ||
funcType.NumIn() > 1 && funcType.In(1) == reflect.TypeOf(i) {
in = append([]reflect.Value{reflect.ValueOf(i)}, in...)
}
return Value(v).Call(t, i, in)
}
// errType is the type of the error interface.
var errType = reflect.TypeOf((*error)(nil)).Elem()
// CallDelegate calls the next Callable of the Delegate with the given name and
// given arguments. If the delegate is variadic then the last argument must be
// a slice, otherwise this function panics. If the next Callable does not
// exist or the last Callable is not MultiCallable, then the mock object will
// be marked as failed. In the case of a fail and if the delegate function
// returns an error as its last return value, then the error will be set and
// returned otherwise the function returns zero values for all of the return
// values.
func CallDelegate[T any](key *T, name string, outTypes []reflect.Type, in ...reflect.Value) (out []reflect.Value) {
mock := registry[key]
t := mock.TB
t.Helper()
delegate := delegateByName(mock, name)
delegate.Lock()
defer delegate.Unlock()
if int(delegate.callCount) >= delegate.Len() && !delegate.MultiCallable() {
msg := "unexpected call to " + name
t.Error(msg)
out = make([]reflect.Value, 0, len(outTypes))
for _, typ := range outTypes {
out = append(out, reflect.Zero(typ))
}
// set last out to error
if i := len(out) - 1; i >= 0 && outTypes[i].Implements(errType) {
out[i] = reflect.ValueOf(errors.New(msg))
}
return
}
var (
fn Value
ok bool
)
if int(delegate.callCount) < delegate.Len() {
fn, ok = delegate.Callables[delegate.callCount].(Value)
} else {
fn, ok = delegate.Callables[delegate.Len()-1].(Value)
}
if fn.inOrder {
mock.ordinal++
}
if ok && fn.ordinal != mock.ordinal {
err := fmt.Sprintf("out of order call to %s: expected %d, got %d", name, fn.ordinal, mock.ordinal)
t.Error(err)
}
t.Logf("call to %s: %d/%d", name, delegate.callCount, mock.ordinal)
defer func() { delegate.callCount++ }()
return delegate.Call(t, delegate.callCount, in)
}
// toValues converts the given values to reflect.Values.
func toValues(in ...any) (out []reflect.Value) {
out = make([]reflect.Value, len(in))
for i, v := range in {
out[i] = reflect.ValueOf(v)
}
return
}
// doCall calls the next Callable of the Delegate with the given name and given
// arguments and sets the given out values to the return values of the Callable.
// If the types of the return values do not match the types of the out values,
// or if the number of return values does not match the number of out values,
// then the last out value will be set to an error if it is assignable to an
// error type otherwise this function will panic.
func doCall[T any](key *T, name string, in []reflect.Value, out []reflect.Value) {
registry[key].Helper()
outTypes := make([]reflect.Type, len(out))
for i := range out {
outTypes[i] = out[i].Type().Elem()
}
results := CallDelegate(key, name, outTypes, in...)
last := len(outTypes) - 1
var err error
if len(results) != len(outTypes) {
err = fmt.Errorf("unexpected number of results: expected %d, got %d", len(outTypes), len(results))
}
for i := range out {
if err != nil {
break
}
if !results[i].IsZero() {
if results[i].Type().AssignableTo(outTypes[i]) {
out[i].Elem().Set(results[i])
} else {
err = fmt.Errorf("unexpected type %T for result parameter %T", results[i].Interface(), out[i].Interface())
}
}
}
if err != nil {
registry[key].Error(err)
t2 := outTypes[last]
if reflect.TypeOf(err).ConvertibleTo(t2) {
out[last].Elem().Set(reflect.ValueOf(err).Convert(t2))
} else {
panic(err)
}
}
}
package mock
import (
"errors"
"reflect"
"testing"
)
func TestDoCall(t *testing.T) {
tests := []struct {
name string
callables Callables
in []reflect.Value
out []reflect.Value
results []reflect.Value
expectFail bool
expectPanic bool
}{
{
name: "Matching types and values",
callables: Callables{Value{Value: reflect.ValueOf(func(t testing.TB, in string) string {
if in != "input" {
t.Errorf("unexpected input: expected %q, got %q", "input", in)
}
return "result"
})}},
in: toValues("input"),
out: toValues(new(string)),
results: toValues("result"),
expectFail: false,
},
{
name: "Matching types and values, multi",
callables: Callables{multi{Value: reflect.ValueOf(func(t testing.TB, count CallCount, in string) string {
if count != 0 {
t.Errorf("unexpected count: expected %d, got %d", 0, count)
}
if in != "input" {
t.Errorf("unexpected input: expected %q, got %q", "input", in)
}
return "result"
})}},
in: toValues("input"),
out: toValues(new(string)),
results: toValues("result"),
expectFail: false,
},
{
name: "Matching types and values, variadic",
callables: Callables{Value{Value: reflect.ValueOf(func(t testing.TB, in ...string) string {
if in[0] != "input" {
t.Errorf("unexpected input: expected %q, got %q", "input", in)
}
return "result"
})}},
in: toValues([]string{"input"}),
out: toValues(new(string)),
results: toValues("result"),
expectFail: false,
},
{
name: "Type mismatch",
callables: Callables{Value{Value: reflect.ValueOf(func() string {
return "result"
})}},
in: toValues(),
out: toValues(new(int)),
results: toValues(0),
expectFail: true,
expectPanic: true,
},
{
name: "Unexpected number of results, panic",
callables: Callables{Value{Value: reflect.ValueOf(func() {})}},
in: toValues(),
out: toValues(new(int)),
results: toValues(0),
expectFail: true,
expectPanic: true,
},
{
name: "Unexpected number of results, error",
callables: Callables{Value{Value: reflect.ValueOf(func() {})}},
in: toValues(),
out: toValues(new(error)),
results: toValues(errors.New("unexpected number of results: expected 1, got 0")),
expectFail: true,
},
{
name: "Unexpected number of calls",
callables: Callables{},
in: toValues(),
out: toValues(new(error)),
results: toValues(errors.New("unexpected call to testMethod")),
expectFail: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
key := &tt.name
mockT := new(testing.T)
defer func() {
// Check for errors in test output
if tt.expectFail && !mockT.Failed() {
t.Errorf("expected a failure, got none")
} else if !tt.expectFail && mockT.Failed() {
t.Errorf("expected no failure, got fail")
}
if len(tt.out) != len(tt.results) {
t.Fatalf("expected %d results, got %d", len(tt.results), len(tt.out))
}
for i := range tt.results {
if !reflect.DeepEqual(tt.out[i].Elem().Interface(), tt.results[i].Interface()) {
t.Errorf("out[%d]: expected %q, got %q", i, tt.results[i].Interface(), tt.out[i].Elem().Interface())
}
}
}()
if tt.expectPanic {
defer func() {
if recover() == nil {
t.Errorf("Expected a panic, got none")
}
}()
}
registry[key] = &mock{
TB: mockT,
Delegates: Delegates{
"testMethod": &Delegate{
Callables: tt.callables,
},
},
}
t.Cleanup(func() {
delete(registry, key)
})
// Call doCall
doCall(key, "testMethod", tt.in, tt.out)
})
}
}
package mock
import "sync"
type CallCount int
// Delegate represents a function that is expected to be called.
type Delegate struct {
sync.Mutex
Callables
callCount CallCount
}
// Append adds one or more callables to the delegate.
func (d *Delegate) Append(callable ...Callable) Callables {
d.Lock()
defer d.Unlock()
d.Callables = d.Callables.Append(callable...)
return d.Callables
}
// delegateByName retrieves or creates a Delegate for a given method name. It
// is safe to call from multiple goroutines.
func delegateByName(mock *mock, name string) (delegate *Delegate) {
var ok bool
delegate, ok = mock.Delegates[name]
if !ok {
mock.Lock()
defer mock.Unlock()
delegate, ok = mock.Delegates[name]
if !ok {
delegate = new(Delegate)
mock.Delegates[name] = delegate
}
}
return
}
package mock_test
import (
"fmt"
"testing"
mock "github.com/Versent/go-mock"
)
// Cache contains a variety of methods with different signatures.
type Cache interface {
Put(string, any) error
Get(string) (any, bool)
Delete(string)
Load(...string)
}
// mockCache is a mock implementation of Cache. It can be anything, but
// zero-sized types are problematic.
type mockCache struct {
_ byte // prevent zero-sized type
}
// Put returns one value, so use mock.Call1.
func (m *mockCache) Put(key string, value any) error {
return mock.Call1[error](m, "Put", key, value)
}
// Get returns two values, so use mock.Call2.
func (m *mockCache) Get(key string) (any, bool) {
return mock.Call2[any, bool](m, "Get", key)
}
// Delete returns no values, so use mock.Call0.
func (m *mockCache) Delete(key string) {
mock.Call0(m, "Delete", key)
}
// Load is variadic, the last argument must be passed as a slice to one of the
// mock.CallN functions.
func (m *mockCache) Load(keys ...string) {
mock.Call0(m, "Load", keys)
}
// UnusedCache is useful to show that a test's intent is that none of the
// interface methods are called.
var UnusedCache func(*mockCache) = nil
func ExampleUnusedCache() {
t := &exampleT{} // or any testing.TB, your test does not create this
// 1. Create a mock object.
var cache Cache = mock.New(t, UnusedCache)
// 2. Use the mock object in your code under test.
// 3. Assert that all expected methods were called.
mock.AssertExpectedCalls(t, cache)
// mock will fail a test if a call is made to an unexpected method or if
// the expected methods are not called.
fmt.Println("less than expected:", t.Failed())
// Output:
// less than expected: false
}
// ExpectDelete is a helper function that hides the stringiness of mock.
func ExpectDelete(delegate func(t testing.TB, key string)) func(*mockCache) {
return mock.Expect[mockCache]("Delete", delegate)
}
func Example_pass() {
t := &exampleT{} // or any testing.TB, your test does not create this
// 1. Create a mock object with expected calls.
var cache Cache = mock.New(t,
// delegate function can receive testing.TB
mock.Expect[mockCache]("Get", func(t testing.TB, key string) (any, bool) {
return "bar", true
}),
mock.Expect[mockCache]("Put", func(t testing.TB, key string, value any) error {
return nil
}),
// or only the method arguments
mock.Expect[mockCache]("Delete", func(key string) {}),
// you may prefer to define a helper function
ExpectDelete(func(t testing.TB, key string) {}),
)
// 2. Use the mock object in your code under test.
cache.Put("foo", "bar")
cache.Get("foo")
cache.Delete("foo")
cache.Delete("foo")
// 3. Assert that all expected methods were called.
mock.AssertExpectedCalls(t, cache)
// mock will not fail the test
fmt.Println("less than expected:", t.Failed())
// Output:
// call to Put: 0/0
// call to Get: 0/0
// call to Delete: 0/0
// call to Delete: 1/0
// less than expected: false
}
func Example_unmetExpectation() {
t := &testing.T{} // or any testing.TB, your test does not create this
// 1. Create a mock object with expected calls.
var cache Cache = mock.New(t,
// delegate function can receive testing.TB
mock.Expect[mockCache]("Put", func(t testing.TB, key string, value any) error {
fmt.Println("put", key, value)
return nil
}),
// or *testing.T
mock.Expect[mockCache]("Get", func(t *testing.T, key string) (any, bool) {
fmt.Println("get", key)
return "bar", true
}),
// or only the method arguments
mock.Expect[mockCache]("Delete", func(key string) {
fmt.Println("delete", key)
}),
// you may prefer to define a helper function
ExpectDelete(func(t testing.TB, key string) {
t.Log("this is not going to be called; causing t.Fail() to be called by mock.AssertExpectedCalls")
}),
)
// 2. Use the mock object in your code under test.
cache.Put("foo", "bar")
cache.Get("foo")
cache.Delete("foo")
// 3. Assert that all expected methods were called.
mock.AssertExpectedCalls(t, cache)
// mock will fail the test because the second call to Delete is not met.
fmt.Println("less than expected:", t.Failed())
// Output:
// put foo bar
// get foo
// delete foo
// less than expected: true
}
func Example_unexpectedCall() {
t := &testing.T{} // or any testing.TB, your test does not create this
// 1. Create a mock object with expected calls.
var cache Cache = mock.New(t,
// delegate function can receive testing.TB
mock.Expect[mockCache]("Put", func(t testing.TB, key string, value any) error {
fmt.Println("put", key, value)
return nil
}),
// or only the method arguments
mock.Expect[mockCache]("Delete", func(key string) {
fmt.Println("delete", key)
}),
)
// 2. Use the mock object in your code under test.
cache.Put("foo", "bar")
cache.Get("foo")
cache.Delete("foo")
// 3. Assert that all expected methods were called.
mock.AssertExpectedCalls(t, cache)
// mock will fail the test because the call to Get is not expected.
fmt.Println("more than expected:", t.Failed())
// Output:
// put foo bar
// delete foo
// more than expected: true
}
func Example_allowRepeatedCalls() {
t := &testing.T{} // or any testing.TB, your test does not create this
// 1. Create a mock object with ExpectMany.
var cache Cache = mock.New(t,
// delegate function may receive a call counter and the method arguments
mock.ExpectMany[mockCache]("Load", func(n mock.CallCount, keys ...string) {
fmt.Println("load", n, keys)
}),
// and testing.TB
mock.ExpectMany[mockCache]("Load", func(t testing.TB, n mock.CallCount, keys ...string) {
fmt.Println("load", n, keys)
}),
// or *testing.T
mock.ExpectMany[mockCache]("Load", func(t *testing.T, n mock.CallCount, keys ...string) {
fmt.Println("load", n, keys)
}),
// or only testing.TB/*testing.T
mock.ExpectMany[mockCache]("Load", func(t testing.TB, keys ...string) {
fmt.Println("load 3", keys)
}),
// or only the method arguments
mock.ExpectMany[mockCache]("Load", func(keys ...string) {
fmt.Println("load 4", keys)
}),
)
// 2. Use the mock object in your code under test.
cache.Load("foo", "bar")
cache.Load("baz")
cache.Load("foo")
cache.Load("bar")
cache.Load("baz")
cache.Load("foo", "bar", "baz")
// 3. Assert that all expected methods were called.
mock.AssertExpectedCalls(t, cache)
// mock will not fail the test because ExpectMany allows repeated calls.
fmt.Println("more than expected:", t.Failed())
// Output:
// load 0 [foo bar]
// load 1 [baz]
// load 2 [foo]
// load 3 [bar]
// load 4 [baz]
// load 4 [foo bar baz]
// more than expected: false
}
func Example_orderedCalls() {
t := &testing.T{} // or any testing.TB, your test does not create this
// 1. Create a mock object with ExpectInOrder.
var cache Cache = mock.New(t,
mock.ExpectInOrder(
mock.Expect[mockCache]("Put", func(key string, value any) error {
fmt.Println("put", key, value)
return nil
}),
mock.Expect[mockCache]("Get", func(key string) (any, bool) {
fmt.Println("get", key)
return "bar", true
}),
),
)
// 2. Use the mock object in your code under test.
cache.Get("foo")
cache.Put("foo", "bar")
// 3. Assert that all expected methods were called.
mock.AssertExpectedCalls(t, cache)
// mock will fail the test because the call to Get is before the call
// to Put.
fmt.Println("less than expected:", t.Failed())
// Output:
// get foo
// put foo bar
// less than expected: true
}
func Example_mixedOrderedCalls() {
t := &exampleT{} // or any testing.TB, your test does not create this
// 1. Create a mock object with ExpectInOrder.
get := mock.Expect[mockCache]("Get", func(key string) (any, bool) {
return "bar", true
})
put := mock.Expect[mockCache]("Put", func(key string, value any) error {
return nil
})
var cache Cache = mock.New(t,
get, put,
mock.ExpectInOrder(put, get),
get, put,
)
// 2. Use the mock object in your code under test.
for i := 0; i < 3; i++ {
cache.Put(fmt.Sprint("foo", i), "bar")
cache.Get(fmt.Sprint("foo", i))
}
// 3. Assert that all expected methods were called.
mock.AssertExpectedCalls(t, cache)
// mock will not fail the test
fmt.Println("less than expected:", t.Failed())
// Output:
// call to Put: 0/0
// call to Get: 0/0
// call to Put: 1/1
// call to Get: 1/2
// call to Put: 2/2
// call to Get: 2/2
// less than expected: false
}
var _ testing.TB = &exampleT{}
type exampleT struct {
testing.T
}
func (t *exampleT) Fatal(args ...any) {
fmt.Println(args...)
t.T.FailNow()
}
func (t *exampleT) Fatalf(format string, args ...any) {
fmt.Printf(format+"\n", args...)
t.T.FailNow()
}
func (t *exampleT) Error(args ...any) {
fmt.Println(args...)
t.T.Fail()
}
func (t *exampleT) Errorf(format string, args ...any) {
fmt.Printf(format+"\n", args...)
t.T.Fail()
}
func (t *exampleT) Log(args ...any) {
fmt.Println(args...)
}
func (t *exampleT) Logf(format string, args ...any) {
fmt.Printf(format+"\n", args...)
}
module github.com/Versent/go-mock
go 1.20
package mock
import "testing"
// AssertExpectedCalls asserts that all expected callables of all delegates of
// the given mocks were called.
func AssertExpectedCalls(t testing.TB, mocks ...any) {
t.Helper()
for _, key := range mocks {
if key == nil {
continue
}
if mock, ok := key.(interface{ AssertExpectedCalls(testing.TB) }); ok {
mock.AssertExpectedCalls(t)
continue
}
mock, ok := registry[key]
if !ok {
t.Fatalf("mock not found: %T", key)
}
for name, delegate := range mock.Delegates {
if count := delegate.callCount; int(count) < delegate.Len() {
if count == 0 {
t.Errorf("failed to make call to %s", name)
} else if count == 1 {
t.Errorf("failed to make call to %s: only got one call", name)
} else {
t.Errorf("failed to make call to %s: only got %d calls", name, count)
}
}
}
}
}
// Call0 calls the function of the given name for the given mock with the
// given arguments. If the function is variadic then the last argument must be
// passed as a slice, otherwise this function panics. The function is expected
// to return no result values, otherwise the will be marked as a fail and this
// function will panic.
func Call0[T any](key *T, name string, in ...any) {
registry[key].Helper()
CallDelegate(key, name, nil, toValues(in...)...)
}
// Call1 calls the function of the given name for the given mock with the
// given arguments. If the function is variadic then the last argument must be
// passed as a slice, otherwise this function panics. The function is expected
// to return one result value, otherwise the will be marked as a fail and this
// function will return an error when T1 is assignable to an error type, or
// this function will panic.
func Call1[T1, T any](key *T, name string, in ...any) (v T1) {
registry[key].Helper()
doCall(key, name, toValues(in...), toValues(&v))
return
}
// Call2 calls the function of the given name for the given mock with the
// given arguments. If the function is variadic then the last argument must be
// passed as a slice, otherwise this function panics. The function is expected
// to return two result values, otherwise the will be marked as a fail and this
// function will return an error when T2 is assignable to an error type, or
// this function will panic.
func Call2[T1, T2, T any](key *T, name string, in ...any) (v1 T1, v2 T2) {
registry[key].Helper()
doCall(key, name, toValues(in...), toValues(&v1, &v2))
return
}
// Call3 calls the function of the given name for the given mock with the
// given arguments. If the function is variadic then the last argument must be
// passed as a slice, otherwise this function panics. The function is expected
// to return three result values, otherwise the will be marked as a fail and
// this function will return an error when T3 is assignable to an error type,
// or this function will panic.
func Call3[T1, T2, T3, T any](key *T, name string, in ...any) (v1 T1, v2 T2, v3 T3) {
registry[key].Helper()
doCall(key, name, toValues(in...), toValues(&v1, &v2, &v3))
return
}
// Call4 calls the function of the given name for the given mock with the
// given arguments. If the function is variadic then the last argument must be
// passed as a slice, otherwise this function panics. The function is expected
// to return four result values, otherwise the will be marked as a fail and
// this function will return an error when T4 is assignable to an error type,
// or this function will panic.
func Call4[T1, T2, T3, T4, T any](key *T, name string, in ...any) (v1 T1, v2 T2, v3 T3, v4 T4) {
registry[key].Helper()
doCall(key, name, toValues(in...), toValues(&v1, &v2, &v3, &v4))
return
}
// Call5 calls the function of the given name for the given mock with the
// given arguments. If the function is variadic then the last argument must be
// passed as a slice, otherwise this function panics. The function is expected
// to return 5 result values, otherwise the will be marked as a fail and this
// function will return an error when T5 is assignable to an error type, or
// this function will panic.
func Call5[T1, T2, T3, T4, T5, T any](key *T, name string, in ...any) (v1 T1, v2 T2, v3 T3, v4 T4, v5 T5) {
registry[key].Helper()
doCall(key, name, toValues(in...), toValues(&v1, &v2, &v3, &v4, &v5))
return
}
// Call6 calls the function of the given name for the given mock with the
// given arguments. If the function is variadic then the last argument must be
// passed as a slice, otherwise this function panics. The function is expected
// to return 6 result values, otherwise the will be marked as a fail and this
// function will return an error when T6 is assignable to an error type, or
// this function will panic.
func Call6[T1, T2, T3, T4, T5, T6, T any](key *T, name string, in ...any) (v1 T1, v2 T2, v3 T3, v4 T4, v5 T5, v6 T6) {
registry[key].Helper()
doCall(key, name, toValues(in...), toValues(&v1, &v2, &v3, &v4, &v5, &v6))
return
}
// Call7 calls the function of the given name for the given mock with the
// given arguments. If the function is variadic then the last argument must be
// passed as a slice, otherwise this function panics. The function is expected
// to return 7 result values, otherwise the will be marked as a fail and this
// function will return an error when T7 is assignable to an error type, or
// this function will panic.
func Call7[T1, T2, T3, T4, T5, T6, T7, T any](key *T, name string, in ...any) (v1 T1, v2 T2, v3 T3, v4 T4, v5 T5, v6 T6, v7 T7) {
registry[key].Helper()
doCall(key, name, toValues(in...), toValues(&v1, &v2, &v3, &v4, &v5, &v6, &v7))
return
}
// Call8 calls the function of the given name for the given mock with the
// given arguments. If the function is variadic then the last argument must be
// passed as a slice, otherwise this function panics. The function is expected
// to return 8 result values, otherwise the will be marked as a fail and this
// function will return an error when T8 is assignable to an error type, or
// this function will panic.
func Call8[T1, T2, T3, T4, T5, T6, T7, T8, T any](key *T, name string, in ...any) (v1 T1, v2 T2, v3 T3, v4 T4, v5 T5, v6 T6, v7 T7, v8 T8) {
registry[key].Helper()
doCall(key, name, toValues(in...), toValues(&v1, &v2, &v3, &v4, &v5, &v6, &v7, &v8))
return
}
// Call9 calls the function of the given name for the given mock with the
// given arguments. If the function is variadic then the last argument must be
// passed as a slice, otherwise this function panics. The function is expected
// to return 9 result values, otherwise the will be marked as a fail and this
// function will return an error when T9 is assignable to an error type, or
// this function will panic.
func Call9[T1, T2, T3, T4, T5, T6, T7, T8, T9, T any](key *T, name string, in ...any) (v1 T1, v2 T2, v3 T3, v4 T4, v5 T5, v6 T6, v7 T7, v8 T8, v9 T9) {
registry[key].Helper()
doCall(key, name, toValues(in...), toValues(&v1, &v2, &v3, &v4, &v5, &v6, &v7, &v8, &v9))
return
}
MIT License
Copyright (c) 2023 Versent
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
// Package mock provides a flexible and functional mocking framework for Go
// tests.
package mock
import (
"fmt"
"reflect"
"sync"
"testing"
)
var (
// registry holds the active mock objects.
registry = make(map[any]*mock)
)
// Delegates maps function names to their Delegate implementations.
type Delegates = map[string]*Delegate
// Option defines a function that configures a mock object.
type Option[T any] func(*T)
func Options[T any](opts ...Option[T]) Option[T] {
return func(key *T) {
for _, opt := range opts {
opt(key)
}
}
}
// mock represents a mock object.
type mock struct {
testing.TB
sync.Mutex
Delegates
ordered
}
// New creates a new mock object of type T and applies the given options.
// It panics if a mock for a zero-sized type is constructed more than once.
func New[T any](t testing.TB, opts ...Option[T]) *T {
key := new(T)
mock := &mock{
TB: t,
Delegates: Delegates{},
}
if _, ok := registry[key]; ok {
panic(fmt.Sprintf("mock.New: zero-sized type used to construct more than one mock: %T", key))
}
registry[key] = mock
t.Cleanup(func() {
delete(registry, key)
})
for _, opt := range opts {
if opt == nil {
continue
}
opt(key)
}
mock.ordinal = 0
return key
}
// Expect registers a function to be called exactly once when a method with the
// given name is invoked on the mock object.
// The function signature of fn must match the named method signature,
// except that the first argument may optionally be a testing.TB or *testing.T.
// Panics if fn is not a function.
func Expect[T any](name string, fn any) Option[T] {
funcType := reflect.TypeOf(fn)
if funcType.Kind() != reflect.Func {
panic(fmt.Sprintf("mock.Expect: expected function, got %T", fn))
}
return func(key *T) {
mock := registry[key]
mock.Helper()
delegate := delegateByName(mock, name)
if mock.inOrder {
mock.ordinal++
}
delegate.Append(Value{
Value: reflect.ValueOf(fn),
ordered: mock.ordered,
})
}
}
// ExpectMany registers a function to be called at least once for a method with
// the given name on the mock object.
// Like Expect, the arguments of fn must match the named method signature and may optionally be
// preceded by a testing.TB or *testing.T.
// In addition, the first argument of fn may optionally be of type CallCount, in such cases fn will
// be passed the total number of times the method has been called (starting at 0).
// Panics if fn is not a function.
func ExpectMany[T any](name string, fn any) Option[T] {
funcType := reflect.TypeOf(fn)
if funcType.Kind() != reflect.Func {
panic(fmt.Sprintf("mock.ExpectMany: expected function, got %T", fn))
}
return func(key *T) {
mock := registry[key]
mock.Helper()
if mock.inOrder {
mock.ordinal++
}
delegateByName(mock, name).Append(multi{
Value: reflect.ValueOf(fn),
ordered: mock.ordered,
})
}
}
package mock_test
import (
"testing"
mock "github.com/Versent/go-mock"
)
func TestNew_identity(t *testing.T) {
t.Run("mockCache", func(t *testing.T) {
m1 := mock.New[mockCache](t)
m2 := mock.New[mockCache](t)
if m1 == m2 {
t.Error("expected different mocks")
}
})
t.Run("mock.Delegates", func(t *testing.T) {
type T mock.Delegates
m1 := mock.New[T](t)
m2 := mock.New[T](t)
if m1 == m2 {
t.Error("expected different mocks")
}
})
t.Run("zero-sized", func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Error("expected panic")
} else if r != "mock.New: zero-sized type used to construct more than one mock: *mock_test.T" {
t.Error("unexpected panic:", r)
}
}()
type T struct{}
_ = mock.New[T](t)
_ = mock.New[T](t)
})
}
func TestNew_Expect(t *testing.T) {
called := false
var cache Cache = mock.New(&testing.T{},
mock.Expect[mockCache]("Put", func(_ testing.TB, key string, value any) error {
if key != "foo" && value != "bar" {
t.Error("unexpected arguments")
}
called = true
return nil
}),
mock.Expect[mockCache]("Get", func(_ *testing.T, key string) (any, bool) {
if key != "foo" {
t.Error("unexpected arguments")
}
called = true
return "bar", true
}),
mock.Expect[mockCache]("Delete", func(key string) {
if key != "foo" {
t.Error("unexpected arguments")
}
called = true
}),
ExpectDelete(func(_ testing.TB, key string) {
t.Error("this should not be called")
}),
)
called = false
if err := cache.Put("foo", "bar"); err != nil {
t.Error("unexpected error:", err)
}
if !called {
t.Error("expected call to Put delegate")
}
called = false
if result, ok := cache.Get("foo"); result != "bar" && ok {
t.Error("unexpected result")
}
if !called {
t.Error("expected call to Get delegate")
}
called = false
cache.Delete("foo")
if !called {
t.Error("expected call to Delete delegate")
}
}
package mock
type ordered struct {
inOrder bool
ordinal uint
}
func orderedOption[T any](inOrder bool, options []Option[T]) Option[T] {
return func(key *T) {
mock := registry[key]
defer func(restore bool) {
mock.inOrder = restore
}(mock.inOrder)
mock.inOrder = inOrder
for _, option := range options {
option(key)
}
}
}
func ExpectInOrder[T any](options ...Option[T]) Option[T] {
return orderedOption(true, options)
}
func ExpectAnyOrder[T any](options ...Option[T]) Option[T] {
return orderedOption(false, options)
}
package mock
import (
"testing"
)
func TestNew(t *testing.T) {
type T Delegates
mock := New[T](t)
_, ok := registry[mock]
if !ok {
t.Fatalf("mock not found")
}
}
func TestExpect(t *testing.T) {
type T Delegates
key := New(t, Expect[T]("foo", func() {}))
mock, ok := registry[key]
if !ok {
t.Fatalf("mock not found")
}
fn, ok := mock.Delegates["foo"]
if !ok {
t.Fatalf("delegate not found")
}
if fn.Len() != 1 {
t.Fatalf("expected one delegate")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment