Skip to content

Instantly share code, notes, and snippets.

@Linuxpizi
Created May 25, 2020 16:03
Show Gist options
  • Save Linuxpizi/ef2afe78bc97a9d0618916df6931d933 to your computer and use it in GitHub Desktop.
Save Linuxpizi/ef2afe78bc97a9d0618916df6931d933 to your computer and use it in GitHub Desktop.
golang实现参考Redux实现的pub/sub
package pubsub

import (
	"errors"
	"fmt"
	"reflect"
	"sync"
	"time"
)

type unsubscribeFunc func()

// MessageBus ..
type MessageBus interface {
	// piblic a message
	Publish(topic string, args interface{})

	// close topic
	Close(topic string)

	// subscribe a topic
	Subscribe(topic string, fn interface{}) (unsubscribeFunc, error)

	// unsubscribe a topic
	// Unsubscribe(topic string, fn interface{}) error
}

// 存储的数据结构是每个 topic 的所有的订阅者
type handlersMap map[string][]*handler

/**
callback is func
queue is subscribes queue
*/
type handler struct {
	callback reflect.Value
	queue    chan reflect.Value
}

type messageBus struct {
	sync.RWMutex
	handlerQueueSize int
	handlers         handlersMap
}

// check messageBus type implement MessageBus interface
var _ MessageBus = &messageBus{}

// Publish public message on a spec topic
func (b *messageBus) Publish(topic string, args interface{}) {
	rArgs := reflect.ValueOf(args)

	b.Lock()
	defer b.Unlock()

	if hs, ok := b.handlers[topic]; ok {
		for _, h := range hs {
			h.queue <- rArgs
		}
	}

}

func (b *messageBus) Subscribe(topic string, fn interface{}) (Unsubscribe unsubscribeFunc, err error) {

	if err := isValidHandler(fn); err != nil {
		return nil, err
	}

	rv := reflect.ValueOf(fn)

	h := &handler{
		callback: rv,
		queue:    make(chan reflect.Value, b.handlerQueueSize),
	}

	go func() {
		for args := range h.queue {
			h.callback.Call([]reflect.Value{args})
		}
	}()

	b.Lock()
	defer b.Unlock()

	b.handlers[topic] = append(b.handlers[topic], h)

	// unsubscribe topic
	_unsubscribe := unsubscribeFunc(func() {
		if _, ok := b.handlers[topic]; ok {
			for i, h := range b.handlers[topic] {
				if h.callback == rv {
					close(h.queue)
					b.handlers[topic] = append(b.handlers[topic][:i], b.handlers[topic][i+1:]...)
				}
			}
		}
	})

	return _unsubscribe, nil
}

func (b *messageBus) Close(topic string) {
	b.Lock()
	defer b.Unlock()

	if _, ok := b.handlers[topic]; ok {
		for _, h := range b.handlers[topic] {
			close(h.queue)
		}

		delete(b.handlers, topic)

		return
	}
}

// judge params is a func
func isValidHandler(fn interface{}) error {
	if reflect.TypeOf(fn).Kind() != reflect.Func {
		return errors.New("args must a func")
	}
	return nil
}

// New creates new MessageBus
// handlerQueueSize sets buffered channel length per subscriber
func New(handlerQueueSize int) MessageBus {
	if handlerQueueSize == 0 {
		panic("handlerQueueSize has to be greater then 0")
	}

	return &messageBus{
		handlerQueueSize: handlerQueueSize,
		handlers:         make(handlersMap),
	}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment