Created
January 3, 2018 04:45
-
-
Save asford/0469c4f3efbbd0c22a04caed87f1c165 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Adaptive subclass to support minimum cluster size and grouped scaling | |
#Cluster modification via gcloud calls | |
from __future__ import division | |
from math import ceil | |
import click | |
import subprocess | |
import distributed.deploy.adaptive as adaptive | |
import logging | |
import json | |
import toolz | |
from collections import defaultdict | |
from urllib.parse import urlparse | |
import tornado.process as process | |
import tornado.gen as gen | |
logger = logging.getLogger("distributed.preload." + __name__) | |
class AdaptiveByHost(adaptive.Adaptive): | |
@classmethod | |
def group_key(self, wname, winfo = None): | |
return urlparse(wname).hostname | |
def __init__(self, scheduler, min_workers=0, max_workers=None, *args, **kwargs): | |
self.min_workers = min_workers | |
self.max_workers = max_workers | |
adaptive.Adaptive.__init__(self, scheduler, *args, **kwargs) | |
def workers_to_close(self): | |
num_workers = len(self.scheduler.workers) | |
want_to_close = self.scheduler.workers_to_close(group_key=self.group_key) | |
can_close = num_workers - self.min_workers | |
if len(want_to_close) <= can_close: | |
to_close = want_to_close | |
else: | |
to_close = [] | |
by_group = defaultdict(list) | |
for w in want_to_close: | |
by_group[self.group_key(w)].append(w) | |
for g, members in by_group.items(): | |
if len(to_close) + len(members) <= can_close: | |
to_close.extend(members) | |
if to_close: | |
logger.info("want_to_close: %s can_close: %s closing: %s", | |
len(want_to_close), can_close, len(to_close)) | |
else: | |
logger.debug("want_to_close: %s can_close: %s closing: %s", | |
len(want_to_close), can_close, len(to_close)) | |
return to_close | |
def should_scale_up(self): | |
if len(self.scheduler.ncores) >= self.max_workers: | |
logger.info("Will not scale scheduler.ncores: %i max_worker: %i.", len(self.scheduler.ncores), self.max_workers) | |
return False | |
return adaptive.Adaptive.should_scale_up(self) | |
def get_scale_up_kwargs(self): | |
""" | |
Get the arguments to be passed to ``self.cluster.scale_up``. | |
""" | |
instances = max(1, len(self.scheduler.ncores) * self.scale_factor) | |
if instances > self.max_workers: | |
logger.info("limiting from target: %s to self.max_workers: %s", instances, self.max_workers) | |
instances = self.max_workers | |
logger.info("Scaling up to %d workers", instances) | |
return {'n': instances} | |
class GCloudInstanceCluster(object): | |
@classmethod | |
@gen.coroutine | |
def check_json_output(cls, cmd): | |
cmd = cmd + ["--format", "json"] | |
output = yield cls.check_output(cmd) | |
output = output.decode() | |
raise gen.Return(json.loads(output)) | |
@classmethod | |
@gen.coroutine | |
def check_output(cls, cmd): | |
logger.info("command: %s", cmd) | |
proc = process.Subprocess(cmd, stdout=process.Subprocess.STREAM) | |
output = yield proc.stdout.read_until_close() | |
ret = yield proc.wait_for_exit() | |
logger.info("ret: %s output: %s", ret, output) | |
raise gen.Return(output) | |
def __init__(self, instance_group, worker_per_instance): | |
self.instance_group = instance_group | |
self.worker_per_instance = worker_per_instance | |
instance_groups = { | |
g["name"] for g in json.loads(subprocess.check_output( | |
"gcloud compute instance-groups managed list --format json".split(" ") | |
).decode()) | |
} | |
if not self.instance_group: | |
raise ValueError("instance_group not provided.") | |
elif not self.instance_group in instance_groups: | |
raise ValueError("Unknown instance group: %r Known groups: %s" % | |
(self.instance_group, instance_groups)) | |
logger.info("Initialized cluster instance_group: %r worker_per_instance: %s", | |
self.instance_group, self.worker_per_instance) | |
@gen.coroutine | |
def scale_up(self, n): | |
""" | |
Bring the total count of workers up to at-least ``n`` | |
""" | |
scale_up_instances = ceil(n / self.worker_per_instance) | |
logger.info("scale_up n: %s target instance count: %s", n, scale_up_instances) | |
logger.info("checking instance group state") | |
yield self.check_output( | |
"gcloud compute instance-groups managed wait-until-stable".split(" ") + [self.instance_group]) | |
instances = yield self.check_json_output( | |
"gcloud compute instance-groups managed list-instances".split(" ") + [self.instance_group]) | |
if len(instances) >= scale_up_instances: | |
logger.info("scale_up n: %s required instance count: %s but already had: %s", | |
n, scale_up_instances, len(instances)) | |
return | |
logger.info("performing resize scale_up n: %s target instance count: %s", n, scale_up_instances) | |
yield self.check_output( | |
"gcloud compute instance-groups managed resize".split(" ") | |
+ [self.instance_group, "--size", str(scale_up_instances)] | |
) | |
@gen.coroutine | |
def scale_down(self, workers): | |
""" | |
Remove ``workers`` from the cluster | |
""" | |
logger.info("scale_down: %s", workers) | |
worker_ips = { urlparse(w).hostname for w in workers } | |
logger.info("scale down target ips: %s", worker_ips) | |
logger.info("checking instance group state") | |
yield self.check_output( | |
"gcloud compute instance-groups managed wait-until-stable".split(" ") + [self.instance_group]) | |
instances = yield self.check_json_output( | |
"gcloud compute instances list".split(" ")) | |
name_by_ip = { | |
i["networkInterfaces"][0]["networkIP"] : i["name"] | |
for i in instances | |
} | |
worker_names = [] | |
for ip in worker_ips: | |
name = name_by_ip.get(ip, None) | |
if name: | |
worker_names.append(name) | |
else: | |
logger.warning("Could not locate worker name for target ip: %s", ip) | |
worker_names = set(worker_names) | |
logger.info("scale down worker names: %s", worker_names) | |
yield self.check_output( | |
"gcloud compute instance-groups managed delete-instances".split(" ") | |
+ [self.instance_group, "--instances=" + ",".join(worker_names)]) | |
logger.info("waiting until stable") | |
yield self.check_output( | |
"gcloud compute instance-groups managed wait-until-stable".split(" ") + [self.instance_group]) | |
@click.command() | |
@click.option("--instance_group", type=str) | |
@click.option("--min_workers", type=click.IntRange(min=0), default=0) | |
@click.option("--max_workers", type=click.IntRange(min=1), default=None) | |
@click.option("--worker_per_instance", type=click.IntRange(min=1), default=1) | |
@click.option("--interval", type=click.IntRange(min=1000), default=60 * 1000) | |
def dask_command(**_config): | |
global adaptive_config | |
adaptive_config = _config | |
def dask_setup(scheduler): | |
cluster = AdaptiveByHost( | |
scheduler, | |
min_workers = adaptive_config["min_workers"], | |
max_workers = adaptive_config["max_workers"], | |
interval = adaptive_config["interval"], | |
cluster = GCloudInstanceCluster( | |
adaptive_config["instance_group"], | |
adaptive_config["worker_per_instance"] | |
) | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment