Last active
August 3, 2017 16:18
-
-
Save SamSaffron/4385544 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'redis' | |
# the heart of the message bus, it acts as 2 things | |
# | |
# 1. A channel multiplexer | |
# 2. Backlog storage per-multiplexed channel. | |
# | |
# ids are all sequencially increasing numbers starting at 0 | |
# | |
class MessageBus::Message < Struct.new(:global_id, :message_id, :channel , :data) | |
def self.decode(encoded) | |
s1 = encoded.index("|") | |
s2 = encoded.index("|", s1+1) | |
s3 = encoded.index("|", s2+1) | |
MessageBus::Message.new encoded[0..s1].to_i, encoded[s1+1..s2].to_i, encoded[s2+1..s3-1].gsub("$$123$$", "|"), encoded[s3+1..-1] | |
end | |
# only tricky thing to encode is pipes in a channel name ... do a straight replace | |
def encode | |
global_id.to_s << "|" << message_id.to_s << "|" << channel.gsub("|","$$123$$") << "|" << data | |
end | |
end | |
class MessageBus::ReliablePubSub | |
# max_backlog_size is per multiplexed channel | |
def initialize(redis_config = {}, max_backlog_size = 1000) | |
@redis_config = redis_config | |
@max_backlog_size = 1000 | |
# we can store a ton here ... | |
@max_global_backlog_size = 100000 | |
end | |
# amount of global backlog we can spin through | |
def max_global_backlog_size=(val) | |
@max_global_backlog_size = val | |
end | |
# per channel backlog size | |
def max_backlog_size=(val) | |
@max_backlog_size = val | |
end | |
def new_redis_connection | |
::Redis.new(@redis_config) | |
end | |
def redis_channel_name | |
db = @redis_config[:db] || 0 | |
"discourse_#{db}" | |
end | |
# redis connection used for publishing messages | |
def pub_redis | |
@pub_redis ||= new_redis_connection | |
end | |
def offset_key(channel) | |
"__mb_offset_#{channel}" | |
end | |
def backlog_key(channel) | |
"__mb_backlog_#{channel}" | |
end | |
def global_id_key | |
"__mb_global_id" | |
end | |
def global_backlog_key | |
"__mb_global_backlog" | |
end | |
def global_offset_key | |
"__mb_global_offset" | |
end | |
# use with extreme care, will nuke all of the data | |
def reset! | |
pub_redis.keys("__mb_*").each do |k| | |
pub_redis.del k | |
end | |
end | |
def publish(channel, data) | |
redis = pub_redis | |
offset_key = offset_key(channel) | |
backlog_key = backlog_key(channel) | |
redis.watch(offset_key, backlog_key, global_id_key, global_backlog_key, global_offset_key) do | |
offset = redis.get(offset_key).to_i | |
backlog = redis.llen(backlog_key).to_i | |
global_offset = redis.get(global_offset_key).to_i | |
global_backlog = redis.llen(global_backlog_key).to_i | |
global_id = redis.get(global_id_key).to_i | |
global_id += 1 | |
too_big = backlog + 1 > @max_backlog_size | |
global_too_big = global_backlog + 1 > @max_global_backlog_size | |
message_id = backlog + offset + 1 | |
redis.multi do | |
if too_big | |
redis.ltrim backlog_key, (backlog+1) - @max_backlog_size, -1 | |
offset += (backlog+1) - @max_backlog_size | |
redis.set(offset_key, offset) | |
end | |
if global_too_big | |
redis.ltrim global_backlog_key, (global_backlog+1) - @max_global_backlog_size, -1 | |
global_offset += (global_backlog+1) - @max_global_backlog_size | |
redis.set(global_offset_key, global_offset) | |
end | |
msg = MessageBus::Message.new global_id, message_id, channel, data | |
payload = msg.encode | |
redis.set global_id_key, global_id | |
redis.rpush backlog_key, payload | |
redis.rpush global_backlog_key, message_id.to_s << "|" << channel | |
redis.publish redis_channel_name, payload | |
end | |
return message_id | |
end | |
end | |
def backlog(channel, last_id = nil) | |
redis = pub_redis | |
offset_key = offset_key(channel) | |
backlog_key = backlog_key(channel) | |
items = nil | |
redis.watch offset_key, backlog_key do | |
offset = redis.get(offset_key).to_i | |
start_at = last_id.to_i - offset | |
items = redis.lrange backlog_key, start_at, -1 | |
end | |
items.map do |i| | |
MessageBus::Message.decode(i) | |
end | |
end | |
def global_backlog(last_id = nil) | |
last_id = last_id.to_i | |
items = nil | |
redis = pub_redis | |
redis.watch global_backlog_key, global_offset_key do | |
offset = redis.get(global_offset_key).to_i | |
start_at = last_id.to_i - offset | |
items = redis.lrange global_backlog_key, start_at, -1 | |
end | |
items.map! do |i| | |
pipe = i.index "|" | |
message_id = i[0..pipe].to_i | |
channel = i[pipe+1..-1] | |
m = get_message(channel, message_id) | |
m | |
end | |
items.compact! | |
items | |
end | |
def get_message(channel, message_id) | |
redis = pub_redis | |
offset_key = offset_key(channel) | |
backlog_key = backlog_key(channel) | |
msg = nil | |
redis.watch(offset_key, backlog_key) do | |
offset = redis.get(offset_key).to_i | |
idx = (message_id-1) - offset | |
return nil if idx < 0 | |
msg = redis.lindex(backlog_key, idx) | |
end | |
if msg | |
msg = MessageBus::Message.decode(msg) | |
end | |
msg | |
end | |
def subscribe(channel, last_id = nil) | |
# trivial implementation for now, | |
# can cut down on connections if we only have one global subscriber | |
raise ArgumentError unless block_given? | |
global_subscribe(last_id) do |m| | |
yield m if m.channel == channel | |
end | |
end | |
def global_subscribe(last_id=nil, &blk) | |
raise ArgumentError unless block_given? | |
highest_id = last_id | |
clear_backlog = lambda do | |
global_backlog(highest_id).each do |old| | |
highest_id = old.global_id | |
yield old | |
end | |
end | |
begin | |
redis = new_redis_connection | |
if highest_id | |
clear_backlog.call(&blk) | |
end | |
redis.subscribe(redis_channel_name) do |on| | |
on.subscribe do | |
if highest_id | |
clear_backlog.call(&blk) | |
end | |
end | |
on.message do |c,m| | |
m = MessageBus::Message.decode m | |
if highest_id && m.global_id != highest_id + 1 | |
clear_backlog.call(&blk) | |
end | |
yield m if highest_id.nil? || m.global_id > highest_id | |
highest_id = m.global_id | |
end | |
end | |
rescue => error | |
MessageBus.logger.warn "#{error} subscribe failed, reconnecting in 1 second. Call stack #{error.backtrace}" | |
sleep 1 | |
retry | |
end | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'spec_helper' | |
require 'message_bus' | |
describe MessageBus::ReliablePubSub do | |
def new_test_bus | |
MessageBus::ReliablePubSub.new(:db => 10) | |
end | |
before do | |
@bus = new_test_bus | |
@bus.reset! | |
end | |
it "should be able to access the backlog" do | |
@bus.publish "/foo", "bar" | |
@bus.publish "/foo", "baz" | |
@bus.backlog("/foo", 0).to_a.should == [ | |
MessageBus::Message.new(1,1,'/foo','bar'), | |
MessageBus::Message.new(2,2,'/foo','baz') | |
] | |
end | |
it "should truncate channels correctly" do | |
@bus.max_backlog_size = 2 | |
4.times do |t| | |
@bus.publish "/foo", t.to_s | |
end | |
@bus.backlog("/foo").to_a.should == [ | |
MessageBus::Message.new(3,3,'/foo','2'), | |
MessageBus::Message.new(4,4,'/foo','3'), | |
] | |
end | |
it "should be able to grab a message by id" do | |
id1 = @bus.publish "/foo", "bar" | |
id2 = @bus.publish "/foo", "baz" | |
@bus.get_message("/foo", id2).should == MessageBus::Message.new(2, 2, "/foo", "baz") | |
@bus.get_message("/foo", id1).should == MessageBus::Message.new(1, 1, "/foo", "bar") | |
end | |
it "should be able to access the global backlog" do | |
@bus.publish "/foo", "bar" | |
@bus.publish "/hello", "world" | |
@bus.publish "/foo", "baz" | |
@bus.publish "/hello", "planet" | |
@bus.global_backlog.to_a.should == [ | |
MessageBus::Message.new(1, 1, "/foo", "bar"), | |
MessageBus::Message.new(2, 1, "/hello", "world"), | |
MessageBus::Message.new(3, 2, "/foo", "baz"), | |
MessageBus::Message.new(4, 2, "/hello", "planet") | |
] | |
end | |
it "should correctly omit dropped messages from the global backlog" do | |
@bus.max_backlog_size = 1 | |
@bus.publish "/foo", "a" | |
@bus.publish "/foo", "b" | |
@bus.publish "/bar", "a" | |
@bus.publish "/bar", "b" | |
@bus.global_backlog.to_a.should == [ | |
MessageBus::Message.new(2, 2, "/foo", "b"), | |
MessageBus::Message.new(4, 2, "/bar", "b") | |
] | |
end | |
it "should have the correct number of messages for multi threaded access" do | |
threads = [] | |
4.times do | |
threads << Thread.new do | |
bus = new_test_bus | |
25.times { | |
bus.publish "/foo", "." | |
} | |
end | |
end | |
threads.each{|t| t.join} | |
@bus.backlog("/foo").length == 100 | |
end | |
it "should be able to subscribe globally with recovery" do | |
@bus.publish("/foo", "1") | |
@bus.publish("/bar", "2") | |
got = [] | |
t = Thread.new do | |
new_test_bus.global_subscribe(0) do |msg| | |
got << msg | |
end | |
end | |
@bus.publish("/bar", "3") | |
wait_for(100) do | |
got.length == 3 | |
end | |
t.kill | |
got.length.should == 3 | |
got.map{|m| m.data}.should == ["1","2","3"] | |
end | |
it "should be able to encode and decode messages properly" do | |
m = MessageBus::Message.new 1,2,'||','||' | |
MessageBus::Message.decode(m.encode).should == m | |
end | |
it "should handle subscribe on single channel, with recovery" do | |
@bus.publish("/foo", "1") | |
@bus.publish("/bar", "2") | |
got = [] | |
t = Thread.new do | |
new_test_bus.subscribe("/foo",0) do |msg| | |
got << msg | |
end | |
end | |
@bus.publish("/foo", "3") | |
wait_for(100) do | |
got.length == 2 | |
end | |
t.kill | |
got.map{|m| m.data}.should == ["1","3"] | |
end | |
it "should not get backlog if subscribe is called without params" do | |
@bus.publish("/foo", "1") | |
got = [] | |
t = Thread.new do | |
new_test_bus.subscribe("/foo") do |msg| | |
got << msg | |
end | |
end | |
# sleep 50ms to allow the bus to correctly subscribe, | |
# I thought about adding a subscribed callback, but outside of testing it matters less | |
sleep 0.05 | |
@bus.publish("/foo", "2") | |
wait_for(100) do | |
got.length == 1 | |
end | |
t.kill | |
got.map{|m| m.data}.should == ["2"] | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment