Skip to content

Instantly share code, notes, and snippets.

@dhermes
Last active June 22, 2024 06:05
Show Gist options
  • Save dhermes/cc150b24cd156b6352259744a2a90645 to your computer and use it in GitHub Desktop.
Save dhermes/cc150b24cd156b6352259744a2a90645 to your computer and use it in GitHub Desktop.
[2024-06-21] Guts of `sql.Tx` <-> `pgx.Tx`

Guts of sql.Tx <-> pgx.Tx

Placeholder

module gist.github.com/dhermes/cc150b24cd156b6352259744a2a90645
go 1.22.4
require github.com/jackc/pgx/v5 v5.6.0
require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
golang.org/x/crypto v0.17.0 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/text v0.14.0 // indirect
)
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
package main
import (
"context"
"database/sql"
"errors"
"fmt"
"os"
"reflect"
"unsafe"
pgx "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
_ "github.com/jackc/pgx/v5/stdlib"
)
// wrapTx is vendored in from the pgx source:
// https://github.com/jackc/pgx/blob/v5.6.0/stdlib/sql.go#L874-L877
type wrapTx struct {
ctx context.Context
tx pgx.Tx
}
func initStdlibPool(ctx context.Context, connectionURL string) (*sql.DB, error) {
pool, err := sql.Open("pgx", connectionURL)
if err != nil {
return nil, err
}
err = pool.PingContext(ctx)
if err != nil {
return nil, err
}
return pool, nil
}
func finalizeStdlibTx(tx *sql.Tx, err error) error {
if tx == nil {
return err
}
rollbackErr := tx.Rollback()
if rollbackErr == nil || rollbackErr == sql.ErrTxDone {
return err
}
return errors.Join(err, rollbackErr)
}
func finalizeStdlibPool(pool *sql.DB, err error) error {
if pool == nil {
return err
}
closeErr := pool.Close()
return errors.Join(err, closeErr)
}
func finalizePgxPool(pool *pgxpool.Pool, err error) error {
if pool != nil {
pool.Close()
}
return err
}
func copyReflectPointer(v reflect.Value) reflect.Value {
// H/T: https://stackoverflow.com/a/43918797/1068170
return reflect.NewAt(v.Type(), unsafe.Pointer(v.UnsafeAddr())).Elem()
}
func copyReflectStruct(v reflect.Value) (reflect.Value, error) {
// H/T: https://stackoverflow.com/a/43918797/1068170
vt := v.Type()
v2 := reflect.New(vt).Elem()
if !v2.CanSet() {
return v2, fmt.Errorf("cannot set copy of struct value; (%s).%s", vt.PkgPath(), vt.Name())
}
v2.Set(v)
return v2, nil
}
func unsafeConvertWrapTx(wrapTxValue reflect.Value) (*wrapTx, error) {
wrapTxValue, err := copyReflectStruct(wrapTxValue)
if err != nil {
return nil, err
}
wrapTxType := wrapTxValue.Type()
if wrapTxType.PkgPath() != "github.com/jackc/pgx/v5/stdlib" || wrapTxType.Name() != "wrapTx" {
return nil, fmt.Errorf("unexpected type; (%s).%s", wrapTxType.PkgPath(), wrapTxType.Name())
}
if !wrapTxValue.CanAddr() {
return nil, errors.New("cannot address wrapTx")
}
p := unsafe.Pointer(wrapTxValue.UnsafeAddr())
wt := (*wrapTx)(p)
return wt, nil
}
func dissectSQLTx(tx *sql.Tx) (pgxTX pgx.Tx, err error) {
// First get a reflect Value for the underlying `sql.Tx` valie
txValue := reflect.ValueOf(tx).Elem()
// Then grab the unexported `txi`, ensure it's addressable and copy
// it onto a `Value` that we can interface with.
txiValue := txValue.FieldByName("txi")
if !txiValue.CanAddr() {
return nil, fmt.Errorf("cannot address txi; (%s).%s", txiValue.Type().PkgPath(), txiValue.Type().Name())
}
txiValue = copyReflectPointer(txiValue)
// Resolve the `driver.Tx` interface (`txi` field) to an actual underlying
// value.
if !txiValue.CanInterface() {
return nil, fmt.Errorf("cannot interface txi; (%s).%s", txiValue.Type().PkgPath(), txiValue.Type().Name())
}
wrapTxValue := reflect.ValueOf(txiValue.Interface())
wt, err := unsafeConvertWrapTx(wrapTxValue)
if err != nil {
return nil, err
}
return wt.tx, nil
}
func initPgxPool(ctx context.Context, connectionURL string) (*pgxpool.Pool, error) {
config, err := pgxpool.ParseConfig(connectionURL)
if err != nil {
return nil, err
}
config.ConnConfig.RuntimeParams["search_path"] = "tmp"
return pgxpool.NewWithConfig(ctx, config)
}
func showSearchPath(ctx context.Context, tx pgx.Tx, extra string) error {
row := tx.QueryRow(ctx, "SHOW search_path")
searchPath := ""
err := row.Scan(&searchPath)
if err != nil {
return err
}
fmt.Printf("search_path (%s): %s\n", extra, searchPath)
return nil
}
func setSearchPath(ctx context.Context, tx pgx.Tx) error {
_, err := tx.Exec(ctx, "SET search_path = 'tmp'")
return err
}
func run() (err error) {
var stdlibPool *sql.DB
var tx *sql.Tx
var pgxPool *pgxpool.Pool
defer func() {
err = finalizeStdlibTx(tx, err)
err = finalizeStdlibPool(stdlibPool, err)
err = finalizePgxPool(pgxPool, err)
}()
ctx := context.Background()
connectionURL, ok := os.LookupEnv("CONNECTION_URL")
if !ok {
return errors.New("missing CONNECTION_URL environment variable")
}
stdlibPool, err = initStdlibPool(ctx, connectionURL)
if err != nil {
return err
}
tx, err = stdlibPool.BeginTx(ctx, nil)
if err != nil {
return err
}
pgxTx, err := dissectSQLTx(tx)
if err != nil {
return err
}
pgxPool, err = initPgxPool(ctx, connectionURL)
if err != nil {
return err
}
err = showSearchPath(ctx, pgxTx, "BEFORE")
if err != nil {
return err
}
err = setSearchPath(ctx, pgxTx)
if err != nil {
return err
}
// NOTE: `setSearchPath()` is a **SMALL**; this is how e.g.
// `RuntimeParams` are set on a connection via `pgx`:
// - https://github.com/jackc/pgx/blob/v5.6.0/pgxpool/pool.go#L227
// - https://github.com/jackc/pgx/blob/v5.6.0/conn.go#L135
// - https://github.com/jackc/pgx/blob/v5.6.0/conn.go#L256-L259
// - https://github.com/jackc/pgx/blob/v5.6.0/pgconn/pgconn.go#L156
// - https://github.com/jackc/pgx/blob/v5.6.0/pgconn/pgconn.go#L261
// - https://github.com/jackc/pgx/blob/v5.6.0/pgconn/pgconn.go#L357-L359
err = showSearchPath(ctx, pgxTx, "AFTER")
if err != nil {
return err
}
return nil
}
func main() {
err := run()
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(1)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment