Skip to content

Instantly share code, notes, and snippets.

@robertnishihara
Last active July 16, 2018 00:38
Show Gist options
  • Save robertnishihara/0cf2b4b29728525426a67f74250cc169 to your computer and use it in GitHub Desktop.
Save robertnishihara/0cf2b4b29728525426a67f74250cc169 to your computer and use it in GitHub Desktop.
Implementing a Parameter Server in 15 Lines of Python with Ray
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Implementing A Parameter Server in 15 Lines of Python with Ray\n",
"\n",
"Parameter servers are a core part of many machine learning applications. Their\n",
"role is to store the *parameters* of a machine learning model (e.g., the weights\n",
"of a neural network) and to *serve* them to clients (clients are often workers\n",
"that process data and compute updates to the parameters).\n",
"\n",
"Parameter servers (like databases) are normally built and shipped as standalone\n",
"systems. This post describes how to use [Ray][1] to implement a parameter server\n",
"in a few lines of code.\n",
"\n",
"By turning the parameter server from a \"system\" into an \"application\", this\n",
"approach makes it orders of magnitude simpler to deploy parameter server\n",
"applications. Similarly, by allowing applications and libraries to implement\n",
"their own parameter servers, this approach makes the behavior of the parameter\n",
"server much more configurable and flexible (since the application can simply\n",
"modify the implementation with a few lines of Python).\n",
"\n",
"**What is Ray?** [Ray][1] is a general-purpose framework for parallel and\n",
"distributed Python. Ray provides a unified task-parallel and actor abstraction\n",
"and achieves high performance through shared memory, zero-copy serialization,\n",
"and distributed scheduling. Ray also includes high-performance libraries\n",
"targeting AI applications, for example [hyperparameter tuning][5] and\n",
"[reinforcement learning][4].\n",
"\n",
"## What is a Parameter Server?\n",
"\n",
"A parameter server is a key-value store used for training machine learning\n",
"models on a cluster. The **values** are the parameters of a machine-learning\n",
"model (e.g., a neural network). The **keys** index the model parameters.\n",
"\n",
"For example, in a movie **recommendation system**, there may be one key per user\n",
"and one key per movie. For each user and movie, there are corresponding\n",
"user-specific and movie-specific parameters. In a **language-modeling**\n",
"application, words may act as keys and their embeddings may be the values. In\n",
"its simplest form, a parameter server may implicitly have a single key and allow\n",
"all of the parameters to be retrieved and updated at once. We show how such a\n",
"parameter server can be implemented as a Ray actor (15 lines) below.\n",
"\n",
"[1]: https://github.com/ray-project/ray\n",
"[2]: http://ray.readthedocs.io/en/latest/resources.html\n",
"[3]: http://www.sysml.cc/doc/206.pdf\n",
"[4]: http://ray.readthedocs.io/en/latest/rllib.html\n",
"[5]: http://ray.readthedocs.io/en/latest/tune.html\n",
"[6]: http://ray.readthedocs.io/en/latest\n",
"[7]: http://ray.readthedocs.io/en/latest/api.html\n",
"[8]: https://github.com/modin-project/modin\n",
"[9]: https://ray-project.github.io/2017/10/15/fast-python-serialization-with-ray-and-arrow.html\n",
"[10]: https://ray-project.github.io/2017/08/08/plasma-in-memory-object-store.html\n",
"[11]: https://arxiv.org/abs/1712.05889\n",
"[12]: http://spark.apache.org\n",
"[13]: https://arxiv.org/abs/1712.09381\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import ray\n",
"\n",
"\n",
"@ray.remote\n",
"class ParameterServer(object):\n",
" def __init__(self, dim):\n",
" # Alternatively, params could be a dictionary mapping keys to arrays.\n",
" self.params = np.zeros(dim)\n",
"\n",
" def get_params(self):\n",
" return self.params\n",
"\n",
" def update_params(self, grad):\n",
" self.params += grad"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**The `@ray.remote` decorator defines a service.** It takes the\n",
"`ParameterServer` class and allows it to be instantiated as a remote service or\n",
"actor.\n",
"\n",
"Here, we assume that the update is a gradient which should be added to the\n",
"parameter vector. This is just the simplest possible example, and many different\n",
"choices could be made.\n",
"\n",
"**A parameter server typically exists as a remote process or service** and\n",
"interacts with clients through remote procedure calls. To instantiate the\n",
"parameter server as a remote actor, we can do the following."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# We need to start Ray first.\n",
"ray.init()\n",
"\n",
"# Create a parameter server process.\n",
"ps = ParameterServer.remote(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Actor method invocations return futures.** If we want to retrieve the actual\n",
"values, we can use a blocking `ray.get` call. For example,\n",
"\n",
"```python\n",
">>> params_id = ps.get_params.remote() # This returns a future.\n",
"\n",
">>> params_id\n",
"ObjectID(7268cb8d345ef26632430df6f18cc9690eb6b300)\n",
"\n",
">>> ray.get(params_id) # This blocks until the task finishes.\n",
"array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])\n",
"```\n",
"\n",
"Now, suppose we want to start some worker tasks that continuously compute\n",
"gradients and update the model parameters. Each worker will run in a loop that\n",
"does three things:\n",
"1. Get the latest parameters.\n",
"2. Compute an update to the parameters.\n",
"3. Update the parameters.\n",
"\n",
"As a Ray remote function (though the worker could also be an actor), this looks\n",
"like the following.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"\n",
"# Note that the worker function takes a handle to the parameter server as an\n",
"# argument, which allows the worker task to invoke methods on the parameter\n",
"# server actor.\n",
"\n",
"@ray.remote\n",
"def worker(ps):\n",
" for _ in range(100):\n",
" # Get the latest parameters.\n",
" params_id = ps.get_params.remote() # This method call is non-blocking\n",
" # and returns a future.\n",
" params = ray.get(params_id) # This is a blocking call which waits for\n",
" # the task to finish and gets the results.\n",
"\n",
" # Compute a gradient update. Here we just make a fake update, but in\n",
" # practice this would use a library like TensorFlow and would also take\n",
" # in a batch of data.\n",
" grad = np.ones(10)\n",
" time.sleep(0.2) # This is a fake placeholder for some computation.\n",
"\n",
" # Update the parameters.\n",
" ps.update_params.remote(grad)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then we can start several worker tasks as follows.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Start 2 workers.\n",
"for _ in range(2):\n",
" worker.remote(ps)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then we can retrieve the parameters from the driver process and see that they\n",
"are being updated by the workers.\n",
"\n",
"```python\n",
">>> ray.get(ps.get_params.remote())\n",
"array([64., 64., 64., 64., 64., 64., 64., 64., 64., 64.])\n",
">>> ray.get(ps.get_params.remote())\n",
"array([78., 78., 78., 78., 78., 78., 78., 78., 78., 78.])\n",
"```\n",
"\n",
"Part of the value that Ray adds here is that *Ray makes it as easy to start up a\n",
"remote service or actor as it is to define a Python class*. Handles to the actor\n",
"can be passed around to other actors and tasks to allow arbitrary and intuitive\n",
"messaging and communication patterns. Current alternatives are much more\n",
"involved. For example, [consider how the equivalent runtime service creation and\n",
"service handle passing would be done with GRPC][14].\n",
"\n",
"## Additional Extensions\n",
"\n",
"Here we describe some important modifications to the above design. We describe\n",
"additional natural extensions in [this paper][3].\n",
"\n",
"**Sharding Across Multiple Parameter Servers:** When your parameters are large and your cluster is large, a single parameter\n",
"server may not suffice because the application could be bottlenecked by the\n",
"network bandwidth into and out of the machine that the parameter server is on\n",
"(especially if there are many workers).\n",
"\n",
"A natural solution in this case is to shard the parameters across multiple\n",
"parameter servers. This can be achieved by simply starting up multiple parameter\n",
"server actors. An example of how to do this is shown in the code example at the\n",
"bottom.\n",
"\n",
"**Controlling Actor Placement:** The placement of specific actors and tasks on different machines can be\n",
"specified by using Ray's support for arbitrary [resource requirements][2].\n",
"For example, if the worker requires a GPU, then its remote decorator can be\n",
"declared with `@ray.remote(num_gpus=1)`. Arbitrary custom resources can be defined\n",
"as well.\n",
"\n",
"## Unifying Tasks and Actors\n",
"\n",
"Ray supports parameter server applications efficiently in large part due to its\n",
"unified task-parallel and actor abstraction.\n",
"\n",
"Popular data processing systems such as [Apache Spark][12] allow stateless tasks\n",
"(functions with no side effects) to operate on immutable data. This assumption\n",
"simplifies the overall system design and makes it easier for applications to\n",
"reason about correctness.\n",
"\n",
"However, mutable state that is shared between many tasks is a recurring theme in\n",
"machine learning applications. That state could be the weights of a neural\n",
"network, the state of a third-party simulator, or an encapsulation of an\n",
"interaction with the physical world.\n",
"\n",
"To support these kinds of applications, Ray introduces an actor abstraction. An\n",
"actor will execute methods serially (so there are no concurrency issues), and\n",
"each method can arbitrarily mutate the actor's internal state. Methods can be\n",
"invoked by other actors and tasks (and even by other applications on the same\n",
"cluster).\n",
"\n",
"One thing that makes Ray so powerful is that it *unifies the actor abstraction\n",
"with the task-parallel abstraction* inheriting the benefits of both approaches.\n",
"Ray uses an underlying dynamic task graph to implement both actors and stateless\n",
"tasks in the same framework. As a consequence, these two abstractions are\n",
"completely interoperable. Tasks and actors can be created from within other\n",
"tasks and actors. Both return futures, which can be passed into other tasks or\n",
"actor methods to introduce scheduling and data dependencies. As a result, Ray\n",
"applications inherit the best features of both tasks and actors.\n",
"\n",
"## Under the Hood\n",
"\n",
"**Dynamic Task Graphs:** Under the hood, remote function invocations and actor\n",
"method invocations create tasks that are added to a dynamically growing graph of\n",
"tasks. The Ray backend is in charge of scheduling and executing these tasks\n",
"across a cluster (or a single multi-core machine). Tasks can be created by the\n",
"\"driver\" application or by other tasks.\n",
"\n",
"**Data:** Ray efficiently serializes data using the [Apache Arrow][9] data\n",
"layout. Objects are shared between workers and actors on the same machine\n",
"through [shared memory][10], which avoids the need for copies or\n",
"deserialization. This optimization is absolutely critical for achieving good\n",
"performance.\n",
"\n",
"**Scheduling:** Ray uses a distributed scheduling approach. Each machine has its\n",
"own scheduler, which manages the workers and actors on that machine. Tasks are\n",
"submitted by applications and workers to the scheduler on the same machine. From\n",
"there, they can be reassigned to other workers or passed to other local\n",
"schedulers. This allows Ray to achieve substantially higher task throughput than\n",
"what can be achieved with a centralized scheduler, which is important for\n",
"machine learning applications.\n",
"\n",
"## Conclusion\n",
"\n",
"A parameter server is normally implemented and shipped as a standalone system.\n",
"The thing that makes this approach so powerful is that we're able to implement a\n",
"parameter server with a few lines of code as an application. *This approach\n",
"makes it much simpler to deploy applications that use parameter servers and to\n",
"modify the behavior of the parameter server.* For example, if we want to shard\n",
"the parameter server, change the update rule, switch between asynchronous and\n",
"synchronous updates, ignore straggler workers, or any number of other\n",
"customizations, we can do each of these things with a few extra lines of code.\n",
"\n",
"This post describes how to use Ray actors to implement a parameter server.\n",
"However, actors are a much more general concept and can be useful for many\n",
"applications that involve stateful computation. Examples include logging,\n",
"streaming, simulation, model serving, graph processing, and many others.\n",
"\n",
"## Running this Code\n",
"\n",
"To run the complete application, first install Ray with `pip install ray`. Then\n",
"you should be able to run the code below, which implements a sharded parameter\n",
"server.\n",
"\n",
"[1]: https://github.com/ray-project/ray\n",
"[2]: http://ray.readthedocs.io/en/latest/resources.html\n",
"[3]: http://www.sysml.cc/doc/206.pdf\n",
"[4]: http://ray.readthedocs.io/en/latest/rllib.html\n",
"[5]: http://ray.readthedocs.io/en/latest/tune.html\n",
"[6]: http://ray.readthedocs.io/en/latest\n",
"[7]: http://ray.readthedocs.io/en/latest/api.html\n",
"[8]: https://github.com/modin-project/modin\n",
"[9]: https://ray-project.github.io/2017/10/15/fast-python-serialization-with-ray-and-arrow.html\n",
"[10]: https://ray-project.github.io/2017/08/08/plasma-in-memory-object-store.html\n",
"[11]: https://arxiv.org/abs/1712.05889\n",
"[12]: http://spark.apache.org\n",
"[13]: https://arxiv.org/abs/1712.09381\n",
"[14]: https://grpc.io/docs/tutorials/basic/python.html#defining-the-service"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import ray\n",
"import time\n",
"\n",
"# Start Ray.\n",
"ray.init()\n",
"\n",
"\n",
"@ray.remote\n",
"class ParameterServer(object):\n",
" def __init__(self, dim):\n",
" # Alternatively, params could be a dictionary mapping keys to arrays.\n",
" self.params = np.zeros(dim)\n",
"\n",
" def get_params(self):\n",
" return self.params\n",
"\n",
" def update_params(self, grad):\n",
" self.params += grad\n",
"\n",
"\n",
"@ray.remote\n",
"def worker(*parameter_servers):\n",
" for _ in range(100):\n",
" # Get the latest parameters.\n",
" parameter_shards = ray.get(\n",
" [ps.get_params.remote() for ps in parameter_servers])\n",
" params = np.concatenate(parameter_shards)\n",
"\n",
" # Compute a gradient update. Here we just make a fake\n",
" # update, but in practice this would use a library like\n",
" # TensorFlow and would also take in a batch of data.\n",
" grad = np.ones(10)\n",
" time.sleep(0.2) # This is a fake placeholder for some computation.\n",
" grad_shards = np.split(grad, len(parameter_servers))\n",
"\n",
" # Send the gradient updates to the parameter servers.\n",
" for ps, grad in zip(parameter_servers, grad_shards):\n",
" ps.update_params.remote(grad)\n",
"\n",
"\n",
"# Start two parameter servers, each with half of the parameters.\n",
"parameter_servers = [ParameterServer.remote(5) for _ in range(2)]\n",
"\n",
"# Start 2 workers.\n",
"workers = [worker.remote(*parameter_servers) for _ in range(2)]\n",
"\n",
"# Inspect the parameters at regular intervals.\n",
"for _ in range(5):\n",
" time.sleep(1)\n",
" print(ray.get([ps.get_params.remote() for ps in parameter_servers]))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that this example focuses on simplicity and that more can be done to\n",
"optimize this code.\n",
"\n",
"## Read More\n",
"\n",
"For more information about Ray, take a look at the following links:\n",
"1. The Ray [documentation][6]\n",
"2. The Ray [API][7]\n",
"3. Fast [serialization][9] with Ray and Apache Arrow\n",
"4. A [paper][11] describing the Ray system\n",
"5. Efficient [hyperparameter][5] tuning with Ray\n",
"6. Scalable [reinforcement][4] learning with Ray and [the RLlib paper][13]\n",
"7. Speeding up [Pandas][8] with Ray\n",
"\n",
"Questions should be directed to *ray-dev@googlegroups.com*.\n",
"\n",
"\n",
"[1]: https://github.com/ray-project/ray\n",
"[2]: http://ray.readthedocs.io/en/latest/resources.html\n",
"[3]: http://www.sysml.cc/doc/206.pdf\n",
"[4]: http://ray.readthedocs.io/en/latest/rllib.html\n",
"[5]: http://ray.readthedocs.io/en/latest/tune.html\n",
"[6]: http://ray.readthedocs.io/en/latest\n",
"[7]: http://ray.readthedocs.io/en/latest/api.html\n",
"[8]: https://github.com/modin-project/modin\n",
"[9]: https://ray-project.github.io/2017/10/15/fast-python-serialization-with-ray-and-arrow.html\n",
"[10]: https://ray-project.github.io/2017/08/08/plasma-in-memory-object-store.html\n",
"[11]: https://arxiv.org/abs/1712.05889\n",
"[12]: http://spark.apache.org\n",
"[13]: https://arxiv.org/abs/1712.09381\n",
"[14]: https://grpc.io/docs/tutorials/basic/python.html#defining-the-service"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment