Skip to content

Instantly share code, notes, and snippets.

@acoshift
Created November 3, 2018 08:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save acoshift/aad4f3843f8cc7b146370cba4c76322b to your computer and use it in GitHub Desktop.
Save acoshift/aad4f3843f8cc7b146370cba4c76322b to your computer and use it in GitHub Desktop.
package sqldb
import (
"context"
"database/sql"
"net/http"
"github.com/acoshift/middleware"
"github.com/acoshift/pgsql"
)
type (
ctxKeyDB struct{}
ctxKeyQueryer struct{}
)
// Abort aborts tx
var Abort = pgsql.ErrAbortTx
// Middleware injects db into context
func Middleware(db *sql.DB) middleware.Middleware {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = context.WithValue(ctx, ctxKeyDB{}, db)
ctx = context.WithValue(ctx, ctxKeyQueryer{}, db)
r = r.WithContext(ctx)
h.ServeHTTP(w, r)
})
}
}
type queryer interface {
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
}
func q(ctx context.Context) queryer {
return ctx.Value(ctxKeyQueryer{}).(queryer)
}
// QueryRow runs query row
func QueryRow(ctx context.Context, query string, args ...interface{}) *sql.Row {
return q(ctx).QueryRowContext(ctx, query, args...)
}
// Query runs query
func Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
return q(ctx).QueryContext(ctx, query, args...)
}
// Exec runs exec
func Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return q(ctx).ExecContext(ctx, query, args...)
}
// RunInTx runs f in tx
func RunInTx(ctx context.Context, f func(context.Context) error) error {
if _, ok := ctx.Value(ctxKeyQueryer{}).(*sql.Tx); ok {
return f(ctx)
}
db := ctx.Value(ctxKeyDB{}).(*sql.DB)
return pgsql.RunInTxContext(ctx, db, nil, func(tx *sql.Tx) error {
ctx := context.WithValue(ctx, ctxKeyQueryer{}, tx)
return f(ctx)
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment