Skip to content

Instantly share code, notes, and snippets.

@zgiber
Created February 21, 2017 00:36
Show Gist options
  • Save zgiber/247705a2405e56577c78dcd1f6e3bdfd to your computer and use it in GitHub Desktop.
Save zgiber/247705a2405e56577c78dcd1f6e3bdfd to your computer and use it in GitHub Desktop.
Sticky Worker Pool example
package pool
import (
"errors"
"fmt"
"log"
"sort"
"time"
)
type action func(*Pool) error
type Message struct {
id string
body string
}
type worker struct {
msgIn chan *Message
lastMsgID string
stopSignal chan struct{}
}
type workerGroup []*worker
func (ws workerGroup) Len() int { return len(ws) }
func (ws workerGroup) Swap(i, j int) { ws[i], ws[j] = ws[j], ws[i] }
func (ws workerGroup) Less(i, j int) bool { return len(ws[i].msgIn) < len(ws[j].msgIn) }
type Pool struct {
workers workerGroup
actionsIn chan action
stopSignal chan struct{}
}
func NewPool(numWorkers int) *Pool {
pool := &Pool{
actionsIn: make(chan action),
stopSignal: make(chan struct{}),
}
for i := 0; i < numWorkers; i++ {
worker := &worker{
msgIn: make(chan *Message, 1024),
}
pool.workers = append(pool.workers, worker)
}
go pool.start()
return pool
}
func (p *Pool) start() {
for _, w := range p.workers {
go w.start()
}
for {
select {
case act := <-p.actionsIn:
err := act(p)
if err != nil {
log.Println(err)
}
case <-p.stopSignal:
return
}
}
}
func (p *Pool) getWorker(id string) *worker {
for _, w := range p.workers {
if w.lastMsgID == id {
return w
}
}
sort.Sort(p.workers)
return p.workers[0]
}
func (w *worker) stop() {
w.stopSignal <- struct{}{}
<-w.stopSignal
}
func (w *worker) start() {
for {
select {
case msg := <-w.msgIn:
w.lastMsgID = msg.id
// do the work
if msg.id == "7" {
time.Sleep(200 * time.Millisecond)
fmt.Print(msg.id)
} else {
fmt.Print(".")
}
case <-w.stopSignal:
w.stopSignal <- struct{}{}
return
}
}
}
func (p *Pool) ProcessMessage(m *Message) {
p.actionsIn <- processMessage(m)
}
func (p *Pool) Stop() {
p.actionsIn <- stop()
}
// just an example
func addWorker() action {
return func(p *Pool) error {
return errors.New("not implemented")
}
}
func processMessage(msg *Message) action {
return func(p *Pool) error {
w := p.getWorker(msg.id)
select {
case w.msgIn <- msg:
default:
return fmt.Errorf("internal buffer is full dropping message: %s", msg.id)
}
return nil
}
}
func stop() action {
return func(p *Pool) error {
for _, w := range p.workers {
w.stop()
}
p.stopSignal <- struct{}{}
return nil
}
}
func NewMessage(id, body string) *Message {
return &Message{id, body}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment