Last active
July 4, 2017 18:34
-
-
Save malisetti/fbc7d936abc5769e8cd49308601c8384 to your computer and use it in GitHub Desktop.
Supervises n number of go routines
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Package hartman supervises and keeps n workers at all times by restarting finished workers. https://en.wikipedia.org/wiki/Full_Metal_Jacket | |
package hartman | |
import ( | |
"context" | |
"errors" | |
"log" | |
"sync" | |
) | |
var AllDone = errors.New("no more work") | |
type Worker interface { | |
Work(ctx context.Context) error | |
} | |
// Supervisor supervises the workers | |
type supervisor struct { | |
sync.Mutex | |
errors chan error | |
ctx context.Context | |
numberOfWorkers int // number of workers | |
worker Worker | |
errorHandler func(errors <-chan error) | |
doneHandler func() | |
} | |
func NewSupervisor(ctx context.Context, numWorkers int, worker Worker) *supervisor { | |
return &supervisor{ | |
ctx: ctx, | |
numberOfWorkers: numWorkers, | |
worker: worker, | |
} | |
} | |
func (s *supervisor) SetErrorHandler(errHandler func(errors <-chan error)) { | |
s.Lock() | |
defer s.Unlock() | |
s.errorHandler = errHandler | |
} | |
func (s *supervisor) SetDoneHandler(doneHandler func()) { | |
s.Lock() | |
defer s.Unlock() | |
s.doneHandler = doneHandler | |
} | |
// Supervise starts the workers | |
func (s *supervisor) Supervise() { | |
s.errors = make(chan error) | |
defer close(s.errors) | |
if s.errorHandler == nil { | |
s.errorHandler = func(errors <-chan error) { | |
for err := range errors { | |
log.Printf("supervisor received error with: %v", err) | |
} | |
} | |
} | |
go s.errorHandler(s.errors) | |
var wg sync.WaitGroup | |
wg.Add(s.numberOfWorkers) | |
for i := 0; i < s.numberOfWorkers; i++ { | |
// start workers | |
go func(ctx context.Context) { | |
defer wg.Done() | |
for ctx.Err() == nil { | |
switch err := s.worker.Work(ctx); err { | |
case nil: | |
/* nop */ | |
case AllDone: | |
return | |
default: | |
s.errors <- err | |
} | |
} | |
}(s.ctx) | |
} | |
wg.Wait() | |
if s.doneHandler != nil { | |
s.doneHandler() | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package hartman | |
import ( | |
"context" | |
"fmt" | |
"log" | |
"strconv" | |
"testing" | |
"time" | |
) | |
func TestSupervise(t *testing.T) { | |
ctx, cancel := context.WithCancel(context.Background()) | |
time.AfterFunc(20*time.Second, func() { | |
cancel() | |
}) | |
numWorkers := 2 | |
work := func(ctx context.Context) error { | |
timer := time.NewTimer(1 * time.Second) | |
for { | |
select { | |
case <-time.After(2 * time.Second): | |
return fmt.Errorf("Done my work") | |
case <-ctx.Done(): | |
return nil | |
case <-timer.C: | |
time.Sleep(1 * time.Second) | |
} | |
} | |
} | |
s := NewSupervisor(ctx, numWorkers, work) | |
s.SetDoneHandler(func() { | |
log.Println("Done supervising") | |
}) | |
s.SetErrorHandler(func(errors <-chan error) { | |
for err := range errors { | |
log.Printf("in error handler %v\n", err) | |
} | |
}) | |
s.Supervise() | |
} | |
func TestSupervise2(t *testing.T) { | |
ctx, cancel := context.WithCancel(context.Background()) | |
time.AfterFunc(20*time.Second, func() { | |
cancel() | |
}) | |
numWorkers := 2 | |
workChan := make(chan string) | |
go func() { | |
defer close(workChan) | |
for i := 0; i < 20; i++ { | |
workChan <- strconv.Itoa(i) | |
} | |
}() | |
work := func(ctx context.Context) error { | |
for { | |
select { | |
case s, ok := <-workChan: | |
if !ok { | |
return AllDone | |
} | |
time.Sleep(2 * time.Second) | |
log.Println("got work", s) | |
case <-ctx.Done(): | |
return nil | |
} | |
} | |
} | |
s := NewSupervisor(ctx, numWorkers, work) | |
s.SetDoneHandler(func() { | |
log.Println("Done supervising") | |
}) | |
s.SetErrorHandler(func(errors <-chan error) { | |
for err := range errors { | |
log.Printf("in error handler %v\n", err) | |
} | |
}) | |
s.Supervise() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment