Skip to content

Instantly share code, notes, and snippets.

@cideM
Created October 18, 2022 08:45
Show Gist options
  • Save cideM/6714c283c42f94299772e813a3a3b9c5 to your computer and use it in GitHub Desktop.
Save cideM/6714c283c42f94299772e813a3a3b9c5 to your computer and use it in GitHub Desktop.
Final main.go
package main
import (
"context"
"errors"
"log"
"strings"
"sync"
"time"
"golang.org/x/sync/semaphore"
)
func producer(ctx context.Context, strings []string) (<-chan string, error) {
outChannel := make(chan string)
go func() {
defer close(outChannel)
for _, s := range strings {
select {
case <-ctx.Done():
return
case outChannel <- s:
}
}
}()
return outChannel, nil
}
func sink(ctx context.Context, cancelFunc context.CancelFunc, values <-chan string, errors <-chan error) {
for {
select {
case <-ctx.Done():
log.Print(ctx.Err().Error())
return
case err := <-errors:
if err != nil {
log.Println("error: ", err.Error())
cancelFunc()
}
case val, ok := <-values:
if ok {
log.Printf("sink: %s", val)
} else {
log.Print("done")
return
}
}
}
}
func step[In any, Out any](
ctx context.Context,
inputChannel <-chan In,
fn func(In) (Out, error),
) (chan Out, chan error) {
outputChannel := make(chan Out)
errorChannel := make(chan error)
limit := int64(2)
// Use all CPU cores to maximize efficiency. We'll set the limit to 2 so you
// can see the values being processed in batches of 2 at a time, in parallel
// limit := int64(runtime.NumCPU())
sem1 := semaphore.NewWeighted(limit)
go func() {
defer close(outputChannel)
defer close(errorChannel)
for {
select {
case <-ctx.Done():
break
case s, ok := <-inputChannel:
if ok {
if err := sem1.Acquire(ctx, 1); err != nil {
log.Printf("Failed to acquire semaphore: %v", err)
break
}
go func(s In) {
defer sem1.Release(1)
time.Sleep(time.Second * 3)
result, err := fn(s)
if err != nil {
errorChannel <- err
} else {
outputChannel <- result
}
}(s)
} else {
if err := sem1.Acquire(ctx, limit); err != nil {
log.Printf("Failed to acquire semaphore: %v", err)
}
return
}
}
}
}()
return outputChannel, errorChannel
}
func Merge[T any](ctx context.Context, cs ...<-chan T) <-chan T {
var wg sync.WaitGroup
out := make(chan T)
output := func(c <-chan T) {
defer wg.Done()
for n := range c {
select {
case out <- n:
case <-ctx.Done():
return
}
}
}
wg.Add(len(cs))
for _, c := range cs {
go output(c)
}
go func() {
wg.Wait()
close(out)
}()
return out
}
func transformA(s string) (string, error) {
log.Println("transformA input: ", s)
return strings.ToLower(s), nil
}
func transformB(s string) (string, error) {
log.Println("transformB input: ", s)
// Comment this out to see the pipeline finish successfully
if s == "foo" {
return "", errors.New("oh no")
}
return strings.Title(s), nil
}
func main() {
source := []string{"FOO", "BAR", "BAX"}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
readStream, err := producer(ctx, source)
if err != nil {
log.Fatal(err)
}
step1results, step1errors := step(ctx, readStream, transformA)
step2results, step2errors := step(ctx, step1results, transformB)
allErrors := Merge(ctx, step1errors, step2errors)
sink(ctx, cancel, step2results, allErrors)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment