Skip to content

Instantly share code, notes, and snippets.

@asford
Created January 3, 2018 04:45
Show Gist options
  • Save asford/0469c4f3efbbd0c22a04caed87f1c165 to your computer and use it in GitHub Desktop.
Save asford/0469c4f3efbbd0c22a04caed87f1c165 to your computer and use it in GitHub Desktop.
#Adaptive subclass to support minimum cluster size and grouped scaling
#Cluster modification via gcloud calls
from __future__ import division
from math import ceil
import click
import subprocess
import distributed.deploy.adaptive as adaptive
import logging
import json
import toolz
from collections import defaultdict
from urllib.parse import urlparse
import tornado.process as process
import tornado.gen as gen
logger = logging.getLogger("distributed.preload." + __name__)
class AdaptiveByHost(adaptive.Adaptive):
@classmethod
def group_key(self, wname, winfo = None):
return urlparse(wname).hostname
def __init__(self, scheduler, min_workers=0, max_workers=None, *args, **kwargs):
self.min_workers = min_workers
self.max_workers = max_workers
adaptive.Adaptive.__init__(self, scheduler, *args, **kwargs)
def workers_to_close(self):
num_workers = len(self.scheduler.workers)
want_to_close = self.scheduler.workers_to_close(group_key=self.group_key)
can_close = num_workers - self.min_workers
if len(want_to_close) <= can_close:
to_close = want_to_close
else:
to_close = []
by_group = defaultdict(list)
for w in want_to_close:
by_group[self.group_key(w)].append(w)
for g, members in by_group.items():
if len(to_close) + len(members) <= can_close:
to_close.extend(members)
if to_close:
logger.info("want_to_close: %s can_close: %s closing: %s",
len(want_to_close), can_close, len(to_close))
else:
logger.debug("want_to_close: %s can_close: %s closing: %s",
len(want_to_close), can_close, len(to_close))
return to_close
def should_scale_up(self):
if len(self.scheduler.ncores) >= self.max_workers:
logger.info("Will not scale scheduler.ncores: %i max_worker: %i.", len(self.scheduler.ncores), self.max_workers)
return False
return adaptive.Adaptive.should_scale_up(self)
def get_scale_up_kwargs(self):
"""
Get the arguments to be passed to ``self.cluster.scale_up``.
"""
instances = max(1, len(self.scheduler.ncores) * self.scale_factor)
if instances > self.max_workers:
logger.info("limiting from target: %s to self.max_workers: %s", instances, self.max_workers)
instances = self.max_workers
logger.info("Scaling up to %d workers", instances)
return {'n': instances}
class GCloudInstanceCluster(object):
@classmethod
@gen.coroutine
def check_json_output(cls, cmd):
cmd = cmd + ["--format", "json"]
output = yield cls.check_output(cmd)
output = output.decode()
raise gen.Return(json.loads(output))
@classmethod
@gen.coroutine
def check_output(cls, cmd):
logger.info("command: %s", cmd)
proc = process.Subprocess(cmd, stdout=process.Subprocess.STREAM)
output = yield proc.stdout.read_until_close()
ret = yield proc.wait_for_exit()
logger.info("ret: %s output: %s", ret, output)
raise gen.Return(output)
def __init__(self, instance_group, worker_per_instance):
self.instance_group = instance_group
self.worker_per_instance = worker_per_instance
instance_groups = {
g["name"] for g in json.loads(subprocess.check_output(
"gcloud compute instance-groups managed list --format json".split(" ")
).decode())
}
if not self.instance_group:
raise ValueError("instance_group not provided.")
elif not self.instance_group in instance_groups:
raise ValueError("Unknown instance group: %r Known groups: %s" %
(self.instance_group, instance_groups))
logger.info("Initialized cluster instance_group: %r worker_per_instance: %s",
self.instance_group, self.worker_per_instance)
@gen.coroutine
def scale_up(self, n):
"""
Bring the total count of workers up to at-least ``n``
"""
scale_up_instances = ceil(n / self.worker_per_instance)
logger.info("scale_up n: %s target instance count: %s", n, scale_up_instances)
logger.info("checking instance group state")
yield self.check_output(
"gcloud compute instance-groups managed wait-until-stable".split(" ") + [self.instance_group])
instances = yield self.check_json_output(
"gcloud compute instance-groups managed list-instances".split(" ") + [self.instance_group])
if len(instances) >= scale_up_instances:
logger.info("scale_up n: %s required instance count: %s but already had: %s",
n, scale_up_instances, len(instances))
return
logger.info("performing resize scale_up n: %s target instance count: %s", n, scale_up_instances)
yield self.check_output(
"gcloud compute instance-groups managed resize".split(" ")
+ [self.instance_group, "--size", str(scale_up_instances)]
)
@gen.coroutine
def scale_down(self, workers):
"""
Remove ``workers`` from the cluster
"""
logger.info("scale_down: %s", workers)
worker_ips = { urlparse(w).hostname for w in workers }
logger.info("scale down target ips: %s", worker_ips)
logger.info("checking instance group state")
yield self.check_output(
"gcloud compute instance-groups managed wait-until-stable".split(" ") + [self.instance_group])
instances = yield self.check_json_output(
"gcloud compute instances list".split(" "))
name_by_ip = {
i["networkInterfaces"][0]["networkIP"] : i["name"]
for i in instances
}
worker_names = []
for ip in worker_ips:
name = name_by_ip.get(ip, None)
if name:
worker_names.append(name)
else:
logger.warning("Could not locate worker name for target ip: %s", ip)
worker_names = set(worker_names)
logger.info("scale down worker names: %s", worker_names)
yield self.check_output(
"gcloud compute instance-groups managed delete-instances".split(" ")
+ [self.instance_group, "--instances=" + ",".join(worker_names)])
logger.info("waiting until stable")
yield self.check_output(
"gcloud compute instance-groups managed wait-until-stable".split(" ") + [self.instance_group])
@click.command()
@click.option("--instance_group", type=str)
@click.option("--min_workers", type=click.IntRange(min=0), default=0)
@click.option("--max_workers", type=click.IntRange(min=1), default=None)
@click.option("--worker_per_instance", type=click.IntRange(min=1), default=1)
@click.option("--interval", type=click.IntRange(min=1000), default=60 * 1000)
def dask_command(**_config):
global adaptive_config
adaptive_config = _config
def dask_setup(scheduler):
cluster = AdaptiveByHost(
scheduler,
min_workers = adaptive_config["min_workers"],
max_workers = adaptive_config["max_workers"],
interval = adaptive_config["interval"],
cluster = GCloudInstanceCluster(
adaptive_config["instance_group"],
adaptive_config["worker_per_instance"]
)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment