Skip to content

Instantly share code, notes, and snippets.

@seveas
Created November 26, 2020 08:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save seveas/372195e1235e6c4269b7efba8b241fd8 to your computer and use it in GitHub Desktop.
Save seveas/372195e1235e6c4269b7efba8b241fd8 to your computer and use it in GitHub Desktop.
package scattergather
import (
"context"
"sync"
)
type ScatteredFunction func(context.Context, ...interface{})(interface{}, error)
type ScatterGather struct {
waitGroup *sync.WaitGroup
results []interface{}
errors *ScatteredError
resultChan chan scatterResult
doneChan chan interface{}
initOnce sync.Once
gatherOnce sync.Once
}
type scatterResult struct {
val interface{}
err error
}
func New() *ScatterGather {
sg := &ScatterGather{}
sg.init()
return sg
}
func (sg *ScatterGather) init() {
sg.initOnce.Do(func() {
sg.waitGroup = &sync.WaitGroup{}
sg.results = make([]interface{}, 0)
sg.errors = &ScatteredError{}
sg.errors.Errors = make([]error, 0)
sg.resultChan = make(chan scatterResult, 10)
sg.doneChan = make(chan interface{})
})
}
func (sg *ScatterGather) gather() {
sg.gatherOnce.Do(func(){
go sg.gatherer()
})
}
func (sg *ScatterGather) gatherer() {
for res := range sg.resultChan {
if res.val != nil {
sg.results = append(sg.results, res.val)
}
if res.err != nil {
sg.errors.AddError(res.err)
}
}
close(sg.doneChan)
}
func (sg *ScatterGather) Run(callable ScatteredFunction, ctx context.Context, args ...interface{}) {
sg.init()
sg.gather()
sg.waitGroup.Add(1)
go func() {
ret, err := callable(ctx, args...)
sg.resultChan <-scatterResult{val: ret, err: err}
sg.waitGroup.Done()
}()
}
func (sg *ScatterGather) Wait() ([]interface{}, *ScatteredError) {
sg.waitGroup.Wait()
close(sg.resultChan)
<-sg.doneChan
if !sg.errors.HasErrors() {
return sg.results, nil
}
return sg.results, sg.errors
}
type ScatteredError struct {
Errors []error
}
func (e *ScatteredError) HasErrors() bool {
return e != nil && e.Errors != nil && len(e.Errors) > 0
}
func (e *ScatteredError) AddError(err error) {
if e.Errors == nil {
e.Errors = []error{err}
} else {
e.Errors = append(e.Errors, err)
}
}
func (e *ScatteredError) Error() string {
if e == nil {
return "(nil error)"
}
if e.Errors == nil || len(e.Errors) == 0 {
return "(empty scattered error)"
}
errstr := e.Errors[0].Error()
for _, err := range(e.Errors[1:]) {
errstr += "\n" + err.Error()
}
return errstr
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment