Skip to content

Instantly share code, notes, and snippets.

@SVilgelm
Last active April 7, 2023 08:41
Show Gist options
  • Save SVilgelm/0854d06308e36228857d08571d20aaf1 to your computer and use it in GitHub Desktop.
Save SVilgelm/0854d06308e36228857d08571d20aaf1 to your computer and use it in GitHub Desktop.
single or array type in Go
package spec
import (
"encoding/json"
"gopkg.in/yaml.v3"
)
// SingleOrArray holds list or single value
type SingleOrArray[T any] []T
// NewSingleOrArray creates SingleOrArray object.
func NewSingleOrArray[T any](v ...T) SingleOrArray[T] {
return append([]T{}, v...)
}
// UnmarshalJSON implements json.Unmarshaler interface.
func (o *SingleOrArray[T]) UnmarshalJSON(data []byte) error {
var ret []T
if json.Unmarshal(data, &ret) != nil {
var s T
if err := json.Unmarshal(data, &s); err != nil {
return err
}
ret = []T{s}
}
*o = ret
return nil
}
// MarshalJSON implements json.Marshaler interface.
func (o SingleOrArray[T]) MarshalJSON() ([]byte, error) {
if len(o) == 1 {
return json.Marshal(o[0])
}
return json.Marshal([]T(o))
}
// UnmarshalYAML implements yaml.Unmarshaler interface.
func (o *SingleOrArray[T]) UnmarshalYAML(node *yaml.Node) error {
var ret []T
if node.Decode(&ret) != nil {
var s T
if err := node.Decode(&s); err != nil {
return err
}
ret = []T{s}
}
*o = ret
return nil
}
// MarshalYAML implements yaml.Marshaler interface.
func (o SingleOrArray[T]) MarshalYAML() (any, error) {
var v any
v = []T(o)
if len(o) == 1 {
v = o[0]
}
return v, nil
}
package spec_test
import (
"bytes"
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
"github.com/sv-tools/openapi/spec"
)
type singleOrArrayCase[T any] struct {
name string
data []byte
expected spec.SingleOrArray[T]
wantErr bool
}
func testSingleOrArrayJSON[T any](t *testing.T, tests []singleOrArrayCase[T]) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var o spec.SingleOrArray[T]
err := json.Unmarshal(tt.data, &o)
if tt.wantErr {
require.Error(t, err)
return
} else {
require.NoError(t, err)
require.Equal(t, tt.expected, o)
}
newData, err := json.Marshal(&o)
require.NoError(t, err)
t.Log("orig: ", string(tt.data))
t.Log(" new: ", string(newData))
require.Equal(t, tt.data, newData)
})
}
}
func TestSingleOrArrayJSON(t *testing.T) {
t.Run("string", func(t *testing.T) {
testSingleOrArrayJSON(t, []singleOrArrayCase[string]{
{
name: "single",
data: []byte(`"single"`),
expected: spec.NewSingleOrArray("single"),
},
{
name: "multi",
data: []byte(`["first","second"]`),
expected: spec.NewSingleOrArray("first", "second"),
},
{
name: "null",
data: []byte(`null`),
},
{
name: "int for string",
data: []byte(`42`),
wantErr: true,
},
{
name: "array of int for string",
data: []byte(`[42, 103]`),
wantErr: true,
},
{
name: "empty for string",
data: []byte(``),
wantErr: true,
},
})
})
t.Run("int", func(t *testing.T) {
testSingleOrArrayJSON(t, []singleOrArrayCase[int]{
{
name: "single",
data: []byte(`1`),
expected: spec.NewSingleOrArray(1),
},
{
name: "multi",
data: []byte(`[1,2]`),
expected: spec.NewSingleOrArray(1, 2),
},
{
name: "null",
data: []byte(`null`),
},
{
name: "string for int",
data: []byte(`"single"`),
wantErr: true,
},
{
name: "array of string for int",
data: []byte(`["first","second"]`),
wantErr: true,
},
{
name: "empty for int",
data: []byte(``),
wantErr: true,
},
})
})
type Foo struct {
A string
B int
}
t.Run("struct", func(t *testing.T) {
testSingleOrArrayJSON(t, []singleOrArrayCase[Foo]{
{
name: "single",
data: []byte(`{"A":"single","B":42}`),
expected: spec.NewSingleOrArray(Foo{A: "single", B: 42}),
},
{
name: "multi",
data: []byte(`[{"A":"first","B":1},{"A":"second","B":2}]`),
expected: spec.NewSingleOrArray(Foo{A: "first", B: 1}, Foo{A: "second", B: 2}),
},
{
name: "null",
data: []byte(`null`),
},
{
name: "string for struct",
data: []byte(`"single"`),
wantErr: true,
},
{
name: "array of string for struct",
data: []byte(`["first","second"]`),
wantErr: true,
},
{
name: "empty for struct",
data: []byte(``),
wantErr: true,
},
})
})
}
func testSingleOrArrayYAML[T any](t *testing.T, tests []singleOrArrayCase[T]) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
//t.Parallel()
var o spec.SingleOrArray[T]
err := yaml.Unmarshal(tt.data, &o)
if tt.wantErr {
require.Error(t, err)
return
} else {
require.NoError(t, err)
require.Equal(t, tt.expected, o)
}
newData, err := yaml.Marshal(&o)
newData = bytes.TrimSpace(newData)
require.NoError(t, err)
t.Log("orig: ", string(tt.data))
t.Log(" new: ", string(newData))
require.Equal(t, tt.data, newData)
})
}
}
func TestSingleOrArrayYAML(t *testing.T) {
t.Run("string", func(t *testing.T) {
testSingleOrArrayYAML(t, []singleOrArrayCase[string]{
{
name: "single",
data: []byte(`single`),
expected: spec.NewSingleOrArray("single"),
},
{
name: "multi",
data: []byte(`- first
- second`),
expected: spec.NewSingleOrArray("first", "second"),
},
})
})
t.Run("int", func(t *testing.T) {
testSingleOrArrayYAML(t, []singleOrArrayCase[int]{
{
name: "single",
data: []byte(`1`),
expected: spec.NewSingleOrArray(1),
},
{
name: "multi",
data: []byte(`- 1
- 2`),
expected: spec.NewSingleOrArray(1, 2),
},
{
name: "string for int",
data: []byte(`single`),
wantErr: true,
},
{
name: "array of string for int",
data: []byte(`- first
- second`),
wantErr: true,
},
})
})
type Foo struct {
A string
B int
}
t.Run("struct", func(t *testing.T) {
testSingleOrArrayYAML(t, []singleOrArrayCase[Foo]{
{
name: "single",
data: []byte(`a: single
b: 42`),
expected: spec.NewSingleOrArray(Foo{A: "single", B: 42}),
},
{
name: "multi",
data: []byte(`- a: first
b: 1
- a: second
b: 2`),
expected: spec.NewSingleOrArray(Foo{A: "first", B: 1}, Foo{A: "second", B: 2}),
},
{
name: "string for struct",
data: []byte(`single`),
wantErr: true,
},
{
name: "array of string for struct",
data: []byte(`- first
- second`),
wantErr: true,
},
})
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment