Skip to content

Instantly share code, notes, and snippets.

@benjbaron
Last active April 11, 2019 08:45
Show Gist options
  • Save benjbaron/bc4ebbc1146dfe6c564d55c344c9b60f to your computer and use it in GitHub Desktop.
Save benjbaron/bc4ebbc1146dfe6c564d55c344c9b60f to your computer and use it in GitHub Desktop.
package types
import (
"encoding/json"
"time"
)
// FIXME: This does not currently work with JSON's omitempty.
// See: https://github.com/golang/go/issues/11939
// NullString represents a string that may be null.
// Use NullString as follows.
// if s.Valid {
// // use s.String
// } else {
// // NULL value
// }
//
type NullString struct {
String string
Valid bool // Valid is true if String is not NULL
}
// NewNullString creates a new NullString.
func NewNullString(s string, valid bool) NullString {
return NullString{
String: s,
Valid: valid,
}
}
// NullStringFrom creates a new String that will never be blank.
func NullStringFrom(s string) NullString {
return NewNullString(s, true)
}
// MarshalJSON for NullString and implements json.Marshaler.
func (s NullString) MarshalJSON() ([]byte, error) {
if !s.Valid {
return json.Marshal(nil)
}
return json.Marshal(s.String)
}
// UnmarshalJSON for NullString and implements json.Unmarshaler.
// It supports string and null input. Blank string input does not produce a null NullString.
func (s *NullString) UnmarshalJSON(b []byte) error {
var x *string
if err := json.Unmarshal(b, &x); err != nil {
return err
}
if x != nil {
s.Valid = true
s.String = *x
} else {
s.Valid = false
}
return nil
}
// NullInt represents an int that may be null.
type NullInt struct {
Int int
Valid bool // Valid is true if Int is not NULL
}
// NewNullInt creates a new NullInt.
func NewNullInt(i int, valid bool) NullInt {
return NullInt{
Int: i,
Valid: valid,
}
}
// NullIntFrom creates a new NullInt that will always be valid.
func NullIntFrom(i int) NullInt {
return NewNullInt(i, true)
}
// MarshalJSON for NullInt and implements json.Marshaler.
func (i NullInt) MarshalJSON() ([]byte, error) {
if !i.Valid {
return json.Marshal(nil)
}
return json.Marshal(i.Int)
}
// UnmarshalJSON for NullInt and implements json.Unmarshaler.
func (i *NullInt) UnmarshalJSON(b []byte) error {
var x *int
if err := json.Unmarshal(b, &x); err != nil {
return err
}
if x != nil {
i.Valid = true
i.Int = *x
} else {
i.Valid = false
}
return nil
}
// NullUInt represents an uint that may be null.
type NullUInt struct {
UInt uint
Valid bool // Valid is true if UInt is not NULL
}
// NewNullUInt creates a new NullUInt.
func NewNullUInt(i uint, valid bool) NullUInt {
return NullUInt{
UInt: i,
Valid: valid,
}
}
// NullUIntFrom creates a new NullUInt that will always be valid.
func NullUIntFrom(i uint) NullUInt {
return NewNullUInt(i, true)
}
// MarshalJSON for NullUInt and implements json.Marshaler.
func (i NullUInt) MarshalJSON() ([]byte, error) {
if !i.Valid {
return json.Marshal(nil)
}
return json.Marshal(i.UInt)
}
// UnmarshalJSON for NullUInt and implements json.Unmarshaler.
func (i *NullUInt) UnmarshalJSON(b []byte) error {
var x *uint
if err := json.Unmarshal(b, &x); err != nil {
return err
}
if x != nil {
i.Valid = true
i.UInt = *x
} else {
i.Valid = false
}
return nil
}
// NullInt64 represents an int64 that may be null.
type NullInt64 struct {
Int64 int64
Valid bool // Valid is true if Int64 is not NULL
}
// NewNullInt64 creates a new NullInt64.
func NewNullInt64(i int64, valid bool) NullInt64 {
return NullInt64{
Int64: i,
Valid: valid,
}
}
// NullInt64From creates a new NullInt64 that will always be valid.
func NullInt64From(i int64) NullInt64 {
return NewNullInt64(i, true)
}
// MarshalJSON for NullInt64 and implements json.Marshaler.
func (i NullInt64) MarshalJSON() ([]byte, error) {
if !i.Valid {
return json.Marshal(nil)
}
return json.Marshal(i.Int64)
}
// UnmarshalJSON for NullInt64 and implements json.Unmarshaler.
func (i *NullInt64) UnmarshalJSON(b []byte) error {
var x *int64
if err := json.Unmarshal(b, &x); err != nil {
return err
}
if x != nil {
i.Valid = true
i.Int64 = *x
} else {
i.Valid = false
}
return nil
}
// NullBool represents a bool that may be null.
type NullBool struct {
Bool bool
Valid bool // Valid is true if Bool is not NULL
}
// NewNullBool creates a new NullBool.
func NewNullBool(b bool, valid bool) NullBool {
return NullBool{
Bool: b,
Valid: valid,
}
}
// NullBoolFrom creates a new NullBool that will always be valid.
func NullBoolFrom(b bool) NullBool {
return NewNullBool(b, true)
}
// MarshalJSON for NullBool and implements json.Marshaler.
func (b NullBool) MarshalJSON() ([]byte, error) {
if !b.Valid {
return json.Marshal(nil)
}
return json.Marshal(b.Bool)
}
// UnmarshalJSON for NullBool and implements json.Unmarshaler.
func (b *NullBool) UnmarshalJSON(bytes []byte) error {
var x *bool
if err := json.Unmarshal(bytes, &x); err != nil {
return err
}
if x != nil {
b.Valid = true
b.Bool = *x
} else {
b.Valid = false
}
return nil
}
// NullFloat64 represents a float64 that may be null.
type NullFloat64 struct {
Float64 float64
Valid bool // Valid is true if Float64 is not NULL
}
// NewNullFloat64 creates a new NullFloat64.
func NewNullFloat64(f float64, valid bool) NullFloat64 {
return NullFloat64{
Float64: f,
Valid: valid,
}
}
// NullFloat64From creates a new NullFloat64 that will always be valid.
func NullFloat64From(f float64) NullFloat64 {
return NewNullFloat64(f, true)
}
// MarshalJSON for NullFloat64 and implements json.Marshaler.
func (f NullFloat64) MarshalJSON() ([]byte, error) {
if !f.Valid {
return json.Marshal(nil)
}
return json.Marshal(f.Float64)
}
// UnmarshalJSON for NullFloat64 and implements json.Unmarshaler.
func (f *NullFloat64) UnmarshalJSON(bytes []byte) error {
var x *float64
if err := json.Unmarshal(bytes, &x); err != nil {
return err
}
if x != nil {
f.Valid = true
f.Float64 = *x
} else {
f.Valid = false
}
return nil
}
// NullTime is an alias for mysql.NullTime data type
type NullTime struct {
Time time.Time
Valid bool // Valid is true if time.Time is not NULL
}
// NewNullTime creates a new NullTime.
func NewNullTime(t time.Time, valid bool) NullTime {
return NullTime{
Time: t,
Valid: valid,
}
}
// NullTimeFrom creates a new NullTime that will always be valid.
func NullTimeFrom(t time.Time) NullTime {
return NewNullTime(t, true)
}
// MarshalJSON for NullTime and implements json.Marshaler.
func (t NullTime) MarshalJSON() ([]byte, error) {
if !t.Valid {
return json.Marshal(nil)
}
return t.Time.MarshalJSON()
}
// UnmarshalJSON for NullTime and implements json.Unmarshaler.
func (t *NullTime) UnmarshalJSON(bytes []byte) error {
var x *time.Time
if err := json.Unmarshal(bytes, &x); err != nil {
return err
}
if x != nil {
t.Valid = true
t.Time = *x
} else {
t.Valid = false
}
return nil
}
package types
import (
"encoding/json"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestNullStringFrom(t *testing.T) {
s := NullStringFrom("test")
require.Equal(t, "test", s.String)
require.True(t, s.Valid)
zero := NullStringFrom("")
require.Equal(t, "", zero.String)
require.True(t, zero.Valid)
}
var testNullString = []struct {
JSON []byte
Value NullString
}{
{[]byte(`"test"`), NullString{"test", true}},
{[]byte(`""`), NullString{"", true}},
{[]byte(`null`), NullString{Valid: false}},
}
func TestNullString_UnmarshalJSON(t *testing.T) {
for i, param := range testNullString {
t.Run(fmt.Sprintf("TestNullString_UnmarshalJSON (%d)", i), func(t *testing.T) {
var s NullString
err := json.Unmarshal(param.JSON, &s)
require.NoError(t, err)
require.Equal(t, param.Value, s)
})
}
}
func TestNullString_MarshalJSON(t *testing.T) {
for i, param := range testNullString {
t.Run(fmt.Sprintf("TestNullString_MarshalJSON (%d)", i), func(t *testing.T) {
bytes, err := json.Marshal(param.Value)
require.NoError(t, err)
require.Equal(t, param.JSON, bytes)
})
}
}
func TestNullIntFrom(t *testing.T) {
i := NullIntFrom(42)
require.Equal(t, 42, i.Int)
require.True(t, i.Valid)
zero := NullIntFrom(0)
require.Equal(t, 0, zero.Int)
require.True(t, zero.Valid)
}
var testNullInt = []struct {
JSON []byte
Value NullInt
}{
{[]byte(`12345`), NullInt{12345, true}},
{[]byte(`0`), NullInt{0, true}},
{[]byte(`null`), NullInt{Valid: false}},
}
func TestNullInt_UnmarshalJSON(t *testing.T) {
for i, param := range testNullInt {
t.Run(fmt.Sprintf("TestNullInt_UnmarshalJSON (%d)", i), func(t *testing.T) {
var i NullInt
err := json.Unmarshal(param.JSON, &i)
require.NoError(t, err)
require.Equal(t, param.Value, i)
})
}
}
func TestNullInt_MarshalJSON(t *testing.T) {
for i, param := range testNullInt {
t.Run(fmt.Sprintf("TestNullInt_MarshalJSON (%d)", i), func(t *testing.T) {
bytes, err := json.Marshal(param.Value)
require.NoError(t, err)
require.Equal(t, param.JSON, bytes)
})
}
}
func TestNullUIntFrom(t *testing.T) {
i := NullUIntFrom(42)
require.Equal(t, uint(42), i.UInt)
require.True(t, i.Valid)
zero := NullUIntFrom(0)
require.Equal(t, uint(0), zero.UInt)
require.True(t, zero.Valid)
}
var testNullUInt = []struct {
JSON []byte
Value NullUInt
}{
{[]byte(`12345`), NullUInt{12345, true}},
{[]byte(`0`), NullUInt{0, true}},
{[]byte(`null`), NullUInt{Valid: false}},
}
func TestNullUInt_UnmarshalJSON(t *testing.T) {
for i, param := range testNullUInt {
t.Run(fmt.Sprintf("TestNullUInt_UnmarshalJSON (%d)", i), func(t *testing.T) {
var i NullUInt
err := json.Unmarshal(param.JSON, &i)
require.NoError(t, err)
require.Equal(t, param.Value, i)
})
}
}
func TestNullUInt_MarshalJSON(t *testing.T) {
for i, param := range testNullUInt {
t.Run(fmt.Sprintf("TestNullUInt_MarshalJSON (%d)", i), func(t *testing.T) {
bytes, err := json.Marshal(param.Value)
require.NoError(t, err)
require.Equal(t, param.JSON, bytes)
})
}
}
func TestNullInt64From(t *testing.T) {
i := NullInt64From(42)
require.Equal(t, int64(42), i.Int64)
require.True(t, i.Valid)
zero := NullInt64From(0)
require.Equal(t, int64(0), zero.Int64)
require.True(t, zero.Valid)
}
var testNullInt64 = []struct {
JSON []byte
Value NullInt64
}{
{[]byte(`12345`), NullInt64{12345, true}},
{[]byte(`0`), NullInt64{0, true}},
{[]byte(`null`), NullInt64{Valid: false}},
}
func TestNullInt64_UnmarshalJSON(t *testing.T) {
for i, param := range testNullInt64 {
t.Run(fmt.Sprintf("TestNullInt64_UnmarshalJSON (%d)", i), func(t *testing.T) {
var i NullInt64
err := json.Unmarshal(param.JSON, &i)
require.NoError(t, err)
require.Equal(t, param.Value, i)
})
}
}
func TestNullInt64_MarshalJSON(t *testing.T) {
for i, param := range testNullInt64 {
t.Run(fmt.Sprintf("TestNullInt64_MarshalJSON (%d)", i), func(t *testing.T) {
bytes, err := json.Marshal(param.Value)
require.NoError(t, err)
require.Equal(t, param.JSON, bytes)
})
}
}
func TestNullBoolFrom(t *testing.T) {
b := NullBoolFrom(true)
require.Equal(t, true, b.Bool)
require.True(t, b.Valid)
zero := NullBoolFrom(false)
require.Equal(t, false, zero.Bool)
require.True(t, zero.Valid)
}
var testNullBool = []struct {
JSON []byte
Value NullBool
}{
{[]byte(`true`), NullBool{true, true}},
{[]byte(`false`), NullBool{false, true}},
{[]byte(`null`), NullBool{Valid: false}},
}
func TestNullBool_UnmarshalJSON(t *testing.T) {
for i, param := range testNullBool {
t.Run(fmt.Sprintf("TestNullBool_UnmarshalJSON (%d)", i), func(t *testing.T) {
var b NullBool
err := json.Unmarshal(param.JSON, &b)
require.NoError(t, err)
require.Equal(t, param.Value, b)
})
}
}
func TestNullBool_MarshalJSON(t *testing.T) {
for i, param := range testNullBool {
t.Run(fmt.Sprintf("TestNullBool_MarshalJSON (%d)", i), func(t *testing.T) {
bytes, err := json.Marshal(param.Value)
require.NoError(t, err)
require.Equal(t, param.JSON, bytes)
})
}
}
func TestNullFloat64From(t *testing.T) {
f := NullFloat64From(1.2345)
require.Equal(t, 1.2345, f.Float64)
require.True(t, f.Valid)
zero := NullFloat64From(0.0)
require.Equal(t, 0.0, zero.Float64)
require.True(t, zero.Valid)
}
var testNullFloat64 = []struct {
JSON []byte
Value NullFloat64
}{
{[]byte(`1.2345`), NullFloat64{1.2345, true}},
{[]byte(`0`), NullFloat64{0.0, true}},
{[]byte(`null`), NullFloat64{Valid: false}},
}
func TestNullFloat64_UnmarshalJSON(t *testing.T) {
for i, param := range testNullFloat64 {
t.Run(fmt.Sprintf("TestNullFloat64_UnmarshalJSON (%d)", i), func(t *testing.T) {
var b NullFloat64
err := json.Unmarshal(param.JSON, &b)
require.NoError(t, err)
require.Equal(t, param.Value, b)
})
}
}
func TestNullFloat64_MarshalJSON(t *testing.T) {
for i, param := range testNullFloat64 {
t.Run(fmt.Sprintf("TestNullFloat64_MarshalJSON (%d)", i), func(t *testing.T) {
bytes, err := json.Marshal(param.Value)
require.NoError(t, err)
require.Equal(t, param.JSON, bytes)
})
}
}
var (
timeString = "2012-12-21T21:21:21Z"
timeZero = "0001-01-01T00:00:00Z"
timeValue, _ = time.Parse(time.RFC3339, timeString)
)
func TestNullTimeFrom(t *testing.T) {
v := NullTimeFrom(timeValue)
require.Equal(t, timeValue, v.Time)
require.True(t, v.Valid)
zero := NullTimeFrom(time.Time{})
require.Equal(t, time.Time{}, zero.Time)
require.True(t, zero.Valid)
}
var testNullTime = []struct {
JSON []byte
Value NullTime
}{
{[]byte(`"` + timeString + `"`), NullTime{timeValue, true}},
{[]byte(`"` + timeZero + `"`), NullTime{time.Time{}, true}},
{[]byte(`null`), NullTime{Valid: false}},
}
func TestNullTime_UnmarshalJSON(t *testing.T) {
for i, param := range testNullTime {
t.Run(fmt.Sprintf("TestNullTime_UnmarshalJSON (%d)", i), func(t *testing.T) {
var v NullTime
err := json.Unmarshal(param.JSON, &v)
require.NoError(t, err)
require.Equal(t, param.Value, v)
})
}
}
func TestNullTime_MarshalJSON(t *testing.T) {
for i, param := range testNullTime {
t.Run(fmt.Sprintf("TestNullTime_MarshalJSON (%d)", i), func(t *testing.T) {
bytes, err := json.Marshal(param.Value)
require.NoError(t, err)
require.Equal(t, param.JSON, bytes)
})
}
}
type testStruct struct {
ID NullInt64 `json:"id"`
AppID NullString `json:"app_id"`
CreatedAt NullTime `json:"created_at"`
Score NullFloat64 `json:"score"`
Blocked NullBool `json:"blocked"`
}
var testParam = []struct {
JSON []byte
Value testStruct
}{
{
JSON: []byte(`{"id":1,"app_id":"test","created_at":"` + timeString + `","score":1.2,"blocked":true}`),
Value: testStruct{ID: NullInt64From(1), AppID: NullStringFrom("test"), CreatedAt: NullTimeFrom(timeValue), Score: NullFloat64From(1.2), Blocked: NullBoolFrom(true)},
},
{
JSON: []byte(`{"id":1,"app_id":"","created_at":"` + timeString + `","score":1.2,"blocked":true}`),
Value: testStruct{ID: NullInt64From(1), AppID: NullStringFrom(""), CreatedAt: NullTimeFrom(timeValue), Score: NullFloat64From(1.2), Blocked: NullBoolFrom(true)},
},
{
JSON: []byte(`{"id":1,"app_id":null,"created_at":"` + timeString + `","score":1.2,"blocked":true}`),
Value: testStruct{ID: NullInt64From(1), CreatedAt: NullTimeFrom(timeValue), Score: NullFloat64From(1.2), Blocked: NullBoolFrom(true)},
},
{
JSON: []byte(`{"id":0,"app_id":"test","created_at":"` + timeZero + `","score":1.2,"blocked":true}`),
Value: testStruct{ID: NullInt64From(0), AppID: NullStringFrom("test"), CreatedAt: NullTimeFrom(time.Time{}), Score: NullFloat64From(1.2), Blocked: NullBoolFrom(true)},
},
{
JSON: []byte(`{"id":1,"app_id":"test","created_at":null,"score":0,"blocked":true}`),
Value: testStruct{ID: NullInt64From(1), AppID: NullStringFrom("test"), Score: NullFloat64From(0), Blocked: NullBoolFrom(true)},
},
{
JSON: []byte(`{"id":0,"app_id":"test","created_at":"` + timeString + `","score":1.2,"blocked":true}`),
Value: testStruct{ID: NullInt64From(0), AppID: NullStringFrom("test"), CreatedAt: NullTimeFrom(timeValue), Score: NullFloat64From(1.2), Blocked: NullBoolFrom(true)},
},
{
JSON: []byte(`{"id":1,"app_id":"test","created_at":"` + timeString + `","score":0,"blocked":true}`),
Value: testStruct{ID: NullInt64From(1), AppID: NullStringFrom("test"), CreatedAt: NullTimeFrom(timeValue), Score: NullFloat64From(0), Blocked: NullBoolFrom(true)},
},
{
JSON: []byte(`{"id":1,"app_id":"test","created_at":"` + timeString + `","score":null,"blocked":true}`),
Value: testStruct{ID: NullInt64From(1), AppID: NullStringFrom("test"), CreatedAt: NullTimeFrom(timeValue), Blocked: NullBoolFrom(true)},
},
{
JSON: []byte(`{"id":1,"app_id":"test","created_at":"` + timeString + `","score":1.2,"blocked":false}`),
Value: testStruct{ID: NullInt64From(1), AppID: NullStringFrom("test"), CreatedAt: NullTimeFrom(timeValue), Score: NullFloat64From(1.2), Blocked: NullBoolFrom(false)},
},
{
JSON: []byte(`{"id":1,"app_id":"test","created_at":"` + timeString + `","score":1.2,"blocked":null}`),
Value: testStruct{ID: NullInt64From(1), AppID: NullStringFrom("test"), CreatedAt: NullTimeFrom(timeValue), Score: NullFloat64From(1.2)},
},
}
func TestJSONUnmarshalling(t *testing.T) {
for i, param := range testParam {
t.Run(fmt.Sprintf("TestJSONUnmarshalling (%d)", i), func(t *testing.T) {
var v testStruct
err := json.Unmarshal(param.JSON, &v)
require.NoError(t, err)
require.Equal(t, param.Value, v)
})
}
}
func TestJSONMarshalling(t *testing.T) {
for i, param := range testParam {
t.Run(fmt.Sprintf("TestJSONMarshalling (%d)", i), func(t *testing.T) {
bytes, err := json.Marshal(param.Value)
require.NoError(t, err)
require.Equal(t, param.JSON, bytes)
})
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment