Skip to content

Instantly share code, notes, and snippets.

@arnehormann
Last active Aug 29, 2015
Embed
What would you like to do?
go sql helper
import (
"database/sql"
"fmt"
"reflect"
)
type Field interface {
Name() string
Ptr() interface{}
Tags() string
Tag(string) string
}
type structField struct {
value reflect.Value
field reflect.StructField
}
var _ Field = (*structField)(nil)
func (f *structField) Name() string {
return f.field.Name
}
func (f *structField) Ptr() interface{} {
return f.value.Addr().Interface()
}
func (f *structField) Tags() string {
return string(f.field.Tag)
}
func (f *structField) Tag(key string) string {
return f.field.Tag.Get(key)
}
type structPtr struct {
fields []structField
}
func Struct(ptr interface{}) (*structPtr, error) {
v := reflect.ValueOf(ptr)
if v.Kind() == reflect.Ptr {
v = v.Elem()
if v.Kind() == reflect.Struct {
n := v.NumField()
sptr := &structPtr{fields: make([]structField, n)}
vt := v.Type()
for i := 0; i < n; i++ {
sptr.fields[i] = structField{
value: v.Field(i),
field: vt.Field(i),
}
}
return sptr, nil
}
}
return nil, fmt.Errorf("%T is not a pointer", ptr)
}
func (s *structPtr) NumField() int {
return len(s.fields)
}
func (s *structPtr) Field(i int) Field {
return &s.fields[i]
}
// ref must be pointer to struct with an embedded target type.
func fetch(ref Appender, db *sql.DB, query string, args ...interface{}) error {
sptr, err := Struct(ref)
if err != nil || sptr.NumField() < 1 {
return fmt.Errorf("%T is not a pointer to a valid struct", ref)
}
f0 := sptr.Field(0)
sptr, err = Struct(f0.Ptr())
if err != nil {
return fmt.Errorf("reference field %s (%T) is not a struct", f0.Name(), f0.Ptr())
}
rows, err := db.Query(query, args...)
if err != nil {
return err
}
defer rows.Close()
cols, err := rows.Columns()
if err != nil {
return err
}
var unbound interface{}
bind := make([]interface{}, len(cols))
for i := range bind {
bind[i] = &unbound
}
found := 0
colmatch := make([]bool, len(cols))
numfield := sptr.NumField()
for i := 0; i < numfield; i++ {
field := sptr.Field(i)
name := field.Name()
for c := range cols {
if cols[c] == name {
bind[c] = field.Ptr()
colmatch[c] = true
}
}
}
// this error could be optional (no field for some cols)
if len(cols) < found {
missing := []string{}
for i := range colmatch {
if !colmatch[i] {
missing = append(missing, cols[i])
}
}
return fmt.Errorf("missing fields for columns (%#v)", missing)
}
for rows.Next() {
err = rows.Scan(bind...)
if err != nil {
return err
}
ref.Append()
}
if err = rows.Err(); err != nil {
return err
}
return nil
}
type Appender interface {
Append()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment