Skip to content

Instantly share code, notes, and snippets.

@pazams
Created March 25, 2021 19:20
Show Gist options
  • Save pazams/fc2e75420cd6c40c48207a3543c8f9f4 to your computer and use it in GitHub Desktop.
Save pazams/fc2e75420cd6c40c48207a3543c8f9f4 to your computer and use it in GitHub Desktop.
// with permission and based on https://stackoverflow.com/a/23598731
package data
import (
"fmt"
"reflect"
"strings"
)
// safeDynamicPartialUpdateQuery ...
// returns query string, params, and error
func safeDynamicPartialUpdateQuery(table string, whereModel map[string]interface{}, patchModel interface{}) (string, []interface{}, error) {
var setColumns []string
var whereConditions []string
var params []interface{}
patchMap, err := StructNonNilPointerFieldsToMap(patchModel)
if err != nil {
return "", params, err
}
if len(patchMap) == 0 {
return "", params, fmt.Errorf("No updates where found in model (all values nil)")
}
for patchField, patchValue := range patchMap {
setColumns = append(setColumns, fmt.Sprintf("%s = ?", patchField))
params = append(params, patchValue)
}
for whereField, whereValue := range whereModel {
whereConditions = append(whereConditions, fmt.Sprintf("%s = ?", whereField))
params = append(params, whereValue)
}
set := strings.Join(setColumns, ", ")
where := strings.Join(whereConditions, " AND ")
query := fmt.Sprintf(`UPDATE %s SET %s WHERE %s;`, table, set, where)
return query, params, nil
}
// StructNonNilPointerFieldsToMap
// uses the "db" struct field
// based on https://stackoverflow.com/a/23598731
func StructNonNilPointerFieldsToMap(strct interface{}) (map[string]interface{}, error) {
out := make(map[string]interface{})
v := reflect.ValueOf(strct)
// de-reference in case of pointer
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
// we only accept structs
if v.Kind() != reflect.Struct {
return nil, fmt.Errorf("StructNonNilPointerFieldsToMap only accepts structs; got %T", v)
}
typ := v.Type()
for i := 0; i < v.NumField(); i++ {
// gets us a StructField
fi := typ.Field(i)
fv := v.Field(i)
// filter out non-pointer fields and nil pointers
if fv.Kind() != reflect.Ptr || fv.IsNil() {
continue
}
if tagv := fi.Tag.Get("db"); tagv != "" {
out[tagv] = fv.Elem().Interface()
}
}
return out, nil
}
// with permission and based on https://stackoverflow.com/a/23598731
package data
import (
"testing"
. "github.com/onsi/gomega"
)
type Foo struct {
Field1 *string `db:"field1"`
Field2 *int `db:"field2"`
IgnoredNonPtr string `db:"IgnoredNonPtr"`
}
func TestSafeDynamicPartialUpdate(t *testing.T) {
table := "fooTable"
field1Val := "value1"
field2Val := 2
patches := Foo{
Field1: &field1Val,
Field2: &field2Val,
IgnoredNonPtr: "this will be ignore",
}
where := map[string]interface{}{
"id": "fooId",
}
query, params, err := safeDynamicPartialUpdateQuery(table, where, patches)
expected := "UPDATE fooTable SET field1 = ?, field2 = ? WHERE id = ?;"
if err != nil {
t.Error(err)
}
if query != expected {
t.Error("expected:", expected, "actual:", query)
}
expectedParams := []interface{}{"value1", 2, "fooId"}
RegisterTestingT(t)
Expect(params).To(ConsistOf(expectedParams))
}
func TestSafeDynamicPartialUpdateWithNilPointers(t *testing.T) {
table := "fooTable"
field2Val := 2
patches := Foo{
Field1: nil,
Field2: &field2Val,
IgnoredNonPtr: "this will be ignore",
}
where := map[string]interface{}{
"id": "fooId",
}
query, params, err := safeDynamicPartialUpdateQuery(table, where, patches)
expected := "UPDATE fooTable SET field2 = ? WHERE id = ?;"
if err != nil {
t.Error(err)
}
if query != expected {
t.Error("expected:", expected, "actual:", query)
}
expectedParams := []interface{}{2, "fooId"}
RegisterTestingT(t)
Expect(params).To(ConsistOf(expectedParams))
}
func TestSafeDynamicPartialUpdateWithNoUpdatesShouldError(t *testing.T) {
table := "fooTable"
patches := Foo{
Field1: nil,
Field2: nil,
IgnoredNonPtr: "this will be ignore",
}
where := map[string]interface{}{
"id": "fooId",
}
query, params, err := safeDynamicPartialUpdateQuery(table, where, patches)
expected := ""
if query != expected {
t.Error("expected:", expected, "actual:", query)
}
if err == nil {
t.Error("Expected an error, but got nil")
}
expectedParams := []interface{}{}
RegisterTestingT(t)
Expect(params).To(ConsistOf(expectedParams))
}
func TestSafeDynamicPartialUpdateMultipleWhere(t *testing.T) {
table := "fooTable"
field1Val := "value1"
field2Val := 2
patches := Foo{
Field1: &field1Val,
Field2: &field2Val,
IgnoredNonPtr: "this will be ignore",
}
where := map[string]interface{}{
"id": "fooId",
"id2": "barId2",
}
query, params, err := safeDynamicPartialUpdateQuery(table, where, patches)
expected := "UPDATE fooTable SET field1 = ?, field2 = ? WHERE id = ? AND id2 = ?;"
if err != nil {
t.Error(err)
}
if query != expected {
t.Error("expected:", expected, "actual:", query)
}
expectedParams := []interface{}{"value1", 2, "fooId", "barId2"}
RegisterTestingT(t)
Expect(params).To(ConsistOf(expectedParams))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment