Skip to content

Instantly share code, notes, and snippets.

@maxsei
Last active May 23, 2024 22:36
Show Gist options
  • Save maxsei/fef76d3306056fbdfea66fcd7fc30c49 to your computer and use it in GitHub Desktop.
Save maxsei/fef76d3306056fbdfea66fcd7fc30c49 to your computer and use it in GitHub Desktop.
streaming for EnTeRpRiSe
package stream
import (
"context"
"errors"
"sync"
)
type CancelableMessage[T any] struct {
ctx context.Context
cancel context.CancelCauseFunc
data T
}
func New[T any](ctx context.Context) *Stream[T] {
res := Stream[T]{
next: make(chan CancelableMessage[T]),
subscribe: make(chan CancelableMessage[chan T]),
unsubscribe: make(chan CancelableMessage[(<-chan T)]),
}
res.ctx, res.cancel = context.WithCancel(ctx)
return &res
}
type Stream[T any] struct {
subscribers []chan T
ctx context.Context
cancel context.CancelFunc
state T
// Channels.
next chan CancelableMessage[T]
deref chan CancelableMessage[(<-chan T)]
subscribe chan CancelableMessage[chan T]
unsubscribe chan CancelableMessage[(<-chan T)]
}
func (s *Stream[T]) Start() {
processingEvents:
for {
select {
case <-s.ctx.Done():
for _, sub := range s.subscribers {
close(sub)
}
s.subscribers = s.subscribers[:0]
return
case subscriber := <-s.unsubscribe:
i := s.findSubscriberIndex(subscriber.data)
if i == -1 {
subscriber.cancel(errors.New("subscriber not found"))
continue processingEvents
}
close(s.subscribers[i])
s.subscribers = append(s.subscribers[:i], s.subscribers[i+1:]...)
case subscriber := <-s.subscribe:
// i := s.findSubscriberIndex(subscriber.data)
// if i != -1 {
// subscriber.cancel(errors.New("subscriber already exists"))
// return
// }
s.subscribers = append(s.subscribers, subscriber.data)
case message := <-s.next:
var wg sync.WaitGroup
wg.Add(len(s.subscribers))
for i := range s.subscribers {
go func(i int) {
defer wg.Done()
select {
case s.subscribers[i] <- message.data:
case <-message.ctx.Done():
// TODO: deal with slow consumers here... <16-05-24, Max Schulte> //
// TODO: I feel like consumers must have their own context too
// so that we can "continue" to other consumers instead of just
// returning when the message context has run out. We can have both.
// Perhaps the unsubscribe method can be part of a consumer object
// as well as context expiration for explicit and implicit
// unsubscriptions. <17-05-24, Max Schulte> //
return
}
}(i)
}
wg.Wait()
case subscriber := <-s.deref:
i := s.findSubscriberIndex(subscriber.data)
if i == -1 {
subscriber.cancel(errors.New("subscriber not found"))
continue processingEvents
}
select {
case s.subscribers[i] <- s.state:
case <-subscriber.ctx.Done():
// TODO: deal with slow consumers here... <16-05-24, Max Schulte> //
continue processingEvents
}
}
}
}
func (s *Stream[T]) Ctx() context.Context { return s.ctx }
func (s *Stream[T]) Close() { s.cancel() }
func (s *Stream[T]) findSubscriberIndex(subscriber <-chan T) int {
for i := range s.subscribers {
if s.subscribers[i] == subscriber {
return i
}
}
return -1
}
func sendMsg[T any](parent context.Context, message T, ch chan CancelableMessage[T]) error {
ctx, cancel := context.WithCancelCause(parent)
// NB: Calling cancel is delegated to the Start() method for error
// handling/control flow.
select {
case <-ctx.Done():
return ctx.Err()
case ch <- CancelableMessage[T]{ctx: ctx, cancel: cancel, data: message}:
}
return nil
}
func (s *Stream[T]) Next(ctx context.Context, message T) error {
return sendMsg(ctx, message, s.next)
}
func (s *Stream[T]) Deref(ctx context.Context, subscriber <-chan T) error {
return sendMsg(ctx, subscriber, s.deref)
}
func (s *Stream[T]) Subscribe(ctx context.Context) (<-chan T, error) {
subscriber := make(chan T)
if err := sendMsg(ctx, subscriber, s.subscribe); err != nil {
return nil, err
}
return subscriber, nil
}
func (s *Stream[T]) Unsubscribe(ctx context.Context, subscriber <-chan T) error {
return sendMsg(ctx, subscriber, s.unsubscribe)
}
package stream
import (
"context"
"sync"
"sync/atomic"
"testing"
)
func TestSingleProducerSingleConsumer(t *testing.T) {
// Setup stream.
s := New[int](context.Background())
go s.Start()
// Consumer
consumer, err := s.Subscribe(context.Background())
if err != nil {
t.Error(err)
}
// Producer
go func() {
for i := 0; i < 32; i++ {
if err := s.Next(context.Background(), i); err != nil {
t.Error(err)
}
}
s.Close()
}()
// Listen to consumer.
for message := range consumer {
t.Log(message)
}
}
func TestSingleProducerMultipeConsumer(t *testing.T) {
// Setup stream.
s := New[int](context.Background())
go s.Start()
const ExpectedMessageCount = 32
// Setup consumers.
consumers := make([]<-chan int, 16)
for i := range consumers {
var err error
consumers[i], err = s.Subscribe(context.Background())
if err != nil {
t.Error(err)
}
}
var consumerReportsCount int64
// Listen to consumers and make sure they get all the messages.
var wg sync.WaitGroup
wg.Add(len(consumers))
for id := range consumers {
go func(id int) {
consumer := consumers[id]
var actualMessageCount int
for range consumer {
actualMessageCount += 1
}
if actualMessageCount != ExpectedMessageCount {
t.Errorf("consumer %03d: expected %d got %d", id, ExpectedMessageCount, actualMessageCount)
}
atomic.AddInt64(&consumerReportsCount, 1)
wg.Done()
}(id)
}
// Produce all values wait for consumers to receive them and close the stream.
for i := range make([]struct{}, ExpectedMessageCount) {
if err := s.Next(context.Background(), i); err != nil {
t.Error(err)
}
}
s.Close()
wg.Wait()
// Ensure that all consumers have reported their counts.
if int(consumerReportsCount) != len(consumers) {
t.Fatal("did not get a report from all consumers")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment