Skip to content

Instantly share code, notes, and snippets.

@mrocklin
Created June 28, 2017 20:56
Show Gist options
  • Save mrocklin/355dd4a72efcecb160c7e5641d4428a3 to your computer and use it in GitHub Desktop.
Save mrocklin/355dd4a72efcecb160c7e5641d4428a3 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Dask: Multi-client workloads. Machine Learning Parameter servers\n",
"====================\n",
"\n",
"Usually in Dask there is a single client that submits a static graph to the scheduler. The scheduler then coordinates a set of workers to do that work. When they finish they send the result back to the client.\n",
"\n",
"However for advanced algorithms it is sometimes useful to have the workers themselves drive the computation. Dask supports starting many clients throughout the cluster that coordinate amongst each other. This small change can create significantly more advanced systems.\n",
"\n",
"To show this off we'll build a light-weight SGD solver using a [Parameter Server](https://www.quora.com/What-is-the-Parameter-Server). One of our workers will take updates from the rest and publish a set of parameters for them to collect when needed."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [],
"source": [
"import dask\n",
"import dask.array as da\n",
"from dask import delayed\n",
"from dask_glm import families\n",
"from dask_glm.algorithms import lbfgs\n",
"from distributed import LocalCluster, Client, worker_client, Variable, Queue\n",
"import numpy as np\n",
"from time import time, sleep\n",
"from sklearn import datasets\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.metrics import roc_auc_score\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from contextlib import contextmanager\n",
"import sys\n",
"\n",
"@contextmanager\n",
"def duration(s):\n",
" start = time()\n",
" try:\n",
" yield\n",
" finally:\n",
" end = time()\n",
" print(s, '%.2e' % (end - start))\n",
" sys.stdout.flush()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def local_update(X, y, beta):\n",
" return pointwise_gradient(beta, X, y)\n",
"\n",
"def parameter_server():\n",
" beta = np.zeros(D)\n",
" gti = np.zeros(D)\n",
" with worker_client() as c:\n",
" betas = Variable('betas', client=c)\n",
" stop = Variable('stop', client=c)\n",
" updates = Queue('updates', client=c, maxsize=30)\n",
"\n",
" future_beta = c.scatter(beta)\n",
" betas.set(future_beta)\n",
" \n",
" gather_times = []\n",
" compute_times = []\n",
" scatter_times = []\n",
"\n",
" while not stop.get():\n",
" start = time()\n",
" futures = updates.get(batch=True)\n",
" batch = c.gather(futures)\n",
" end = time()\n",
" gather_times.append(end - start)\n",
" # print(\"PS: received %d updates\" % len(batch))\n",
" \n",
" start = time()\n",
" for update in batch:\n",
" gti += update ** 2\n",
" adj_grad = update / (1e-6 + np.sqrt(gti))\n",
" beta = beta - STEP_SIZE * adj_grad\n",
" end = time()\n",
" compute_times.append(end - start)\n",
" \n",
" start = time()\n",
" future_beta = c.scatter(beta)\n",
" betas.set(future_beta)\n",
" end = time()\n",
" scatter_times.append(end - start)\n",
" \n",
" # print(\"PS: pushed beta: %s\" % beta)\n",
" \n",
" return {'gather': sum(gather_times) / len(gather_times),\n",
" 'compute': sum(compute_times) / len(compute_times),\n",
" 'scatter': sum(scatter_times) / len(scatter_times)}\n",
"\n",
"\n",
"def worker(X, y, i):\n",
" with worker_client(separate_thread=False) as c:\n",
" betas = Variable('betas', client=c)\n",
" stop = Variable('stop', client=c)\n",
" updates = Queue('updates', client=c, maxsize=30)\n",
" \n",
" gather_times = []\n",
" compute_times = []\n",
" scatter_times = []\n",
"\n",
" while not stop.get():\n",
" start = time()\n",
" future = betas.get()\n",
" beta = c.gather(future)\n",
" end = time()\n",
" gather_times.append(end - start)\n",
" # print(\"Worker %d: received beta\" % i)\n",
"\n",
" start = time()\n",
" update = local_update(X, y, beta) #.compute()\n",
" end = time()\n",
" compute_times.append(end - start)\n",
" \n",
" start = time()\n",
" update_future = c.scatter(update)\n",
" updates.put(update_future)\n",
" end = time()\n",
" scatter_times.append(end - start)\n",
" # print(\"Worker %d: pushed update\" % i)\n",
" \n",
" return {'gather': sum(gather_times) / len(gather_times),\n",
" 'compute': sum(compute_times) / len(compute_times),\n",
" 'scatter': sum(scatter_times) / len(scatter_times)}\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"STEP_SIZE = 1.0\n",
"N = 100000\n",
"D = 10\n",
"ITER = 10\n",
"\n",
"X_local, y_local = datasets.make_classification(n_classes=2, n_samples=N, n_features=D)\n",
"X = da.from_array(X_local, N // 8)\n",
"y = da.from_array(y_local, N // 8)\n",
"\n",
"XD = X.to_delayed().flatten().tolist() # a list of numpy arrays, one for each chunk\n",
"yD = y.to_delayed().flatten().tolist()\n",
"\n",
"STEP_SIZE /= len(XD) # need to adjust based on parallelism for convegence?\n",
"\n",
"family = families.Logistic()\n",
"pointwise_gradient = family.pointwise_gradient\n",
"pointwise_loss = family.pointwise_loss\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"cluster = LocalCluster(n_workers=0, processes=True)\n",
"cluster.start_worker(1, name=\"ps\")\n",
"cluster.start_worker(4, name=\"w1\")\n",
"cluster.start_worker(4, name=\"w2\")\n",
"client = Client(cluster)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"updates queue size: 1\n",
"================\n",
" 1.0716931819915771\n",
"updates queue size: 1\n",
"================\n",
" 2.1429290771484375\n",
"updates queue size: 1\n",
"================\n",
" 3.2133562564849854\n",
"updates queue size: 1\n",
"================\n",
" 4.276886463165283\n",
"updates queue size: 1\n",
"================\n",
" 5.366945266723633\n",
"updates queue size: 4\n",
"================\n",
" 6.4310901165008545\n",
"updates queue size: 4\n",
"================\n",
" 7.516523361206055\n",
"updates queue size: 1\n",
"================\n",
" 8.668256521224976\n",
"updates queue size: 5\n",
"================\n",
" 9.724462032318115\n",
"updates queue size: 6\n",
"================\n",
" 10.772139072418213\n",
"Converged to [-0.62378075 0.00243964 -0.00749007 1.2745907 -0.49762186 1.34883934\n",
" 0.00507104 -0.00692682 -0.01895825 -0.00855459]\n"
]
}
],
"source": [
"betas = Variable('betas')\n",
"stop = Variable('stop')\n",
"updates = Queue('updates', maxsize=30)\n",
"stop.set(False)\n",
"\n",
"# start parameter server\n",
"res_ps = client.submit(parameter_server, workers=['ps'], pure=False)\n",
"# start workers computing\n",
"XD = client.compute(XD)\n",
"yD = client.compute(yD)\n",
"res_workers = [client.submit(worker, xx, yy, i)\n",
" for i, (xx, yy) in enumerate(zip(XD, yD))]\n",
"\n",
"# Stop if beta converges or after ten seconds\n",
"last = -1\n",
"count_same = 0\n",
"start = time()\n",
"while count_same < 5:\n",
" beta = betas.get().result()\n",
" if np.allclose(last, beta, atol=1e-5):\n",
" count_same += 1\n",
" else:\n",
" count_same = 0\n",
" last = beta\n",
" sleep(1)\n",
" print('updates queue size:', updates.qsize())\n",
" print(\"================\\n\", time() - start)\n",
" if time() - start > 10:\n",
" break\n",
"\n",
"stop.set(True)\n",
"print(\"Converged to\", beta)\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"{'compute': 0.00010444972436752678,\n",
" 'gather': 0.02480465145737912,\n",
" 'scatter': 0.018120493687374492}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res_ps.result()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"compute 0.001567\n",
"gather 0.026303\n",
"scatter 0.022501\n",
"dtype: float64"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"df = pd.DataFrame(client.gather(res_workers))\n",
"df.mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Time is spent in overhead\n",
"\n",
"We see that the majority of time spent by the workers and parameter server is gathering and scattering data. Presumably this 10-20ms is overhead? We should be able to get these numbers down.\n",
"\n",
"How much of an issue is this? How low-latency do we need to go to be competitive?"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 0 ns, sys: 4 ms, total: 4 ms\n",
"Wall time: 5.43 ms\n"
]
}
],
"source": [
"%time future = betas.get()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4 ms, sys: 0 ns, total: 4 ms\n",
"Wall time: 8.02 ms\n"
]
}
],
"source": [
"%time _ = client.gather(future, direct=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment