Skip to content

Instantly share code, notes, and snippets.

@acoshift
Last active May 5, 2018 04:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save acoshift/f8539b1d31e1cb4d70f27216094392bf to your computer and use it in GitHub Desktop.
Save acoshift/f8539b1d31e1cb4d70f27216094392bf to your computer and use it in GitHub Desktop.
db, tx in ctx
package main
import (
"context"
"database/sql"
"io"
"log"
"net/http"
)
type ctxKey int
const (
_ ctxKey = iota
ctxKeyQueryer
ctxKeyDB
)
type queryer interface {
QueryRow(string, ...interface{}) *sql.Row
}
type txController interface {
WithTx(context.Context, func(context.Context) error) error
}
func withDB(ctx context.Context, db *sql.DB) context.Context {
ctx = context.WithValue(ctx, ctxKeyDB, db)
ctx = context.WithValue(ctx, ctxKeyQueryer, db)
return ctx
}
type txCtrl struct{}
func (txCtrl) WithTx(ctx context.Context, f func(context.Context) error) error {
db := getDB(ctx)
// can add retry logic here
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
ctx = context.WithValue(ctx, ctxKeyQueryer, tx)
if err = f(ctx); err != nil {
return err
}
return tx.Commit()
}
func getQueryer(ctx context.Context) queryer {
return ctx.Value(ctxKeyQueryer).(queryer)
}
func getDB(ctx context.Context) *sql.DB {
return ctx.Value(ctxKeyQueryer).(*sql.DB)
}
func queryUsername(ctx context.Context, userID string) (username string, err error) {
q := getQueryer(ctx)
err = q.QueryRow("select username from users where id = $1", userID).Scan(&username)
return
}
type handlers struct {
txCtrl txController
}
func (c *handlers) h1(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
queryUsername(ctx, "id1")
io.WriteString(w, "ok")
}
func (c *handlers) h2(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
err := c.txCtrl.WithTx(ctx, func(ctx context.Context) error {
queryUsername(ctx, "id1")
queryUsername(ctx, "id2")
return nil
})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
io.WriteString(w, "ok")
}
func injectDB(db *sql.DB) func(h http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = withDB(ctx, db)
r = r.WithContext(ctx)
h.ServeHTTP(w, r)
})
}
}
func main() {
db, err := sql.Open("postgres", "postgresql://postgres@localhost?sslmode=disable")
if err != nil {
log.Fatal(err)
}
hs := handlers{txCtrl{}}
mux := http.NewServeMux()
mux.HandleFunc("/h1", hs.h1)
mux.HandleFunc("/h2", hs.h2)
h := injectDB(db)(mux)
http.ListenAndServe(":8080", h)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment