Skip to content

Instantly share code, notes, and snippets.

@leolara
Last active April 29, 2023 04:30
Show Gist options
  • Star 12 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save leolara/f6fb5dfc04d64947487f16764d6b37b6 to your computer and use it in GitHub Desktop.
Save leolara/f6fb5dfc04d64947487f16764d6b37b6 to your computer and use it in GitHub Desktop.
Example in Go of how to close a channel written by several goroutines
// Package gochannels example of how to close a channel written by several goroutines
package gochannels
import (
"math/big"
"sync"
)
// Publisher write sequences of big.Int into a channel
type Publisher struct {
ch chan big.Int
closingCh chan interface{}
writersWG sync.WaitGroup
writersWGMutex sync.Mutex
}
// New creates a Publisher
func New() *Publisher {
return &Publisher{
ch: make(chan big.Int),
closingCh: make(chan interface{}),
}
}
// Run write into the channel the sequence 0..n-1
func (p *Publisher) Run(n int) {
for i := 0; i < n; i++ {
p.write(*big.NewInt(int64(i)))
}
}
// Read returns the channel to write
func (p *Publisher) Read() <-chan big.Int {
return p.ch
}
// write into the channel in a different goroutine
func (p *Publisher) write(data big.Int) {
go func(data big.Int) {
p.writersWGMutex.Lock()
p.writersWG.Add(1)
p.writersWGMutex.Unlock()
defer p.writersWG.Done()
select {
case <-p.closingCh:
return
default:
}
select {
case <-p.closingCh:
case p.ch <- data:
}
}(data)
}
// Closes channel, draining any blocked writes
func (p *Publisher) Close() {
close(p.closingCh)
go func() {
for range p.ch {
}
}()
p.writersWGMutex.Lock()
p.writersWG.Wait()
p.writersWGMutex.Unlock()
close(p.ch)
}
// CloseWithoutDraining closes channel, without draining any pending writes, this method
// will block until all writes have been unblocked by reads
func (p *Publisher) CloseWithoutDraining() {
close(p.closingCh)
p.writersWGMutex.Lock()
p.writersWG.Wait()
p.writersWGMutex.Unlock()
close(p.ch)
}
package gochannels
// In order to detect race conditions run the test with:
// go test -cpu=1,9,55,99 -race -count=100 -failfast
import (
"math/big"
"sync"
"testing"
)
func TestSimple(t *testing.T) {
consumer := func(pub *Publisher, n int, wg *sync.WaitGroup, result chan *big.Int) {
ch := pub.Read()
acc := big.NewInt(0)
for i := 0; i < n; i++ {
val := <-ch
t.Log(&val)
acc.Add(acc, &val)
}
wg.Done()
result <- acc
}
producer := func(pub *Publisher, n int, wg *sync.WaitGroup) {
pub.Run(n)
wg.Done()
}
precalc := func(n int) *big.Int {
acc := big.NewInt(0)
for i := 0; i < n; i++ {
acc.Add(acc, big.NewInt(int64(i)))
}
return acc
}
p := New()
var wg sync.WaitGroup
resultCh := make(chan *big.Int)
wg.Add(2)
go consumer(p, 100, &wg, resultCh)
go producer(p, 100, &wg)
wg.Wait()
p.CloseWithoutDraining()
result := <-resultCh
t.Log(result)
t.Log(precalc(100))
if result.Cmp(precalc(100)) != 0 {
t.Error("wrong result")
}
}
func TestIntermediate(t *testing.T) {
consumer := func(pub *Publisher, n int, wg *sync.WaitGroup, result chan *big.Int) {
ch := pub.Read()
acc := big.NewInt(0)
for i := 0; i < n; i++ {
val := <-ch
t.Log(&val)
acc.Add(acc, &val)
}
wg.Done()
result <- acc
}
producer := func(pub *Publisher, n int, wg *sync.WaitGroup) {
pub.Run(n)
wg.Done()
}
p := New()
var wg sync.WaitGroup
resultCh := make(chan *big.Int)
wg.Add(3)
go consumer(p, 100, &wg, resultCh)
go producer(p, 100, &wg)
go producer(p, 100, &wg)
<-resultCh
p.Close()
wg.Wait()
}
@leolara
Copy link
Author

leolara commented Jun 21, 2020

Hi @anirbanroydas

I wrote this a long time ago, so I do not remember all details now.

I think close is non-blocking, so I think it would not make a diference, but as I said it was a long time ago

@rueian
Copy link

rueian commented Aug 22, 2021

Hi @leolara,

I also come here from the https://www.leolara.me/blog/closing_a_go_channel_written_by_several_goroutines/. It is really a great article.

And since you mentioned the usage of a WaitGroup, I come up with an idea to avoid using mutex: we could probably use a writing counter to do similar work.

type Publisher struct {
	ch chan big.Int

	state   int32
	writing int32
}

func (p *Publisher) write(n int) {
	atomic.AddInt32(&p.writing, 1)
	if atomic.LoadInt32(&p.state) == 0 {
		p.ch <- n
	}
	atomic.AddInt32(&p.writing, -1)
}

func (p *Publisher) Close() {
	if atomic.CompareAndSwapInt32(&p.state, 0, 1) {
		go func() {
			for range p.ch {
			}
		}()
	}
	for atomic.LoadInt32(&p.writing) != 0 {
		runtime.Gosched()
	}
	if atomic.CompareAndSwapInt32(&p.state, 1, 2) {
		close(p.ch)
	}
}

@leolara
Copy link
Author

leolara commented Aug 23, 2021

Hi @rueian,

I guess you do atomic.CompareAndSwapInt32(&p.state, 1, 2) in case Close is called more than once?

@rueian
Copy link

rueian commented Aug 24, 2021

Hi @rueian,

I guess you do atomic.CompareAndSwapInt32(&p.state, 1, 2) in case Close is called more than once?

Yes, you are right. Otherwise it may panic.

@fourdim
Copy link

fourdim commented Jun 14, 2022

Hi, @leolara,
Thanks for your excellent post, would you be so kind to add a license to this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment