Skip to content

Instantly share code, notes, and snippets.

@tadasv
Last active February 14, 2022 19:19
Show Gist options
  • Save tadasv/525ab661bd7916d79cee83e6e1635a56 to your computer and use it in GitHub Desktop.
Save tadasv/525ab661bd7916d79cee83e6e1635a56 to your computer and use it in GitHub Desktop.
Simple ORM for Golang
/*
This is an implementation of simple ORM that supports CRUD operations on single object, column mapping from struct tags.
example:
type Account struct {
orm.Model
UUID string
Name string
Email string
ActiveOrganizationUUID *string
MemberSince time.Time
CreatedAt time.Time
ModifiedAt time.Time
}
func (a *Account) BindTo(connector orm.Connector) *Account {
a.Model.Connector = connector
a.Model.Object = a
return a
}
acc := (&Account{UUID: uuid}).BindTo(db)
err := acc.Load()
fmt.Printf("%v %v\n", err, acc)
Unlike other ORM frameworks I've seen, here we bind DB connection to your object. Which is very useful if you want
to do rapid web development and pass objects to template renderer, where you might pull out more related data from DB
during rendering.
*/
package orm
import (
"database/sql"
sq "github.com/Masterminds/squirrel"
"reflect"
"regexp"
"strings"
)
var (
matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")
defaultPrimaryKeyColumnNames = []string{
"id",
"uuid",
}
)
type Connector interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
type Model struct {
PrimaryKeyColumnName string
TableName string
Connector Connector
Object interface{}
}
func (o Model) getTableName() string {
if o.TableName != "" {
return o.TableName
}
var name string
if t := reflect.TypeOf(o.Object); t.Kind() == reflect.Ptr {
name = t.Elem().Name()
} else {
name = t.Name()
}
return strings.ToLower(name)
}
func (o Model) getColumnsAndNames() ([]interface{}, []string) {
names := []string{}
columns := []interface{}{}
val := reflect.Indirect(reflect.ValueOf(o.Object))
valType := val.Type()
for i := 0; i < valType.NumField(); i++ {
field := valType.Field(i)
if field.Type.AssignableTo(reflect.TypeOf(o)) {
// Ignore embedded Model internals when deriving column names
continue
}
if colName, ok := field.Tag.Lookup("db_col"); ok {
if colName == "-" {
continue
}
names = append(names, colName)
} else {
names = append(names, toSnakeCase(field.Name))
}
columns = append(columns, val.Field(i).Addr().Interface())
}
return columns, names
}
func (o Model) getPrimaryKey(names []string, columns []interface{}) (string, interface{}) {
for i, name := range names {
for _, pkName := range defaultPrimaryKeyColumnNames {
if name == pkName {
return name, columns[i]
}
}
}
return "", nil
}
func (o Model) Load() error {
columns, colNames := o.getColumnsAndNames()
primaryKeyName, primaryKeyValue := o.getPrimaryKey(colNames, columns)
q := sq.Select(colNames...).From(o.getTableName())
q = q.Where(sq.Eq{
primaryKeyName: primaryKeyValue,
}).Limit(1)
sql, args, err := q.ToSql()
if err != nil {
return err
}
row := o.Connector.QueryRow(sql, args...)
return row.Scan(columns...)
}
func (o Model) getUpdateQuery() sq.UpdateBuilder {
columns, colNames := o.getColumnsAndNames()
primaryKeyName, primaryKeyValue := o.getPrimaryKey(colNames, columns)
q := sq.Update(o.getTableName())
for i, name := range colNames {
value := columns[i]
q = q.Set(name, value)
}
q = q.Where(sq.Eq{
primaryKeyName: primaryKeyValue,
}).Limit(1)
return q
}
func (o Model) Save() error {
q := o.getUpdateQuery()
sql, args, err := q.ToSql()
if err != nil {
return err
}
_, err = o.Connector.Exec(sql, args...)
if err != nil {
return err
}
return nil
}
func (o Model) Create() error {
columns, colNames := o.getColumnsAndNames()
q := sq.Insert(o.getTableName()).Columns(colNames...).Values(columns...)
sql, args, err := q.ToSql()
if err != nil {
return err
}
_, err = o.Connector.Exec(sql, args...)
return err
}
func (o Model) Delete() error {
columns, colNames := o.getColumnsAndNames()
primaryKeyName, primaryKeyValue := o.getPrimaryKey(colNames, columns)
q := sq.Delete(o.getTableName()).Where(sq.Eq{
primaryKeyName: primaryKeyValue,
}).Limit(1)
sql, args, err := q.ToSql()
if err != nil {
return err
}
_, err = o.Connector.Exec(sql, args...)
if err != nil {
return err
}
return nil
}
func toSnakeCase(str string) string {
snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}")
snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}")
return strings.ToLower(snake)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment