Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Parallel Executor for Elixir
defmodule ParallelExecutor do
@moduledoc """
This module in conjunction with it's implementation of the Collectable
protocol, handles executing tasks over an enumerable/stream in parallel,
up to a provided parallelization factor `F`. It implements a backpressure mechanism
to ensure that there are never more than, but at least `F` tasks executing.
"""
alias __MODULE__
defstruct procs: 4, func: nil, tasks: []
def new(parallel_factor \\ 4, func) when is_function(func, 1) do
%ParallelExecutor{procs: parallel_factor, func: func}
end
@doc """
Loops over all provided tasks, and returns once there
is at least task which yielded a result. Returns a list
of completed task pids.
"""
def wait_any([]), do: []
def wait_any(tasks) when is_list(tasks) do
case wait_any_impl(tasks, []) do
[] -> wait_any(tasks)
x when is_list(x) -> x
end
end
defp wait_any_impl([t|rest], acc) do
case Task.yield(t, 10) do
nil -> wait_any_impl(rest, acc)
{:ok, res} -> wait_any_impl(rest, [{t, res}|acc])
end
end
defp wait_any_impl([], acc), do: acc
@doc """
Waits until all tasks provided have yielded results.
If a timeout of 0 is provided, wait_all will only return
when every task has completed or an error is thrown.
Any other value will use Task.await/2 with the provided timeout
on each task.
"""
def wait_all(tasks, timeout \\ 5_000) do
wait_all_impl(tasks, timeout, [], [])
end
def wait_all_impl([], _, _, acc), do: acc
def wait_all_impl(tasks, 0, remaining, acc) when is_list(tasks) do
case wait_all_impl(tasks, remaining, acc) do
{[], results} -> results
{remaining, acc} -> wait_all_impl(remaining, [], acc)
end
end
def wait_all_impl(tasks, timeout, _, _) when is_list(tasks) do
for task <- tasks, do: Task.await(task, timeout)
end
defp wait_all_impl([t|rest], remaining, acc) do
case Task.yield(t, 10) do
nil -> wait_all_impl(rest, [t|remaining], acc)
{:ok, res} -> wait_all_impl(rest, remaining, [res|acc])
end
end
defp wait_all_impl([], _, acc), do: acc
@doc """
Takes an input enumerable/stream, and maps it's elements
through ParallelExecutor, returning a stream of the results.
"""
def map(stream, func) when is_function(func, 1), do: map(stream, 4, func)
def map(stream, parallel_factor, func) when is_function(func, 1) do
Stream.transform(
stream,
fn -> ParallelExecutor.new(parallel_factor, func) end,
&map_impl/2,
fn acc -> acc end
)
end
defp map_impl(item, %ParallelExecutor{} = state) do
cond do
state.procs > length(state.tasks) ->
task = Task.async(fn -> state.func.(item) end)
{[], %{state | :tasks => [task | state.tasks]}}
state.procs <= length(state.tasks) ->
# Wait for one or more tasks to finish, this is our backpressure mechanism
finished = ParallelExecutor.wait_any(state.tasks)
results = Enum.map(finished, fn {_, res} -> res end)
incomplete = Enum.filter(state.tasks, fn t -> not Enum.any?(finished, fn {ft,_} -> ft == t end) end)
# Start new task for incoming element
task = Task.async(fn -> state.func.(item) end)
{results, %{state | :tasks => [task | incomplete]}}
end
end
end
defimpl Collectable, for: CatalogEtl.Extract.ParallelExecutor do
alias CatalogEtl.Extract.ParallelExecutor
def into(original) do
{original, fn
state, {:cont, x} ->
cond do
state.procs > length(state.tasks) ->
task = Task.async(fn -> state.func.(x) end)
%{state | :tasks => [task | state.tasks]}
state.procs <= length(state.tasks) ->
# Wait for one or more tasks to finish, this is our backpressure mechanism
finished = ParallelExecutor.wait_any(state.tasks)
incomplete = Enum.filter(state.tasks, fn t -> not Enum.any?(finished, fn {ft,_} -> ft == t end) end)
# Start new task for incoming element
task = Task.async(fn -> state.func.(x) end)
%{state | :tasks => [task | incomplete]}
end
state, :done -> ParallelExecutor.wait_all(state.tasks); :ok
state, :halt -> ParallelExecutor.wait_all(state.tasks); :ok
end}
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment