Skip to content

Instantly share code, notes, and snippets.

@sharonjl
Created November 17, 2022 18:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sharonjl/9f7f2aa60b7ca73ba06231a724ad7d1f to your computer and use it in GitHub Desktop.
Save sharonjl/9f7f2aa60b7ca73ba06231a724ad7d1f to your computer and use it in GitHub Desktop.
postgresql queue with gorm
package pg
import (
"context"
"errors"
"fmt"
"time"
"github.com/google/uuid"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
var ErrNoJobs = errors.New("no jobs")
type Namer interface {
Name() string
}
type Job[T Namer] struct {
ID uuid.UUID `gorm:"primaryKey;type:uuid"`
Status int `gorm:"index;default:0"`
Tries int `gorm:"default:0"`
MaxTries int `gorm:"default:3"`
Params T `gorm:"params"`
FailReason string
CreatedAt time.Time `gorm:"default:current_timestamp"`
UpdatedAt time.Time `gorm:"default:current_timestamp"`
}
func (j Job[T]) TableName() string {
var t T
return "queue_" + t.Name()
}
type Processor[T Namer] interface {
Process(*gorm.DB, Job[T]) error
}
type Queue[T Namer] struct {
DB *gorm.DB
MaxTries int
}
func (q Queue[T]) Enqueue(ctx context.Context, params T) error {
job := Job[T]{
ID: uuid.New(),
MaxTries: q.MaxTries,
Params: params,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := q.DB.WithContext(ctx).Create(&job).Error; err != nil {
return fmt.Errorf("error pushing job to queue: %w", err)
}
return nil
}
func (q Queue[T]) Work(ctx context.Context, p Processor[T]) error {
if err := next[T](ctx, q.DB, p); err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNoJobs
}
return fmt.Errorf("error getting next job: %w", err)
}
return nil
}
func next[T Namer](ctx context.Context, db *gorm.DB, p Processor[T]) error {
err := db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
job, err := dequeue[T](tx)
if err != nil {
return fmt.Errorf("error getting next job: %w", err)
}
if err := p.Process(tx, *job); err != nil {
incTries(tx, job)
setStatus(tx, job, -1, err.Error())
return nil
}
incTries(tx, job)
setStatus(tx, job, 1, "")
return nil
})
if err != nil {
return fmt.Errorf("error processing job: %w", err)
}
return nil
}
func dequeue[T Namer](tx *gorm.DB) (*Job[T], error) {
var job Job[T]
err := tx.
Clauses(clause.Locking{Strength: "UPDATE", Options: "SKIP LOCKED"}).
Where("status = ?", 0).
Where("tries <= max_tries").
Order("created_at").
First(&job).
Limit(1).Error
if err != nil {
return nil, fmt.Errorf("error selecting job: %w", err)
}
return &job, nil
}
func incTries[T Namer](tx *gorm.DB, job *Job[T]) error {
err := tx.Model(job).Where("id", job.ID).Update("tries", gorm.Expr("tries + 1")).Error
if err != nil {
return fmt.Errorf("error updating job try count: %w", err)
}
return nil
}
func setStatus[T Namer](tx *gorm.DB, job *Job[T], status int, reason string) error {
q := tx.Model(job).Where("id", job.ID).Update("status", status)
if reason != "" {
q = q.Update("fail_reason", reason)
}
if err := q.Error; err != nil {
return fmt.Errorf("error updating job status: %w", err)
}
return nil
}
package pg
import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"os"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
type example struct {
Text string `json:"text"`
}
func (example) Name() string {
return "example"
}
func (ex *example) Scan(value interface{}) error {
bytes, ok := value.([]byte)
if !ok {
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value))
}
return json.Unmarshal(bytes, ex)
}
func (ex example) Value() (driver.Value, error) {
b, err := json.Marshal(ex)
return b, err
}
type exampleProcessor struct {
fn func(db *gorm.DB, j Job[example]) error
}
func (e exampleProcessor) Process(db *gorm.DB, j Job[example]) error {
return e.fn(db, j)
}
func TestQueue(t *testing.T) {
if os.Getenv("TEST_DISABLE_POSTGRES") != "" {
t.Skip("Skipping pg.Queue tests")
return
}
pool, err := dockertest.NewPool("")
if err != nil {
t.Fatalf("Could not connect to docker: %s", err)
}
pool.MaxWait = 120 * time.Second
res, err := pool.RunWithOptions(&dockertest.RunOptions{
Repository: "postgres",
Tag: "12.11",
Env: []string{
"POSTGRES_PASSWORD=secret",
"POSTGRES_USER=user",
"POSTGRES_DB=dbname",
"listen_addresses = '*'",
},
}, func(config *docker.HostConfig) {
config.AutoRemove = true
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
})
if err != nil {
t.Fatalf("Could not start resource: %s", err)
}
res.Expire(120)
var db *gorm.DB
err = pool.Retry(func() error {
dsn := fmt.Sprintf("postgres://user:secret@%s/dbname?sslmode=disable", res.GetHostPort("5432/tcp"))
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: &gormLogger{t: t, lvl: logger.Info},
})
if err != nil {
return err
}
return nil
})
if err != nil {
t.Fatalf("Could not connect to database: %s", err)
}
if err := db.AutoMigrate(&Job[example]{}); err != nil { // Create the queue table
t.Fatalf("Could not migrate database: %s", err)
}
t.Run("Process all jobs", func(t *testing.T) {
N := 100
W := 3
k := atomic.Int32{}
q := Queue[example]{DB: db, MaxTries: 3}
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(3*time.Second))
defer func() {
for {
if k.Load() == int32(N) {
t.Log("All jobs processed")
cancel()
return
}
}
}()
wg := sync.WaitGroup{}
wg.Add(N)
for i := 0; i < N; i++ { // Enqueue
go func(i int) {
defer wg.Done()
if err := q.Enqueue(ctx, example{Text: fmt.Sprintf("test_%d", i)}); err != nil {
t.Errorf("Could not enqueue job: %s", err)
}
}(i)
}
proc := exampleProcessor{
fn: func(db *gorm.DB, j Job[example]) error {
k.Add(1)
return nil
},
}
wg.Add(W)
for i := 0; i < W; i++ {
go func(w int) {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
default:
if err := q.Work(ctx, proc); err != nil {
if err == context.Canceled || err == context.DeadlineExceeded {
return
}
if err == ErrNoJobs {
time.Sleep(time.Second)
continue
}
t.Errorf("Worker %d: %s", w, err)
}
}
}
}(i)
}
wg.Wait()
if k.Load() != int32(N) {
t.Errorf("expected job processed %d, got %d", N, k.Load())
}
})
t.Cleanup(func() {
if err := pool.Purge(res); err != nil {
t.Errorf("could not purge resources: %v", err)
}
})
}
type gormLogger struct {
t *testing.T
lvl logger.LogLevel
}
func (tl gormLogger) LogMode(level logger.LogLevel) logger.Interface {
return &gormLogger{
t: tl.t,
lvl: level,
}
}
func (tl gormLogger) Info(ctx context.Context, s string, i ...interface{}) {
if tl.lvl < logger.Info {
return
}
tl.t.Logf("DB:INFO] "+s, i...)
}
func (tl gormLogger) Warn(ctx context.Context, s string, i ...interface{}) {
if tl.lvl < logger.Warn {
return
}
tl.t.Logf("DB:WARN] "+s, i...)
}
func (tl gormLogger) Error(ctx context.Context, s string, i ...interface{}) {
if tl.lvl < logger.Error {
return
}
tl.t.Logf("DB:ERROR] "+s, i...)
}
func (tl gormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
if tl.lvl < logger.Info {
return
}
sql, rowsAffected := fc()
tl.t.Logf("DB:TRACE] sql: %s, rowsAffected: %d, err: %v", sql, rowsAffected, err)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment