Last active
July 2, 2020 05:10
-
-
Save 17twenty/0a5eb637006692b68e81b8f67b3ef2cd to your computer and use it in GitHub Desktop.
JSONFilteredByWhitelist filters an incoming JSON payload against a while list then applies the filtered payload to the destination struct.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package types | |
import ( | |
"encoding/json" | |
"io" | |
"reflect" | |
"strings" | |
) | |
// JSONFilteredByWhitelist filters an incoming JSON payload against a while list, | |
// then applies the filtered payload to the destination struct. | |
// Note: this only works for structs and embeded structs, we do not expect to traverse slices | |
// | |
// Use Cases: | |
// | |
// - Patch operations where only specific fields are exposed, eg "first_name", "last_name" can be patched but not "created_at" | |
// | |
// - Put operations where you either create or update and want to only expose specific fields if it's an update operation | |
func JSONFilteredByWhitelist(jsonBody io.Reader, destination interface{}, whitelist ...string) error { | |
tmp := map[string]json.RawMessage{} | |
if err := json.NewDecoder(jsonBody).Decode(&tmp); err != nil { | |
return err | |
} | |
fields := extract(destination) | |
for _, name := range whitelist { | |
field, ok := fields[name] | |
if !ok { | |
continue | |
} | |
raw, ok := tmp[name] | |
if !ok { | |
continue | |
} | |
val := reflect.New(field.Type()) | |
if err := json.Unmarshal(raw, val.Interface()); err != nil { | |
return err | |
} | |
field.Set(val.Elem()) | |
} | |
return nil | |
} | |
func extract(dst interface{}) map[string]reflect.Value { | |
t := reflect.TypeOf(dst) | |
v := reflect.ValueOf(dst) | |
if t.Kind() == reflect.Ptr { | |
t = t.Elem() | |
v = v.Elem() | |
} | |
fields := map[string]reflect.Value{} | |
for i := 0; i < t.NumField(); i++ { | |
field := t.Field(i) | |
tags := strings.Split(field.Tag.Get("json"), ",") | |
name := "" | |
if len(tags) > 0 { | |
name = tags[0] | |
} | |
if name == "-" { | |
continue | |
} | |
if name == "" { | |
name = field.Name | |
} | |
fieldValue := v.Field(i) | |
if field.Anonymous { // embedded struct | |
fieldType := field.Type | |
if fieldType.Kind() == reflect.Ptr { | |
fieldType = fieldType.Elem() | |
fieldValue = fieldValue.Elem() | |
} | |
if !fieldValue.IsValid() { // eg. is nil | |
// init embedded struct | |
fieldValue = reflect.New(fieldType) | |
v.Field(i).Set(fieldValue) | |
fieldValue = fieldValue.Elem() | |
} | |
for key, value := range extract(fieldValue.Addr().Interface()) { | |
fields[key] = value | |
} | |
} | |
fields[name] = fieldValue | |
} | |
return fields | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package types | |
import ( | |
"bytes" | |
"encoding/json" | |
"testing" | |
"github.com/stretchr/testify/assert" | |
) | |
func Test_extract(t *testing.T) { | |
foo := User{ | |
FirstName: "bar", | |
LastName: "bin", | |
} | |
got := extract(&foo) | |
assert.Equal(t, foo.FirstName, got["first_name"].Interface()) | |
assert.Equal(t, foo.LastName, got["last_name"].Interface()) | |
} | |
func TestJSONFilteredByWhitelist(t *testing.T) { | |
incomingPayload := User{ | |
FirstName: "bunny", | |
MiddleName: "yeah naa bro it changed", | |
LastName: "foofoo", | |
Address: Address{ | |
RawAddress: "foo", | |
}, | |
} | |
data, err := json.Marshal(incomingPayload) | |
b := bytes.NewBuffer(data) | |
whitelist := []string{"first_name", "last_name", "raw_address"} | |
dataToBeStored := User{ | |
FirstName: "firstname change", | |
LastName: "lastname change", | |
MiddleName: "middlename no change", | |
Address: Address{ | |
RawAddress: "raw address change", | |
}, | |
} | |
err = JSONFilteredByWhitelist(b, &dataToBeStored, whitelist...) | |
assert.Nil(t, err) | |
assert.Equal(t, incomingPayload.FirstName, dataToBeStored.FirstName) | |
assert.Equal(t, incomingPayload.LastName, dataToBeStored.LastName) | |
assert.Equal(t, "middlename no change", dataToBeStored.MiddleName) | |
assert.Equal(t, incomingPayload.Address.RawAddress, dataToBeStored.Address.RawAddress) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment