Skip to content

Instantly share code, notes, and snippets.

@llan-ml
Created December 12, 2018 09:04
Show Gist options
  • Save llan-ml/f11cf6c98a2f21074121dd2574b7463f to your computer and use it in GitHub Desktop.
Save llan-ml/f11cf6c98a2f21074121dd2574b7463f to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
# @Author : Lin Lan (ryan.linlan@gmail.com)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
import ray
from ray.tune.trainable import Trainable
from ray.tune.trial import Resources
from ray.tune import register_trainable, run_experiments
@ray.remote(num_cpus=1)
class ParameterServer(object):
def __init__(self):
self.weights = np.random.rand(128, 128).astype(np.float64)
def get(self):
return self.weights
def update(self, diff):
self.weights += diff
@ray.remote(num_cpus=1)
class Worker(object):
def __init__(self, seed_holder):
self.weights = None
self.seed_holder = seed_holder
def set_weights(self, weights):
self.weights = weights
def calculate_diff(self):
seeds = ray.get(
[self.seed_holder.get.remote() for _ in range(100)])
rng = np.random.choice(seeds)
return rng.rand(*self.weights.shape)
@ray.remote(num_cpus=1)
class SeedHolder(object):
def __init__(self):
self.seeds = [
np.random.RandomState(seed) for seed in range(10)]
def get(self):
return np.random.choice(self.seeds)
class Foo(Trainable):
@classmethod
def default_resource_request(cls, config):
return Resources(
cpu=1,
gpu=0,
extra_cpu=20 + 2,
extra_gpu=0)
def _setup(self, config):
self.seed_holder = SeedHolder.remote()
self.ps = ParameterServer.remote()
self.workers = [
Worker.remote(self.seed_holder) for _ in range(20)]
def _train(self):
weights = ray.get(self.ps.get.remote())
weights_id = ray.put(weights)
ray.get([w.set_weights.remote(weights_id)
for w in self.workers])
all_diffs = ray.get(
[e.calculate_diff.remote() for e in self.workers])
diff = np.mean(all_diffs, axis=0)
self.ps.update.remote(diff)
weights = ray.get(self.ps.get.remote())
return {"weight_norm": np.linalg.norm(weights)}
register_trainable("foo", Foo)
ray.init(redis_address="localhost:32222")
# gcs_policy = ray.experimental.SimpleGcsFlushPolicy(
# flush_when_at_least_bytes=10000000000,
# flush_period_secs=10,
# flush_num_entries_each_time=70000)
# ray.experimental.set_flushing_policy(gcs_policy)
run_experiments(
{
"test": {
"run": "foo",
"stop": {"training_iteration": 1000},
"num_samples": 1000,
"local_dir": "/tmp/ray_results"
}
}
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment