Skip to content

Instantly share code, notes, and snippets.

@pseudomuto
Last active May 23, 2024 17:03
Show Gist options
  • Save pseudomuto/0900a7a3605470760579752fcf0fc2b7 to your computer and use it in GitHub Desktop.
Save pseudomuto/0900a7a3605470760579752fcf0fc2b7 to your computer and use it in GitHub Desktop.
Blog Code: Clean SQL Transactions in Golang
package main
import (
"database/sql"
"log"
)
func main() {
db, err := sql.Open("VENDOR_HERE", "YOUR_DSN_HERE")
handleError(err)
defer db.Close()
tx, err := db.Begin()
handleError(err)
// insert a record into table1
res, err := tx.Exec("INSERT INTO table1(name) VALUES(?)", "some name")
if err != nil {
tx.Rollback()
log.Fatal(err)
}
// fetch the auto incremented id
id, err := res.LastInsertId()
handleError(err)
// insert record into table2, referencing the first record from table1
res, err = tx.Exec("INSERT INTO table2(table1_id, name) VALUES(?, ?)", id, "other name")
if err != nil {
tx.Rollback()
log.Fatal(err)
}
// commit the transaction
handleError(tx.Commit())
log.Println("Done.")
}
func handleError(err error) {
if err != nil {
log.Fatal(err)
}
}
package main
import (
"database/sql"
"log"
)
func main() {
db, err := sql.Open("VENDOR_HERE", "YOUR_DSN_HERE")
handleError(err)
defer db.Close()
err = WithTransaction(db, func(tx Transaction) error {
// insert a record into table1
res, err := tx.Exec("INSERT INTO table1(name) VALUES(?)", "some name")
if err != nil {
return err
}
id, err := res.LastInsertId()
if err != nil {
return err
}
res, err = tx.Exec("INSERT INTO table2(table1_id, name) VALUES(?, ?)", id, "other name")
if err != nil {
return err
}
})
handleError(err)
log.Println("Done.")
}
func handleError(err error) {
if err != nil {
log.Fatal(err)
}
}
package main
import (
"database/sql"
"log"
)
func main() {
db, err := sql.Open("VENDOR_HERE", "YOUR_DSN_HERE")
handleError(err)
defer db.Close()
stmts := []*PipelineStmt{
NewPipelineStmt("INSERT INTO table1(name) VALUES(?)", "some name"),
NewPipelineStmt("INSERT INTO table2(table1_id, name) VALUES({LAST_INS_ID}, ?)", "other name"),
}
err = WithTransaction(db, func(tx Transaction) error {
_, err := RunPipeline(tx, stmts...)
return err
})
handleError(err)
log.Println("Done.")
}
func handleError(err error) {
if err != nil {
log.Fatal(err)
}
}
package main
import (
"database/sql"
"strconv"
"strings"
)
// A PipelineStmt is a simple wrapper for creating a statement consisting of
// a query and a set of arguments to be passed to that query.
type PipelineStmt struct {
query string
args []interface{}
}
func NewPipelineStmt(query string, args ...interface{}) *PipelineStmt {
return &PipelineStmt{query, args}
}
// Executes the statement within supplied transaction. The literal string `{LAST_INS_ID}`
// will be replaced with the supplied value to make chaining `PipelineStmt` objects together
// simple.
func (ps *PipelineStmt) Exec(tx Transaction, lastInsertId int64) (sql.Result, error) {
query := strings.Replace(ps.query, "{LAST_INS_ID}", strconv.Itoa(int(lastInsertId)), -1)
return tx.Exec(query, ps.args...)
}
// Runs the supplied statements within the transaction. If any statement fails, the transaction
// is rolled back, and the original error is returned.
//
// The `LastInsertId` from the previous statement will be passed to `Exec`. The zero-value (0) is
// used initially.
func RunPipeline(tx Transaction, stmts ...*PipelineStmt) (sql.Result, error) {
var res sql.Result
var err error
var lastInsId int64
for _, ps := range stmts {
res, err = ps.Exec(tx, lastInsId)
if err != nil {
return nil, err
}
lastInsId, err = res.LastInsertId()
if err != nil {
return nil, err
}
}
return res, nil
}
package main
import (
"database/sql"
)
// Transaction is an interface that models the standard transaction in
// `database/sql`.
//
// To ensure `TxFn` funcs cannot commit or rollback a transaction (which is
// handled by `WithTransaction`), those methods are not included here.
type Transaction interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
// A Txfn is a function that will be called with an initialized `Transaction` object
// that can be used for executing statements and queries against a database.
type TxFn func(Transaction) error
// WithTransaction creates a new transaction and handles rollback/commit based on the
// error object returned by the `TxFn`
func WithTransaction(db *sql.DB, fn TxFn) (err error) {
tx, err := db.Begin()
if err != nil {
return
}
defer func() {
if p := recover(); p != nil {
// a panic occurred, rollback and repanic
tx.Rollback()
panic(p)
} else if err != nil {
// something went wrong, rollback
tx.Rollback()
} else {
// all good, commit
err = tx.Commit()
}
}()
err = fn(tx)
return err
}
@pranayhere
Copy link

I want to be able to create transactions at the service/use case layer. The problem I'm facing is db *sql.DB required by func WithTransaction(db *sql.DB, fn TxFn) is not available in the service layer. Is there a way to get the transaction at the service layer?

@umardev500
Copy link

I want to be able to create transactions at the service/use case layer. The problem I'm facing is db *sql.DB required by func WithTransaction(db *sql.DB, fn TxFn) is not available in the service layer. Is there a way to get the transaction at the service layer?

package tx

import (
	"context"
	"database/sql"
	"noname/constant"

	"github.com/rs/zerolog/log"
)

type Queries interface {
	Exec(query string, args ...interface{}) (sql.Result, error)
	Prepare(query string) (*sql.Stmt, error)
	Query(query string, args ...interface{}) (*sql.Rows, error)
	QueryRow(query string, args ...interface{}) *sql.Row
	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
	PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}

type DB interface {
	Begin() (*sql.Tx, error)
	BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}

type Tr struct {
	db DB
}

func NewTransaction(db DB) *Tr {
	return &Tr{
		db: db,
	}
}

type TxFn func(context.Context) error

func (tr *Tr) WithTransaction(ctx context.Context, fn TxFn) (err error) {
	tx, err := tr.db.BeginTx(ctx, nil)
	if err != nil {
		return
	}

	defer func() {
		if err != nil {
			log.Info().Msg("Rollback")
			tx.Rollback()
		} else {
			tx.Commit()
		}
	}()

	ctx = context.WithValue(ctx, constant.KeyTx, tx)
	err = fn(ctx)

	return
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment