Skip to content

Instantly share code, notes, and snippets.

@pokutuna
Created March 11, 2019 06:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pokutuna/169764dfdf44f69ce13beac24e281493 to your computer and use it in GitHub Desktop.
Save pokutuna/169764dfdf44f69ce13beac24e281493 to your computer and use it in GitHub Desktop.
context with transaction
package main
import (
"context"
"fmt"
"time"
"github.com/gin-gonic/gin"
_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
)
func OpenDB() *sqlx.DB {
db, err := sqlx.Open("mysql", "readwrite:readwrite@tcp(localhost:3306)/test")
if err != nil {
panic(err)
}
db.Exec("DROP TABLE IF EXISTS test_go_context;")
db.Exec(`
CREATE TABLE test_go_context(
id bigint unsigned,
value text,
PRIMARY KEY(id)
);`)
return db
}
func TransactionWithTimeout(ctx context.Context, db *sqlx.DB, txFunc func(tx *sqlx.Tx) error) error {
tx, err := db.Beginx()
if err != nil {
return err
}
errCh := make(chan error, 1)
go func() {
err := txFunc(tx)
defer func() {
if p := recover(); p != nil {
fmt.Printf("panic: %+v\n", p)
tx.Rollback()
panic(p)
} else if err != nil {
fmt.Printf("rollback: %+v\n", err)
tx.Rollback()
} else {
fmt.Printf("commit\n")
err = tx.Commit()
}
}()
errCh <- err
}()
select {
case err := <-errCh:
if err != nil {
return err
}
case <-ctx.Done():
tx.Rollback()
<-errCh
return ctx.Err()
}
return nil
}
func slowTask(ctx context.Context, db *sqlx.DB) error {
return TransactionWithTimeout(ctx, db, func(tx *sqlx.Tx) error {
_, err := tx.Exec("INSERT INTO test_go_context SET id = UUID_SHORT(), value = 'slow1'")
if err != nil {
return err
}
_, err = tx.Exec("DO SLEEP(3)")
if err != nil {
return err
}
_, err = tx.Exec("INSERT INTO test_go_context SET id = UUID_SHORT(), value = 'slow2'")
if err != nil {
return err
}
return nil
})
}
func usualTask(ctx context.Context, db *sqlx.DB) error {
return TransactionWithTimeout(ctx, db, func(tx *sqlx.Tx) error {
_, err := tx.Exec("INSERT INTO test_go_context SET id = UUID_SHORT(), value = 'usual1'")
if err != nil {
return err
}
_, err = tx.Exec("DO SLEEP(1)")
if err != nil {
return err
}
_, err = tx.Exec("INSERT INTO test_go_context SET id = UUID_SHORT(), value = 'usual2'")
if err != nil {
return err
}
return nil
})
}
func main() {
r := gin.Default()
db := OpenDB()
r.GET("/usual", func(c *gin.Context) {
ctx, cancel := context.WithTimeout(c, 3*time.Second)
defer cancel()
err := usualTask(ctx, db)
if err != nil {
c.String(500, err.Error())
return
}
c.String(200, "ok")
})
r.GET("/timeout", func(c *gin.Context) {
ctx, cancel := context.WithTimeout(c, 3*time.Second)
defer cancel()
err := slowTask(ctx, db)
if err != nil {
c.String(500, err.Error())
return
}
c.String(200, "ok")
})
r.Run()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment