Created
August 23, 2023 16:09
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* a generic reflection based wrapper around common testing boilerplate | |
messages.On("Transaction", ctx, mock.AnythingOfType("func(storage.MessageStorer) error")). | |
Run(func(args mock.Arguments) { | |
err := args.Get(1).(func(tx storage.MessageStorer) error)(tx) | |
assert.NoError(err) | |
}). | |
Return(nil) | |
usage | |
messages := *storage.MockMessageStorer{} | |
tx := storage.MockMessageStorer{} | |
err = Transaction[storage.MessageStorer](ctx, messages, tx) | |
assert.NoError(err) | |
does it work? no | |
*/ | |
func Transaction[T any](ctx context.Context, mockStore T, mockTXStore T) error { | |
rt := reflect.TypeOf(mockStore) | |
if rt.Kind() != reflect.Pointer { | |
// store should be *mock.Mock which contain mutexes that need to be passed by reference | |
return errors.New("store is not a pointer") | |
} | |
on := reflect.ValueOf(mockStore).MethodByName("On") | |
if on.IsZero() { | |
return errors.New("expected On method") | |
} | |
args := make([]reflect.Value, 3) | |
args[0] = reflect.ValueOf("Transaction") | |
args[1] = reflect.ValueOf(ctx) | |
// trick to get T as string | |
txFnType := fmt.Sprintf("func(%s) error", reflect.TypeOf(new(T)).Elem().String()) | |
args[2] = reflect.ValueOf(mock.AnythingOfType(txFnType)) | |
res := on.Call(args) | |
if len(res) < 1 { | |
return errors.New("unexpected On return") | |
} | |
run := res[0].MethodByName("Run") | |
if run.IsZero() { | |
return errors.New("expected Run method") | |
} | |
var err error | |
fn := func(args mock.Arguments) { | |
err = args.Get(1).(func(tx T) error)(mockTXStore) | |
} | |
res = run.Call([]reflect.Value{reflect.ValueOf(fn)}) | |
if len(res) < 1 { | |
return errors.New("unexpected Run return") | |
} | |
if err != nil { | |
return errs.Wrap(err, "error running transaction") | |
} | |
ret := res[0].MethodByName("Return") | |
if ret.IsZero() { | |
return errors.New("expected Return method") | |
} | |
res = ret.Call([]reflect.Value{reflect.New(reflect.TypeOf(new(error)).Elem())}) | |
if len(res) < 1 { | |
return errors.New("unexpected Return return") | |
} | |
return nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
new(error) -> new(ErrCode)