Skip to content

Instantly share code, notes, and snippets.

@sb8244
Created July 30, 2019 19:59
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sb8244/85482cddb8092c24d45ea68550d2cc90 to your computer and use it in GitHub Desktop.
Save sb8244/85482cddb8092c24d45ea68550d2cc90 to your computer and use it in GitHub Desktop.
ClusterLoadBalancer for balancing anything (WebSocket) across a cluster
# We use this ClusterLoadBalancer to prevent hot nodes in our load balanced setup.
# The WebSockets are provided by Phoenix through the PushEx application https://github.com/pushex-project/pushex/
# The load balancer's `Worker` module is where the bulk of the cluster orchestration happens, using pg2 for cross node communication
defmodule ClusterLoadBalancer.Behavior do
@moduledoc """
Behavior for implementing a ClusterLoadBalancer compatible tool.
"""
@callback count() :: number
@callback kill_processes(number) :: {:ok, number}
end
defmodule PushExClusterLoadBalancer do
require Logger
@behaviour ClusterLoadBalancer.Behavior
def count() do
PushEx.connected_socket_count()
end
def kill_processes(number) do
killed_count =
PushEx.connected_transport_pids()
|> Enum.shuffle()
|> Enum.take(number)
|> Enum.map(&send(&1, %Phoenix.Socket.Broadcast{event: "disconnect"}))
|> length()
{:ok, killed_count}
end
end
defmodule ClusterLoadBalancer.Config do
@moduledoc """
Configuration for a load balancer instance
"""
@typedoc """
* impl: ClusterLoadBalancer.Implementation.Behavior implementation module used to power the load balancer worker
* allowed_average_deviation_percent: How far the highest count node can be above the average before it is load balanced
* shed_percentage: The amount of percentage between the average and the highest count that will be shed
* max_shed_count: The most that can be shed at one time. An amount above this will be reduced to it
* min_shed_count: The minimum that can be shed at one time. An amount below this will *not* shed
* round_duration_seconds: How long each round will last
"""
@type t :: %__MODULE__{
impl: any(),
allowed_average_deviation_percent: non_neg_integer(),
shed_percentage: non_neg_integer(),
max_shed_count: non_neg_integer(),
min_shed_count: non_neg_integer(),
round_duration_seconds: non_neg_integer()
}
@enforce_keys [
:impl,
:allowed_average_deviation_percent,
:shed_percentage,
:max_shed_count,
:min_shed_count,
:round_duration_seconds
]
defstruct @enforce_keys
end
defmodule ClusterLoadBalancer.Calculator do
@moduledoc false
require Logger
alias ClusterLoadBalancer.{Collection, Config}
@doc """
Provides the number of processes that should be corrected by the load balancer. This value will be 0
if the current node doesn't meet the criteria for shedding or if it would shed an amount less than
the configured minimum.
"""
def amount_to_correct_by(collection = %Collection{topic: topic}, config = %Config{}) do
if should_shed?(collection, config) do
shed_amount(collection, config)
else
Logger.debug("#{topic} should_shed?=false")
0
end
end
@doc false
def should_shed?(collection = %Collection{}, config = %Config{}) do
highest_count? = Collection.self_has_highest_count?(collection)
self_count = Collection.self_count(collection)
average_count = Collection.average_count(collection)
max_allowed_count = average_count + config.allowed_average_deviation_percent / 100 * average_count
outside_allowed_deviation? = self_count > max_allowed_count
Logger.debug("#{collection.topic} highest_count?=#{inspect(highest_count?)} avg=#{average_count} self=#{self_count} max_allowed_count=#{max_allowed_count}")
highest_count? && outside_allowed_deviation?
end
@doc false
def shed_amount(collection = %Collection{}, config = %Config{}) do
self_count = Collection.self_count(collection)
average_count = Collection.average_count(collection)
calc_shed_amount = trunc((self_count - average_count) * (config.shed_percentage / 100))
shed_amount = min(config.max_shed_count, calc_shed_amount)
will_shed? = shed_amount >= config.min_shed_count
Logger.debug(
"#{collection.topic} highest=true shed_amount=#{shed_amount} [calc,min,max]=[#{calc_shed_amount}, #{config.min_shed_count}, #{config.max_shed_count}] will_shed=#{
will_shed?
}"
)
if will_shed? do
shed_amount
else
0
end
end
end
defmodule ClusterLoadBalancer.Collection do
@moduledoc false
defmodule Result do
@moduledoc false
@enforce_keys [:tick, :count, :rand]
defstruct @enforce_keys
end
@enforce_keys [:expected_results_count, :self_result, :tick, :topic]
defstruct @enforce_keys ++ [collected: []]
def init(topic, tick, node_count, self_result = %Result{}) do
%__MODULE__{expected_results_count: node_count, self_result: self_result, tick: tick, topic: topic}
end
def init_result(tick, count, rand) do
%Result{tick: tick, count: count, rand: rand}
end
def add_result(state = %__MODULE__{collected: collected, tick: state_tick}, result = %Result{tick: tick}) when state_tick == tick do
new_collected = [result | collected]
%{state | collected: new_collected}
end
def add_result(state, _), do: state
def finalized?(%{expected_results_count: expected_count, collected: collected}) do
length(collected) == expected_count
end
def participant_count(%{expected_results_count: count}), do: count + 1
def self_has_highest_count?(%{self_result: self, collected: collected}) do
Enum.all?(collected, fn other ->
self.count > other.count || (self.count == other.count && self.rand > other.rand)
end)
end
def self_count(%{self_result: self}) do
self.count
end
def average_count(%{self_result: self, collected: collected}) do
collected
|> Enum.map(& &1.count)
|> Enum.sum()
|> Kernel.+(self.count)
|> Kernel./(length(collected) + 1)
end
end
defmodule ClusterLoadBalancer.Worker do
@moduledoc """
Watcher process that monitors for discrepancies in resource count across the cluster
and kills processes on the current node if it's the maximum node and falls outside of
an allowed deviation.
All of the values are configurable to either allow the cluster to be more loosely
load balanced or to be more tightly controlled.
An assumption is made that the resources will recreate themselves once killed (or your
implementation has to handle that). Phoenix.Channels, the primary use case for this,
usually involve the client reconnecting when disconnected. This means that the processes
will reform but in a load balanced way.
TODO:
* pg2 is used to broadcast a request for process count collection across the cluster.
Change this to Phoenix.PubSub as it would then work for any standard Phoenix.Channel installation.
This would add challenges because there would be no equivalent "node count" feature that is
currently used.
"""
use GenServer
require Logger
alias ClusterLoadBalancer.{Calculator, Collection, Config}
def start_link(opts) when is_list(opts) do
GenServer.start_link(__MODULE__, opts)
end
def init(opts) when is_list(opts) do
namespace = Keyword.fetch!(opts, :namespace)
topic = subscribe_to_pg2(namespace)
config = %Config{
impl: Keyword.fetch!(opts, :implementation),
allowed_average_deviation_percent: Keyword.get(opts, :allowed_average_deviation_percent, 25),
shed_percentage: Keyword.get(opts, :shed_percentage, 50),
max_shed_count: Keyword.get(opts, :max_shed_count, 100),
min_shed_count: Keyword.get(opts, :min_shed_count, 10),
round_duration_seconds: Keyword.get(opts, :round_duration_seconds, 10)
}
schedule_tick(topic, config, nil)
{:ok,
%{
rand: :rand.uniform(),
topic: topic,
tick: 0,
tick_state: nil,
config: config
}}
end
@doc false
def handle_info(:tick, state = %{config: config, rand: rand, tick: prev_tick, topic: topic, tick_state: prev_state}) do
schedule_tick(topic, config, prev_state)
tick = prev_tick + 1
remote_node_count = collect_counts_in_cluster(tick, topic)
tick_state = Collection.init(topic, tick, remote_node_count, new_result(tick, config.impl, rand))
{:noreply, %{state | tick: tick, tick_state: tick_state}}
end
@doc false
def handle_cast({:collect_request, tick, from_pid}, state = %{config: config, rand: rand}) do
GenServer.cast(from_pid, {:collect_result, new_result(tick, config.impl, rand)})
{:noreply, state}
end
@doc false
def handle_cast({:collect_result, result = %Collection.Result{tick: tick}}, state = %{config: config, tick: self_tick, tick_state: tick_state, topic: topic})
when tick == self_tick do
new_state = Collection.add_result(tick_state, result)
new_state =
case Collection.finalized?(new_state) do
true ->
count = Collection.participant_count(new_state)
amount_to_correct_by = Calculator.amount_to_correct_by(new_state, config)
correct_deviation(config, amount_to_correct_by)
Logger.debug("#{topic} round finalized with #{count} participants, amount_to_correct_by=#{amount_to_correct_by}")
nil
_ ->
new_state
end
{:noreply, %{state | tick_state: new_state}}
end
@doc false
def handle_cast({:collect_result, _}, state = %{topic: topic}) do
Logger.error("#{topic} collect_result delivered too slow")
{:noreply, state}
end
# private
defp correct_deviation(%{impl: impl}, kill_count) when kill_count > 0 do
{:ok, killed_count} = impl.kill_processes(kill_count)
Logger.error("#{impl} kill_processes requested_count=#{kill_count} killed_count=#{killed_count}")
:ok
end
defp correct_deviation(_, _), do: :ok
defp collect_counts_in_cluster(tick, topic) do
on_remote_nodes(topic, fn pid ->
GenServer.cast(pid, {:collect_request, tick, self()})
end)
end
defp on_remote_nodes(topic, func) do
topic
|> :pg2.get_members()
|> Kernel.--(:pg2.get_local_members(topic))
|> Enum.map(func)
|> length()
end
defp new_result(tick, impl_mod, tie_breaker) do
Collection.init_result(tick, impl_mod.count(), tie_breaker)
end
defp subscribe_to_pg2(namespace) do
topic = String.to_atom("#{__MODULE__}.#{namespace}")
:ok = :pg2.create(topic)
:ok = :pg2.join(topic, self())
topic
end
defp schedule_tick(topic, %{round_duration_seconds: timeout}, prev_state) do
if prev_state != nil && prev_state.expected_results_count > 0, do: Logger.error("#{topic} round occurred without being finalized")
Process.send_after(self(), :tick, trunc(timeout * 1000))
end
end
# Started with the following application.ex entry:
{
ClusterLoadBalancer.Worker,
[
implementation: PushExClusterLoadBalancer,
namespace: :pushex_websocket,
allowed_average_deviation_percent: 25,
shed_percentage: 50,
max_shed_count: 100,
min_shed_count: 20,
round_duration_seconds: 15
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment