Created
November 17, 2022 18:25
-
-
Save sharonjl/9f7f2aa60b7ca73ba06231a724ad7d1f to your computer and use it in GitHub Desktop.
postgresql queue with gorm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package pg | |
import ( | |
"context" | |
"errors" | |
"fmt" | |
"time" | |
"github.com/google/uuid" | |
"gorm.io/gorm" | |
"gorm.io/gorm/clause" | |
) | |
var ErrNoJobs = errors.New("no jobs") | |
type Namer interface { | |
Name() string | |
} | |
type Job[T Namer] struct { | |
ID uuid.UUID `gorm:"primaryKey;type:uuid"` | |
Status int `gorm:"index;default:0"` | |
Tries int `gorm:"default:0"` | |
MaxTries int `gorm:"default:3"` | |
Params T `gorm:"params"` | |
FailReason string | |
CreatedAt time.Time `gorm:"default:current_timestamp"` | |
UpdatedAt time.Time `gorm:"default:current_timestamp"` | |
} | |
func (j Job[T]) TableName() string { | |
var t T | |
return "queue_" + t.Name() | |
} | |
type Processor[T Namer] interface { | |
Process(*gorm.DB, Job[T]) error | |
} | |
type Queue[T Namer] struct { | |
DB *gorm.DB | |
MaxTries int | |
} | |
func (q Queue[T]) Enqueue(ctx context.Context, params T) error { | |
job := Job[T]{ | |
ID: uuid.New(), | |
MaxTries: q.MaxTries, | |
Params: params, | |
CreatedAt: time.Now(), | |
UpdatedAt: time.Now(), | |
} | |
if err := q.DB.WithContext(ctx).Create(&job).Error; err != nil { | |
return fmt.Errorf("error pushing job to queue: %w", err) | |
} | |
return nil | |
} | |
func (q Queue[T]) Work(ctx context.Context, p Processor[T]) error { | |
if err := next[T](ctx, q.DB, p); err != nil { | |
if errors.Is(err, gorm.ErrRecordNotFound) { | |
return ErrNoJobs | |
} | |
return fmt.Errorf("error getting next job: %w", err) | |
} | |
return nil | |
} | |
func next[T Namer](ctx context.Context, db *gorm.DB, p Processor[T]) error { | |
err := db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { | |
job, err := dequeue[T](tx) | |
if err != nil { | |
return fmt.Errorf("error getting next job: %w", err) | |
} | |
if err := p.Process(tx, *job); err != nil { | |
incTries(tx, job) | |
setStatus(tx, job, -1, err.Error()) | |
return nil | |
} | |
incTries(tx, job) | |
setStatus(tx, job, 1, "") | |
return nil | |
}) | |
if err != nil { | |
return fmt.Errorf("error processing job: %w", err) | |
} | |
return nil | |
} | |
func dequeue[T Namer](tx *gorm.DB) (*Job[T], error) { | |
var job Job[T] | |
err := tx. | |
Clauses(clause.Locking{Strength: "UPDATE", Options: "SKIP LOCKED"}). | |
Where("status = ?", 0). | |
Where("tries <= max_tries"). | |
Order("created_at"). | |
First(&job). | |
Limit(1).Error | |
if err != nil { | |
return nil, fmt.Errorf("error selecting job: %w", err) | |
} | |
return &job, nil | |
} | |
func incTries[T Namer](tx *gorm.DB, job *Job[T]) error { | |
err := tx.Model(job).Where("id", job.ID).Update("tries", gorm.Expr("tries + 1")).Error | |
if err != nil { | |
return fmt.Errorf("error updating job try count: %w", err) | |
} | |
return nil | |
} | |
func setStatus[T Namer](tx *gorm.DB, job *Job[T], status int, reason string) error { | |
q := tx.Model(job).Where("id", job.ID).Update("status", status) | |
if reason != "" { | |
q = q.Update("fail_reason", reason) | |
} | |
if err := q.Error; err != nil { | |
return fmt.Errorf("error updating job status: %w", err) | |
} | |
return nil | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package pg | |
import ( | |
"context" | |
"database/sql/driver" | |
"encoding/json" | |
"errors" | |
"fmt" | |
"os" | |
"sync" | |
"sync/atomic" | |
"testing" | |
"time" | |
"github.com/ory/dockertest/v3" | |
"github.com/ory/dockertest/v3/docker" | |
"gorm.io/driver/postgres" | |
"gorm.io/gorm" | |
"gorm.io/gorm/logger" | |
) | |
type example struct { | |
Text string `json:"text"` | |
} | |
func (example) Name() string { | |
return "example" | |
} | |
func (ex *example) Scan(value interface{}) error { | |
bytes, ok := value.([]byte) | |
if !ok { | |
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value)) | |
} | |
return json.Unmarshal(bytes, ex) | |
} | |
func (ex example) Value() (driver.Value, error) { | |
b, err := json.Marshal(ex) | |
return b, err | |
} | |
type exampleProcessor struct { | |
fn func(db *gorm.DB, j Job[example]) error | |
} | |
func (e exampleProcessor) Process(db *gorm.DB, j Job[example]) error { | |
return e.fn(db, j) | |
} | |
func TestQueue(t *testing.T) { | |
if os.Getenv("TEST_DISABLE_POSTGRES") != "" { | |
t.Skip("Skipping pg.Queue tests") | |
return | |
} | |
pool, err := dockertest.NewPool("") | |
if err != nil { | |
t.Fatalf("Could not connect to docker: %s", err) | |
} | |
pool.MaxWait = 120 * time.Second | |
res, err := pool.RunWithOptions(&dockertest.RunOptions{ | |
Repository: "postgres", | |
Tag: "12.11", | |
Env: []string{ | |
"POSTGRES_PASSWORD=secret", | |
"POSTGRES_USER=user", | |
"POSTGRES_DB=dbname", | |
"listen_addresses = '*'", | |
}, | |
}, func(config *docker.HostConfig) { | |
config.AutoRemove = true | |
config.RestartPolicy = docker.RestartPolicy{Name: "no"} | |
}) | |
if err != nil { | |
t.Fatalf("Could not start resource: %s", err) | |
} | |
res.Expire(120) | |
var db *gorm.DB | |
err = pool.Retry(func() error { | |
dsn := fmt.Sprintf("postgres://user:secret@%s/dbname?sslmode=disable", res.GetHostPort("5432/tcp")) | |
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{ | |
Logger: &gormLogger{t: t, lvl: logger.Info}, | |
}) | |
if err != nil { | |
return err | |
} | |
return nil | |
}) | |
if err != nil { | |
t.Fatalf("Could not connect to database: %s", err) | |
} | |
if err := db.AutoMigrate(&Job[example]{}); err != nil { // Create the queue table | |
t.Fatalf("Could not migrate database: %s", err) | |
} | |
t.Run("Process all jobs", func(t *testing.T) { | |
N := 100 | |
W := 3 | |
k := atomic.Int32{} | |
q := Queue[example]{DB: db, MaxTries: 3} | |
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(3*time.Second)) | |
defer func() { | |
for { | |
if k.Load() == int32(N) { | |
t.Log("All jobs processed") | |
cancel() | |
return | |
} | |
} | |
}() | |
wg := sync.WaitGroup{} | |
wg.Add(N) | |
for i := 0; i < N; i++ { // Enqueue | |
go func(i int) { | |
defer wg.Done() | |
if err := q.Enqueue(ctx, example{Text: fmt.Sprintf("test_%d", i)}); err != nil { | |
t.Errorf("Could not enqueue job: %s", err) | |
} | |
}(i) | |
} | |
proc := exampleProcessor{ | |
fn: func(db *gorm.DB, j Job[example]) error { | |
k.Add(1) | |
return nil | |
}, | |
} | |
wg.Add(W) | |
for i := 0; i < W; i++ { | |
go func(w int) { | |
defer wg.Done() | |
for { | |
select { | |
case <-ctx.Done(): | |
return | |
default: | |
if err := q.Work(ctx, proc); err != nil { | |
if err == context.Canceled || err == context.DeadlineExceeded { | |
return | |
} | |
if err == ErrNoJobs { | |
time.Sleep(time.Second) | |
continue | |
} | |
t.Errorf("Worker %d: %s", w, err) | |
} | |
} | |
} | |
}(i) | |
} | |
wg.Wait() | |
if k.Load() != int32(N) { | |
t.Errorf("expected job processed %d, got %d", N, k.Load()) | |
} | |
}) | |
t.Cleanup(func() { | |
if err := pool.Purge(res); err != nil { | |
t.Errorf("could not purge resources: %v", err) | |
} | |
}) | |
} | |
type gormLogger struct { | |
t *testing.T | |
lvl logger.LogLevel | |
} | |
func (tl gormLogger) LogMode(level logger.LogLevel) logger.Interface { | |
return &gormLogger{ | |
t: tl.t, | |
lvl: level, | |
} | |
} | |
func (tl gormLogger) Info(ctx context.Context, s string, i ...interface{}) { | |
if tl.lvl < logger.Info { | |
return | |
} | |
tl.t.Logf("DB:INFO] "+s, i...) | |
} | |
func (tl gormLogger) Warn(ctx context.Context, s string, i ...interface{}) { | |
if tl.lvl < logger.Warn { | |
return | |
} | |
tl.t.Logf("DB:WARN] "+s, i...) | |
} | |
func (tl gormLogger) Error(ctx context.Context, s string, i ...interface{}) { | |
if tl.lvl < logger.Error { | |
return | |
} | |
tl.t.Logf("DB:ERROR] "+s, i...) | |
} | |
func (tl gormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { | |
if tl.lvl < logger.Info { | |
return | |
} | |
sql, rowsAffected := fc() | |
tl.t.Logf("DB:TRACE] sql: %s, rowsAffected: %d, err: %v", sql, rowsAffected, err) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment