Skip to content

Instantly share code, notes, and snippets.

@thomasbrandon
Created September 27, 2019 14:30
Show Gist options
  • Save thomasbrandon/2e8e365086ed20cef655dac11cd8c365 to your computer and use it in GitHub Desktop.
Save thomasbrandon/2e8e365086ed20cef655dac11cd8c365 to your computer and use it in GitHub Desktop.
NB for MNIST Stats update in Fastai
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"# MNIST Stats"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"from fastai.vision import *\n",
"from itertools import islice"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"DATA = untar_data(URLs.MNIST)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"src = (ImageList.from_folder(DATA)\n",
" .split_by_folder(train='training', valid='testing')\n",
" .label_from_folder())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"text/plain": [
"([0.15, 0.15, 0.15], [0.15, 0.15, 0.15])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mnist_stats"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"So the stats look odd, looks like copy-paste error."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x576 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"data = src.databunch(bs=4, num_workers=0).normalize(mnist_stats)\n",
"data.show_batch()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"text/plain": [
"[(tensor(0.1296), tensor(0.3091)),\n",
" (tensor(0.1280), tensor(0.3092)),\n",
" (tensor(0.2026), tensor(0.3676)),\n",
" (tensor(0.1413), tensor(0.3152)),\n",
" (tensor(0.1778), tensor(0.3480))]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[(im.px.mean(), im.px.std()) for im in src.train.x[:5]]"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"Not the same as the fastai stats."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean= 0.024, std=2.227\n",
"mean= 0.007, std=2.202\n",
"mean=-0.153, std=2.062\n",
"mean=-0.178, std=2.000\n",
"mean=-0.050, std=2.137\n"
]
}
],
"source": [
"for xb,_ in islice(data.train_dl, 5):\n",
" print(f\"mean={xb.mean().item(): 0.3f}, std={xb.std().item():0.3f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"Resulting std isn't what we're going for."
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"## Calculate stats"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"class RunningStatistics:\n",
" '''Records mean and variance of the final `n_dims` dimension over other dimensions across items. So collecting across `(l,m,n,o)` sized\n",
" items with `n_dims=1` will collect `(l,m,n)` sized statistics while with `n_dims=2` the collected statistics will be of size `(l,m)`.\n",
"\n",
" Uses the algorithm from Chan, Golub, and LeVeque in \"Algorithms for computing the sample variance: analysis and recommendations\":\n",
"\n",
" `variance = variance1 + variance2 + n/(m*(m+n)) * pow(((m/n)*t1 - t2), 2)`\n",
"\n",
" This combines the variance for 2 blocks: block 1 having `n` elements with `variance1` and a sum of `t1` and block 2 having `m` elements\n",
" with `variance2` and a sum of `t2`. The algorithm is proven to be numerically stable but there is a reasonable loss of accuracy (~0.1% error).\n",
"\n",
" Note that collecting minimum and maximum values is reasonably innefficient, adding about 80% to the running time, and hence is disabled by default.\n",
" '''\n",
" def __init__(self, n_dims:int=2, record_range=False):\n",
" self._n_dims,self._range = n_dims,record_range\n",
" self.n,self.sum,self.min,self.max = 0,None,None,None\n",
" \n",
" def update(self, data:Tensor):\n",
" data = data.view(*list(data.shape[:-self._n_dims]) + [-1])\n",
" with torch.no_grad():\n",
" new_n,new_var,new_sum = data.shape[-1],data.var(-1),data.sum(-1)\n",
" if self.n == 0:\n",
" self.n = new_n\n",
" self._shape = data.shape[:-1]\n",
" self.sum = new_sum\n",
" self._nvar = new_var.mul_(new_n)\n",
" if self._range:\n",
" self.min = data.min(-1)[0]\n",
" self.max = data.max(-1)[0]\n",
" else:\n",
" assert data.shape[:-1] == self._shape, f\"Mismatched shapes, expected {self._shape} but got {data.shape[:-1]}.\"\n",
" ratio = self.n / new_n\n",
" t = (self.sum / ratio).sub_(new_sum).pow_(2)\n",
" self._nvar.add_(new_n, new_var).add_(ratio / (self.n + new_n), t)\n",
" self.sum.add_(new_sum)\n",
" self.n += new_n\n",
" if self._range:\n",
" self.min = torch.min(self.min, data.min(-1)[0])\n",
" self.max = torch.max(self.max, data.max(-1)[0])\n",
"\n",
" @property\n",
" def mean(self): return self.sum / self.n if self.n > 0 else None\n",
" @property\n",
" def var(self): return self._nvar / self.n if self.n > 0 else None\n",
" @property\n",
" def std(self): return self.var.sqrt() if self.n > 0 else None\n",
"\n",
" def __repr__(self):\n",
" def _fmt_t(t:Tensor):\n",
" if t.numel() > 5: return f\"tensor of ({','.join(map(str,t.shape))})\"\n",
" def __fmt_t(t:Tensor):\n",
" return '[' + ','.join([f\"{v:.3g}\" if v.ndim==0 else __fmt_t(v) for v in t]) + ']'\n",
" return __fmt_t(t)\n",
" rng_str = f\", min={_fmt_t(self.min)}, max={_fmt_t(self.max)}\" if self._range else \"\"\n",
" return f\"RunningStatistics(n={self.n}, mean={_fmt_t(self.mean)}, std={_fmt_t(self.std)}{rng_str})\"\n",
"\n",
"def collect_stats(items:Iterable, n_dims:int=2, record_range:bool=False):\n",
" stats = RunningStatistics(n_dims, record_range)\n",
" for it in progress_bar(items):\n",
" it = getattr(it, 'data', it) # Use data from fastai Image\n",
" stats.update(it)\n",
" return stats"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='60000' class='' max='60000', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 100.00% [60000/60000 00:15<00:00]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"RunningStatistics(n=47040000, mean=[0.131,0.131,0.131], std=[0.308,0.308,0.308])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"stats = collect_stats(src.train.x)\n",
"stats"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"Algorithm seems to work, slight innaccuracy but seems numerically stable (in my local version I `assert_allclose(..., rtol=0.001, atol=0.01)` in unit tests on `torch.randn` data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"## Apply stats"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"text/plain": [
"([0.131, 0.131, 0.131], [0.308, 0.308, 0.308])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"true_mnist_stats = ([0.131]*3, [0.308]*3)\n",
"true_mnist_stats"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean= 0.021, std=1.016\n",
"mean= 0.032, std=1.022\n",
"mean= 0.046, std=1.050\n",
"mean= 0.060, std=1.058\n",
"mean= 0.047, std=1.045\n"
]
}
],
"source": [
"data = src.databunch(bs=4, num_workers=0).normalize(true_mnist_stats)\n",
"for xb,_ in islice(data.train_dl, 5):\n",
" print(f\"mean={xb.mean().item(): 0.3f}, std={xb.std().item():0.3f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"Looks right."
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"However, for grayscale:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([8, 3, 28, 28])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = (ImageList.from_folder(DATA, convert_mode='L')\n",
" .split_by_folder(train='training', valid='testing')\n",
" .label_from_folder()\n",
" .databunch(bs=8, num_workers=0)\n",
" .normalize(true_mnist_stats))\n",
"xb,_ = next(iter(data.train_dl))\n",
"xb.shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"It's broadcast the image up to 3 channels. So:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"good_mnist_stats = ([0.131], [0.308])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([8, 1, 28, 28])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = (ImageList.from_folder(DATA, convert_mode='L')\n",
" .split_by_folder(train='training', valid='testing')\n",
" .label_from_folder()\n",
" .databunch(bs=8, num_workers=0)\n",
" .normalize(good_mnist_stats))\n",
"xb,_ = next(iter(data.train_dl))\n",
"xb.shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"These work for RGB as well where it broadcasts the stats:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([8, 3, 28, 28])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = (ImageList.from_folder(DATA)\n",
" .split_by_folder(train='training', valid='testing')\n",
" .label_from_folder()\n",
" .databunch(bs=8, num_workers=0)\n",
" .normalize(good_mnist_stats))\n",
"xb,_ = next(iter(data.train_dl))\n",
"xb.shape"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:.conda-fastai-dev]",
"language": "python",
"name": "conda-env-.conda-fastai-dev-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment