Skip to content

Instantly share code, notes, and snippets.

@leolara
Last active April 29, 2023 04:30
Show Gist options
  • Star 12 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save leolara/f6fb5dfc04d64947487f16764d6b37b6 to your computer and use it in GitHub Desktop.
Save leolara/f6fb5dfc04d64947487f16764d6b37b6 to your computer and use it in GitHub Desktop.
Example in Go of how to close a channel written by several goroutines
// Package gochannels example of how to close a channel written by several goroutines
package gochannels
import (
"math/big"
"sync"
)
// Publisher write sequences of big.Int into a channel
type Publisher struct {
ch chan big.Int
closingCh chan interface{}
writersWG sync.WaitGroup
writersWGMutex sync.Mutex
}
// New creates a Publisher
func New() *Publisher {
return &Publisher{
ch: make(chan big.Int),
closingCh: make(chan interface{}),
}
}
// Run write into the channel the sequence 0..n-1
func (p *Publisher) Run(n int) {
for i := 0; i < n; i++ {
p.write(*big.NewInt(int64(i)))
}
}
// Read returns the channel to write
func (p *Publisher) Read() <-chan big.Int {
return p.ch
}
// write into the channel in a different goroutine
func (p *Publisher) write(data big.Int) {
go func(data big.Int) {
p.writersWGMutex.Lock()
p.writersWG.Add(1)
p.writersWGMutex.Unlock()
defer p.writersWG.Done()
select {
case <-p.closingCh:
return
default:
}
select {
case <-p.closingCh:
case p.ch <- data:
}
}(data)
}
// Closes channel, draining any blocked writes
func (p *Publisher) Close() {
close(p.closingCh)
go func() {
for range p.ch {
}
}()
p.writersWGMutex.Lock()
p.writersWG.Wait()
p.writersWGMutex.Unlock()
close(p.ch)
}
// CloseWithoutDraining closes channel, without draining any pending writes, this method
// will block until all writes have been unblocked by reads
func (p *Publisher) CloseWithoutDraining() {
close(p.closingCh)
p.writersWGMutex.Lock()
p.writersWG.Wait()
p.writersWGMutex.Unlock()
close(p.ch)
}
package gochannels
// In order to detect race conditions run the test with:
// go test -cpu=1,9,55,99 -race -count=100 -failfast
import (
"math/big"
"sync"
"testing"
)
func TestSimple(t *testing.T) {
consumer := func(pub *Publisher, n int, wg *sync.WaitGroup, result chan *big.Int) {
ch := pub.Read()
acc := big.NewInt(0)
for i := 0; i < n; i++ {
val := <-ch
t.Log(&val)
acc.Add(acc, &val)
}
wg.Done()
result <- acc
}
producer := func(pub *Publisher, n int, wg *sync.WaitGroup) {
pub.Run(n)
wg.Done()
}
precalc := func(n int) *big.Int {
acc := big.NewInt(0)
for i := 0; i < n; i++ {
acc.Add(acc, big.NewInt(int64(i)))
}
return acc
}
p := New()
var wg sync.WaitGroup
resultCh := make(chan *big.Int)
wg.Add(2)
go consumer(p, 100, &wg, resultCh)
go producer(p, 100, &wg)
wg.Wait()
p.CloseWithoutDraining()
result := <-resultCh
t.Log(result)
t.Log(precalc(100))
if result.Cmp(precalc(100)) != 0 {
t.Error("wrong result")
}
}
func TestIntermediate(t *testing.T) {
consumer := func(pub *Publisher, n int, wg *sync.WaitGroup, result chan *big.Int) {
ch := pub.Read()
acc := big.NewInt(0)
for i := 0; i < n; i++ {
val := <-ch
t.Log(&val)
acc.Add(acc, &val)
}
wg.Done()
result <- acc
}
producer := func(pub *Publisher, n int, wg *sync.WaitGroup) {
pub.Run(n)
wg.Done()
}
p := New()
var wg sync.WaitGroup
resultCh := make(chan *big.Int)
wg.Add(3)
go consumer(p, 100, &wg, resultCh)
go producer(p, 100, &wg)
go producer(p, 100, &wg)
<-resultCh
p.Close()
wg.Wait()
}
@fourdim
Copy link

fourdim commented Jun 14, 2022

Hi, @leolara,
Thanks for your excellent post, would you be so kind to add a license to this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment