Skip to content

Instantly share code, notes, and snippets.

@lajosdeme
Forked from stsievert/Centralized-PS.ipynb
Created December 6, 2023 23:44
Show Gist options
  • Save lajosdeme/e6a5b465938d64e391cb8b16cb15dda5 to your computer and use it in GitHub Desktop.
Save lajosdeme/e6a5b465938d64e391cb8b16cb15dda5 to your computer and use it in GitHub Desktop.
PyTorch MNIST parameter server
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is a centralized parameter server with PyTorch's MNIST example. Is is centralized because the model is stored in one place. This requires the communication of models in addition to the communication of gradients.\n",
"\n",
"The next couple cells are basically copy/pasted from [PyTorch's MNIST example](https://github.com/pytorch/examples/tree/master/mnist)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"execution_count": 1,
"metadata": {
"image/png": {
"width": 400
}
},
"output_type": "execute_result"
}
],
"source": [
"from IPython.display import Image\n",
"Image('./centralized.png', width=400)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from __future__ import print_function\n",
"import argparse\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torchvision import datasets, transforms"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n",
" self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n",
" self.conv2_drop = nn.Dropout2d()\n",
" self.fc1 = nn.Linear(320, 50)\n",
" self.fc2 = nn.Linear(50, 10)\n",
"\n",
" def forward(self, x):\n",
" x = F.relu(F.max_pool2d(self.conv1(x), 2))\n",
" x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n",
" x = x.view(-1, 320)\n",
" x = F.relu(self.fc1(x))\n",
" x = F.dropout(x, training=self.training)\n",
" x = self.fc2(x)\n",
" return F.log_softmax(x, dim=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Depends on serialization of torch.Device objects: https://github.com/pytorch/pytorch/pull/7713"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from types import SimpleNamespace\n",
"args = SimpleNamespace(batch_size=64, test_batch_size=1000,\n",
" epochs=2, lr=0.01, momentum=0.5,\n",
" no_cuda=True, seed=42, log_interval=80)\n",
" \n",
"use_cuda = not args.no_cuda and torch.cuda.is_available()\n",
"\n",
"torch.manual_seed(args.seed)\n",
"\n",
"kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}\n",
"train_loader = torch.utils.data.DataLoader(\n",
" datasets.MNIST('../data', train=True, download=True,\n",
" transform=transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
" ])),\n",
" batch_size=args.batch_size, shuffle=True, **kwargs)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/Users/ssievert/Developer/mrocklin/distributed/distributed/__init__.py\n"
]
},
{
"data": {
"text/html": [
"<table style=\"border: 2px solid white;\">\n",
"<tr>\n",
"<td style=\"vertical-align: top; border: 0px solid white\">\n",
"<h3>Client</h3>\n",
"<ul>\n",
" <li><b>Scheduler: </b>tcp://127.0.0.1:49717\n",
" <li><b>Dashboard: </b><a href='http://127.0.0.1:8787/status' target='_blank'>http://127.0.0.1:8787/status</a>\n",
"</ul>\n",
"</td>\n",
"<td style=\"vertical-align: top; border: 0px solid white\">\n",
"<h3>Cluster</h3>\n",
"<ul>\n",
" <li><b>Workers: </b>8</li>\n",
" <li><b>Cores: </b>8</li>\n",
" <li><b>Memory: </b>17.18 GB</li>\n",
"</ul>\n",
"</td>\n",
"</tr>\n",
"</table>"
],
"text/plain": [
"<Client: scheduler='tcp://127.0.0.1:49717' processes=8 cores=8>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from distributed import Client\n",
"import distributed as d\n",
"print(d.__file__)\n",
"client = Client()\n",
"client"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This next cell is *almost* copy and pasted. It takes in a `(data, target)` pair instead of `train_loader`, does not do anything with the optimizer."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def train(model, device, data, target):\n",
" model.train()\n",
" \n",
" data, target = data.to(device), target.to(device)\n",
" # optimizer.zero_grad()\n",
" output = model(data)\n",
" loss = F.nll_loss(output, target)\n",
" loss.backward()\n",
" # optimizer.step()\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from time import sleep\n",
"import copy\n",
"\n",
"def clone(model):\n",
" return copy.deepcopy(model)\n",
"\n",
"def test_clone_no_modification():\n",
" model = Net()\n",
" m2 = clone(model)\n",
"\n",
" m2_params = dict(m2.named_parameters())\n",
" for name, param in model.named_parameters():\n",
" param = param.detach()\n",
" param.data += 1\n",
" assert (m2_params[name].data != param.data).all()\n",
" \n",
"test_clone_no_modification()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here's the definition of our actor. It sends out models, and receives gradients."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"class PS:\n",
" def __init__(self, model, num_workers=1):\n",
" self.models = {0: model}\n",
" self._grads = {}\n",
" self.model = model\n",
" self.optimizer = optim.SGD(self.model.parameters(), lr=args.lr, momentum=args.momentum)\n",
" self.num_workers = num_workers\n",
" \n",
" def pull(self, key):\n",
" \"\"\"\n",
" For a worker to pull a model from this PS\n",
" \"\"\"\n",
" if key not in self.models:\n",
" return None\n",
" return self.models[key]\n",
" \n",
" def pull_latest(self):\n",
" key = max(self.models)\n",
" return key, self.pull(key)\n",
" \n",
" def push(self, key, grads):\n",
" \"\"\"\n",
" For a worker to push some gradients to this PS\n",
" \"\"\"\n",
" if key not in self._grads:\n",
" self._grads[key] = []\n",
" self._grads[key] += [grads]\n",
" \n",
" # have we collected enough gradients?\n",
" if len(self._grads[key]) == self.num_workers:\n",
" old_model = clone(self.model)\n",
" self.aggregate(self._grads[key])\n",
" self.models[key + 1] = self.model\n",
" self.models[key] = old_model\n",
" \n",
" def aggregate(self, grads):\n",
" for name, param in self.model.named_parameters():\n",
" worker_grads = [grad[name] for grad in grads]\n",
" param.grad = sum(worker_grads)\n",
" self.optimizer.step()\n",
" self.optimizer.zero_grad()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import itertools\n",
"from time import sleep, time\n",
"import toolz\n",
"import numpy as np\n",
"\n",
"def worker(ps, device, train_loader,\n",
" worker_id=0, num_workers=1,\n",
" iters=5):\n",
" meta = {'comm_model': 0, 'compute_grad': 0, 'comm_grad': 0}\n",
" step_start, _model = ps.pull_latest().result()\n",
" whole_start = time()\n",
" params = [np.prod(tuple(p.size())) for p in _model.parameters()]\n",
" \n",
" for step in range(step_start, step_start + iters):\n",
" start = time()\n",
" while _model is None:\n",
" _model = ps.pull(key=step).result()\n",
" sleep(1e-4)\n",
" meta['comm_model'] += time() - start\n",
" model, _model = _model, None\n",
" \n",
" param_check = toolz.last(model.parameters())\n",
" check = param_check.detach().numpy().flat[:4]\n",
" print(\"worker {} iter {}, last params = {}\".format(worker_id, step, check))\n",
" \n",
" data, target = next(iter(train_loader))\n",
" start = time()\n",
" model = train(model, device, data, target)\n",
" grads = {name: p.grad.data for name, p in model.named_parameters()}\n",
" meta['compute_grad'] += time() - start\n",
" \n",
" start = time()\n",
" ps.push(step, grads)\n",
" meta['comm_grad'] += time() - start\n",
" meta = {k: v / iters for k, v in meta.items()}\n",
" meta['avg_step_time'] = (time() - whole_start) / iters\n",
" meta['params'] = sum(params)\n",
" return meta"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Actor: PS, key=PS-362f3f6b-5654-4665-ad6a-79f69aa95d59>"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
"model = Net().to(device)\n",
"num_workers = 4\n",
"\n",
"model = client.scatter(model)\n",
"train_loader = client.scatter(train_loader)\n",
"\n",
"ps = client.gather(client.submit(PS, model, num_workers=num_workers,\n",
" actor=True))\n",
"ps"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here's calling the train function. This can be called repeated times, since the function `worker` gets the latest model to start."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"futures = [client.submit(worker, ps, device, train_loader,\n",
" worker_id=i, num_workers=num_workers)\n",
" for i in range(num_workers)]\n",
"meta = client.gather(futures)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>avg_step_time</th>\n",
" <th>comm_grad</th>\n",
" <th>comm_model</th>\n",
" <th>compute_grad</th>\n",
" <th>params</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.334983</td>\n",
" <td>0.000134</td>\n",
" <td>0.259987</td>\n",
" <td>0.054862</td>\n",
" <td>21840</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.245956</td>\n",
" <td>0.000338</td>\n",
" <td>0.175483</td>\n",
" <td>0.054756</td>\n",
" <td>21840</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.112934</td>\n",
" <td>0.000140</td>\n",
" <td>0.035176</td>\n",
" <td>0.063261</td>\n",
" <td>21840</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.113489</td>\n",
" <td>0.000145</td>\n",
" <td>0.043078</td>\n",
" <td>0.055989</td>\n",
" <td>21840</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" avg_step_time comm_grad comm_model compute_grad params\n",
"0 0.334983 0.000134 0.259987 0.054862 21840\n",
"1 0.245956 0.000338 0.175483 0.054756 21840\n",
"2 0.112934 0.000140 0.035176 0.063261 21840\n",
"3 0.113489 0.000145 0.043078 0.055989 21840"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"df = pd.DataFrame(meta)\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.axes._subplots.AxesSubplot at 0x1245cc470>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show = df.groupby('params').mean()\n",
"show.plot.bar()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment