Skip to content

Instantly share code, notes, and snippets.

@lundha
Created November 2, 2023 08:41
Show Gist options
  • Save lundha/4dab4999dea4581cd5a53c921ed8258a to your computer and use it in GitHub Desktop.
Save lundha/4dab4999dea4581cd5a53c921ed8258a to your computer and use it in GitHub Desktop.
SELECT ... FOR UPDATE
package main
import (
"context"
"database/sql"
"fmt"
"log"
"sort"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
_ "github.com/lib/pq"
"golang.org/x/sync/errgroup"
)
const dbName = "selectforupdate"
func main() {
connStr := fmt.Sprintf("host=localhost port=6432 user=postgres password=postgres dbname=%s sslmode=disable", dbName)
db, err := sqlx.Connect("postgres", connStr)
if err != nil {
fmt.Printf("Error opening database: %q", err)
}
defer db.Close()
// Test the connection
err = db.Ping()
if err != nil {
log.Fatal(err)
}
fmt.Println("Connected to the PostgreSQL database!")
// err = createTable(db)
// if err != nil {
// log.Fatal(err)
// }
// err = insertUser(db.DB, "John Doe", []string{"A", "B", "C"})
// if err != nil {
// log.Fatal(err)
// }
err = cleanFavoriteLetters(db.DB)
if err != nil {
log.Fatal(err)
}
err = runInSerial(db)
if err != nil {
log.Fatal(err)
}
err = cleanFavoriteLetters(db.DB)
if err != nil {
log.Fatal(err)
}
withForUpdate := false
err = runInParallel(db, withForUpdate)
if err != nil {
log.Fatal(err)
}
err = cleanFavoriteLetters(db.DB)
if err != nil {
log.Fatal(err)
}
withForUpdate = true
err = runInParallel(db, withForUpdate)
if err != nil {
log.Fatal(err)
}
fmt.Println("Done!")
}
func runInParallel(db *sqlx.DB, withForUpdate bool) error {
wg, _ := errgroup.WithContext(context.Background())
// user 1
for _, letters := range listOfLists {
letters := letters
wg.Go(func() error {
return updateUser(db.DB, 1, withForUpdate, letters...)
})
}
// user 2
for _, letters := range listOfLists {
letters := letters
wg.Go(func() error {
return updateUser(db.DB, 2, withForUpdate, letters...)
})
}
err := wg.Wait()
if err != nil {
log.Fatal(err)
}
fmt.Println("[parallell] Done!")
letters1, err := readFavoriteLetters(db, 1)
if err != nil {
log.Fatal(err)
}
letters2, err := readFavoriteLetters(db, 2)
if err != nil {
log.Fatal(err)
}
ok := compareLetters(letters1, letters2)
fmt.Printf("[parallell - withForUpdate: %t] Letters for user 1 are equal user 2: %t\n", withForUpdate, ok)
return nil
}
func runInSerial(db *sqlx.DB) error {
// user 1
for _, letters := range listOfLists {
err := updateUser(db.DB, 1, false, letters...)
if err != nil {
log.Fatal(err)
}
}
// user 2
for _, letters := range listOfLists {
err := updateUser(db.DB, 2, false, letters...)
if err != nil {
log.Fatal(err)
}
}
fmt.Println("[serial] Done!")
letters1, err := readFavoriteLetters(db, 1)
if err != nil {
log.Fatal(err)
}
letters2, err := readFavoriteLetters(db, 2)
if err != nil {
log.Fatal(err)
}
ok := compareLetters(letters1, letters2)
fmt.Printf("[serial] Letters for user 1 are equal user 2: %t\n", ok)
return nil
}
func createTable(db *sqlx.DB) error {
createTableSQL := `
CREATE TABLE IF NOT EXISTS users (
id serial PRIMARY KEY,
name text,
favorite_letters text[]
)
`
_, err := db.Exec(createTableSQL)
if err != nil {
return fmt.Errorf("error creating table: %q", err)
}
fmt.Println("Table 'users' created successfully!")
return nil
}
func insertUser(db *sql.DB, name string, favoriteLetters []string) error {
insertSQL := `
INSERT INTO users (name, favorite_letters)
VALUES ($1, $2)
`
_, err := db.Exec(insertSQL, name, pq.Array(favoriteLetters))
if err != nil {
return fmt.Errorf("error inserting user: %q", err)
}
fmt.Printf("User '%s' created successfully!\n", name)
return err
}
func updateUser(db *sql.DB, userID int, withForUpdate bool, newLetters ...string) error {
tx, err := db.Begin()
if err != nil {
log.Fatal(err)
}
defer func() {
if err != nil {
tx.Rollback() // Rollback if an error occurs or commit at the end
}
tx.Commit()
}()
letters, err := readFavoriteLettersTx(tx, userID, withForUpdate)
if err != nil {
return fmt.Errorf("error reading favorite letters: %q", err)
}
lettersToInsert := append(letters, newLetters...)
err = updateFavoriteLetters(tx, userID, lettersToInsert)
if err != nil {
return fmt.Errorf("error updating favorite letters: %q", err)
}
return err
}
func readFavoriteLettersTx(tx *sql.Tx, userID int, withForUpdate bool) ([]string, error) {
var selectSQL string
if withForUpdate {
selectSQL = `
SELECT favorite_letters FROM users WHERE id = $1 FOR UPDATE
`
} else {
selectSQL = `
SELECT favorite_letters FROM users WHERE id = $1
`
}
rows, err := tx.Query(selectSQL, userID)
if err != nil {
return nil, err
}
defer rows.Close()
// Check if the user was found
if rows.Next() {
var favoriteLetters []string
err := rows.Scan(pq.Array(&favoriteLetters))
if err != nil {
return nil, err
}
return favoriteLetters, nil
}
return nil, fmt.Errorf("user not found")
}
func updateFavoriteLetters(tx *sql.Tx, userID int, newFavoriteLetters []string) error {
updateSQL := `
UPDATE users SET favorite_letters = $1 WHERE id = $2
`
_, err := tx.Exec(updateSQL, pq.Array(newFavoriteLetters), userID)
return err
}
func cleanFavoriteLetters(db *sql.DB) error {
cleanSQL := `
UPDATE users SET favorite_letters = '{}'
`
_, err := db.Exec(cleanSQL)
return err
}
func readFavoriteLetters(db *sqlx.DB, userID int) ([]string, error) {
selectSQL := `
SELECT favorite_letters FROM users WHERE id = $1
`
rows, err := db.Query(selectSQL, userID)
if err != nil {
return nil, err
}
defer rows.Close()
// Check if the user was found
if rows.Next() {
var favoriteLetters []string
err := rows.Scan(pq.Array(&favoriteLetters))
if err != nil {
return nil, err
}
return favoriteLetters, nil
}
return nil, fmt.Errorf("user not found")
}
func compareLetters(letters1, letters2 []string) bool {
sort.Strings(letters1)
sort.Strings(letters2)
// Check if the sorted slices are equal
return fmt.Sprintf("%v", letters1) == fmt.Sprintf("%v", letters2)
}
// update data
var listOfLists = [][]string{
{"X", "Y"},
{"A", "B", "C"},
{"D", "E"},
{"F", "G"},
{"H", "I", "J"},
{"K", "L"},
{"M", "N"},
{"O", "P"},
{"Q", "R", "S"},
{"T", "U"},
{"V", "W"},
{"X", "Y"},
{"A", "B", "C"},
{"D", "E"},
{"F", "G"},
{"H", "I", "J"},
{"K", "L"},
{"M", "N"},
{"O", "P"},
{"Q", "R", "S"},
{"T", "U"},
{"V", "W"},
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment