Skip to content

Instantly share code, notes, and snippets.

@owulveryck
Created July 22, 2020 06:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save owulveryck/e1ec898ee29521a18e3ace9528858e84 to your computer and use it in GitHub Desktop.
Save owulveryck/e1ec898ee29521a18e3ace9528858e84 to your computer and use it in GitHub Desktop.
Concurrence
func broadcast(ctx context.Context, ch <-chan gorgonia.Value, cs []chan gorgonia.Value) {
for {
select {
case msg := <-ch:
for _, c := range cs {
select {
case c <- msg:
case <-ctx.Done():
return
}
}
case <-ctx.Done():
return
}
}
}
func Test_broadcast(t *testing.T) {
t.Run("context cancel without value", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cs := make([]chan gorgonia.Value, 1)
for i := range cs {
// The size of the channels buffer controls how far behind the receivers
// of the fanOut channels can lag the other channels.
cs[i] = make(chan gorgonia.Value, 0)
}
c := make(<-chan gorgonia.Value, 0)
go broadcast(ctx, c, cs)
cancel()
})
t.Run("context cancel without value", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cs := make([]chan gorgonia.Value, 1)
for i := range cs {
// The size of the channels buffer controls how far behind the receivers
// of the fanOut channels can lag the other channels.
cs[i] = make(chan gorgonia.Value, 0)
}
c := make(chan gorgonia.Value, 0)
go broadcast(ctx, c, cs)
c <- nil
cancel()
})
t.Run("broadcast ", func(t *testing.T) {
fortyTwo := gorgonia.F32(42.0)
ctx := context.Background()
cs := make([]chan gorgonia.Value, 2)
for i := range cs {
// The size of the channels buffer controls how far behind the receivers
// of the fanOut channels can lag the other channels.
cs[i] = make(chan gorgonia.Value, 0)
}
c := make(chan gorgonia.Value, 0)
go broadcast(ctx, c, cs)
c <- &fortyTwo
v0 := <-cs[0]
v1 := <-cs[1]
if !reflect.DeepEqual(v0, &fortyTwo) {
t.Errorf("broadcast want %v, got %v", &fortyTwo, v0)
}
if !reflect.DeepEqual(v1, &fortyTwo) {
t.Errorf("broadcast want %v, got %v", &fortyTwo, v1)
}
})
}
func fanOut(ctx context.Context, ch <-chan gorgonia.Value, size, lag int) []chan gorgonia.Value {
cs := make([]chan gorgonia.Value, size)
for i := range cs {
// The size of the channels buffer controls how far behind the receivers
// of the fanOut channels can lag the other channels.
cs[i] = make(chan gorgonia.Value, lag)
}
go func() {
for {
select {
case msg := <-ch:
for _, c := range cs {
select {
case c <- msg:
case <-ctx.Done():
for _, c := range cs {
// close all our fanOut channels when the input channel is exhausted.
close(c)
}
return
}
}
case <-ctx.Done():
for _, c := range cs {
// close all our fanOut channels when the input channel is exhausted.
close(c)
}
return
}
}
}()
return cs
}
func Test_fanOut(t *testing.T) {
t.Run("context cancel without value", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
c := make(chan gorgonia.Value, 0)
cs := fanOut(ctx, c, 1, 0)
cancel()
<-cs[0]
})
t.Run("context cancel with one value", func(t *testing.T) {
fortyTwo := gorgonia.F32(42.0)
ctx, cancel := context.WithCancel(context.Background())
c := make(chan gorgonia.Value, 0)
cs := fanOut(ctx, c, 1, 0)
c <- &fortyTwo
out := <-cs[0]
if !reflect.DeepEqual(fortyTwo.Data(), out.Data()) {
t.Errorf("Expected %v, got %v", fortyTwo, out)
}
c <- &fortyTwo
cancel()
})
t.Run("two chans", func(t *testing.T) {
fortyTwo := gorgonia.F32(42.0)
ctx := context.Background()
c := make(chan gorgonia.Value, 0)
cs := fanOut(ctx, c, 2, 0)
c <- &fortyTwo
out := <-cs[0]
if !reflect.DeepEqual(fortyTwo.Data(), out.Data()) {
t.Errorf("Expected %v, got %v", fortyTwo, out)
}
out = <-cs[1]
if !reflect.DeepEqual(fortyTwo.Data(), out.Data()) {
t.Errorf("Expected %v, got %v", fortyTwo, out)
}
})
}
func merge(ctx context.Context, cs []chan gorgonia.Value, out chan ioValue) {
var wg sync.WaitGroup
// Start an output goroutine for each input channel in cs. output
// copies values from c to out until c or done is closed, then calls
// wg.Done.
output := func(ctx context.Context, c <-chan gorgonia.Value, pos int) {
defer wg.Done()
for {
select {
case n := <-c:
select {
case out <- ioValue{
pos: pos,
v: n,
}:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
}
wg.Add(len(cs))
for i, c := range cs {
go output(ctx, c, i)
}
// Start a goroutine to close out once all the output goroutines are
// done. This must start after the wg.Add call.
go func() {
wg.Wait()
//close(out)
}()
}
func Test_merge(t *testing.T) {
fortyTwo := gorgonia.F32(42.0)
fortyThree := gorgonia.F32(43.0)
t.Run("context cancel without value", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
c := make([]chan gorgonia.Value, 1)
c[0] = make(chan gorgonia.Value, 0)
output := make(chan ioValue, 0)
merge(ctx, c, output)
cancel()
//<-output
})
t.Run("context cancel with one value", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
c := make([]chan gorgonia.Value, 1)
c[0] = make(chan gorgonia.Value, 0)
output := make(chan ioValue, 0)
merge(ctx, c, output)
c[0] <- &fortyTwo
out := <-output
if !reflect.DeepEqual(fortyTwo.Data(), out.v.Data()) {
t.Errorf("Expected %v, got %v", fortyTwo, out)
}
c[0] <- &fortyTwo
cancel()
})
t.Run("with one value", func(t *testing.T) {
ctx := context.Background()
c := make([]chan gorgonia.Value, 1)
c[0] = make(chan gorgonia.Value, 0)
output := make(chan ioValue, 0)
merge(ctx, c, output)
c[0] <- &fortyTwo
out := <-output
if !reflect.DeepEqual(fortyTwo.Data(), out.v.Data()) {
t.Errorf("Expected %v, got %v", fortyTwo, out)
}
})
t.Run("2 channels with two values", func(t *testing.T) {
ctx := context.Background()
lenChan := 2
cs := make([]chan gorgonia.Value, lenChan)
for i := range cs {
// The size of the channels buffer controls how far behind the receivers
// of the fanOut channels can lag the other channels.
cs[i] = make(chan gorgonia.Value, 0)
}
output := make(chan ioValue, 0)
merge(ctx, cs, output)
cs[1] <- &fortyThree
cs[0] <- &fortyTwo
missFortyTwo := true
missFortyThree := true
for i := 0; i < lenChan; i++ {
out := <-output
switch {
case out.pos == 0 && out.v.Data().(float32) == 42.0:
missFortyTwo = false
case out.pos == 1 && out.v.Data().(float32) == 43.0:
missFortyThree = false
default:
t.Errorf("bad conbination %v/%v", out.pos, out.v.Data())
}
}
if missFortyThree || missFortyTwo {
t.Error("Missing value")
}
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment