Skip to content

Instantly share code, notes, and snippets.

@17twenty
Last active July 2, 2020 05:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save 17twenty/0a5eb637006692b68e81b8f67b3ef2cd to your computer and use it in GitHub Desktop.
Save 17twenty/0a5eb637006692b68e81b8f67b3ef2cd to your computer and use it in GitHub Desktop.
JSONFilteredByWhitelist filters an incoming JSON payload against a while list then applies the filtered payload to the destination struct.
package types
import (
"encoding/json"
"io"
"reflect"
"strings"
)
// JSONFilteredByWhitelist filters an incoming JSON payload against a while list,
// then applies the filtered payload to the destination struct.
// Note: this only works for structs and embeded structs, we do not expect to traverse slices
//
// Use Cases:
//
// - Patch operations where only specific fields are exposed, eg "first_name", "last_name" can be patched but not "created_at"
//
// - Put operations where you either create or update and want to only expose specific fields if it's an update operation
func JSONFilteredByWhitelist(jsonBody io.Reader, destination interface{}, whitelist ...string) error {
tmp := map[string]json.RawMessage{}
if err := json.NewDecoder(jsonBody).Decode(&tmp); err != nil {
return err
}
fields := extract(destination)
for _, name := range whitelist {
field, ok := fields[name]
if !ok {
continue
}
raw, ok := tmp[name]
if !ok {
continue
}
val := reflect.New(field.Type())
if err := json.Unmarshal(raw, val.Interface()); err != nil {
return err
}
field.Set(val.Elem())
}
return nil
}
func extract(dst interface{}) map[string]reflect.Value {
t := reflect.TypeOf(dst)
v := reflect.ValueOf(dst)
if t.Kind() == reflect.Ptr {
t = t.Elem()
v = v.Elem()
}
fields := map[string]reflect.Value{}
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
tags := strings.Split(field.Tag.Get("json"), ",")
name := ""
if len(tags) > 0 {
name = tags[0]
}
if name == "-" {
continue
}
if name == "" {
name = field.Name
}
fieldValue := v.Field(i)
if field.Anonymous { // embedded struct
fieldType := field.Type
if fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
fieldValue = fieldValue.Elem()
}
if !fieldValue.IsValid() { // eg. is nil
// init embedded struct
fieldValue = reflect.New(fieldType)
v.Field(i).Set(fieldValue)
fieldValue = fieldValue.Elem()
}
for key, value := range extract(fieldValue.Addr().Interface()) {
fields[key] = value
}
}
fields[name] = fieldValue
}
return fields
}
package types
import (
"bytes"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_extract(t *testing.T) {
foo := User{
FirstName: "bar",
LastName: "bin",
}
got := extract(&foo)
assert.Equal(t, foo.FirstName, got["first_name"].Interface())
assert.Equal(t, foo.LastName, got["last_name"].Interface())
}
func TestJSONFilteredByWhitelist(t *testing.T) {
incomingPayload := User{
FirstName: "bunny",
MiddleName: "yeah naa bro it changed",
LastName: "foofoo",
Address: Address{
RawAddress: "foo",
},
}
data, err := json.Marshal(incomingPayload)
b := bytes.NewBuffer(data)
whitelist := []string{"first_name", "last_name", "raw_address"}
dataToBeStored := User{
FirstName: "firstname change",
LastName: "lastname change",
MiddleName: "middlename no change",
Address: Address{
RawAddress: "raw address change",
},
}
err = JSONFilteredByWhitelist(b, &dataToBeStored, whitelist...)
assert.Nil(t, err)
assert.Equal(t, incomingPayload.FirstName, dataToBeStored.FirstName)
assert.Equal(t, incomingPayload.LastName, dataToBeStored.LastName)
assert.Equal(t, "middlename no change", dataToBeStored.MiddleName)
assert.Equal(t, incomingPayload.Address.RawAddress, dataToBeStored.Address.RawAddress)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment