Skip to content

Instantly share code, notes, and snippets.

@hyfather
Created April 6, 2018 23:13
Show Gist options
  • Save hyfather/eedb8a50ee2f3e32cdfb1443e0f98dd5 to your computer and use it in GitHub Desktop.
Save hyfather/eedb8a50ee2f3e32cdfb1443e0f98dd5 to your computer and use it in GitHub Desktop.
Go Pipelines
package main
import (
"log"
"sync"
"time"
)
func main() {
log.Println("start")
inChan := make(chan interface{})
pipeline := NewPipeline()
pipeline.AddStageWithFanOut(Add1Stage, 20)
pipeline.AddStageWithFanOut(SquareStage, 20)
// Modify the size of the SlowPrintStage fan to see varying output
pipeline.AddStageWithFanOut(SlowPrintStage, 5)
doneChan := pipeline.Run(inChan)
go func() {
defer close(inChan)
for i := 0; i <= 20; i++ {
inChan <- i
}
}()
log.Println("waiting for done")
<- doneChan
log.Println("done")
}
func SquareStage(inChan <-chan interface{}) (outChan chan interface{}) {
outChan = make(chan interface{})
go func() {
defer close(outChan)
for i := range inChan {
if i, ok := i.(int); ok {
outChan <- i * i
}
}
}()
return
}
func Add1Stage(inChan <-chan interface{}) (outChan chan interface{}) {
outChan = make(chan interface{})
go func() {
defer close(outChan)
for i := range inChan {
if i, ok := i.(int); ok {
outChan <- i + 1
}
}
}()
return
}
func SlowPrintStage(inChan <-chan interface{}) (outChan chan interface{}) {
outChan = make(chan interface{})
go func() {
defer close(outChan)
for i := range inChan {
if i, ok := i.(int); ok {
log.Println(i)
time.Sleep(1*time.Second)
outChan <- i
}
}
}()
return
}
type PipelineStage func(inChan <-chan interface{}) (outChan chan interface{})
type Pipeline struct {
stages []PipelineStage
}
func NewPipeline() Pipeline {
return Pipeline{}
}
func (p *Pipeline) AddStage(stage PipelineStage) {
p.stages = append(p.stages, stage)
}
func (p *Pipeline) AddStageWithFanOut(stage PipelineStage, fanSize int) {
p.AddStage(p.fanningStageFactory(stage, fanSize))
}
func (p *Pipeline) Run(channel <-chan interface{}) chan struct{} {
// the output channel of the preceding stage is the input channel of the next stage
for _, stage := range p.stages {
channel = stage(channel)
}
return p.drainAndDone(channel)
}
func (p *Pipeline) fanningStageFactory(inputFunc PipelineStage, fanSize int) (outputFunc PipelineStage) {
outputFunc = func(inChan <-chan interface{}) (outChan chan interface{}) {
var channels []chan interface{}
for i := 0; i < fanSize; i++ {
channels = append(channels, inputFunc(inChan))
}
outChan = p.fanIn(channels)
return
}
return
}
func (p *Pipeline) fanIn(channels []chan interface{}) (outChan chan interface{}) {
var wg sync.WaitGroup
wg.Add(len(channels))
outChan = make(chan interface{})
for _, ch := range channels {
go func(ch <-chan interface{}) {
defer wg.Done()
for obj := range ch {
outChan <- obj
}
}(ch)
}
go func() {
defer close(outChan)
wg.Wait()
}()
return
}
// read objects from the last channel so that the gc can mark these objects
func (p *Pipeline) drainAndDone(inChan <-chan interface{}) (doneChan chan struct{}) {
doneChan = make(chan struct{}, 1)
go func() {
for range inChan {
}
log.Println("pipeline complete, sending on doneChan")
doneChan <- struct{}{}
}()
return
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment