Skip to content

Instantly share code, notes, and snippets.

@mmcken3
Last active June 3, 2020 14:28
Show Gist options
  • Save mmcken3/f0898dac288580052fe44af877850463 to your computer and use it in GitHub Desktop.
Save mmcken3/f0898dac288580052fe44af877850463 to your computer and use it in GitHub Desktop.
How to create a transaction function wrapper for use with sqlx in Go to work with your DB in transaction blocks.
package main
import (
"fmt"
"log"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
// postgres driver
_ "github.com/lib/pq"
)
// DB is a struct to hold our DB connection type.
type DB struct {
db *sqlx.DB
}
// This main func will demonstrate how to connect to the database
// and then run a transaction on the database.
func main() {
// Connect to the DB.
db, err := CreateDBConnection("mydbconnectionstring")
if err != nil {
log.Fatalf("error connecting to the db: %v", err)
}
// Run a transaction on the DB.
err = db.Transact(func(tx *sqlx.Tx) error {
// You add your sqlx functions here to do items
// in your DB transaction like this example query.
_, err = tx.Exec(`DROP TABLE example_table`)
return errors.Wrap(err, "error during the DB transaction")
})
// Check transaction error.
if err != nil {
log.Fatalf("error during transaction: %v", err)
}
log.Println("DB transaction successful")
return
}
// CreateDBConnection will connect to the DB at the passed connection string
// and return a pointer to the DB struct and the error if there is one.
func CreateDBConnection(conn string) (*DB, error) {
// This example uses a postgres driver, same things could be done
// for other DB drivers.
conn, err := sqlx.Open("postgres", connection)
if err != nil {
return nil, errors.Wrapf(err, "Error creating db")
}
db := DB{
DB: conn,
}
// Ping the DB to ensure connection was good
if err = db.Ping(); err != nil {
return nil, err
}
return &db, nil
}
// Transact provides a wrapper for database transactions.
func (db *DB) Transact(txFunc func(*sqlx.Tx) error) (err error) {
tx, err := db.Beginx()
if err != nil {
return
}
defer func() {
if r := recover(); r != nil {
switch r := r.(type) {
case error:
err = r
default:
err = fmt.Errorf("%s", r)
}
}
if err != nil {
log.Println("rolling back changes from db error")
tx.Rollback()
return
}
tx.Commit()
}()
err = txFunc(tx)
return err
}
// Close closes the DB connections.
func (db *DB) Close() error {
err := db.DB.Close()
if err != nil {
return errors.Wrap(err, "cannot close db")
}
return nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment