Created
July 29, 2022 20:43
-
-
Save maratori/812049324580487f562027de0162d503 to your computer and use it in GitHub Desktop.
db.InTransaction
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package db | |
import ( | |
"context" | |
"database/sql" | |
"sync" | |
"github.com/pkg/errors" | |
"go.uber.org/multierr" | |
) | |
// ForTx is a separate instance of *sql.DB to be used only for transactions. | |
// We need to use two different connection pools (*sql.DB) to avoid deadlock. | |
// | |
// We had a fire in another service when the same pool had been used | |
// for transactions and for queries outside of transaction. | |
// | |
// Code example: | |
// | |
// db := new(sql.DB) | |
// db.SetMaxOpenConns(1) | |
// tx, _ := db.BeginTx(ctx, nil) // take a connection from pool | |
// tx.QueryRowContext(ctx, "select 1") // ok | |
// db.QueryRowContext(ctx, "select 1") // deadlock - waiting for a free connection | |
// tx.Commit() | |
// | |
// In some cases we call db.Query (not tx.Query) by design, | |
// in some cases it's a bug, in some cases it doesn't matter. | |
// Using ForTx prevents fire in all cases. | |
type ForTx interface { | |
DBForTx() // DBForTx is used only for a compiled time check. | |
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) | |
} | |
// Tx wraps sql.Tx. | |
type Tx interface { | |
// AfterCommit stores function that will be called after successful commit of the transaction. | |
// If AfterCommit is called several times, all functions will be called sequentially. | |
AfterCommit(fn func()) | |
// Tx returns underlying sql.Tx. | |
// It's necessary only because some libraries (ex. sqlc) need access to sql.Tx. | |
Tx() *sql.Tx | |
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) | |
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) | |
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) | |
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row | |
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt | |
} | |
func InTransaction( | |
ctx context.Context, | |
dbForTx ForTx, | |
do func(context.Context, Tx) error, | |
) (err error) { | |
sqlTx, err := dbForTx.BeginTx(ctx, nil) | |
if err != nil { | |
return errors.Wrap(err, "begin transaction") | |
} | |
// Need to rollback if panic or any error (even from sqlTx.Commit, because it doesn't guarantee that tx is aborted). | |
// Also, we don't want to rollback on happy path (even if it can be done with ignoring sql.ErrTxDone). | |
needRollback := true | |
defer func() { | |
if needRollback { | |
err = multierr.Append(err, errors.Wrap(sqlTx.Rollback(), "rollback")) | |
} | |
}() | |
tx := &xTx{tx: sqlTx} | |
err = do(ctx, tx) | |
if err != nil { | |
return errors.Wrap(err, "in transaction") | |
} | |
err = sqlTx.Commit() | |
if err != nil { | |
return errors.Wrap(err, "commit") | |
} | |
needRollback = false | |
tx.mu.Lock() | |
defer tx.mu.Unlock() | |
for _, fn := range tx.afterCommit { | |
fn() | |
} | |
return nil | |
} | |
type xTx struct { | |
tx *sql.Tx | |
mu sync.Mutex | |
afterCommit []func() | |
} | |
func (x *xTx) AfterCommit(fn func()) { | |
x.mu.Lock() | |
x.afterCommit = append(x.afterCommit, fn) | |
x.mu.Unlock() | |
} | |
func (x *xTx) Tx() *sql.Tx { | |
return x.tx | |
} | |
func (x *xTx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { | |
return x.tx.ExecContext(ctx, query, args...) | |
} | |
func (x *xTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { | |
return x.tx.PrepareContext(ctx, query) | |
} | |
func (x *xTx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { | |
return x.tx.QueryContext(ctx, query, args...) | |
} | |
func (x *xTx) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { | |
return x.tx.QueryRowContext(ctx, query, args...) | |
} | |
func (x *xTx) StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt { | |
return x.tx.StmtContext(ctx, stmt) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment