Skip to content

Instantly share code, notes, and snippets.

@adamfdl
Created January 28, 2020 08:13
Show Gist options
  • Save adamfdl/49da2b850991c59939da2426fa85c1ab to your computer and use it in GitHub Desktop.
Save adamfdl/49da2b850991c59939da2426fa85c1ab to your computer and use it in GitHub Desktop.
package sql
import (
"database/sql"
)
var (
ErrNoRows = sql.ErrNoRows
)
type DB struct {
DB *sql.DB
}
func NewSqlDB(sqlDB *sql.DB) Queryer {
return DB{
DB: sqlDB,
}
}
// =========================================================================
// Implement base Queryer
func (db DB) Exec(query string, args ...interface{}) (sql.Result, error) {
return db.DB.Exec(query, args...)
}
func (db DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
return db.DB.Query(query, args...)
}
func (db DB) QueryRow(query string, args ...interface{}) *sql.Row {
return db.DB.QueryRow(query, args...)
}
// =========================================================================
// Add support for transaction
func (db DB) Begin() (DBTx, error) {
tx, err := db.DB.Begin()
return DBTx{tx}, err
}
package sql
import "database/sql"
type DBTx struct {
DB *sql.Tx
}
// =========================================================================
// Implement base Queryer
func (sdb DBTx) Exec(query string, args ...interface{}) (sql.Result, error) {
return sdb.DB.Exec(query, args...)
}
func (sdb DBTx) Query(query string, args ...interface{}) (*sql.Rows, error) {
return sdb.DB.Query(query, args...)
}
func (sdb DBTx) QueryRow(query string, args ...interface{}) *sql.Row {
return sdb.DB.QueryRow(query, args...)
}
package sql
import "database/sql"
type Queryer interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
package repository
type transaction struct {
db sql.DB
}
// Create factory method here
type txEnabler interface {
TxEnable(sql.DBTx)
}
func (tx transaction) Wrap(txFn func() error, txEnablers ...txEnabler) error {
// Get sql.DBTx
dbTx, err := tx.db.Begin()
if err != nil {
return err
}
// Enable transaction for the passed repositories
for _, txEnabler := range txEnablers {
txEnabler.TxEnable(dbTx)
}
// Check for errors, if there is an error, rollback.
// If there is no error, commits.
// This block will be deffered first.
defer func() {
if p := recover(); p != nil {
dbTx.DB.Rollback()
panic(p)
} else if err != nil {
dbTx.DB.Rollback()
} else {
err = dbTx.DB.Commit()
}
}()
err = txFn()
return err
}
package repository
import (
"github.com/adamfdl/tester/cleanarch/database/sql"
"github.com/adamfdl/tester/cleanarch/domain"
)
type userRepo struct {
db sql.Queryer
}
func (ur userRepo) Create(u domain.User) error {
// This way the repository layer does not need to know
// about wether to use transaction or not
_, err := ur.db.Exec("INSERT INTO ...")
return err
}
func (ur userRepo) TxEnable(sqlTx sql.DBTx) {
ur.db = sqlTx
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment