Skip to content

Instantly share code, notes, and snippets.

@jacobtomlinson
Created May 11, 2023 13:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jacobtomlinson/071d652692a146e7d5ee835dab9085de to your computer and use it in GitHub Desktop.
Save jacobtomlinson/071d652692a146e7d5ee835dab9085de to your computer and use it in GitHub Desktop.
Apache Beam Dask Limitation MRE
import warnings
import time
from contextlib import contextmanager
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.runners.dask.dask_runner import DaskRunner
from dask.distributed import Client
from distributed.versions import VersionMismatchWarning
from dask_kubernetes.operator import KubeCluster
# Reduce output noise
for warn_type in [VersionMismatchWarning, FutureWarning]:
warnings.filterwarnings("ignore", category=warn_type)
@contextmanager
def daskcluster():
"""Get a Dask cluster however you prefer."""
n_workers = 256
with KubeCluster(
name="beam-test",
n_workers=n_workers,
env={"EXTRA_PIP_PACKAGES": "apache-beam"},
resources={
"requests": {"cpu": "500m", "memory": "1Gi"},
"limits": {"cpu": "1000m", "memory": "1.85Gi"},
},
shutdown_on_close=False, # Leave running for reuse next time
) as cluster:
cluster.scale(
n_workers
) # Ensure the right number of workers if reusing a cluster
print(f"Dashboard at: {cluster.dashboard_link}")
with Client(cluster) as client:
print(f"Waiting for all {n_workers} workers")
client.wait_for_workers(n_workers=n_workers)
yield client
class NoopDoFn(beam.DoFn):
def process(self, item):
time.sleep(10)
return [item]
def main() -> None:
# If this is 199 I get one task per file per stage, if this is 200 I get max 100 tasks per stage
n_items = 200
with daskcluster() as client:
# Start a beam pipeline with a dask backend, and its options.
print("Running Pipeline")
pipeline = beam.Pipeline(
runner=DaskRunner(),
options=PipelineOptions(
["--dask_client_address", client.cluster.scheduler_address]
),
)
(
pipeline
| "Create collection" >> beam.Create(range(n_items))
| "Noop 1" >> beam.ParDo(NoopDoFn())
| "Noop 2" >> beam.ParDo(NoopDoFn())
| "Noop 3" >> beam.ParDo(NoopDoFn())
)
result = pipeline.run()
result.wait_until_finish()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment