Skip to content

Instantly share code, notes, and snippets.

@deanwampler
Last active December 30, 2020 23:39
Show Gist options
  • Save deanwampler/7d6580b01846c428e25d14f9a5a9b8f4 to your computer and use it in GitHub Desktop.
Save deanwampler/7d6580b01846c428e25d14f9a5a9b8f4 to your computer and use it in GitHub Desktop.
The full example in one listing for Ray for the curious
import numpy as np
import ray
import time
ray.init() # Start Ray
@ray.remote
class ParameterServer(object):
def __init__(self, dim):
self.params = np.zeros(dim)
def get_params(self):
return self.params
def update_params(self, grad):
self.params += grad
@ray.remote
def sharded_worker(*parameter_servers):
for _ in range(100):
# Get the latest parameters.
parameter_shards = ray.get(
[ps.get_params.remote() for ps in parameter_servers])
params = np.concatenate(parameter_shards)
# Compute a gradient update as before in `worker`, but
# with additional logic for sharding.
grad = np.ones(10)
# A placeholder for some expensive computation:
time.sleep(0.2)
grad_shards = np.split(grad, len(parameter_servers))
# Send the gradient updates to the parameter servers.
for ps, grad in zip(parameter_servers, grad_shards):
ps.update_params.remote(grad)
# Start two parameter servers, each with half of the parameters.
parameter_servers = [ParameterServer.remote(5) for _ in range(2)]
# Start 2 workers.
workers = [
sharded_worker.remote(*parameter_servers) for _ in range(2)]
# Inspect the parameters at regular intervals until we've
# reached the end (i.e., each parameter equals 200)
while True:
time.sleep(1)
results = ray.get(
[ps.get_params.remote() for ps in parameter_servers])
print(results)
if results[0][0] >= 200.0:
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment