Skip to content

Instantly share code, notes, and snippets.

@wtask
Last active February 14, 2021 00:35
Show Gist options
  • Save wtask/a9d3418f3615d7d8b06d8c00c5c61a6f to your computer and use it in GitHub Desktop.
Save wtask/a9d3418f3615d7d8b06d8c00c5c61a6f to your computer and use it in GitHub Desktop.
MongoDB bson.ValueMarshaler/bson.ValueUnmarshaler Golang example
package model
import (
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// Reference is string representation of object ID.
//
// Marshaling rules are:
// - nil reference pointer --> Null
// - empty reference --> Null
// - non-empty reference --> primitive.ObjectID
//
// Unmarshaling rules:
// - nil reference pointer raises error
// - Null --> empty reference
// - primitive.ObjectID --> non-empty reference with hex ID representation
//
type Reference string
// MarshalBSONValue implements bson.ValueMarshaler interface.
// IMPORTANT! Implementation must use value receiver always!
func (ref Reference) MarshalBSONValue() (bsontype.Type, []byte, error) {
if ref == "" {
return bsontype.Null, nil, nil
}
oid, err := primitive.ObjectIDFromHex(string(ref))
if err != nil {
return 0, nil, err
}
return bson.MarshalValue(oid)
}
// UnmarshalBSONValue implements bson.ValueUnmarshaler interface.
func (ref *Reference) UnmarshalBSONValue(t bsontype.Type, value []byte) error {
if ref == nil {
return bson.ErrDecodeToNil
}
if t == bsontype.Null {
*ref = ""
return nil
}
oid, ok := (bson.RawValue{Type: t, Value: value}).ObjectIDOK()
if !ok {
return fmt.Errorf("failed decode object ID from %v", t)
}
*ref = Reference(oid.Hex())
return nil
}
package model
import (
"reflect"
"testing"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// TestObjectID__marshalValue checks bson package expectations
func TestObjectID__MarshalValue(t *testing.T) {
sourceID := primitive.NewObjectID()
bsonType, bsonValue, err := bson.MarshalValue(&sourceID)
if err != nil {
t.Fatal(err)
}
targetID, ok := (bson.RawValue{Type: bsonType, Value: bsonValue}).ObjectIDOK()
if !ok {
t.Fatalf("cannot decode ID from BSON [%x]", bsonValue)
}
if sourceID != targetID {
t.Fatalf("expected %x got %x", sourceID, targetID)
}
}
func TestReference__MarshalValue(t *testing.T) {
objRef := Reference(primitive.NewObjectID().Hex())
nilRef := Reference("000000000000000000000000") // nil object id
invalidRef := Reference("00")
cases := []struct {
ref *Reference
expType bsontype.Type
expError bool
}{
{nil, bsontype.Null, false},
{new(Reference), bsontype.Null, false},
{&objRef, bsontype.ObjectID, false},
{&nilRef, bsontype.ObjectID, false},
{&invalidRef, 0, true},
}
for i, c := range cases {
bsonType, bsonValue, err := bson.MarshalValue(c.ref)
if err != nil {
if c.expError {
t.Logf("#%d: error expectation OK: %s", i, err)
} else {
t.Fatalf("#%d: %s", i, err)
}
} else if c.expError {
t.Fatalf("#%d: error expectation FAILED", i)
}
if c.expType != bsonType {
t.Fatalf("#%d: unexpected BSON type %v instead of %v", i, bsonType, c.expType)
}
if c.expType != bsontype.ObjectID {
continue
}
id, ok := (bson.RawValue{Type: bsonType, Value: bsonValue}).ObjectIDOK()
if !ok {
t.Fatalf("#%d: cannot decode ID from BSON [%x]", i, bsonValue)
}
if *c.ref != Reference(id.Hex()) {
t.Fatalf("#%d: unexpected hex ID %x instead of %s", i, id, *c.ref)
}
}
}
func TestReference__Marshal(t *testing.T) {
objID := primitive.NewObjectID()
objRef := Reference(objID.Hex())
nilRef := Reference("000000000000000000000000") // nil object id
emptyRef := Reference("")
cases := []struct {
native *primitive.ObjectID
ref *Reference
}{
{nil, nil},
{nil, &emptyRef},
{&primitive.NilObjectID, &nilRef},
{&objID, &objRef},
}
for i, c := range cases {
native := struct {
ID *primitive.ObjectID `bson:"_id"`
}{c.native}
nativeBSON, err := bson.Marshal(&native)
if err != nil {
t.Fatalf("#%d: %s", i, err)
}
ref := struct {
Ref *Reference `bson:"_id"`
}{c.ref}
refBSON, err := bson.Marshal(&ref)
if err != nil {
t.Fatalf("#%d: %s", i, err)
}
if !reflect.DeepEqual(nativeBSON, refBSON) {
t.Fatalf("#%d: unexpected BSON %x instead of native %x ", i, refBSON, nativeBSON)
}
}
}
func TestReference__Marshal_error(t *testing.T) {
invalidRef := Reference("0011")
withPointer := struct {
ID *Reference `bson:"_id"`
}{
&invalidRef,
}
withValue := struct {
ID Reference `bson:"_id"`
}{
invalidRef,
}
cases := []interface{}{
withPointer,
withValue,
}
for i, c := range cases {
_, err := bson.Marshal(c)
if err == nil {
t.Fatalf("#%d: error expectation FAILED", i)
}
t.Logf("#%d: error expectation OK: %s", i, err)
}
}
func TestReference__Unmarshal(t *testing.T) {
objID := primitive.NewObjectID()
ref := Reference(objID.Hex())
str := "000000000000000000000000"
anyNil, err := bson.Marshal(
struct {
ID interface{} `bson:"_id"`
}{nil},
)
if err != nil {
t.Fatal(err)
}
nativeNil, err := bson.Marshal(
struct {
ID *primitive.ObjectID `bson:"_id"`
}{nil},
)
if err != nil {
t.Fatal(err)
}
nativePointer, err := bson.Marshal(
struct {
ID *primitive.ObjectID `bson:"_id"`
}{&objID},
)
if err != nil {
t.Fatal(err)
}
refPointer, err := bson.Marshal(
struct {
Ref *Reference `bson:"_id"`
}{&ref},
)
if err != nil {
t.Fatal(err)
}
refValue, err := bson.Marshal(
struct {
Ref Reference `bson:"_id"`
}{ref},
)
if err != nil {
t.Fatal(err)
}
strPointer, err := bson.Marshal(
struct {
StrID *string `bson:"_id"`
}{&str},
)
if err != nil {
t.Fatal(err)
}
cases := []struct {
data []byte
expError bool
expValue string
}{
{anyNil, false, ""},
{nativeNil, false, ""},
{nativePointer, false, objID.Hex()},
{refPointer, false, string(ref)},
{refValue, false, string(ref)}, // incredible )
{strPointer, true, ""},
}
t.Run("pointer", func(t *testing.T) {
for i, c := range cases {
actual := struct {
Pointer *Reference `bson:"_id"`
}{}
err := bson.Unmarshal(c.data, &actual)
if err != nil {
if !c.expError {
t.Fatalf("#%d: %s", i, err)
}
t.Logf("#%d: error expectation OK: %s", i, err)
} else if c.expError {
t.Fatalf("#%d: error expectation FAILED", i)
}
if c.expError {
continue
}
if *actual.Pointer != Reference(c.expValue) {
t.Fatalf("#%d: unexpected Reference %q instead of native %q ", i, c.expValue, *actual.Pointer)
}
}
})
t.Run("value", func(t *testing.T) {
for i, c := range cases {
actual := struct {
Value Reference `bson:"_id"`
}{}
err := bson.Unmarshal(c.data, &actual)
if err != nil {
if !c.expError {
t.Fatalf("#%d: %s", i, err)
}
t.Logf("#%d: error expectation OK: %s", i, err)
} else if c.expError {
t.Fatalf("#%d: error expectation FAILED", i)
}
if c.expError {
continue
}
if actual.Value != Reference(c.expValue) {
t.Fatalf("#%d: unexpected Reference %q instead of native %q ", i, c.expValue, actual.Value)
}
}
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment