Skip to content

Instantly share code, notes, and snippets.

@icio
Created Jan 14, 2021
Embed
What would you like to do?
Hooking into Go SQL drivers for triggering race conditions
package repo_test
import (
"database/sql"
"database/sql/driver"
"testing"
"./repo"
)
func TestPersonRepo_Update_Conflict(t *testing.T) {
var realDB = connectToPostgres()
defer realDB.Close()
// These are the updates that we want to apply. Each update is going to
// happen in parallel and all will complete their reads before any start
// their writes.
pid := 1
updates := []repo.Person{
{ID: pid, Name: "Paul"},
{ID: pid, Age: 32,
// Adding more than the number of retries will result in some workers
// may start getting too many serialisation errors from postgres.
}
// read WaitGroup is used to block all workers until they've all read,
// when the write channel will be closed and all reads/writes are unblocked.
write := make(chan bool)
var read sync.WaitGroup
read.Add(len(updates))
go func() {
read.Wait()
close(write)
}()
// hookDB lets us synchronise with read and write after reading the account.
db := hookDB(realDB, sqlHooks{
QueryPost: func(query string, args []driver.Value, rows driver.Rows, err error) (driver.Rows, error) {
if strings.Contains(query, "FROM people") {
select {
case <-write:
// The synchronisation has already completed. Carry on.
default:
// Synchronise post-read/pre-write.
read.Done()
<-write
}
}
return rows, err
},
})
// Write the starting person.
r := repo.PersonRepo{DB: db}
err := r.Create(repo.Person{
ID: pid,
Age: 31,
})
if err != nil {
t.Fatalf("Creating account: %s", err)
}
// Have the workers perform the updates.
var wg sync.WaitGroup
wg.Add(len(updates))
for _, upd := range updates {
go func(upd repo.Person) {
defer wg.Done()
err := r.Update(upd)
if err != nil {
t.Errorf("PersonRepo.Update(%#v) returned error: %s", upd, err)
}
}(upd)
}
read.Wait() // Check synchronisation occurred.
wg.Wait() // Wait for the workers to complete.
if t.Failed() {
return
}
// Check that all updates were applied.
exp := repo.Person{ID: pid, Name: "Paul", Age: 32}
act, err := r.Read(pid)
if err != nil {
t.Fatalf("Reading final account: %s", err)
}
if diff := cmp.Diff(exp, act); diff != "" {
t.Fatalf("Expected account (-) but got (+):\n%s", diff)
}
}
type sqlHooks struct {
ExecPost func(query string, args []driver.Value, res driver.Result, err error) (driver.Result, error)
QueryPost func(query string, args []driver.Value, rows driver.Rows, err error) (driver.Rows, error)
}
// hookDB returns a *sql.DB that will call hooks.ExecPost for each INSERT/UPDATE
// query run, and hooks.QueryPost for each SELECT query run. This works by
// wrapping the database connections (driver.Conn) with our own sqlHookConn
// which in turn wraps the database statements (driver.Stmt) with our own
// sqlHookStmt which invokes the hooks.
//
// The flow looks something like this:
//
// var db *sql.DB = realDB
// db = hookDB(db, sqlHooks{})
// var tx *sql.Tx = db.BeginTx(...)
//
// This will now request a new database connection, which ends up chaining
// db.Conn() -> sqlHookConnector.Connect -> realDB.Connect -> sqlHookConn. We
// then try to perform tx.Exec(`UPDATE ...`) which ends up chaining
// sqlHookConn.Prepare -> sqlHookStmt -> sqlHookStmt.Exec -> sqlHooks.ExecPost.
//
// There are shortcut database/sql/driver interfaces such as Execer that we
// purposefully don't implement here to ensure that all queries are routed
// through sqlHookStmt and therefore not require us to implement the same logic
// in multiple places.
func hookDB(db *sql.DB, hooks sqlHooks) *sql.DB {
if hooks.ExecPost == nil {
hooks.ExecPost = func(query string, args []driver.Value, res driver.Result, err error) (driver.Result, error) {
return res, err
}
}
if hooks.QueryPost == nil {
hooks.QueryPost = func(query string, args []driver.Value, rows driver.Rows, err error) (driver.Rows, error) {
return rows, err
}
}
return sql.OpenDB(&sqlHookConnector{hooks, db})
}
type sqlHookConnector struct {
hooks sqlHooks
db *sql.DB
}
func (c *sqlHookConnector) Connect(ctx context.Context) (driver.Conn, error) {
sqlConn, err := c.db.Conn(ctx)
if err != nil {
return nil, err
}
var conn driverConnFull
err = sqlConn.Raw(func(driverConn interface{}) error {
// Apparently we're not supposed to do this.
conn = driverConn.(driverConnFull)
return nil
})
return &sqlHookConn{c.hooks, conn}, err
}
func (c *sqlHookConnector) Driver() driver.Driver {
return c.db.Driver()
}
type sqlHookConn struct {
hooks sqlHooks
driverConnFull
}
type driverConnFull interface {
driver.Conn
driver.ConnBeginTx
}
func (c *sqlHookConn) Prepare(query string) (driver.Stmt, error) {
stmt, err := c.driverConnFull.Prepare(query)
return &sqlHookStmt{c.hooks, query, stmt}, err
}
type sqlHookStmt struct {
hooks sqlHooks
query string
driver.Stmt
}
func (s *sqlHookStmt) Exec(args []driver.Value) (driver.Result, error) {
r, err := s.Stmt.Exec(args)
return s.hooks.ExecPost(s.query, args, r, err)
}
func (s *sqlHookStmt) Query(args []driver.Value) (driver.Rows, error) {
r, err := s.Stmt.Query(args)
return s.hooks.QueryPost(s.query, args, r, err)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment