Skip to content

Instantly share code, notes, and snippets.

@Seanny123
Created September 2, 2022 21:45
Show Gist options
  • Save Seanny123/7a3066a84698ee097c31a53f2525bd55 to your computer and use it in GitHub Desktop.
Save Seanny123/7a3066a84698ee097c31a53f2525bd55 to your computer and use it in GitHub Desktop.
An example of nested fanout using Ray Workflows.
"""
An example of nested fanout using Ray Workflows.
The full workflow creates several batches given a single input, then each of those batches is fanned-out and evaluated.
"""
import time
import ray
from ray import workflow
@ray.remote
def get_range(start: int) -> list[int]:
"""
Generate a batch.
"""
time.sleep(2)
return list(range(start, start + 5))
@ray.remote
def get_square(data: float) -> float:
"""
A simple operation to test fanout.
"""
time.sleep(2)
return data**2
@ray.remote
def range_and_square(data):
"""
Generate batches, then fanout from each batch
Starting node of the whole operation as invoked in the `workflow.run` call below.
"""
# Batch operation
my_ranges = [get_range.bind(start) for start in data]
# Fanout to simple operations
return workflow.continuation(expand_ranges.bind(my_ranges))
@ray.remote
def expand_ranges(my_ranges):
"""
Ray Workflows do not allow multiple `.bind` operations to occur simultaenously,
so we need to chain functions in order to allow the `workflow.continuation`
call in `range_and_square` to construct the full DAG before operation.
"""
expanded = []
for my_range in my_ranges:
# proof ray.get is non-blocking, which is why this function does not create a bottleneck
print("get range!")
expanded.append(ray.get(my_range))
return workflow.continuation(finalize_range.bind(expanded))
@ray.remote
def finalize_range(my_ranges):
squares = []
for my_range in my_ranges:
for item in my_range:
squares.append(get_square.bind(item))
return workflow.continuation(finalize_results.bind(squares))
@ray.remote
def finalize_results(squares):
return [ray.get(s) for s in squares]
if __name__ == "__main__":
values = list(range(10))
result = workflow.run(range_and_square.bind(values))
# Sleep to allow stderr messages from Ray to flow past
# Before printing the output
time.sleep(1)
print(f"{result=}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment