Skip to content

Instantly share code, notes, and snippets.

@FZambia
Created October 15, 2020 16:53
Show Gist options
  • Save FZambia/8fc2b3f79d463e28c3d3c32e462aebff to your computer and use it in GitHub Desktop.
Save FZambia/8fc2b3f79d463e28c3d3c32e462aebff to your computer and use it in GitHub Desktop.
Goroutine (worker) pool for Go language
package gpool
import "context"
// Job represents function to be executed in worker.
type Job func()
type worker struct {
jobs chan Job
stop chan struct{}
done chan struct{}
}
func newWorker(jobs chan Job) *worker {
return &worker{
jobs: jobs,
stop: make(chan struct{}, 1),
done: make(chan struct{}, 1),
}
}
func (w *worker) start() {
go func() {
for {
select {
case job := <-w.jobs:
job()
case <-w.stop:
w.done <- struct{}{}
return
}
}
}()
}
// Pool of worker goroutines.
type Pool struct {
workers []*worker
Jobs chan Job
}
// NewPool will make a pool of worker goroutines.
// Returned object contains Jobs to send a job for execution.
func NewPool(numWorkers int) *Pool {
jobs := make(chan Job, 0)
workers := make([]*worker, 0, numWorkers)
for i := 0; i < numWorkers; i++ {
worker := newWorker(jobs)
worker.start()
workers = append(workers, worker)
}
return &Pool{
Jobs: jobs,
workers: workers,
}
}
// Close will release resources used by a pool.
func (p *Pool) Close(ctx context.Context) error {
for i := 0; i < len(p.workers); i++ {
worker := p.workers[i]
select {
case <-ctx.Done():
return ctx.Err()
case worker.stop <- struct{}{}:
}
}
for i := 0; i < len(p.workers); i++ {
worker := p.workers[i]
select {
case <-ctx.Done():
return ctx.Err()
case <-worker.done:
}
}
return nil
}
package gpool
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestWorker_New(t *testing.T) {
jobQueue := make(chan Job)
worker := newWorker(jobQueue)
worker.start()
require.NotNil(t, worker)
called := false
done := make(chan bool)
job := func() {
called = true
done <- true
}
worker.jobs <- job
<-done
require.Equal(t, true, called)
}
func TestPool_New(t *testing.T) {
pool := NewPool(1000)
defer func() { _ = pool.Close(context.Background()) }()
numJobs := 10000
var wg sync.WaitGroup
wg.Add(numJobs)
var counter uint64
for i := 0; i < numJobs; i++ {
arg := uint64(1)
job := func() {
defer wg.Done()
atomic.AddUint64(&counter, arg)
require.Equal(t, uint64(1), arg)
}
pool.Jobs <- job
}
wg.Wait()
require.Equal(t, uint64(numJobs), atomic.LoadUint64(&counter))
}
func TestPool_Close(t *testing.T) {
pool := NewPool(100)
numJobs := 1000
for i := 0; i < numJobs; i++ {
job := func() {}
pool.Jobs <- job
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = pool.Close(ctx)
}
func TestPool_CloseContext(t *testing.T) {
pool := NewPool(1)
pool.Jobs <- func() {
time.Sleep(5 * time.Second)
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
err := pool.Close(ctx)
require.Equal(t, context.DeadlineExceeded, err)
}
func BenchmarkPool_RawPerformance(b *testing.B) {
pool := NewPool(1)
defer func() { _ = pool.Close(context.Background()) }()
ch := make(chan struct{}, 1)
b.ResetTimer()
for n := 0; n < b.N; n++ {
pool.Jobs <- func() {
ch <- struct{}{}
}
<-ch
}
}
func BenchmarkPool_Sequential(b *testing.B) {
pool := NewPool(16)
defer func() { _ = pool.Close(context.Background()) }()
for n := 0; n < b.N; n++ {
var wg sync.WaitGroup
wg.Add(1)
pool.Jobs <- func() {
time.Sleep(100 * time.Millisecond)
wg.Done()
}
wg.Wait()
}
}
func BenchmarkPool_Parallel(b *testing.B) {
pool := NewPool(4096)
defer func() { _ = pool.Close(context.Background()) }()
b.SetParallelism(4096)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
var wg sync.WaitGroup
wg.Add(1)
pool.Jobs <- func() {
time.Sleep(100 * time.Millisecond)
wg.Done()
}
wg.Wait()
}
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment