Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Code for the Dask cluster on demand blog post
import logging
import time
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import requests
import yaml
from distributed import Client
from distributed.security import Security
from googleapiclient import discovery
LOGGER = logging.getLogger(__name__)
# Define some constants here, directly in the class,
# or import from another module
GCP_PROJECT_ID: str = "default project id"
GCP_PROJECT_NUMBER: str = "default project number"
GCP_CLUSTER_ZONE: str = "default cluster zone"
GCP_INSTANCE_NAME: str = "default instance name"
GCP_DOCKER_IMAGE: str = "default docker image"
DASK_CERT_FILEPATH: str = "path to default dask certificate you want to use"
DASK_KEY_FILEPATH: str = "path to default dask key you want to use"
MACHINE_TYPE: str = "e2-standard-16",
# Defaults for single node Dask cluster docker image
# see https://gist.github.com/ian-whitestone/d3b876e77743923b112d7d004d86480c
# or https://ianwhitestone.work/single-node-dask-cluster-on-gcp/ for more details
NUM_WORKERS: int = 16,
THREADS_PER_WORKER: int = 1,
MEMORY_PER_WORKER_GB: float = 4,
@dataclass
class Cluster:
gcp_project_id: str = GCP_PROJECT_ID
gcp_project_number: str = GCP_PROJECT_NUMBER
gcp_cluster_zone: str = GCP_CLUSTER_ZONE
gcp_instance_name: str = GCP_INSTANCE_NAME
gcp_docker_image: str = GCP_DOCKER_IMAGE
# Only needed if you're using a Dask cluster with SSL security
# See https://ianwhitestone.work/dask-cluster-security/ for more details
dask_cert_filepath: str = DASK_CERT_FILEPATH
dask_key_filepath: str = DASK_KEY_FILEPATH
machine_type: str = MACHINE_TYPE
num_workers: int = NUM_WORKERS
threads_per_worker: int = THREADS_PER_WORKER
memory_per_worker_gb: int = MEMORY_PER_WORKER_GB
def __post_init__(self):
self._validate_machine_type()
self.compute = discovery.build("compute", "v1", cache_discovery=False)
self.disk_image_name, self.disk_image_link = self._get_latest_image()
self.create()
self.cluster_host_ip_address = self._get_cluster_ip_address()
self._wait_until_cluster_is_ready()
self.client = self.create_client()
def _validate_machine_type(self):
gcp_machine_types = {
# shared core
"e2-micro": {"vCPU": 2, "memory_gb": 1},
"e2-small": {"vCPU": 2, "memory_gb": 2},
"e2-medium": {"vCPU": 2, "memory_gb": 4},
# standard
"e2-standard-2": {"vCPU": 2, "memory_gb": 8},
"e2-standard-4": {"vCPU": 4, "memory_gb": 16},
"e2-standard-8": {"vCPU": 8, "memory_gb": 32},
"e2-standard-16": {"vCPU": 16, "memory_gb": 64},
"e2-standard-32": {"vCPU": 32, "memory_gb": 128},
# high memory
"e2-highmem-2": {"vCPU": 2, "memory_gb": 16},
"e2-highmem-4": {"vCPU": 4, "memory_gb": 32},
"e2-highmem-8": {"vCPU": 8, "memory_gb": 64},
"e2-highmem-16": {"vCPU": 16, "memory_gb": 128},
# high compute
"e2-highcpu-2": {"vCPU": 2, "memory_gb": 2},
"e2-highcpu-4": {"vCPU": 4, "memory_gb": 4},
"e2-highcpu-8": {"vCPU": 8, "memory_gb": 8},
"e2-highcpu-16": {"vCPU": 16, "memory_gb": 16},
"e2-highcpu-32": {"vCPU": 32, "memory_gb": 32},
}
# Example custom machine spec: e2-custom-32-49152
if "custom" in self.machine_type:
parts = self.machine_type.split("-")
if len(parts) != 4: # TODO: replace with regex validation
raise ValueError(
"Custom machine type must be formatted like 'e2-custom-32-49152'"
)
num_cpus = int(parts[2])
memory = int(parts[3])
if memory % 256:
raise ValueError("Memory must be a multiple of 256")
if num_cpus < 2 or (num_cpus % 2):
raise ValueError("# of CPUs must be greater than 2 and a multiple of 2")
return
if self.machine_type not in gcp_machine_types:
raise ValueError(
f"'{self.machine_type}' is not a valid machine type. "
f"Expecting one of {list(gcp_machine_types.keys())}"
)
num_cores_available = gcp_machine_types[self.machine_type]["vCPU"]
if self.num_workers > num_cores_available:
raise ValueError(
f"{self.machine_type} has {num_cores_available} cores available and "
f"you requested {self.num_workers}. Try specifying a machine_type with "
"more vCPUs or reduce num_workers."
)
def _get_latest_image(self):
"""
https://googleapis.github.io/google-api-python-client/docs/dyn/compute_v1.images.html#getFromFamily
Returns the latest image that is part of an image family and is not deprecated.
"""
image_response = (
self.compute.images()
.getFromFamily(project="cos-cloud", family="cos-stable")
.execute()
)
return image_response["name"], image_response["selfLink"]
def _get_cluster_ip_address(self):
instances_list = (
self.compute.instances()
.list(
project=self.gcp_project_id,
zone=self.gcp_cluster_zone,
filter=f"name = {self.gcp_instance_name}",
)
.execute()
)
if not instances_list.get("items"):
raise Exception("Instance not found")
if len(instances_list.get("items", [])) > 1:
raise Exception("More than 1 instance returned with search criteria")
return instances_list["items"][0]["networkInterfaces"][0]["accessConfigs"][0][
"natIP"
]
def _wait_until_cluster_is_ready(self):
cluster_url = f"http://{self.cluster_host_ip_address}:8787/"
LOGGER.info(f"Waiting until cluster {cluster_url} is ready")
while True:
try:
r = requests.get(cluster_url)
if r.ok and "dask" in r.text.lower():
LOGGER.info("Cluster is ready 🟢")
break
except ConnectionError:
time.sleep(30)
def _wait_for_operation(self, operation_name: str):
while True:
result = (
self.compute.zoneOperations()
.get(
project=self.gcp_project_id,
zone=self.gcp_cluster_zone,
operation=operation_name,
)
.execute()
)
if result["status"] == "DONE":
if "error" in result:
raise Exception(result["error"])
return
time.sleep(1)
@property
def gce_container_spec(self):
container_spec = {
"spec": {
"containers": [
{
"name": self.gcp_instance_name,
"image": self.gcp_docker_image,
"env": [
{
"name": "MEMORY_PER_WORKER",
"value": f"{self.memory_per_worker_gb}",
},
{
"name": "THREADS_PER_WORKER",
"value": f"{self.threads_per_worker}",
},
{"name": "NUM_WORKERS", "value": f"{self.num_workers}"},
],
"stdin": False,
"tty": False,
}
],
"restartPolicy": "Always",
}
}
return yaml.dump(container_spec)
@property
def machine_type_full_name(self):
return (
f"projects/{self.gcp_project_id}/zones/"
f"{self.gcp_cluster_zone}/machineTypes/{self.machine_type}"
)
@property
def instance_config(self):
return {
"kind": "compute#instance",
"name": self.gcp_instance_name,
"zone": self.gcp_cluster_zone,
"machineType": self.machine_type_full_name,
"metadata": {
"kind": "compute#metadata",
"items": [
{
"key": "gce-container-declaration",
"value": self.gce_container_spec,
},
{"key": "google-logging-enabled", "value": "true"},
],
},
"tags": {"items": ["http-server"]},
"disks": [
{
"boot": True,
"autoDelete": True,
"initializeParams": {"sourceImage": self.disk_image_link},
}
],
# Specify a network interface with NAT to access the public
# internet.
"networkInterfaces": [
{
"network": "global/networks/default",
"accessConfigs": [
{"type": "ONE_TO_ONE_NAT", "name": "External NAT"}
],
}
],
"labels": {"container-vm": self.disk_image_name},
"serviceAccounts": [
{
"email": f"{self.gcp_project_number}-compute@developer.gserviceaccount.com", # noqa
"scopes": [
"https://www.googleapis.com/auth/devstorage.read_only",
"https://www.googleapis.com/auth/logging.write",
"https://www.googleapis.com/auth/monitoring.write",
"https://www.googleapis.com/auth/servicecontrol",
"https://www.googleapis.com/auth/service.management.readonly",
"https://www.googleapis.com/auth/trace.append",
],
}
],
}
def create(self):
LOGGER.info("Creating new compute instance")
operation = (
self.compute.instances()
.insert(
project=self.gcp_project_id,
zone=self.gcp_cluster_zone,
body=self.instance_config,
)
.execute()
)
self._wait_for_operation(operation["name"])
def create_client(self):
cluster_host_url = f"tls://{self.cluster_host_ip_address}:8786"
LOGGER.info(f"Connecting new client to {cluster_host_url}")
sec = Security(
tls_ca_file=self.dask_cert_filepath,
tls_client_cert=self.dask_cert_filepath,
tls_client_key=self.dask_key_filepath,
require_encryption=True,
)
return Client(cluster_host_url, security=sec)
def teardown(self):
LOGGER.info("Shutting down client & tearing down cluster")
self.client.close()
operation = (
self.compute.instances()
.delete(
project=self.gcp_project_id,
zone=self.gcp_cluster_zone,
instance=self.gcp_instance_name,
)
.execute()
)
LOGGER.info("Waiting for teardown to finish...")
self._wait_for_operation(operation["name"])
def inspect_requires_cluster_function(func) -> Tuple[int, Optional[int]]:
"""
Validate a function decorated with @requires_cluster
has client specified as an arg or kwarg.
Return the position of the client argument in func.
"""
func_sig = inspect.signature(func)
client_arg_position = None
teardown_cluster_arg_position = None
for param_pos, param in enumerate(func_sig.parameters.values()):
if param.name == "client":
client_arg_position = param_pos
if param.name == "teardown_cluster":
teardown_cluster_arg_position = param_pos
if client_arg_position is None:
raise ValueError(
"Functions using @requires_cluster must accept client as an arg or kwarg "
"For example: \n"
"""
@requires_cluster
def do_stuff(a, b, client=None):
# do stuff with a, b and client
"""
)
return client_arg_position, teardown_cluster_arg_position
def requires_cluster(
num_workers: int = NUM_WORKERS,
threads_per_worker: int = THREADS_PER_WORKER,
memory_per_worker_gb: float = MEMORY_PER_WORKER,
machine_type: str = MACHINE_TYPE,
gcp_instance_name: str = GCP_INSTANCE_NAME,
gcp_cluster_zone: str = GCP_CLUSTER_ZONE,
teardown_cluster=True,
):
"""
A decorator to automatically provide a function with a ready to use
dask cluster client
"""
def decorator(func):
(
client_arg_position,
teardown_cluster_arg_position,
) = inspect_requires_cluster_function(func)
def wrapper(*args, **kwargs):
cluster = None
client_provided_in_args = False
for arg in args:
if isinstance(arg, Client):
client_provided_in_args = True
break
# When client is provided, just run the function as is
if isinstance(kwargs.get("client"), Client) or client_provided_in_args:
return func(*args, **kwargs)
if kwargs.get("teardown_cluster") is not None:
if not isinstance(kwargs["teardown_cluster"], bool):
raise ValueError("Value of teardown_cluster must be a boolean")
_teardown_cluster = kwargs["teardown_cluster"]
elif (
teardown_cluster_arg_position is not None
and len(args) > teardown_cluster_arg_position
):
if not isinstance(args[teardown_cluster_arg_position], bool):
raise ValueError("Value of teardown_cluster must be a boolean")
_teardown_cluster = args[teardown_cluster_arg_position]
else:
_teardown_cluster = teardown_cluster
try:
cluster = Cluster(
num_workers=num_workers,
threads_per_worker=threads_per_worker,
memory_per_worker_gb=memory_per_worker_gb,
machine_type=machine_type,
gcp_instance_name=gcp_instance_name,
gcp_cluster_zone=gcp_cluster_zone,
)
# update the args/kwargs with the newly created client
new_args = []
for i, arg in enumerate(args):
if i == client_arg_position:
new_args.append(cluster.client)
else:
new_args.append(arg)
if len(args) <= client_arg_position:
# client was not passed in as an arg, so update the kwargs
kwargs["client"] = cluster.client
return func(*new_args, **kwargs)
finally:
if _teardown_cluster:
if cluster is not None:
cluster.teardown()
else:
LOGGER.info(
f"Leaving cluster running at {cluster.cluster_host_ip_address}"
)
return wrapper
return decorator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment