Skip to content

Instantly share code, notes, and snippets.

@maratori
Created July 29, 2022 20:43
Show Gist options
  • Save maratori/812049324580487f562027de0162d503 to your computer and use it in GitHub Desktop.
Save maratori/812049324580487f562027de0162d503 to your computer and use it in GitHub Desktop.
db.InTransaction
package db
import (
"context"
"database/sql"
"sync"
"github.com/pkg/errors"
"go.uber.org/multierr"
)
// ForTx is a separate instance of *sql.DB to be used only for transactions.
// We need to use two different connection pools (*sql.DB) to avoid deadlock.
//
// We had a fire in another service when the same pool had been used
// for transactions and for queries outside of transaction.
//
// Code example:
//
// db := new(sql.DB)
// db.SetMaxOpenConns(1)
// tx, _ := db.BeginTx(ctx, nil) // take a connection from pool
// tx.QueryRowContext(ctx, "select 1") // ok
// db.QueryRowContext(ctx, "select 1") // deadlock - waiting for a free connection
// tx.Commit()
//
// In some cases we call db.Query (not tx.Query) by design,
// in some cases it's a bug, in some cases it doesn't matter.
// Using ForTx prevents fire in all cases.
type ForTx interface {
DBForTx() // DBForTx is used only for a compiled time check.
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}
// Tx wraps sql.Tx.
type Tx interface {
// AfterCommit stores function that will be called after successful commit of the transaction.
// If AfterCommit is called several times, all functions will be called sequentially.
AfterCommit(fn func())
// Tx returns underlying sql.Tx.
// It's necessary only because some libraries (ex. sqlc) need access to sql.Tx.
Tx() *sql.Tx
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
}
func InTransaction(
ctx context.Context,
dbForTx ForTx,
do func(context.Context, Tx) error,
) (err error) {
sqlTx, err := dbForTx.BeginTx(ctx, nil)
if err != nil {
return errors.Wrap(err, "begin transaction")
}
// Need to rollback if panic or any error (even from sqlTx.Commit, because it doesn't guarantee that tx is aborted).
// Also, we don't want to rollback on happy path (even if it can be done with ignoring sql.ErrTxDone).
needRollback := true
defer func() {
if needRollback {
err = multierr.Append(err, errors.Wrap(sqlTx.Rollback(), "rollback"))
}
}()
tx := &xTx{tx: sqlTx}
err = do(ctx, tx)
if err != nil {
return errors.Wrap(err, "in transaction")
}
err = sqlTx.Commit()
if err != nil {
return errors.Wrap(err, "commit")
}
needRollback = false
tx.mu.Lock()
defer tx.mu.Unlock()
for _, fn := range tx.afterCommit {
fn()
}
return nil
}
type xTx struct {
tx *sql.Tx
mu sync.Mutex
afterCommit []func()
}
func (x *xTx) AfterCommit(fn func()) {
x.mu.Lock()
x.afterCommit = append(x.afterCommit, fn)
x.mu.Unlock()
}
func (x *xTx) Tx() *sql.Tx {
return x.tx
}
func (x *xTx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
return x.tx.ExecContext(ctx, query, args...)
}
func (x *xTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
return x.tx.PrepareContext(ctx, query)
}
func (x *xTx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
return x.tx.QueryContext(ctx, query, args...)
}
func (x *xTx) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
return x.tx.QueryRowContext(ctx, query, args...)
}
func (x *xTx) StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt {
return x.tx.StmtContext(ctx, stmt)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment