Last active
February 24, 2024 17:02
-
-
Save lukewatts/15e6c74b4c7f6427a02e2cc10707bc5c to your computer and use it in GitHub Desktop.
Test MySQL SELECT/UPDATE query sequence for isolation issues (e.g. race condition scenarios, concurrent read/write, phantom reads etc)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
CREATE TABLE `users` ( | |
`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT, | |
`name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL, | |
`email` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL, | |
`email_verified_at` timestamp NULL DEFAULT NULL, | |
`password` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL, | |
`balance` decimal(8,2) NOT NULL DEFAULT '100.00', | |
`remember_token` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL, | |
`created_at` timestamp NULL DEFAULT NULL, | |
`updated_at` timestamp NULL DEFAULT NULL, | |
PRIMARY KEY (`id`), | |
UNIQUE KEY `users_email_unique` (`email`) | |
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"database/sql" | |
"fmt" | |
"sync" | |
_ "github.com/go-sql-driver/mysql" | |
) | |
const selectQuery = `SELECT balance | |
FROM users | |
WHERE | |
id = 1 | |
AND | |
balance >= 100` | |
const updateQuery = `UPDATE users | |
SET balance = balance - 100 | |
WHERE id = 1` | |
func connectDatabase() *sql.DB { | |
db, err := sql.Open("mysql", "root@tcp(localhost)/sql_race_tests") | |
if err != nil { | |
fmt.Println("Error opening database: ", err) | |
} | |
fmt.Println("Main: Opened database connection") | |
return db | |
} | |
func selectUser(db *sql.DB, usersBalance chan float32) { | |
fmt.Println("Selecting balance from users table where id = 1 and balance >= 100") | |
rows, err := db.Query(selectQuery) | |
if err != nil { | |
fmt.Println("Error selecting balance: ", err) | |
return | |
} | |
defer rows.Close() | |
rows.Next() | |
var balance float32 | |
rows.Scan(&balance) | |
fmt.Println("Balance: ", balance) | |
usersBalance <- balance | |
} | |
func updateUser(usersBalance chan float32, numSuccessfulUpdates int64, db *sql.DB) int64 { | |
defer func() { | |
if r := recover(); r != nil { | |
fmt.Println("Recovered from panic: ", r) | |
} | |
}() | |
if <-usersBalance < 100 { | |
fmt.Println("Balance is less than 100...not updating balance") | |
return numSuccessfulUpdates | |
} | |
fmt.Println("Balance is >= 100...updating balance") | |
res, err := db.Exec(updateQuery) | |
if err != nil { | |
fmt.Println("Error updating balance: ", err) | |
return numSuccessfulUpdates | |
} | |
fmt.Println("Balance updated") | |
rowsAffected, err := res.RowsAffected() | |
if err != nil { | |
fmt.Println("Error getting rows affected: ", err) | |
return numSuccessfulUpdates | |
} | |
fmt.Println("Rows affected: ", rowsAffected) | |
return rowsAffected + numSuccessfulUpdates | |
} | |
func checkBalance(db *sql.DB) float32 { | |
fmt.Println("Checking final balance") | |
var finalBalance float32 | |
err := db.QueryRow("SELECT balance FROM users WHERE id = 1").Scan(&finalBalance) | |
if err != nil { | |
fmt.Println("Error selecting balance: ", err) | |
} | |
return finalBalance | |
} | |
func MySQL() (float32, int64) { | |
// 2 workers, which will run the select and update queries in parallel | |
var workers int = 2 | |
// userBalance channel will store the balance | |
// of the user between the select and update queries | |
usersBalance := make(chan float32, workers) | |
// Waitgroups will ensure the selects happen before the updates | |
// otherwise the order would not be guaranteed | |
var wg sync.WaitGroup | |
db := connectDatabase() | |
defer db.Close() | |
fmt.Println("Begin WaitGroups") | |
// First 2 waitgroup for the select queries | |
for i := 0; i < workers; i++ { | |
wg.Add(1) | |
// we will run 2 select queries in parallel using goroutines | |
go func(i int) { | |
fmt.Printf("WaitGroup 1 Start: Worker %d: Running selectUser\n", i) | |
// Once this go routine has completed is done this will let the next waitgroup know it can start | |
defer wg.Done() | |
selectUser(db, usersBalance) | |
fmt.Printf("WaitGroup 1 End: Worker %d: Running selectUser\n", i) | |
}(i) | |
} | |
var numSuccessfulUpdates int64 = 0 | |
// Third and 4th Waitgroup for the update queries | |
for i := 0; i < workers; i++ { | |
wg.Add(1) | |
go func(i int) { | |
fmt.Printf("WaitGroup 2 Start: Worker %d: Running updateUser\n", i) | |
defer wg.Done() | |
numSuccessfulUpdates = updateUser(usersBalance, numSuccessfulUpdates, db) + numSuccessfulUpdates | |
fmt.Printf("WaitGroup 2 End: Worker %d: Running updateUser\n", i) | |
}(i) | |
} | |
// Wait for all waitgroups to complete | |
wg.Wait() | |
// Get final balance | |
var finalBalance float32 = 0 | |
finalBalance = checkBalance(db) | |
return finalBalance, numSuccessfulUpdates | |
} | |
func main() { | |
finalBalance, numSuccessfulUpdates := MySQL() | |
fmt.Println("Main: Final balance: ", finalBalance) | |
fmt.Println("----- PRE-FLIGHT CHECKS -----") | |
// This prevents a false pass after a failed test, where the data was not reset | |
if numSuccessfulUpdates == 0 { | |
fmt.Println(">>> PRE-FLIGHT FAILED: No updates were successful. Please check your dataset and code for issues and try again.") | |
return | |
} | |
fmt.Println(">>> PRE-FLIGHT PASSED: The dataset has been reset and the code is ready to be tested.") | |
fmt.Println("----- TEST RESULTS -----") | |
if finalBalance == -100 && numSuccessfulUpdates == 2 { | |
fmt.Println(">>> FAIL: This SELECT and UPDATE sequence are vulnerable to an ACIDRain Attack! Please use FOR UPDATE or FOR SHARE on the SELECT to fix this issue!") | |
} else { | |
fmt.Println(">>> PASS: This SELECT and UPDATE are not vulnerable to an ACIDRain attack.") | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment