Created
July 30, 2019 19:59
-
-
Save sb8244/85482cddb8092c24d45ea68550d2cc90 to your computer and use it in GitHub Desktop.
ClusterLoadBalancer for balancing anything (WebSocket) across a cluster
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# We use this ClusterLoadBalancer to prevent hot nodes in our load balanced setup. | |
# The WebSockets are provided by Phoenix through the PushEx application https://github.com/pushex-project/pushex/ | |
# The load balancer's `Worker` module is where the bulk of the cluster orchestration happens, using pg2 for cross node communication | |
defmodule ClusterLoadBalancer.Behavior do | |
@moduledoc """ | |
Behavior for implementing a ClusterLoadBalancer compatible tool. | |
""" | |
@callback count() :: number | |
@callback kill_processes(number) :: {:ok, number} | |
end | |
defmodule PushExClusterLoadBalancer do | |
require Logger | |
@behaviour ClusterLoadBalancer.Behavior | |
def count() do | |
PushEx.connected_socket_count() | |
end | |
def kill_processes(number) do | |
killed_count = | |
PushEx.connected_transport_pids() | |
|> Enum.shuffle() | |
|> Enum.take(number) | |
|> Enum.map(&send(&1, %Phoenix.Socket.Broadcast{event: "disconnect"})) | |
|> length() | |
{:ok, killed_count} | |
end | |
end | |
defmodule ClusterLoadBalancer.Config do | |
@moduledoc """ | |
Configuration for a load balancer instance | |
""" | |
@typedoc """ | |
* impl: ClusterLoadBalancer.Implementation.Behavior implementation module used to power the load balancer worker | |
* allowed_average_deviation_percent: How far the highest count node can be above the average before it is load balanced | |
* shed_percentage: The amount of percentage between the average and the highest count that will be shed | |
* max_shed_count: The most that can be shed at one time. An amount above this will be reduced to it | |
* min_shed_count: The minimum that can be shed at one time. An amount below this will *not* shed | |
* round_duration_seconds: How long each round will last | |
""" | |
@type t :: %__MODULE__{ | |
impl: any(), | |
allowed_average_deviation_percent: non_neg_integer(), | |
shed_percentage: non_neg_integer(), | |
max_shed_count: non_neg_integer(), | |
min_shed_count: non_neg_integer(), | |
round_duration_seconds: non_neg_integer() | |
} | |
@enforce_keys [ | |
:impl, | |
:allowed_average_deviation_percent, | |
:shed_percentage, | |
:max_shed_count, | |
:min_shed_count, | |
:round_duration_seconds | |
] | |
defstruct @enforce_keys | |
end | |
defmodule ClusterLoadBalancer.Calculator do | |
@moduledoc false | |
require Logger | |
alias ClusterLoadBalancer.{Collection, Config} | |
@doc """ | |
Provides the number of processes that should be corrected by the load balancer. This value will be 0 | |
if the current node doesn't meet the criteria for shedding or if it would shed an amount less than | |
the configured minimum. | |
""" | |
def amount_to_correct_by(collection = %Collection{topic: topic}, config = %Config{}) do | |
if should_shed?(collection, config) do | |
shed_amount(collection, config) | |
else | |
Logger.debug("#{topic} should_shed?=false") | |
0 | |
end | |
end | |
@doc false | |
def should_shed?(collection = %Collection{}, config = %Config{}) do | |
highest_count? = Collection.self_has_highest_count?(collection) | |
self_count = Collection.self_count(collection) | |
average_count = Collection.average_count(collection) | |
max_allowed_count = average_count + config.allowed_average_deviation_percent / 100 * average_count | |
outside_allowed_deviation? = self_count > max_allowed_count | |
Logger.debug("#{collection.topic} highest_count?=#{inspect(highest_count?)} avg=#{average_count} self=#{self_count} max_allowed_count=#{max_allowed_count}") | |
highest_count? && outside_allowed_deviation? | |
end | |
@doc false | |
def shed_amount(collection = %Collection{}, config = %Config{}) do | |
self_count = Collection.self_count(collection) | |
average_count = Collection.average_count(collection) | |
calc_shed_amount = trunc((self_count - average_count) * (config.shed_percentage / 100)) | |
shed_amount = min(config.max_shed_count, calc_shed_amount) | |
will_shed? = shed_amount >= config.min_shed_count | |
Logger.debug( | |
"#{collection.topic} highest=true shed_amount=#{shed_amount} [calc,min,max]=[#{calc_shed_amount}, #{config.min_shed_count}, #{config.max_shed_count}] will_shed=#{ | |
will_shed? | |
}" | |
) | |
if will_shed? do | |
shed_amount | |
else | |
0 | |
end | |
end | |
end | |
defmodule ClusterLoadBalancer.Collection do | |
@moduledoc false | |
defmodule Result do | |
@moduledoc false | |
@enforce_keys [:tick, :count, :rand] | |
defstruct @enforce_keys | |
end | |
@enforce_keys [:expected_results_count, :self_result, :tick, :topic] | |
defstruct @enforce_keys ++ [collected: []] | |
def init(topic, tick, node_count, self_result = %Result{}) do | |
%__MODULE__{expected_results_count: node_count, self_result: self_result, tick: tick, topic: topic} | |
end | |
def init_result(tick, count, rand) do | |
%Result{tick: tick, count: count, rand: rand} | |
end | |
def add_result(state = %__MODULE__{collected: collected, tick: state_tick}, result = %Result{tick: tick}) when state_tick == tick do | |
new_collected = [result | collected] | |
%{state | collected: new_collected} | |
end | |
def add_result(state, _), do: state | |
def finalized?(%{expected_results_count: expected_count, collected: collected}) do | |
length(collected) == expected_count | |
end | |
def participant_count(%{expected_results_count: count}), do: count + 1 | |
def self_has_highest_count?(%{self_result: self, collected: collected}) do | |
Enum.all?(collected, fn other -> | |
self.count > other.count || (self.count == other.count && self.rand > other.rand) | |
end) | |
end | |
def self_count(%{self_result: self}) do | |
self.count | |
end | |
def average_count(%{self_result: self, collected: collected}) do | |
collected | |
|> Enum.map(& &1.count) | |
|> Enum.sum() | |
|> Kernel.+(self.count) | |
|> Kernel./(length(collected) + 1) | |
end | |
end | |
defmodule ClusterLoadBalancer.Worker do | |
@moduledoc """ | |
Watcher process that monitors for discrepancies in resource count across the cluster | |
and kills processes on the current node if it's the maximum node and falls outside of | |
an allowed deviation. | |
All of the values are configurable to either allow the cluster to be more loosely | |
load balanced or to be more tightly controlled. | |
An assumption is made that the resources will recreate themselves once killed (or your | |
implementation has to handle that). Phoenix.Channels, the primary use case for this, | |
usually involve the client reconnecting when disconnected. This means that the processes | |
will reform but in a load balanced way. | |
TODO: | |
* pg2 is used to broadcast a request for process count collection across the cluster. | |
Change this to Phoenix.PubSub as it would then work for any standard Phoenix.Channel installation. | |
This would add challenges because there would be no equivalent "node count" feature that is | |
currently used. | |
""" | |
use GenServer | |
require Logger | |
alias ClusterLoadBalancer.{Calculator, Collection, Config} | |
def start_link(opts) when is_list(opts) do | |
GenServer.start_link(__MODULE__, opts) | |
end | |
def init(opts) when is_list(opts) do | |
namespace = Keyword.fetch!(opts, :namespace) | |
topic = subscribe_to_pg2(namespace) | |
config = %Config{ | |
impl: Keyword.fetch!(opts, :implementation), | |
allowed_average_deviation_percent: Keyword.get(opts, :allowed_average_deviation_percent, 25), | |
shed_percentage: Keyword.get(opts, :shed_percentage, 50), | |
max_shed_count: Keyword.get(opts, :max_shed_count, 100), | |
min_shed_count: Keyword.get(opts, :min_shed_count, 10), | |
round_duration_seconds: Keyword.get(opts, :round_duration_seconds, 10) | |
} | |
schedule_tick(topic, config, nil) | |
{:ok, | |
%{ | |
rand: :rand.uniform(), | |
topic: topic, | |
tick: 0, | |
tick_state: nil, | |
config: config | |
}} | |
end | |
@doc false | |
def handle_info(:tick, state = %{config: config, rand: rand, tick: prev_tick, topic: topic, tick_state: prev_state}) do | |
schedule_tick(topic, config, prev_state) | |
tick = prev_tick + 1 | |
remote_node_count = collect_counts_in_cluster(tick, topic) | |
tick_state = Collection.init(topic, tick, remote_node_count, new_result(tick, config.impl, rand)) | |
{:noreply, %{state | tick: tick, tick_state: tick_state}} | |
end | |
@doc false | |
def handle_cast({:collect_request, tick, from_pid}, state = %{config: config, rand: rand}) do | |
GenServer.cast(from_pid, {:collect_result, new_result(tick, config.impl, rand)}) | |
{:noreply, state} | |
end | |
@doc false | |
def handle_cast({:collect_result, result = %Collection.Result{tick: tick}}, state = %{config: config, tick: self_tick, tick_state: tick_state, topic: topic}) | |
when tick == self_tick do | |
new_state = Collection.add_result(tick_state, result) | |
new_state = | |
case Collection.finalized?(new_state) do | |
true -> | |
count = Collection.participant_count(new_state) | |
amount_to_correct_by = Calculator.amount_to_correct_by(new_state, config) | |
correct_deviation(config, amount_to_correct_by) | |
Logger.debug("#{topic} round finalized with #{count} participants, amount_to_correct_by=#{amount_to_correct_by}") | |
nil | |
_ -> | |
new_state | |
end | |
{:noreply, %{state | tick_state: new_state}} | |
end | |
@doc false | |
def handle_cast({:collect_result, _}, state = %{topic: topic}) do | |
Logger.error("#{topic} collect_result delivered too slow") | |
{:noreply, state} | |
end | |
# private | |
defp correct_deviation(%{impl: impl}, kill_count) when kill_count > 0 do | |
{:ok, killed_count} = impl.kill_processes(kill_count) | |
Logger.error("#{impl} kill_processes requested_count=#{kill_count} killed_count=#{killed_count}") | |
:ok | |
end | |
defp correct_deviation(_, _), do: :ok | |
defp collect_counts_in_cluster(tick, topic) do | |
on_remote_nodes(topic, fn pid -> | |
GenServer.cast(pid, {:collect_request, tick, self()}) | |
end) | |
end | |
defp on_remote_nodes(topic, func) do | |
topic | |
|> :pg2.get_members() | |
|> Kernel.--(:pg2.get_local_members(topic)) | |
|> Enum.map(func) | |
|> length() | |
end | |
defp new_result(tick, impl_mod, tie_breaker) do | |
Collection.init_result(tick, impl_mod.count(), tie_breaker) | |
end | |
defp subscribe_to_pg2(namespace) do | |
topic = String.to_atom("#{__MODULE__}.#{namespace}") | |
:ok = :pg2.create(topic) | |
:ok = :pg2.join(topic, self()) | |
topic | |
end | |
defp schedule_tick(topic, %{round_duration_seconds: timeout}, prev_state) do | |
if prev_state != nil && prev_state.expected_results_count > 0, do: Logger.error("#{topic} round occurred without being finalized") | |
Process.send_after(self(), :tick, trunc(timeout * 1000)) | |
end | |
end | |
# Started with the following application.ex entry: | |
{ | |
ClusterLoadBalancer.Worker, | |
[ | |
implementation: PushExClusterLoadBalancer, | |
namespace: :pushex_websocket, | |
allowed_average_deviation_percent: 25, | |
shed_percentage: 50, | |
max_shed_count: 100, | |
min_shed_count: 20, | |
round_duration_seconds: 15 | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment