Skip to content

Instantly share code, notes, and snippets.

@lukewatts
Last active February 24, 2024 17:02
Show Gist options
  • Save lukewatts/15e6c74b4c7f6427a02e2cc10707bc5c to your computer and use it in GitHub Desktop.
Save lukewatts/15e6c74b4c7f6427a02e2cc10707bc5c to your computer and use it in GitHub Desktop.
Test MySQL SELECT/UPDATE query sequence for isolation issues (e.g. race condition scenarios, concurrent read/write, phantom reads etc)
CREATE TABLE `users` (
`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT,
`name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL,
`email` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL,
`email_verified_at` timestamp NULL DEFAULT NULL,
`password` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL,
`balance` decimal(8,2) NOT NULL DEFAULT '100.00',
`remember_token` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL,
`created_at` timestamp NULL DEFAULT NULL,
`updated_at` timestamp NULL DEFAULT NULL,
PRIMARY KEY (`id`),
UNIQUE KEY `users_email_unique` (`email`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
package main
import (
"database/sql"
"fmt"
"sync"
_ "github.com/go-sql-driver/mysql"
)
const selectQuery = `SELECT balance
FROM users
WHERE
id = 1
AND
balance >= 100`
const updateQuery = `UPDATE users
SET balance = balance - 100
WHERE id = 1`
func connectDatabase() *sql.DB {
db, err := sql.Open("mysql", "root@tcp(localhost)/sql_race_tests")
if err != nil {
fmt.Println("Error opening database: ", err)
}
fmt.Println("Main: Opened database connection")
return db
}
func selectUser(db *sql.DB, usersBalance chan float32) {
fmt.Println("Selecting balance from users table where id = 1 and balance >= 100")
rows, err := db.Query(selectQuery)
if err != nil {
fmt.Println("Error selecting balance: ", err)
return
}
defer rows.Close()
rows.Next()
var balance float32
rows.Scan(&balance)
fmt.Println("Balance: ", balance)
usersBalance <- balance
}
func updateUser(usersBalance chan float32, numSuccessfulUpdates int64, db *sql.DB) int64 {
defer func() {
if r := recover(); r != nil {
fmt.Println("Recovered from panic: ", r)
}
}()
if <-usersBalance < 100 {
fmt.Println("Balance is less than 100...not updating balance")
return numSuccessfulUpdates
}
fmt.Println("Balance is >= 100...updating balance")
res, err := db.Exec(updateQuery)
if err != nil {
fmt.Println("Error updating balance: ", err)
return numSuccessfulUpdates
}
fmt.Println("Balance updated")
rowsAffected, err := res.RowsAffected()
if err != nil {
fmt.Println("Error getting rows affected: ", err)
return numSuccessfulUpdates
}
fmt.Println("Rows affected: ", rowsAffected)
return rowsAffected + numSuccessfulUpdates
}
func checkBalance(db *sql.DB) float32 {
fmt.Println("Checking final balance")
var finalBalance float32
err := db.QueryRow("SELECT balance FROM users WHERE id = 1").Scan(&finalBalance)
if err != nil {
fmt.Println("Error selecting balance: ", err)
}
return finalBalance
}
func MySQL() (float32, int64) {
// 2 workers, which will run the select and update queries in parallel
var workers int = 2
// userBalance channel will store the balance
// of the user between the select and update queries
usersBalance := make(chan float32, workers)
// Waitgroups will ensure the selects happen before the updates
// otherwise the order would not be guaranteed
var wg sync.WaitGroup
db := connectDatabase()
defer db.Close()
fmt.Println("Begin WaitGroups")
// First 2 waitgroup for the select queries
for i := 0; i < workers; i++ {
wg.Add(1)
// we will run 2 select queries in parallel using goroutines
go func(i int) {
fmt.Printf("WaitGroup 1 Start: Worker %d: Running selectUser\n", i)
// Once this go routine has completed is done this will let the next waitgroup know it can start
defer wg.Done()
selectUser(db, usersBalance)
fmt.Printf("WaitGroup 1 End: Worker %d: Running selectUser\n", i)
}(i)
}
var numSuccessfulUpdates int64 = 0
// Third and 4th Waitgroup for the update queries
for i := 0; i < workers; i++ {
wg.Add(1)
go func(i int) {
fmt.Printf("WaitGroup 2 Start: Worker %d: Running updateUser\n", i)
defer wg.Done()
numSuccessfulUpdates = updateUser(usersBalance, numSuccessfulUpdates, db) + numSuccessfulUpdates
fmt.Printf("WaitGroup 2 End: Worker %d: Running updateUser\n", i)
}(i)
}
// Wait for all waitgroups to complete
wg.Wait()
// Get final balance
var finalBalance float32 = 0
finalBalance = checkBalance(db)
return finalBalance, numSuccessfulUpdates
}
func main() {
finalBalance, numSuccessfulUpdates := MySQL()
fmt.Println("Main: Final balance: ", finalBalance)
fmt.Println("----- PRE-FLIGHT CHECKS -----")
// This prevents a false pass after a failed test, where the data was not reset
if numSuccessfulUpdates == 0 {
fmt.Println(">>> PRE-FLIGHT FAILED: No updates were successful. Please check your dataset and code for issues and try again.")
return
}
fmt.Println(">>> PRE-FLIGHT PASSED: The dataset has been reset and the code is ready to be tested.")
fmt.Println("----- TEST RESULTS -----")
if finalBalance == -100 && numSuccessfulUpdates == 2 {
fmt.Println(">>> FAIL: This SELECT and UPDATE sequence are vulnerable to an ACIDRain Attack! Please use FOR UPDATE or FOR SHARE on the SELECT to fix this issue!")
} else {
fmt.Println(">>> PASS: This SELECT and UPDATE are not vulnerable to an ACIDRain attack.")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment