Code for the Dask cluster on demand blog post
import logging | |
import time | |
from dataclasses import dataclass | |
from typing import Any, Optional, Tuple | |
import requests | |
import yaml | |
from distributed import Client | |
from distributed.security import Security | |
from googleapiclient import discovery | |
LOGGER = logging.getLogger(__name__) | |
# Define some constants here, directly in the class, | |
# or import from another module | |
GCP_PROJECT_ID: str = "default project id" | |
GCP_PROJECT_NUMBER: str = "default project number" | |
GCP_CLUSTER_ZONE: str = "default cluster zone" | |
GCP_INSTANCE_NAME: str = "default instance name" | |
GCP_DOCKER_IMAGE: str = "default docker image" | |
DASK_CERT_FILEPATH: str = "path to default dask certificate you want to use" | |
DASK_KEY_FILEPATH: str = "path to default dask key you want to use" | |
MACHINE_TYPE: str = "e2-standard-16", | |
# Defaults for single node Dask cluster docker image | |
# see https://gist.github.com/ian-whitestone/d3b876e77743923b112d7d004d86480c | |
# or https://ianwhitestone.work/single-node-dask-cluster-on-gcp/ for more details | |
NUM_WORKERS: int = 16, | |
THREADS_PER_WORKER: int = 1, | |
MEMORY_PER_WORKER_GB: float = 4, | |
@dataclass | |
class Cluster: | |
gcp_project_id: str = GCP_PROJECT_ID | |
gcp_project_number: str = GCP_PROJECT_NUMBER | |
gcp_cluster_zone: str = GCP_CLUSTER_ZONE | |
gcp_instance_name: str = GCP_INSTANCE_NAME | |
gcp_docker_image: str = GCP_DOCKER_IMAGE | |
# Only needed if you're using a Dask cluster with SSL security | |
# See https://ianwhitestone.work/dask-cluster-security/ for more details | |
dask_cert_filepath: str = DASK_CERT_FILEPATH | |
dask_key_filepath: str = DASK_KEY_FILEPATH | |
machine_type: str = MACHINE_TYPE | |
num_workers: int = NUM_WORKERS | |
threads_per_worker: int = THREADS_PER_WORKER | |
memory_per_worker_gb: int = MEMORY_PER_WORKER_GB | |
def __post_init__(self): | |
self._validate_machine_type() | |
self.compute = discovery.build("compute", "v1", cache_discovery=False) | |
self.disk_image_name, self.disk_image_link = self._get_latest_image() | |
self.create() | |
self.cluster_host_ip_address = self._get_cluster_ip_address() | |
self._wait_until_cluster_is_ready() | |
self.client = self.create_client() | |
def _validate_machine_type(self): | |
gcp_machine_types = { | |
# shared core | |
"e2-micro": {"vCPU": 2, "memory_gb": 1}, | |
"e2-small": {"vCPU": 2, "memory_gb": 2}, | |
"e2-medium": {"vCPU": 2, "memory_gb": 4}, | |
# standard | |
"e2-standard-2": {"vCPU": 2, "memory_gb": 8}, | |
"e2-standard-4": {"vCPU": 4, "memory_gb": 16}, | |
"e2-standard-8": {"vCPU": 8, "memory_gb": 32}, | |
"e2-standard-16": {"vCPU": 16, "memory_gb": 64}, | |
"e2-standard-32": {"vCPU": 32, "memory_gb": 128}, | |
# high memory | |
"e2-highmem-2": {"vCPU": 2, "memory_gb": 16}, | |
"e2-highmem-4": {"vCPU": 4, "memory_gb": 32}, | |
"e2-highmem-8": {"vCPU": 8, "memory_gb": 64}, | |
"e2-highmem-16": {"vCPU": 16, "memory_gb": 128}, | |
# high compute | |
"e2-highcpu-2": {"vCPU": 2, "memory_gb": 2}, | |
"e2-highcpu-4": {"vCPU": 4, "memory_gb": 4}, | |
"e2-highcpu-8": {"vCPU": 8, "memory_gb": 8}, | |
"e2-highcpu-16": {"vCPU": 16, "memory_gb": 16}, | |
"e2-highcpu-32": {"vCPU": 32, "memory_gb": 32}, | |
} | |
# Example custom machine spec: e2-custom-32-49152 | |
if "custom" in self.machine_type: | |
parts = self.machine_type.split("-") | |
if len(parts) != 4: # TODO: replace with regex validation | |
raise ValueError( | |
"Custom machine type must be formatted like 'e2-custom-32-49152'" | |
) | |
num_cpus = int(parts[2]) | |
memory = int(parts[3]) | |
if memory % 256: | |
raise ValueError("Memory must be a multiple of 256") | |
if num_cpus < 2 or (num_cpus % 2): | |
raise ValueError("# of CPUs must be greater than 2 and a multiple of 2") | |
return | |
if self.machine_type not in gcp_machine_types: | |
raise ValueError( | |
f"'{self.machine_type}' is not a valid machine type. " | |
f"Expecting one of {list(gcp_machine_types.keys())}" | |
) | |
num_cores_available = gcp_machine_types[self.machine_type]["vCPU"] | |
if self.num_workers > num_cores_available: | |
raise ValueError( | |
f"{self.machine_type} has {num_cores_available} cores available and " | |
f"you requested {self.num_workers}. Try specifying a machine_type with " | |
"more vCPUs or reduce num_workers." | |
) | |
def _get_latest_image(self): | |
""" | |
https://googleapis.github.io/google-api-python-client/docs/dyn/compute_v1.images.html#getFromFamily | |
Returns the latest image that is part of an image family and is not deprecated. | |
""" | |
image_response = ( | |
self.compute.images() | |
.getFromFamily(project="cos-cloud", family="cos-stable") | |
.execute() | |
) | |
return image_response["name"], image_response["selfLink"] | |
def _get_cluster_ip_address(self): | |
instances_list = ( | |
self.compute.instances() | |
.list( | |
project=self.gcp_project_id, | |
zone=self.gcp_cluster_zone, | |
filter=f"name = {self.gcp_instance_name}", | |
) | |
.execute() | |
) | |
if not instances_list.get("items"): | |
raise Exception("Instance not found") | |
if len(instances_list.get("items", [])) > 1: | |
raise Exception("More than 1 instance returned with search criteria") | |
return instances_list["items"][0]["networkInterfaces"][0]["accessConfigs"][0][ | |
"natIP" | |
] | |
def _wait_until_cluster_is_ready(self): | |
cluster_url = f"http://{self.cluster_host_ip_address}:8787/" | |
LOGGER.info(f"Waiting until cluster {cluster_url} is ready") | |
while True: | |
try: | |
r = requests.get(cluster_url) | |
if r.ok and "dask" in r.text.lower(): | |
LOGGER.info("Cluster is ready 🟢") | |
break | |
except ConnectionError: | |
time.sleep(30) | |
def _wait_for_operation(self, operation_name: str): | |
while True: | |
result = ( | |
self.compute.zoneOperations() | |
.get( | |
project=self.gcp_project_id, | |
zone=self.gcp_cluster_zone, | |
operation=operation_name, | |
) | |
.execute() | |
) | |
if result["status"] == "DONE": | |
if "error" in result: | |
raise Exception(result["error"]) | |
return | |
time.sleep(1) | |
@property | |
def gce_container_spec(self): | |
container_spec = { | |
"spec": { | |
"containers": [ | |
{ | |
"name": self.gcp_instance_name, | |
"image": self.gcp_docker_image, | |
"env": [ | |
{ | |
"name": "MEMORY_PER_WORKER", | |
"value": f"{self.memory_per_worker_gb}", | |
}, | |
{ | |
"name": "THREADS_PER_WORKER", | |
"value": f"{self.threads_per_worker}", | |
}, | |
{"name": "NUM_WORKERS", "value": f"{self.num_workers}"}, | |
], | |
"stdin": False, | |
"tty": False, | |
} | |
], | |
"restartPolicy": "Always", | |
} | |
} | |
return yaml.dump(container_spec) | |
@property | |
def machine_type_full_name(self): | |
return ( | |
f"projects/{self.gcp_project_id}/zones/" | |
f"{self.gcp_cluster_zone}/machineTypes/{self.machine_type}" | |
) | |
@property | |
def instance_config(self): | |
return { | |
"kind": "compute#instance", | |
"name": self.gcp_instance_name, | |
"zone": self.gcp_cluster_zone, | |
"machineType": self.machine_type_full_name, | |
"metadata": { | |
"kind": "compute#metadata", | |
"items": [ | |
{ | |
"key": "gce-container-declaration", | |
"value": self.gce_container_spec, | |
}, | |
{"key": "google-logging-enabled", "value": "true"}, | |
], | |
}, | |
"tags": {"items": ["http-server"]}, | |
"disks": [ | |
{ | |
"boot": True, | |
"autoDelete": True, | |
"initializeParams": {"sourceImage": self.disk_image_link}, | |
} | |
], | |
# Specify a network interface with NAT to access the public | |
# internet. | |
"networkInterfaces": [ | |
{ | |
"network": "global/networks/default", | |
"accessConfigs": [ | |
{"type": "ONE_TO_ONE_NAT", "name": "External NAT"} | |
], | |
} | |
], | |
"labels": {"container-vm": self.disk_image_name}, | |
"serviceAccounts": [ | |
{ | |
"email": f"{self.gcp_project_number}-compute@developer.gserviceaccount.com", # noqa | |
"scopes": [ | |
"https://www.googleapis.com/auth/devstorage.read_only", | |
"https://www.googleapis.com/auth/logging.write", | |
"https://www.googleapis.com/auth/monitoring.write", | |
"https://www.googleapis.com/auth/servicecontrol", | |
"https://www.googleapis.com/auth/service.management.readonly", | |
"https://www.googleapis.com/auth/trace.append", | |
], | |
} | |
], | |
} | |
def create(self): | |
LOGGER.info("Creating new compute instance") | |
operation = ( | |
self.compute.instances() | |
.insert( | |
project=self.gcp_project_id, | |
zone=self.gcp_cluster_zone, | |
body=self.instance_config, | |
) | |
.execute() | |
) | |
self._wait_for_operation(operation["name"]) | |
def create_client(self): | |
cluster_host_url = f"tls://{self.cluster_host_ip_address}:8786" | |
LOGGER.info(f"Connecting new client to {cluster_host_url}") | |
sec = Security( | |
tls_ca_file=self.dask_cert_filepath, | |
tls_client_cert=self.dask_cert_filepath, | |
tls_client_key=self.dask_key_filepath, | |
require_encryption=True, | |
) | |
return Client(cluster_host_url, security=sec) | |
def teardown(self): | |
LOGGER.info("Shutting down client & tearing down cluster") | |
self.client.close() | |
operation = ( | |
self.compute.instances() | |
.delete( | |
project=self.gcp_project_id, | |
zone=self.gcp_cluster_zone, | |
instance=self.gcp_instance_name, | |
) | |
.execute() | |
) | |
LOGGER.info("Waiting for teardown to finish...") | |
self._wait_for_operation(operation["name"]) | |
def inspect_requires_cluster_function(func) -> Tuple[int, Optional[int]]: | |
""" | |
Validate a function decorated with @requires_cluster | |
has client specified as an arg or kwarg. | |
Return the position of the client argument in func. | |
""" | |
func_sig = inspect.signature(func) | |
client_arg_position = None | |
teardown_cluster_arg_position = None | |
for param_pos, param in enumerate(func_sig.parameters.values()): | |
if param.name == "client": | |
client_arg_position = param_pos | |
if param.name == "teardown_cluster": | |
teardown_cluster_arg_position = param_pos | |
if client_arg_position is None: | |
raise ValueError( | |
"Functions using @requires_cluster must accept client as an arg or kwarg " | |
"For example: \n" | |
""" | |
@requires_cluster | |
def do_stuff(a, b, client=None): | |
# do stuff with a, b and client | |
""" | |
) | |
return client_arg_position, teardown_cluster_arg_position | |
def requires_cluster( | |
num_workers: int = NUM_WORKERS, | |
threads_per_worker: int = THREADS_PER_WORKER, | |
memory_per_worker_gb: float = MEMORY_PER_WORKER, | |
machine_type: str = MACHINE_TYPE, | |
gcp_instance_name: str = GCP_INSTANCE_NAME, | |
gcp_cluster_zone: str = GCP_CLUSTER_ZONE, | |
teardown_cluster=True, | |
): | |
""" | |
A decorator to automatically provide a function with a ready to use | |
dask cluster client | |
""" | |
def decorator(func): | |
( | |
client_arg_position, | |
teardown_cluster_arg_position, | |
) = inspect_requires_cluster_function(func) | |
def wrapper(*args, **kwargs): | |
cluster = None | |
client_provided_in_args = False | |
for arg in args: | |
if isinstance(arg, Client): | |
client_provided_in_args = True | |
break | |
# When client is provided, just run the function as is | |
if isinstance(kwargs.get("client"), Client) or client_provided_in_args: | |
return func(*args, **kwargs) | |
if kwargs.get("teardown_cluster") is not None: | |
if not isinstance(kwargs["teardown_cluster"], bool): | |
raise ValueError("Value of teardown_cluster must be a boolean") | |
_teardown_cluster = kwargs["teardown_cluster"] | |
elif ( | |
teardown_cluster_arg_position is not None | |
and len(args) > teardown_cluster_arg_position | |
): | |
if not isinstance(args[teardown_cluster_arg_position], bool): | |
raise ValueError("Value of teardown_cluster must be a boolean") | |
_teardown_cluster = args[teardown_cluster_arg_position] | |
else: | |
_teardown_cluster = teardown_cluster | |
try: | |
cluster = Cluster( | |
num_workers=num_workers, | |
threads_per_worker=threads_per_worker, | |
memory_per_worker_gb=memory_per_worker_gb, | |
machine_type=machine_type, | |
gcp_instance_name=gcp_instance_name, | |
gcp_cluster_zone=gcp_cluster_zone, | |
) | |
# update the args/kwargs with the newly created client | |
new_args = [] | |
for i, arg in enumerate(args): | |
if i == client_arg_position: | |
new_args.append(cluster.client) | |
else: | |
new_args.append(arg) | |
if len(args) <= client_arg_position: | |
# client was not passed in as an arg, so update the kwargs | |
kwargs["client"] = cluster.client | |
return func(*new_args, **kwargs) | |
finally: | |
if _teardown_cluster: | |
if cluster is not None: | |
cluster.teardown() | |
else: | |
LOGGER.info( | |
f"Leaving cluster running at {cluster.cluster_host_ip_address}" | |
) | |
return wrapper | |
return decorator |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment