Skip to content

Instantly share code, notes, and snippets.

Created Feb 25, 2021
What would you like to do?

I would recommend Dask for this sort of work as it is easy to produce a task graph and minimise computation overhead

import dask

def A(argument):
    print('Running A with', argument)
    return argument

def B(argument_one, argument_two):
    print('Running B with', argument_one, argument_two)
    return argument_one*argument_two

# I produce a mapping between A's input argument and the dask Delayed object
# that will compute its result.
a_results = {
    a_argument: dask.delayed(A)(a_argument)
    for a_argument in range(1, 4)
# Build an interable of the outputs you desire to compute all at once. This ensures
# that no result is computed more than once.
b_results = [
    dask.delayed(B)(a_results[i-1], a_results[i])
    for i in range(2, 4)

print('Results:', dask.compute(*b_results))


Running A with 1
Running A with 2
Running A with 3
Running B with 1 2
Running B with 2 3
Results: (2, 6)

Result one is A(1)=1 * A(2)=2 = 2 Result two is A(2)=2 * A(3)=3 = 6

From the printing, you can see that A was only executed once per input.

Note: as Dask parallelises the tasks, your stdout may be scrambled or in a different order to mine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment