Skip to content

Instantly share code, notes, and snippets.

@kvu787
Created February 19, 2016 21:01
Show Gist options
  • Save kvu787/f3c95033e9f4211aa152 to your computer and use it in GitHub Desktop.
Save kvu787/f3c95033e9f4211aa152 to your computer and use it in GitHub Desktop.
// Package queue implements a multi-producer, multi-consumer, unbuffered
// concurrent queue using atomic compare-and-swap operations.
// It has equivalent semantics to an unbuffered Go channel.
package queue
import (
"runtime"
"sync/atomic"
)
/*
Pseudocode:
compare and swap = cas(state *int, old, new)
states:
Waiting // no consumers are ready
Ready // some consumer C is ready to get a value
Writing // some producer P is writing a value
Written // producer P finished writing
Reading // consumer C finished reading
queue init:
state = Waiting // start state
data = alloc()
enq(newdata):
while (!cas(&state, Ready, Writing)) {
yield()
}
copy(newdata to data)
cas(&state, Writing, Written)
deq():
retval = alloc()
while (!cas(&state, Waiting, Ready)) {
yield()
}
while (!cas(&state, Written, Reading)) {
yield()
}
copy(data to retval)
cas(&state, Reading, Waiting)
return retval
*/
const (
Waiting int32 = iota
Ready
Writing
Written
Reading
)
func NewQueue() *queue {
return &queue{Waiting, 0}
}
type queue struct {
state int32
data int
}
func (q *queue) Enq(data int) {
for !atomic.CompareAndSwapInt32(&q.state, Ready, Writing) {
runtime.Gosched() // necessary to force the scheduler to switch to another goroutine
}
q.data = data
atomic.CompareAndSwapInt32(&q.state, Writing, Written)
}
func (q *queue) Deq() int {
for !atomic.CompareAndSwapInt32(&q.state, Waiting, Ready) {
runtime.Gosched()
}
for !atomic.CompareAndSwapInt32(&q.state, Written, Reading) {
runtime.Gosched()
}
val := q.data
atomic.CompareAndSwapInt32(&q.state, Reading, Waiting)
return val
}
package queue
import (
"math/rand"
"sort"
"sync"
"testing"
"time"
)
func TestBasic(t *testing.T) {
iters := 100000
q := NewQueue()
go func() {
for i := 0; i < iters; i++ {
q.Enq(i)
}
}()
for i := 0; i < iters; i++ {
if i != q.Deq() {
t.FailNow()
}
}
}
func TestMany(t *testing.T) {
const nthreads = 50
const nitems = 50
q := NewQueue()
// create rngs for each producer thread
rngs := [nthreads]*rand.Rand{}
for i := range rngs {
rngs[i] = rand.New(rand.NewSource(time.Now().UnixNano()))
}
// 2d arrays to collect sent and received items
sent := [nthreads][nitems]int{}
recv := [nthreads][nitems]int{}
var wg sync.WaitGroup
wg.Add(nthreads * 2)
// producers
for i := 0; i < nthreads; i++ {
go func(n int) {
for i := 0; i < nitems; i++ {
sent[n][i] = rngs[n].Int()
q.Enq(sent[n][i])
}
wg.Done()
}(i)
}
// consumers
for i := 0; i < nthreads; i++ {
go func(n int) {
for i := 0; i < nitems; i++ {
recv[n][i] = q.Deq()
}
wg.Done()
}(i)
}
// wait for all producer and consumer threads to finish
wg.Wait()
// collect sent and received items
allsent := []int{}
allrecv := []int{}
for i := 0; i < nthreads; i++ {
for j := 0; j < nitems; j++ {
allsent = append(allsent, sent[i][j])
allrecv = append(allrecv, recv[i][j])
}
}
// sort and compare
sort.Ints(allsent)
sort.Ints(allrecv)
if !intsEq(allsent, allrecv) {
t.FailNow()
}
}
func intsEq(a, b []int) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment