Skip to content

Instantly share code, notes, and snippets.

@brunsgaard
Created August 16, 2022 16:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save brunsgaard/11b200a5137d9fe67b86e646727b531b to your computer and use it in GitHub Desktop.
Save brunsgaard/11b200a5137d9fe67b86e646727b531b to your computer and use it in GitHub Desktop.
Ordered Workpool
package concurrentprotojson
import (
"container/heap"
"sync"
"github.com/twmb/franz-go/pkg/kgo"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/dynamicpb"
)
// Task is a result of a
type Task struct {
record *kgo.Record
protoMessage proto.Message
err error
}
// task is an internal container wrapping a JSONToProtoTask. It is
// used for keeping track of concurrency structures
type task struct {
index int
result *Task
wg *sync.WaitGroup
outChan chan *task
}
// Per topic jsontoproto
type WorkerPool struct {
inputChan chan *task
pq *TaskPriorityQueue
}
// Fixed number of goroutines
func NewWorkerPool(msgdec protoreflect.MessageDescriptor) *WorkerPool {
concurrency := 8
pq := make(TaskPriorityQueue, 0, concurrency)
heap.Init(&pq)
wp := &WorkerPool{
inputChan: make(chan *task),
pq: &pq,
}
for i := 0; i < concurrency; i++ {
go func() {
for task := range wp.inputChan {
msg := dynamicpb.NewMessage(msgdec)
err := protojson.Unmarshal(task.result.record.Value, msg)
if err != nil {
task.result.err = err
}
task.result.protoMessage = proto.Message(msg)
task.outChan <- task
task.wg.Done()
}
}()
}
return wp
}
func (wp *WorkerPool) Close() {
close(wp.inputChan)
}
func (wp *WorkerPool) UnmarshalFromChan(inChan chan *kgo.Record) chan *Task {
wg := &sync.WaitGroup{}
outChan := make(chan *task)
outOrderedChan := make(chan *Task)
go func() {
index := 0
for record := range inChan {
wg.Add(1)
task := &task{
index: index,
result: &Task{record: record},
wg: wg,
outChan: outChan,
}
wp.inputChan <- task
index++
}
wg.Wait()
close(outChan)
}()
go func() {
cursor := 0
for t := range outChan {
if cursor == t.index {
outOrderedChan <- t.result
cursor++
} else {
heap.Push(wp.pq, t)
}
for wp.pq.PeakIndex() == cursor {
t := heap.Pop(wp.pq).(*task)
outOrderedChan <- t.result
cursor++
}
}
close(outOrderedChan)
}()
return outOrderedChan
}
func (wp *WorkerPool) Unmarshal(rs []*kgo.Record) ([]proto.Message, []*Task) {
inChan := make(chan *kgo.Record)
outChan := wp.UnmarshalFromChan(inChan)
go func() {
for _, record := range rs {
inChan <- record
}
close(inChan)
}()
protoMessages := make([]proto.Message, 0, len(rs))
errs := make([]*Task, 0)
for t := range outChan {
if t.err != nil {
errs = append(errs, t)
} else {
protoMessages = append(protoMessages, t.protoMessage)
}
}
return protoMessages, errs
}
// An TaskPriorityQueue is a min-heap of ints.
type TaskPriorityQueue []*task
func (h TaskPriorityQueue) Len() int {
return len(h)
}
func (h TaskPriorityQueue) Cap() int {
return cap(h)
}
func (h TaskPriorityQueue) Less(i, j int) bool {
return h[i].index < h[j].index
}
func (h TaskPriorityQueue) Swap(i, j int) {
h[i], h[j] = h[j], h[i]
}
func (h *TaskPriorityQueue) Push(x any) {
*h = append(*h, x.(*task))
}
func (h *TaskPriorityQueue) Pop() any {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
func (h *TaskPriorityQueue) PeakIndex() int {
if len(*h) == 0 {
return -1
}
return (*h)[0].index
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment