Skip to content

Instantly share code, notes, and snippets.

@nubbel
Last active April 27, 2016 13:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nubbel/171dd530fb5e9b86ac5a to your computer and use it in GitHub Desktop.
Save nubbel/171dd530fb5e9b86ac5a to your computer and use it in GitHub Desktop.
Go inTransaction usage example
package main
import (
"flag"
"log"
"database/sql"
_ "github.com/lib/pq"
)
func main() {
var email, name string
flag.StringVar(&email, "email", "nicky.nubbel@gmail.com", "The user's email")
flag.StringVar(&name, "name", "Dominique d'Argent", "The user's name")
flag.Parse()
db, err := sql.Open("postgres", "postgres://localhost:5432/go?sslmode=disable")
if err != nil {
log.Fatal(err)
}
defer db.Close()
updated, err := upsertUser(db, email, name)
if err != nil {
log.Fatal(err)
}
if updated {
log.Printf("User updated: %v, %v", email, name)
} else {
log.Printf("User inserted: %v, %v", email, name)
}
}
func upsertUser(db *sql.DB, email, name string) (bool, error) {
var exists bool
err := inTransaction(db, func(tx *sql.Tx) (err error) {
err = db.QueryRow(`
SELECT EXISTS(
SELECT 1
FROM users
WHERE email = $1
)`, email,
).Scan(&exists)
if err != nil {
return
}
if exists {
_, err = db.Exec(`
UPDATE users
SET name = $1
WHERE email = $2
`, name, email)
return
}
_, err = db.Exec(`
INSERT INTO users (email, name)
VALUES ($1, $2)
`, email, name)
return
})
return exists, err
}
func inTransaction(db *sql.DB, f func(*sql.Tx) error) (err error) {
tx, err := db.Begin()
if err != nil {
return
}
defer func() {
if err != nil {
tx.Rollback()
return
}
err = tx.Commit()
}()
err = f(tx)
return
}
@maxsz
Copy link

maxsz commented Apr 27, 2016

Improved version of inTransaction which allows nesting calls to inTransaction

func inTransaction(db DB, f func(*sqlx.Tx) error) (err error) {
    switch dbOrTx := db.(type) {
    case *sqlx.Tx:
        return f(dbOrTx)
    case *sqlx.DB:
        tx, beginErr := dbOrTx.Beginx()
        if beginErr != nil {
            err = beginErr
            return
        }
        defer func() {
            if err != nil {
                tx.Rollback()
                return
            }
            err = tx.Commit()
        }()
        err = f(tx)
    default:
        return errors.New("Invalid `db` type")
    }
    return
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment