Skip to content

Instantly share code, notes, and snippets.

@jaekwon
Created December 3, 2013 11:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jaekwon/7767747 to your computer and use it in GitHub Desktop.
Save jaekwon/7767747 to your computer and use it in GitHub Desktop.
type RowScanner interface {
Scan(dest ...interface{}) error
}
type ModelInfo struct {
Type reflect.Type
TableName string
Fields []*reflect.StructField
FieldsSimple string
FieldsPrefixed string
Placeholders string
}
var allModelInfos = map[string]*ModelInfo{}
func (m *ModelInfo) FieldValues(i interface{}) []interface{} {
v := reflect.ValueOf(i)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Type() != m.Type {
log.Panicf("Invalid argument for FieldValues: Type mismatch. Expected %v but got %v",
v.Type(), m.Type)
}
fvs := []interface{}{}
for _, field := range m.Fields {
name := field.Name
fieldValue := v.FieldByName(name)
fvs = append(fvs, fieldValue.Interface())
}
return fvs
}
func GetModelInfo(i interface{}) *ModelInfo {
t := reflect.TypeOf(i)
return GetModelInfoFromType(t)
}
func GetModelInfoFromType(modelType reflect.Type) *ModelInfo {
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType.Kind() != reflect.Struct {
return nil
}
modelName := modelType.Name()
// Check cache
if allModelInfos[modelName] != nil {
return allModelInfos[modelName]
}
// Construct
m := &ModelInfo{}
allModelInfos[modelName] = m
m.Type = modelType
m.TableName = strings.ToLower(modelName)
// Fields
numFields := m.Type.NumField()
for i:=0; i<numFields; i++ {
field := m.Type.Field(i)
if field.Tag.Get("db") != "" {
m.Fields = append(m.Fields, &field)
}
}
// Simple & Prefixed
fieldNames := []string{}
ph := []string{}
for _, field := range m.Fields {
fieldNames = append(fieldNames, field.Tag.Get("db"))
ph = append(ph, "?")
}
m.FieldsSimple = strings.Join(fieldNames, ", ")
m.FieldsPrefixed = m.TableName+"."+strings.Join(fieldNames, ", "+m.TableName+".")
m.Placeholders = strings.Join(ph, ", ")
return m
}
func expandArgs(args []interface{}) []interface{} {
a := []interface{}{}
for _, arg := range args {
modelInfo := GetModelInfo(arg)
if modelInfo == nil {
a = append(a, arg)
} else {
a = append(a, modelInfo.FieldValues(arg)...)
}
}
return a
}
func Exec(query string, args ...interface{}) (sql.Result, error) {
return GetDB().Exec(query, expandArgs(args)...)
}
func QueryRow(query string, args ...interface{}) RowScanner {
return &StructScanner{GetDB().QueryRow(query, expandArgs(args)...)}
}
func Query(query string, args ...interface{}) (RowScanner, error) {
rows, err := GetDB().Query(query, expandArgs(args)...)
if err != nil { return nil, err }
return &StructScanner{rows}, nil
}
type StructScanner struct {
Scanner RowScanner
}
func (s *StructScanner) Scan(dest ...interface{}) error {
destValuesP := []interface{}{}
for _, d := range dest {
dValueP := reflect.ValueOf(d)
dValue := dValueP.Elem()
if dValue.Kind() != reflect.Struct {
destValuesP = append(destValuesP, dValueP.Addr().Interface())
} else {
m := GetModelInfoFromType(dValue.Type())
for _, field := range m.Fields {
dField := dValue.FieldByName(field.Name)
destValuesP = append(destValuesP, dField.Addr().Interface())
}
}
}
return s.Scanner.Scan(destValuesP...)
}
////////////// USAGE
type User struct {
Id string `db:"id"`
Email string `db:"email"`
}
var UserModel = GetModelIfno(new(User))
func test() {
// inserting a struct
db.Exec(`INSERT INTO user(`+UserModel.FieldsSimple+`) VALUES (`+UserModel.Placeholders+`)`, user)
// loading a struct
var user User
db.QueryRow(`SELECT `+UserModel.FieldsSimple+` FROM user WHERE email=?`, email).Scan(&user)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment