Skip to content

Instantly share code, notes, and snippets.

@stilyng94
Forked from pseudomuto/main_1.go
Created May 17, 2024 08:07
Show Gist options
  • Save stilyng94/5a180fed047a53c4451f2d2eece12918 to your computer and use it in GitHub Desktop.
Save stilyng94/5a180fed047a53c4451f2d2eece12918 to your computer and use it in GitHub Desktop.
Blog Code: Clean SQL Transactions in Golang
package main
import (
"database/sql"
"log"
)
func main() {
db, err := sql.Open("VENDOR_HERE", "YOUR_DSN_HERE")
handleError(err)
defer db.Close()
tx, err := db.Begin()
handleError(err)
// insert a record into table1
res, err := tx.Exec("INSERT INTO table1(name) VALUES(?)", "some name")
if err != nil {
tx.Rollback()
log.Fatal(err)
}
// fetch the auto incremented id
id, err := res.LastInsertId()
handleError(err)
// insert record into table2, referencing the first record from table1
res, err = tx.Exec("INSERT INTO table2(table1_id, name) VALUES(?, ?)", id, "other name")
if err != nil {
tx.Rollback()
log.Fatal(err)
}
// commit the transaction
handleError(tx.Commit())
log.Println("Done.")
}
func handleError(err error) {
if err != nil {
log.Fatal(err)
}
}
package main
import (
"database/sql"
"log"
)
func main() {
db, err := sql.Open("VENDOR_HERE", "YOUR_DSN_HERE")
handleError(err)
defer db.Close()
err = WithTransaction(db, func(tx Transaction) error {
// insert a record into table1
res, err := tx.Exec("INSERT INTO table1(name) VALUES(?)", "some name")
if err != nil {
return err
}
id, err := res.LastInsertId()
if err != nil {
return err
}
res, err = tx.Exec("INSERT INTO table2(table1_id, name) VALUES(?, ?)", id, "other name")
if err != nil {
return err
}
})
handleError(err)
log.Println("Done.")
}
func handleError(err error) {
if err != nil {
log.Fatal(err)
}
}
package main
import (
"database/sql"
"log"
)
func main() {
db, err := sql.Open("VENDOR_HERE", "YOUR_DSN_HERE")
handleError(err)
defer db.Close()
stmts := []*PipelineStmt{
NewPipelineStmt("INSERT INTO table1(name) VALUES(?)", "some name"),
NewPipelineStmt("INSERT INTO table2(table1_id, name) VALUES({LAST_INS_ID}, ?)", "other name"),
}
err = WithTransaction(db, func(tx Transaction) error {
_, err := RunPipeline(tx, stmts...)
return err
})
handleError(err)
log.Println("Done.")
}
func handleError(err error) {
if err != nil {
log.Fatal(err)
}
}
package main
import (
"database/sql"
"strconv"
"strings"
)
// A PipelineStmt is a simple wrapper for creating a statement consisting of
// a query and a set of arguments to be passed to that query.
type PipelineStmt struct {
query string
args []interface{}
}
func NewPipelineStmt(query string, args ...interface{}) *PipelineStmt {
return &PipelineStmt{query, args}
}
// Executes the statement within supplied transaction. The literal string `{LAST_INS_ID}`
// will be replaced with the supplied value to make chaining `PipelineStmt` objects together
// simple.
func (ps *PipelineStmt) Exec(tx Transaction, lastInsertId int64) (sql.Result, error) {
query := strings.Replace(ps.query, "{LAST_INS_ID}", strconv.Itoa(int(lastInsertId)), -1)
return tx.Exec(query, ps.args...)
}
// Runs the supplied statements within the transaction. If any statement fails, the transaction
// is rolled back, and the original error is returned.
//
// The `LastInsertId` from the previous statement will be passed to `Exec`. The zero-value (0) is
// used initially.
func RunPipeline(tx Transaction, stmts ...*PipelineStmt) (sql.Result, error) {
var res sql.Result
var err error
var lastInsId int64
for _, ps := range stmts {
res, err = ps.Exec(tx, lastInsId)
if err != nil {
return nil, err
}
lastInsId, err = res.LastInsertId()
if err != nil {
return nil, err
}
}
return res, nil
}
package main
import (
"database/sql"
)
// Transaction is an interface that models the standard transaction in
// `database/sql`.
//
// To ensure `TxFn` funcs cannot commit or rollback a transaction (which is
// handled by `WithTransaction`), those methods are not included here.
type Transaction interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
// A Txfn is a function that will be called with an initialized `Transaction` object
// that can be used for executing statements and queries against a database.
type TxFn func(Transaction) error
// WithTransaction creates a new transaction and handles rollback/commit based on the
// error object returned by the `TxFn`
func WithTransaction(db *sql.DB, fn TxFn) (err error) {
tx, err := db.Begin()
if err != nil {
return
}
defer func() {
if p := recover(); p != nil {
// a panic occurred, rollback and repanic
tx.Rollback()
panic(p)
} else if err != nil {
// something went wrong, rollback
tx.Rollback()
} else {
// all good, commit
err = tx.Commit()
}
}()
err = fn(tx)
return err
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment