Skip to content

Instantly share code, notes, and snippets.

@jxskiss
Created December 19, 2021 02:04
Show Gist options
  • Save jxskiss/3d8a0c1362e961a2edb178473b91a154 to your computer and use it in GitHub Desktop.
Save jxskiss/3d8a0c1362e961a2edb178473b91a154 to your computer and use it in GitHub Desktop.
Opinionted and simple Golang assertion helpers.
package assert
import (
"bytes"
"fmt"
"reflect"
"regexp"
"strings"
"testing"
)
func Contains(t testing.TB, container interface{}, elem interface{}, msgAndArgs ...interface{}) bool {
if checkContains(container, elem) {
return true
}
t.Helper()
errmsg := fmt.Sprintf("%#v does not contain %#v", container, elem)
return fail(t, errmsg, msgAndArgs)
}
func NotContains(t testing.TB, container interface{}, elem interface{}, msgAndArgs ...interface{}) bool {
if !checkContains(container, elem) {
return true
}
t.Helper()
errmsg := fmt.Sprintf("%#v does contain %#v", container, elem)
return fail(t, errmsg, msgAndArgs)
}
func Equal(t testing.TB, left, right interface{}, msgAndArgs ...interface{}) bool {
if isEqual(left, right) {
return true
}
t.Helper()
errmsg := fmt.Sprintf("%#v does not equal %#v", left, right)
return fail(t, errmsg, msgAndArgs)
}
func NotEqual(t testing.TB, left, right interface{}, msgAndArgs ...interface{}) bool {
if !isEqual(left, right) {
return true
}
t.Helper()
errmsg := fmt.Sprintf("%#v does equal %#v", left, right)
return fail(t, errmsg, msgAndArgs)
}
func True(t testing.TB, value bool, msgAndArgs ...interface{}) bool {
if value {
return true
}
t.Helper()
return fail(t, "should be true, got false", msgAndArgs)
}
func False(t testing.TB, value bool, msgAndArgs ...interface{}) bool {
if !value {
return true
}
t.Helper()
return fail(t, "should be false, got true", msgAndArgs)
}
func Nil(t testing.TB, object interface{}, msgAndArgs ...interface{}) bool {
if isNil(object) {
return true
}
t.Helper()
errmsg := fmt.Sprintf("shoule be nil, got %#v", object)
return fail(t, errmsg, msgAndArgs)
}
func NotNil(t testing.TB, object interface{}, msgAndArgs ...interface{}) bool {
if !isNil(object) {
return true
}
t.Helper()
errmsg := fmt.Sprintf("should not be nil, got %#v", object)
return fail(t, errmsg, msgAndArgs)
}
func Empty(t testing.TB, object interface{}, msgAndArgs ...interface{}) bool {
if isEmpty(object) {
return true
}
t.Helper()
errmsg := fmt.Sprintf("should be empty, got %#v", object)
return fail(t, errmsg, msgAndArgs)
}
func NotEmpty(t testing.TB, objeect interface{}, msgAndArgs ...interface{}) bool {
if !isEmpty(objeect) {
return true
}
t.Helper()
errmsg := fmt.Sprintf("should not be empty, got %#v", objeect)
return fail(t, errmsg, msgAndArgs)
}
func Regexp(t testing.TB, re interface{}, str string, msgAndArgs ...interface{}) bool {
var match bool
var errmsg string
switch re := re.(type) {
case *regexp.Regexp:
match = re.MatchString(str)
case string:
match = regexp.MustCompile(re).MatchString(str)
default:
errmsg = fmt.Sprintf("%#v is not a valid re param", re)
}
if match {
return true
}
t.Helper()
if errmsg == "" {
errmsg = fmt.Sprintf("%q does not match re %q", str, re)
}
return fail(t, errmsg, msgAndArgs)
}
func Panics(t testing.TB, f func(), msgAndArgs ...interface{}) bool {
var didPanic bool
var exc interface{}
func() {
defer func() {
exc = recover()
didPanic = exc != nil
}()
f()
}()
if didPanic {
return true
}
t.Helper()
return fail(t, "the code does not panic", msgAndArgs)
}
// -------------------------------- //
var typeOf, valueOf = reflect.TypeOf, reflect.ValueOf
func fail(t testing.TB, errmsg string, userMsgAndArgs []interface{}) bool {
t.Helper()
switch {
case len(userMsgAndArgs) == 1:
errmsg += "\n" + fmt.Sprint(userMsgAndArgs[0])
case len(userMsgAndArgs) > 1:
if tmpl, ok := userMsgAndArgs[0].(string); ok && strings.Contains(tmpl, "%") {
errmsg += "\n" + fmt.Sprintf(tmpl, userMsgAndArgs[1:]...)
} else {
errmsg += "\n" + fmt.Sprint(userMsgAndArgs...)
}
}
t.Error(errmsg)
return false
}
func checkContains(container, elem interface{}) bool {
cVal, eVal := valueOf(container), valueOf(elem)
cTyp, eTyp := cVal.Type(), eVal.Type()
cKind, eKind := cTyp.Kind(), eTyp.Kind()
if cKind == eKind {
if cKind == reflect.String {
return strings.Contains(cVal.String(), eVal.String())
}
if cKind == reflect.Slice {
if cTyp.Elem().Kind() == reflect.Uint {
return bytes.Contains(cVal.Bytes(), eVal.Bytes())
}
cLen, eLen := cVal.Len(), eVal.Len()
for i := 0; i+eLen < cLen; i++ {
if reflect.DeepEqual(cVal.Slice(i, i+eLen).Interface(), elem) {
return true
}
}
return false
}
}
if cKind == reflect.Slice && cTyp.Elem() == eTyp {
for i := 0; i < cVal.Len(); i++ {
if cVal.Index(i).Interface() == elem {
return true
}
}
return false
}
if cKind == reflect.Map && cTyp.Key() == eTyp {
if x := cVal.MapIndex(eVal); x.IsValid() {
return true
}
}
return false
}
func isEqual(left, right interface{}) bool {
aTyp, bTyp := typeOf(left), typeOf(right)
if aTyp != nil && bTyp != nil && aTyp.Comparable() && bTyp.Comparable() &&
(left == right || literalConvert(left) == literalConvert(right)) {
return true
}
return reflect.DeepEqual(left, right)
}
func isNil(object interface{}) bool {
if object == nil {
return true
}
kind := typeOf(object).Kind()
return (kind == reflect.Chan || kind == reflect.Func || kind == reflect.Interface ||
kind == reflect.Map || kind == reflect.Ptr || kind == reflect.Slice ||
kind == reflect.UnsafePointer) && valueOf(object).IsNil()
}
func isEmpty(object interface{}) bool {
if object == nil {
return true
}
kind, val := typeOf(object).Kind(), valueOf(object)
return ((kind == reflect.Array || kind == reflect.Chan || kind == reflect.Map || kind == reflect.Slice) && val.Len() == 0) ||
(kind == reflect.Ptr && (val.IsNil() || isEmpty(val.Elem()))) ||
reflect.DeepEqual(object, reflect.Zero(val.Type()).Interface())
}
func literalConvert(val interface{}) interface{} {
switch val := valueOf(val); val.Kind() {
case reflect.Bool:
return val.Bool()
case reflect.String:
return val.Convert(typeOf("")).Interface()
case reflect.Float32, reflect.Float64:
return val.Float()
case reflect.Complex64, reflect.Complex128:
return val.Complex()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if asInt := val.Int(); asInt < 0 {
return asInt
}
return val.Convert(typeOf(uint64(0))).Uint()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return val.Uint()
default:
return val
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment