Skip to content

Instantly share code, notes, and snippets.

Created Jan 14, 2021
What would you like to do?
Hooking into Go SQL drivers for triggering race conditions
package repo_test
import (
func TestPersonRepo_Update_Conflict(t *testing.T) {
var realDB = connectToPostgres()
defer realDB.Close()
// These are the updates that we want to apply. Each update is going to
// happen in parallel and all will complete their reads before any start
// their writes.
pid := 1
updates := []repo.Person{
{ID: pid, Name: "Paul"},
{ID: pid, Age: 32,
// Adding more than the number of retries will result in some workers
// may start getting too many serialisation errors from postgres.
// read WaitGroup is used to block all workers until they've all read,
// when the write channel will be closed and all reads/writes are unblocked.
write := make(chan bool)
var read sync.WaitGroup
go func() {
// hookDB lets us synchronise with read and write after reading the account.
db := hookDB(realDB, sqlHooks{
QueryPost: func(query string, args []driver.Value, rows driver.Rows, err error) (driver.Rows, error) {
if strings.Contains(query, "FROM people") {
select {
case <-write:
// The synchronisation has already completed. Carry on.
// Synchronise post-read/pre-write.
return rows, err
// Write the starting person.
r := repo.PersonRepo{DB: db}
err := r.Create(repo.Person{
ID: pid,
Age: 31,
if err != nil {
t.Fatalf("Creating account: %s", err)
// Have the workers perform the updates.
var wg sync.WaitGroup
for _, upd := range updates {
go func(upd repo.Person) {
defer wg.Done()
err := r.Update(upd)
if err != nil {
t.Errorf("PersonRepo.Update(%#v) returned error: %s", upd, err)
read.Wait() // Check synchronisation occurred.
wg.Wait() // Wait for the workers to complete.
if t.Failed() {
// Check that all updates were applied.
exp := repo.Person{ID: pid, Name: "Paul", Age: 32}
act, err := r.Read(pid)
if err != nil {
t.Fatalf("Reading final account: %s", err)
if diff := cmp.Diff(exp, act); diff != "" {
t.Fatalf("Expected account (-) but got (+):\n%s", diff)
type sqlHooks struct {
ExecPost func(query string, args []driver.Value, res driver.Result, err error) (driver.Result, error)
QueryPost func(query string, args []driver.Value, rows driver.Rows, err error) (driver.Rows, error)
// hookDB returns a *sql.DB that will call hooks.ExecPost for each INSERT/UPDATE
// query run, and hooks.QueryPost for each SELECT query run. This works by
// wrapping the database connections (driver.Conn) with our own sqlHookConn
// which in turn wraps the database statements (driver.Stmt) with our own
// sqlHookStmt which invokes the hooks.
// The flow looks something like this:
// var db *sql.DB = realDB
// db = hookDB(db, sqlHooks{})
// var tx *sql.Tx = db.BeginTx(...)
// This will now request a new database connection, which ends up chaining
// db.Conn() -> sqlHookConnector.Connect -> realDB.Connect -> sqlHookConn. We
// then try to perform tx.Exec(`UPDATE ...`) which ends up chaining
// sqlHookConn.Prepare -> sqlHookStmt -> sqlHookStmt.Exec -> sqlHooks.ExecPost.
// There are shortcut database/sql/driver interfaces such as Execer that we
// purposefully don't implement here to ensure that all queries are routed
// through sqlHookStmt and therefore not require us to implement the same logic
// in multiple places.
func hookDB(db *sql.DB, hooks sqlHooks) *sql.DB {
if hooks.ExecPost == nil {
hooks.ExecPost = func(query string, args []driver.Value, res driver.Result, err error) (driver.Result, error) {
return res, err
if hooks.QueryPost == nil {
hooks.QueryPost = func(query string, args []driver.Value, rows driver.Rows, err error) (driver.Rows, error) {
return rows, err
return sql.OpenDB(&sqlHookConnector{hooks, db})
type sqlHookConnector struct {
hooks sqlHooks
db *sql.DB
func (c *sqlHookConnector) Connect(ctx context.Context) (driver.Conn, error) {
sqlConn, err := c.db.Conn(ctx)
if err != nil {
return nil, err
var conn driverConnFull
err = sqlConn.Raw(func(driverConn interface{}) error {
// Apparently we're not supposed to do this.
conn = driverConn.(driverConnFull)
return nil
return &sqlHookConn{c.hooks, conn}, err
func (c *sqlHookConnector) Driver() driver.Driver {
return c.db.Driver()
type sqlHookConn struct {
hooks sqlHooks
type driverConnFull interface {
func (c *sqlHookConn) Prepare(query string) (driver.Stmt, error) {
stmt, err := c.driverConnFull.Prepare(query)
return &sqlHookStmt{c.hooks, query, stmt}, err
type sqlHookStmt struct {
hooks sqlHooks
query string
func (s *sqlHookStmt) Exec(args []driver.Value) (driver.Result, error) {
r, err := s.Stmt.Exec(args)
return s.hooks.ExecPost(s.query, args, r, err)
func (s *sqlHookStmt) Query(args []driver.Value) (driver.Rows, error) {
r, err := s.Stmt.Query(args)
return s.hooks.QueryPost(s.query, args, r, err)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment