Skip to content

Instantly share code, notes, and snippets.

@slene
Created January 7, 2014 10:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save slene/8297565 to your computer and use it in GitHub Desktop.
Save slene/8297565 to your computer and use it in GitHub Desktop.
qbs_disable_stmt_cache full file
package qbs
import (
"database/sql"
"errors"
"fmt"
"log"
"os"
"reflect"
"strings"
"sync"
"time"
)
var driver, driverSource, dbName string
var dial Dialect
var connectionLimit chan struct{}
var blockingOnLimit bool
var ConnectionLimitError = errors.New("Connection limit reached")
var db *sql.DB
var stmtMap map[string]*sql.Stmt
var mu *sync.RWMutex
var queryLogger *log.Logger = log.New(os.Stdout, "qbs:", log.LstdFlags)
var errorLogger *log.Logger = log.New(os.Stderr, "qbs:", log.LstdFlags)
type Qbs struct {
Dialect Dialect
Log bool //Set to true to print out sql statement.
tx *sql.Tx
txStmtMap map[string]*sql.Stmt
criteria *criteria
firstTxError error
}
type Validator interface {
Validate(*Qbs) error
}
//Register a database, should be call at the beginning of the application.
func Register(driverName, driverSourceName, databaseName string, dialect Dialect) {
driver = driverName
driverSource = driverSourceName
dial = dialect
dbName = databaseName
if db == nil {
var err error
db, err = sql.Open(driver, driverSource)
if err != nil {
panic(err)
}
db.SetMaxIdleConns(100)
stmtMap = make(map[string]*sql.Stmt)
mu = new(sync.RWMutex)
}
}
//A safe and easy way to work with *Qbs instance without the need to open and close it.
func WithQbs(task func(*Qbs) error) error {
q, err := GetQbs()
if err != nil {
return err
}
defer q.Close()
return task(q)
}
//Get an Qbs instance, should call `defer q.Close()` next, like:
//
// q, err := qbs.GetQbs()
// if err != nil {
// fmt.Println(err)
// return
// }
// defer q.Close()
// ...
//
func GetQbs() (q *Qbs, err error) {
if driver == "" || dial == nil {
panic("database driver has not been registered, should call Register first.")
}
if connectionLimit != nil {
if blockingOnLimit {
connectionLimit <- struct{}{}
} else {
select {
case connectionLimit <- struct{}{}:
default:
return nil, ConnectionLimitError
}
}
}
q = new(Qbs)
q.Dialect = dial
q.criteria = new(criteria)
return q, nil
}
//The default connection pool size is 100.
func ChangePoolSize(size int) {
db.SetMaxIdleConns(size)
}
func SetLogger(query *log.Logger, err *log.Logger) {
queryLogger = query
errorLogger = err
}
//Set the connection limit, there is no limit by default.
//If blocking is true, GetQbs method will be blocked, otherwise returns ConnectionLimitError.
func SetConnectionLimit(maxCon int, blocking bool) {
if maxCon > 0 {
connectionLimit = make(chan struct{}, maxCon)
} else if maxCon < 0 {
connectionLimit = nil
}
blockingOnLimit = blocking
}
// Create a new criteria for subsequent query
func (q *Qbs) Reset() {
q.criteria = new(criteria)
}
// Begin create a transaction object internally
// You can perform queries with the same Qbs object
// no matter it is in transaction or not.
// It panics if it's already in a transaction.
func (q *Qbs) Begin() error {
if q.tx != nil {
panic("cannot start nested transaction")
}
tx, err := db.Begin()
q.tx = tx
q.txStmtMap = make(map[string]*sql.Stmt)
return err
}
func (q *Qbs) InTransaction() bool {
return q.tx != nil
}
func (q *Qbs) updateTxError(e error) error {
if e != nil {
if errorLogger != nil {
errorLogger.Println(e)
}
// don't shadow the first error
if q.firstTxError == nil {
q.firstTxError = e
}
}
return e
}
// Commit commits a started transaction and will report the first error that
// occurred inside the transaction.
func (q *Qbs) Commit() error {
err := q.tx.Commit()
q.updateTxError(err)
q.tx = nil
for _, v := range q.txStmtMap {
v.Close()
}
q.txStmtMap = nil
return q.firstTxError
}
// Rollback rolls back a started transaction.
func (q *Qbs) Rollback() error {
err := q.tx.Rollback()
q.tx = nil
for _, v := range q.txStmtMap {
v.Close()
}
q.txStmtMap = nil
return q.updateTxError(err)
}
// Where is a shortcut method to call Condtion(NewCondtition(expr, args...)).
func (q *Qbs) Where(expr string, args ...interface{}) *Qbs {
q.criteria.condition = NewCondition(expr, args...)
return q
}
//Snakecase column name
func (q *Qbs) WhereEqual(column string, value interface{}) *Qbs {
q.criteria.condition = NewEqualCondition(column, value)
return q
}
func (q *Qbs) WhereIn(column string, values []interface{}) *Qbs {
q.criteria.condition = NewInCondition(column, values)
return q
}
//Condition defines the SQL "WHERE" clause
//If other condition can be inferred by the struct argument in
//Find method, it will be merged with AND
func (q *Qbs) Condition(condition *Condition) *Qbs {
q.criteria.condition = condition
return q
}
func (q *Qbs) Limit(limit int) *Qbs {
q.criteria.limit = limit
return q
}
func (q *Qbs) Offset(offset int) *Qbs {
q.criteria.offset = offset
return q
}
func (q *Qbs) OrderBy(path string) *Qbs {
q.criteria.orderBys = append(q.criteria.orderBys, order{q.Dialect.quote(path), false})
return q
}
func (q *Qbs) OrderByDesc(path string) *Qbs {
q.criteria.orderBys = append(q.criteria.orderBys, order{q.Dialect.quote(path), true})
return q
}
// Camel case field names
func (q *Qbs) OmitFields(fieldName ...string) *Qbs {
q.criteria.omitFields = fieldName
return q
}
func (q *Qbs) OmitJoin() *Qbs {
q.criteria.omitJoin = true
return q
}
// Perform select query by parsing the struct's type and then fill the values into the struct
// All fields of supported types in the struct will be added in select clause.
// If Id value is provided, it will be added into the where clause
// If a foreign key field with its referenced struct pointer field are provided,
// It will perform a join query, the referenced struct pointer field will be filled in
// the values obtained by the query.
// If not found, "sql.ErrNoRows" will be returned.
func (q *Qbs) Find(structPtr interface{}) error {
q.criteria.model = structPtrToModel(structPtr, !q.criteria.omitJoin, q.criteria.omitFields)
q.criteria.limit = 1
if !q.criteria.model.pkZero() {
idPath := q.Dialect.quote(q.criteria.model.table) + "." + q.Dialect.quote(q.criteria.model.pk.name)
idCondition := NewCondition(idPath+" = ?", q.criteria.model.pk.value)
if q.criteria.condition == nil {
q.criteria.condition = idCondition
} else {
q.criteria.condition = idCondition.AndCondition(q.criteria.condition)
}
}
query, args := q.Dialect.querySql(q.criteria)
return q.doQueryRow(structPtr, query, args...)
}
// Similar to Find, except that FindAll accept pointer of slice of struct pointer,
// rows will be appended to the slice.
func (q *Qbs) FindAll(ptrOfSliceOfStructPtr interface{}) error {
strucType := reflect.TypeOf(ptrOfSliceOfStructPtr).Elem().Elem().Elem()
strucPtr := reflect.New(strucType).Interface()
q.criteria.model = structPtrToModel(strucPtr, !q.criteria.omitJoin, q.criteria.omitFields)
query, args := q.Dialect.querySql(q.criteria)
return q.doQueryRows(ptrOfSliceOfStructPtr, query, args...)
}
func (q *Qbs) doQueryRow(out interface{}, query string, args ...interface{}) error {
defer q.Reset()
rowValue := reflect.ValueOf(out)
q.log(query, args...)
stmt, err := q.prepare(query)
if err != nil {
return q.updateTxError(err)
}
defer stmt.Close()
rows, err := stmt.Query(args...)
if err != nil {
return q.updateTxError(err)
}
defer rows.Close()
if rows.Next() {
err = q.scanRows(rowValue, rows)
if err != nil {
return err
}
} else {
return sql.ErrNoRows
}
return nil
}
func (q *Qbs) doQueryRows(out interface{}, query string, args ...interface{}) error {
defer q.Reset()
sliceValue := reflect.Indirect(reflect.ValueOf(out))
structType := sliceValue.Type().Elem().Elem()
q.log(query, args...)
stmt, err := q.prepare(query)
if err != nil {
return q.updateTxError(err)
}
defer stmt.Close()
rows, err := stmt.Query(args...)
if err != nil {
return q.updateTxError(err)
}
defer rows.Close()
for rows.Next() {
rowValue := reflect.New(structType)
err = q.scanRows(rowValue, rows)
if err != nil {
return err
}
sliceValue.Set(reflect.Append(sliceValue, rowValue))
}
return nil
}
func (q *Qbs) scanRows(rowValue reflect.Value, rows *sql.Rows) (err error) {
cols, _ := rows.Columns()
containers := make([]interface{}, 0, len(cols))
for i := 0; i < cap(containers); i++ {
var v interface{}
containers = append(containers, &v)
}
err = rows.Scan(containers...)
if err != nil {
return
}
for i, v := range containers {
value := reflect.Indirect(reflect.ValueOf(v))
if !value.Elem().IsValid() {
continue
}
key := cols[i]
paths := strings.Split(key, "___")
if len(paths) == 2 {
subStruct := rowValue.Elem().FieldByName(TableNameToStructName(paths[0]))
if subStruct.IsNil() {
subStruct.Set(reflect.New(subStruct.Type().Elem()))
}
subField := subStruct.Elem().FieldByName(ColumnNameToFieldName(paths[1]))
if subField.IsValid() {
err = q.Dialect.setModelValue(value, subField)
if err != nil {
return
}
}
} else {
field := rowValue.Elem().FieldByName(ColumnNameToFieldName(key))
if field.IsValid() {
err = q.Dialect.setModelValue(value, field)
if err != nil {
return
}
}
}
}
return
}
// Same as sql.Db.Exec or sql.Tx.Exec depends on if transaction has began
func (q *Qbs) Exec(query string, args ...interface{}) (sql.Result, error) {
defer q.Reset()
query = q.Dialect.substituteMarkers(query)
q.log(query, args...)
stmt, err := q.prepare(query)
if err != nil {
return nil, q.updateTxError(err)
}
defer stmt.Close()
result, err := stmt.Exec(args...)
if err != nil {
return nil, q.updateTxError(err)
}
return result, nil
}
// Same as sql.Db.QueryRow or sql.Tx.QueryRow depends on if transaction has began
func (q *Qbs) QueryRow(query string, args ...interface{}) *sql.Row {
q.log(query, args...)
query = q.Dialect.substituteMarkers(query)
stmt, err := q.prepare(query)
if err != nil {
q.updateTxError(err)
return nil
}
defer stmt.Close()
return stmt.QueryRow(args...)
}
// Same as sql.Db.Query or sql.Tx.Query depends on if transaction has began
func (q *Qbs) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
q.log(query, args...)
query = q.Dialect.substituteMarkers(query)
stmt, err := q.prepare(query)
if err != nil {
q.updateTxError(err)
return
}
defer stmt.Close()
return stmt.Query(args...)
}
// Same as sql.Db.Prepare or sql.Tx.Prepare depends on if transaction has began
func (q *Qbs) prepare(query string) (stmt *sql.Stmt, err error) {
// var ok bool
if q.tx != nil {
// stmt, ok = q.txStmtMap[query]
// if ok {
// return
// }
stmt, err = q.tx.Prepare(query)
if err != nil {
q.updateTxError(err)
return
}
// q.txStmtMap[query] = stmt
} else {
// mu.RLock()
// stmt, ok = stmtMap[query]
// mu.RUnlock()
// if ok {
// return
// }
stmt, err = db.Prepare(query + ";")
if err != nil {
q.updateTxError(err)
return
}
// mu.Lock()
// stmtMap[query] = stmt
// mu.Unlock()
}
return
}
// If Id value is not provided, save will insert the record, and the Id value will
// be filled in the struct after insertion.
// If Id value is provided, save will do a query count first to see if the row exists, if not then insert it,
// otherwise update it.
// If struct implements Validator interface, it will be validated first
func (q *Qbs) Save(structPtr interface{}) (affected int64, err error) {
if v, ok := structPtr.(Validator); ok {
err = v.Validate(q)
if err != nil {
return
}
}
model := structPtrToModel(structPtr, true, q.criteria.omitFields)
if model.pk == nil {
panic("no primary key field")
}
q.criteria.model = model
now := time.Now()
var id int64 = 0
updateModelField := model.timeField("updated")
if updateModelField != nil {
updateModelField.value = now
}
createdModelField := model.timeField("created")
var isInsert bool
if !model.pkZero() && q.WhereEqual(model.pk.name, model.pk.value).Count(model.table) > 0 { //id is given, can be an update operation.
affected, err = q.Dialect.update(q)
} else {
if createdModelField != nil {
createdModelField.value = now
}
id, err = q.Dialect.insert(q)
isInsert = true
if err == nil {
affected = 1
}
}
if err == nil {
structValue := reflect.Indirect(reflect.ValueOf(structPtr))
if _, ok := model.pk.value.(int64); ok && id != 0 {
idField := structValue.FieldByName(model.pk.camelName)
idField.SetInt(id)
}
if updateModelField != nil {
updateField := structValue.FieldByName(updateModelField.camelName)
updateField.Set(reflect.ValueOf(now))
}
if isInsert {
if createdModelField != nil {
createdField := structValue.FieldByName(createdModelField.camelName)
createdField.Set(reflect.ValueOf(now))
}
}
}
return affected, q.updateTxError(err)
}
func (q *Qbs) BulkInsert(sliceOfStructPtr interface{}) error {
defer q.Reset()
var err error
if q.tx == nil {
q.Begin()
defer func() {
if err != nil {
q.Rollback()
} else {
q.Commit()
}
}()
}
sliceValue := reflect.ValueOf(sliceOfStructPtr)
for i := 0; i < sliceValue.Len(); i++ {
structPtr := sliceValue.Index(i)
structPtrInter := structPtr.Interface()
if v, ok := structPtrInter.(Validator); ok {
err = v.Validate(q)
if err != nil {
return q.updateTxError(err)
}
}
model := structPtrToModel(structPtrInter, false, nil)
if model.pk == nil {
panic("no primary key field")
}
q.criteria.model = model
var id int64
id, err = q.Dialect.insert(q)
if err != nil {
return q.updateTxError(err)
}
if _, ok := model.pk.value.(int64); ok && id != 0 {
idField := structPtr.Elem().FieldByName(model.pk.camelName)
idField.SetInt(id)
}
}
return nil
}
// If the struct type implements Validator interface, values will be validated before update.
// In order to avoid inadvertently update the struct field to zero value, it is better to define a
// temporary struct in function, only define the fields that should be updated.
// But the temporary struct can not implement Validator interface, we have to validate values manually.
// The update condition can be inferred by the Id value of the struct.
// If neither Id value or condition are provided, it would cause runtime panic
func (q *Qbs) Update(structPtr interface{}) (affected int64, err error) {
if v, ok := structPtr.(Validator); ok {
err := v.Validate(q)
if err != nil {
return 0, err
}
}
model := structPtrToModel(structPtr, true, q.criteria.omitFields)
q.criteria.model = model
q.criteria.mergePkCondition(q.Dialect)
if q.criteria.condition == nil {
panic("Can not update without condition")
}
return q.Dialect.update(q)
}
// The delete condition can be inferred by the Id value of the struct
// If neither Id value or condition are provided, it would cause runtime panic
func (q *Qbs) Delete(structPtr interface{}) (affected int64, err error) {
model := structPtrToModel(structPtr, true, q.criteria.omitFields)
q.criteria.model = model
q.criteria.mergePkCondition(q.Dialect)
if q.criteria.condition == nil {
panic("Can not delete without condition")
}
return q.Dialect.delete(q)
}
// This method can be used to validate unique column before trying to save
// The table parameter can be either a string or a struct pointer
func (q *Qbs) ContainsValue(table interface{}, column string, value interface{}) bool {
quotedColumn := q.Dialect.quote(column)
quotedTable := q.Dialect.quote(tableName(table))
query := fmt.Sprintf("SELECT %v FROM %v WHERE %v = ?", quotedColumn, quotedTable, quotedColumn)
row := q.QueryRow(query, value)
var result interface{}
err := row.Scan(&result)
q.updateTxError(err)
return err == nil
}
// If the connection pool is not full, the Db will be sent back into the pool, otherwise the Db will get closed.
func (q *Qbs) Close() error {
if connectionLimit != nil {
<-connectionLimit
}
if q.tx != nil {
return q.Rollback()
}
return nil
}
//Query the count of rows in a table the talbe parameter can be either a string or struct pointer.
//If condition is given, the count will be the count of rows meet that condition.
func (q *Qbs) Count(table interface{}) int64 {
quotedTable := q.Dialect.quote(tableName(table))
query := "SELECT COUNT(*) FROM " + quotedTable
var row *sql.Row
if q.criteria.condition != nil {
conditionSql, args := q.criteria.condition.Merge()
query += " WHERE " + conditionSql
row = q.QueryRow(query, args...)
} else {
row = q.QueryRow(query)
}
var count int64
err := row.Scan(&count)
if err == sql.ErrNoRows {
return 0
} else if err != nil {
q.updateTxError(err)
}
return count
}
//Query raw sql and return a map.
func (q *Qbs) QueryMap(query string, args ...interface{}) (map[string]interface{}, error) {
mapSlice, err := q.doQueryMap(query, true, args...)
if len(mapSlice) == 1 {
return mapSlice[0], err
}
return nil, sql.ErrNoRows
}
//Query raw sql and return a slice of map..
func (q *Qbs) QueryMapSlice(query string, args ...interface{}) ([]map[string]interface{}, error) {
return q.doQueryMap(query, false, args...)
}
func (q *Qbs) doQueryMap(query string, once bool, args ...interface{}) ([]map[string]interface{}, error) {
query = q.Dialect.substituteMarkers(query)
stmt, err := q.prepare(query)
if err != nil {
return nil, q.updateTxError(err)
}
defer stmt.Close()
rows, err := stmt.Query(args...)
if err != nil {
return nil, q.updateTxError(err)
}
defer rows.Close()
var results []map[string]interface{}
columns, _ := rows.Columns()
containers := make([]interface{}, len(columns))
for i := 0; i < len(columns); i++ {
var container interface{}
containers[i] = &container
}
for rows.Next() {
if err := rows.Scan(containers...); err != nil {
return nil, q.updateTxError(err)
}
result := make(map[string]interface{}, len(columns))
for i, key := range columns {
if containers[i] == nil {
continue
}
value := reflect.Indirect(reflect.ValueOf(containers[i]))
if value.Elem().Kind() == reflect.Slice {
result[key] = string(value.Interface().([]byte))
} else {
result[key] = value.Interface()
}
}
results = append(results, result)
if once {
return results, nil
}
}
return results, nil
}
//Do a raw sql query and set the result values in dest parameter.
//The dest parameter can be either a struct pointer or a pointer of struct pointer.slice
//This method do not support pointer field in the struct.
func (q *Qbs) QueryStruct(dest interface{}, query string, args ...interface{}) error {
query = q.Dialect.substituteMarkers(query)
stmt, err := q.prepare(query)
if err != nil {
return q.updateTxError(err)
}
defer stmt.Close()
rows, err := stmt.Query(args...)
if err != nil {
return q.updateTxError(err)
}
defer rows.Close()
outPtr := reflect.ValueOf(dest)
outValue := outPtr.Elem()
var structType reflect.Type
var single bool
if outValue.Kind() == reflect.Slice {
structType = outValue.Type().Elem().Elem()
} else {
structType = outValue.Type()
single = true
}
columns, _ := rows.Columns()
fieldNames := make([]string, len(columns))
for i, v := range columns {
upper := snakeToUpperCamel(v)
_, ok := structType.FieldByName(upper)
if ok {
fieldNames[i] = upper
} else {
fieldNames[i] = "-"
}
}
for rows.Next() {
var rowStructPointer reflect.Value
if single { //query row
rowStructPointer = outPtr
} else { //query rows
rowStructPointer = reflect.New(structType)
}
dests := make([]interface{}, len(columns))
for i := 0; i < len(dests); i++ {
fieldName := fieldNames[i]
if fieldName == "-" {
var placeholder interface{}
dests[i] = &placeholder
} else {
field := rowStructPointer.Elem().FieldByName(fieldName)
dests[i] = field.Addr().Interface()
}
}
err = rows.Scan(dests...)
if err != nil {
return err
}
if single {
return nil
}
outValue.Set(reflect.Append(outValue, rowStructPointer))
}
return nil
}
//Iterate the rows, the first parameter is a struct pointer, the second parameter is a fucntion
//which will get called on each row, the in `do` function the structPtr's value will be set to the current row's value..
//if `do` function returns an error, the iteration will be stopped.
func (q *Qbs) Iterate(structPtr interface{}, do func() error) error {
q.criteria.model = structPtrToModel(structPtr, !q.criteria.omitJoin, q.criteria.omitFields)
query, args := q.Dialect.querySql(q.criteria)
q.log(query, args...)
defer q.Reset()
stmt, err := q.prepare(query)
if err != nil {
return q.updateTxError(err)
}
defer stmt.Close()
rows, err := stmt.Query(args...)
if err != nil {
return q.updateTxError(err)
}
rowValue := reflect.ValueOf(structPtr)
defer rows.Close()
for rows.Next() {
err = q.scanRows(rowValue, rows)
if err != nil {
return err
}
if err = do(); err != nil {
return err
}
}
return nil
}
func (q *Qbs) log(query string, args ...interface{}) {
if q.Log && queryLogger != nil {
queryLogger.Print(query)
queryLogger.Println(args...)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment