Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Learned RELU
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"from exp.nb_06 import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ConvNet"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's get the data and training interface from where we left in the last notebook."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"x_train,y_train,x_valid,y_valid = get_data()\n",
"\n",
"x_train,x_valid = normalize_to(x_train,x_valid)\n",
"train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)\n",
"\n",
"nh,bs = 50,512\n",
"c = y_train.max().item()+1\n",
"loss_func = F.cross_entropy\n",
"\n",
"data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"mnist_view = view_tfm(1,28,28)\n",
"cbfs = [Recorder,\n",
" partial(AvgStatsCallback,accuracy),\n",
" CudaCallback,\n",
" partial(BatchTransformXCallback, mnist_view)]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"nfs = [8,16,32,64,64]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Learned RELU"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"param = nn.Parameter(torch.tensor(0.1))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.10000000149011612"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"param.item()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"class LearnedRelu(nn.Module):\n",
" def __init__(self, leak=0.1, sub=0.25, maxv=100):\n",
" super().__init__()\n",
" self.leak = nn.Parameter(torch.ones(1)*leak)\n",
" self.sub = nn.Parameter(torch.zeros(1)+sub)\n",
" self.maxv = nn.Parameter(torch.ones(1)*maxv)\n",
"\n",
" def forward(self, x): \n",
" x = F.leaky_relu(x,self.leak.item())\n",
" x.sub_(self.sub)\n",
" x.clamp_max_(self.maxv.item()) \n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Batchnorm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Custom"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's start by building our own `BatchNorm` layer from scratch."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"class BatchNorm(nn.Module):\n",
" def __init__(self, nf, mom=0.1, eps=1e-5):\n",
" super().__init__()\n",
" # NB: pytorch bn mom is opposite of what you'd expect\n",
" self.mom,self.eps = mom,eps\n",
" self.mults = nn.Parameter(torch.ones (nf,1,1))\n",
" self.adds = nn.Parameter(torch.zeros(nf,1,1))\n",
" self.register_buffer('vars', torch.ones(1,nf,1,1))\n",
" self.register_buffer('means', torch.zeros(1,nf,1,1))\n",
"\n",
" def update_stats(self, x):\n",
" m = x.mean((0,2,3), keepdim=True)\n",
" v = x.var ((0,2,3), keepdim=True)\n",
" self.means.lerp_(m, self.mom)\n",
" self.vars.lerp_ (v, self.mom)\n",
" return m,v\n",
" \n",
" def forward(self, x):\n",
" if self.training:\n",
" with torch.no_grad(): m,v = self.update_stats(x)\n",
" else: m,v = self.means,self.vars\n",
" x = (x-m) / (v+self.eps).sqrt()\n",
" return x*self.mults + self.adds"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def conv_layer_gen(ni, nf, ks=3, stride=2, bn=True, **kwargs):\n",
" # No bias needed if using bn\n",
" layers = [nn.Conv2d(ni, nf, ks, padding=ks//2, stride=stride, bias=not bn),\n",
" GeneralRelu(**kwargs)]\n",
" if bn: layers.append(BatchNorm(nf))\n",
" return nn.Sequential(*layers)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def conv_layer_learn(ni, nf, ks=3, stride=2, bn=True, **kwargs):\n",
" # No bias needed if using bn\n",
" layers = [nn.Conv2d(ni, nf, ks, padding=ks//2, stride=stride, bias=not bn),\n",
" LearnedRelu(**kwargs)]\n",
" if bn: layers.append(BatchNorm(nf))\n",
" return nn.Sequential(*layers)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"def init_cnn_(m, f):\n",
" if isinstance(m, nn.Conv2d):\n",
" f(m.weight, a=0.1)\n",
" if getattr(m, 'bias', None) is not None: m.bias.data.zero_()\n",
" for l in m.children(): init_cnn_(l, f)\n",
"\n",
"def init_cnn(m, uniform=False):\n",
" f = init.kaiming_uniform_ if uniform else init.kaiming_normal_\n",
" init_cnn_(m, f)\n",
"\n",
"def get_learn_run(nfs, data, lr, layer, cbs=None, opt_func=None, uniform=False, **kwargs):\n",
" model = get_cnn_model(data, nfs, layer, **kwargs)\n",
" init_cnn(model, uniform=uniform)\n",
" return get_runner(model, data, lr=lr, cbs=cbs, opt_func=opt_func)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can then use it in training and see how it helps keep the activations means to 0 and the std to 1."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"learn,run = get_learn_run(nfs, data, 0.9, conv_layer_gen, cbs=cbfs)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train: [0.2491483984375, tensor(0.9221, device='cuda:0')]\n",
"valid: [0.11270089111328126, tensor(0.9654, device='cuda:0')]\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"with Hooks(learn.model, append_stats) as hooks:\n",
" run.fit(1, learn)\n",
" fig,(ax0,ax1) = plt.subplots(1,2, figsize=(10,4))\n",
" for h in hooks[:-1]:\n",
" ms,ss = h.stats\n",
" ax0.plot(ms[:10])\n",
" ax1.plot(ss[:10])\n",
" plt.legend(range(6));\n",
" \n",
" fig,(ax0,ax1) = plt.subplots(1,2, figsize=(10,4))\n",
" for h in hooks[:-1]:\n",
" ms,ss = h.stats\n",
" ax0.plot(ms)\n",
" ax1.plot(ss)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"learn,run = get_learn_run(nfs, data, 1.0, conv_layer_gen, cbs=cbfs)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train: [0.26710783203125, tensor(0.9157, device='cuda:0')]\n",
"valid: [0.17995037841796874, tensor(0.9414, device='cuda:0')]\n",
"train: [0.091767626953125, tensor(0.9712, device='cuda:0')]\n",
"valid: [0.15938792724609374, tensor(0.9472, device='cuda:0')]\n",
"train: [0.0640931494140625, tensor(0.9801, device='cuda:0')]\n",
"valid: [0.15299569091796875, tensor(0.9506, device='cuda:0')]\n",
"CPU times: user 2.73 s, sys: 272 ms, total: 3 s\n",
"Wall time: 3.06 s\n"
]
}
],
"source": [
"%time run.fit(3, learn)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"learn,run = get_learn_run(nfs, data, 0.9, conv_layer_learn, cbs=cbfs)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train: [0.26300181640625, tensor(0.9198, device='cuda:0')]\n",
"valid: [0.2810866943359375, tensor(0.9055, device='cuda:0')]\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"with Hooks(learn.model, append_stats) as hooks:\n",
" run.fit(1, learn)\n",
" fig,(ax0,ax1) = plt.subplots(1,2, figsize=(10,4))\n",
" for h in hooks[:-1]:\n",
" ms,ss = h.stats\n",
" ax0.plot(ms[:10])\n",
" ax1.plot(ss[:10])\n",
" plt.legend(range(6));\n",
" \n",
" fig,(ax0,ax1) = plt.subplots(1,2, figsize=(10,4))\n",
" for h in hooks[:-1]:\n",
" ms,ss = h.stats\n",
" ax0.plot(ms)\n",
" ax1.plot(ss)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"learn,run = get_learn_run(nfs, data, 1.0, conv_layer_learn, cbs=cbfs)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train: [0.253928984375, tensor(0.9200, device='cuda:0')]\n",
"valid: [0.165762109375, tensor(0.9478, device='cuda:0')]\n",
"train: [0.088923076171875, tensor(0.9723, device='cuda:0')]\n",
"valid: [0.3850069091796875, tensor(0.9094, device='cuda:0')]\n",
"train: [0.0636737890625, tensor(0.9805, device='cuda:0')]\n",
"valid: [0.07958184814453124, tensor(0.9754, device='cuda:0')]\n",
"CPU times: user 3.47 s, sys: 208 ms, total: 3.68 s\n",
"Wall time: 3.74 s\n"
]
}
],
"source": [
"%time run.fit(3, learn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### With scheduler"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's add the usual warm-up/annealing."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"sched = combine_scheds([0.3, 0.7], [sched_lin(0.6, 2.), sched_lin(2., 0.1)]) "
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"learn,run = get_learn_run(nfs, data, 0.9, conv_layer_gen, cbs=cbfs\n",
" +[partial(ParamScheduler,'lr', sched)])"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train: [0.301419296875, tensor(0.9094, device='cuda:0')]\n",
"valid: [0.2026967529296875, tensor(0.9417, device='cuda:0')]\n",
"train: [0.10052498046875, tensor(0.9687, device='cuda:0')]\n",
"valid: [0.1126065673828125, tensor(0.9666, device='cuda:0')]\n",
"train: [0.070996572265625, tensor(0.9776, device='cuda:0')]\n",
"valid: [0.13142176513671874, tensor(0.9617, device='cuda:0')]\n",
"train: [0.0482486083984375, tensor(0.9854, device='cuda:0')]\n",
"valid: [0.07715589599609375, tensor(0.9771, device='cuda:0')]\n",
"train: [0.0344202392578125, tensor(0.9894, device='cuda:0')]\n",
"valid: [0.06021231689453125, tensor(0.9828, device='cuda:0')]\n",
"train: [0.0255533837890625, tensor(0.9928, device='cuda:0')]\n",
"valid: [0.055064190673828124, tensor(0.9840, device='cuda:0')]\n",
"train: [0.01967509033203125, tensor(0.9944, device='cuda:0')]\n",
"valid: [0.04997901000976562, tensor(0.9847, device='cuda:0')]\n",
"train: [0.016732589111328126, tensor(0.9958, device='cuda:0')]\n",
"valid: [0.047377349853515625, tensor(0.9860, device='cuda:0')]\n"
]
}
],
"source": [
"run.fit(8, learn)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"learn,run = get_learn_run(nfs, data, 0.9, conv_layer_learn, cbs=cbfs\n",
" +[partial(ParamScheduler,'lr', sched)])"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train: [0.3009943359375, tensor(0.9079, device='cuda:0')]\n",
"valid: [0.2146206298828125, tensor(0.9312, device='cuda:0')]\n",
"train: [0.100962646484375, tensor(0.9694, device='cuda:0')]\n",
"valid: [0.3427990234375, tensor(0.8986, device='cuda:0')]\n",
"train: [0.20606509765625, tensor(0.9405, device='cuda:0')]\n",
"valid: [0.09719484252929687, tensor(0.9709, device='cuda:0')]\n",
"train: [0.06491736328125, tensor(0.9801, device='cuda:0')]\n",
"valid: [0.17026732177734374, tensor(0.9478, device='cuda:0')]\n",
"train: [0.045457685546875, tensor(0.9863, device='cuda:0')]\n",
"valid: [0.0574800048828125, tensor(0.9833, device='cuda:0')]\n",
"train: [0.0349907958984375, tensor(0.9900, device='cuda:0')]\n",
"valid: [0.05554625244140625, tensor(0.9841, device='cuda:0')]\n",
"train: [0.02864280029296875, tensor(0.9917, device='cuda:0')]\n",
"valid: [0.045075537109375, tensor(0.9876, device='cuda:0')]\n",
"train: [0.02446511474609375, tensor(0.9931, device='cuda:0')]\n",
"valid: [0.042788345336914065, tensor(0.9875, device='cuda:0')]\n"
]
}
],
"source": [
"run.fit(8, learn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## More norms"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Layer norm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From [the paper](https://arxiv.org/abs/1607.06450): \"*batch normalization cannot be applied to online learning tasks or to extremely large distributed models where the minibatches have to be small*\"."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"General equation for a norm layer with learnable affine:\n",
"\n",
"$$y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta$$\n",
"\n",
"The difference with BatchNorm is\n",
"1. we don't keep a moving average\n",
"2. we don't average over the batches dimension but over the hidden dimension, so it's independent of the batch size"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"class LayerNorm(nn.Module):\n",
" __constants__ = ['eps']\n",
" def __init__(self, eps=1e-5):\n",
" super().__init__()\n",
" self.eps = eps\n",
" self.mult = nn.Parameter(tensor(1.))\n",
" self.add = nn.Parameter(tensor(0.))\n",
"\n",
" def forward(self, x):\n",
" m = x.mean((1,2,3), keepdim=True)\n",
" v = x.var ((1,2,3), keepdim=True)\n",
" x = (x-m) / ((v+self.eps).sqrt())\n",
" return x*self.mult + self.add"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"def conv_ln(ni, nf, ks=3, stride=2, bn=True, **kwargs):\n",
" layers = [nn.Conv2d(ni, nf, ks, padding=ks//2, stride=stride, bias=True),\n",
" GeneralRelu(**kwargs)]\n",
" if bn: layers.append(LayerNorm())\n",
" return nn.Sequential(*layers)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"learn,run = get_learn_run(nfs, data, 0.8, conv_ln, cbs=cbfs)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train: [nan, tensor(0.1355, device='cuda:0')]\n",
"valid: [nan, tensor(0.0991, device='cuda:0')]\n",
"train: [nan, tensor(0.0986, device='cuda:0')]\n",
"valid: [nan, tensor(0.0991, device='cuda:0')]\n",
"train: [nan, tensor(0.0986, device='cuda:0')]\n",
"valid: [nan, tensor(0.0991, device='cuda:0')]\n",
"CPU times: user 3.84 s, sys: 221 ms, total: 4.07 s\n",
"Wall time: 4.09 s\n"
]
}
],
"source": [
"%time run.fit(3, learn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Thought experiment*: can this distinguish foggy days from sunny days (assuming you're using it before the first conv)?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Instance norm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From [the paper](https://arxiv.org/abs/1607.08022): "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The key difference between **contrast** and batch normalization is that the latter applies the normalization to a whole batch of images instead for single ones:\n",
"\n",
"\\begin{equation}\\label{eq:bnorm}\n",
" y_{tijk} = \\frac{x_{tijk} - \\mu_{i}}{\\sqrt{\\sigma_i^2 + \\epsilon}},\n",
" \\quad\n",
" \\mu_i = \\frac{1}{HWT}\\sum_{t=1}^T\\sum_{l=1}^W \\sum_{m=1}^H x_{tilm},\n",
" \\quad\n",
" \\sigma_i^2 = \\frac{1}{HWT}\\sum_{t=1}^T\\sum_{l=1}^W \\sum_{m=1}^H (x_{tilm} - mu_i)^2.\n",
"\\end{equation}\n",
"\n",
"In order to combine the effects of instance-specific normalization and batch normalization, we propose to replace the latter by the *instance normalization* (also known as *contrast normalization*) layer:\n",
"\n",
"\\begin{equation}\\label{eq:inorm}\n",
" y_{tijk} = \\frac{x_{tijk} - \\mu_{ti}}{\\sqrt{\\sigma_{ti}^2 + \\epsilon}},\n",
" \\quad\n",
" \\mu_{ti} = \\frac{1}{HW}\\sum_{l=1}^W \\sum_{m=1}^H x_{tilm},\n",
" \\quad\n",
" \\sigma_{ti}^2 = \\frac{1}{HW}\\sum_{l=1}^W \\sum_{m=1}^H (x_{tilm} - mu_{ti})^2.\n",
"\\end{equation}"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"class InstanceNorm(nn.Module):\n",
" __constants__ = ['eps']\n",
" def __init__(self, nf, eps=1e-0):\n",
" super().__init__()\n",
" self.eps = eps\n",
" self.mults = nn.Parameter(torch.ones (nf,1,1))\n",
" self.adds = nn.Parameter(torch.zeros(nf,1,1))\n",
"\n",
" def forward(self, x):\n",
" m = x.mean((2,3), keepdim=True)\n",
" v = x.var ((2,3), keepdim=True)\n",
" res = (x-m) / ((v+self.eps).sqrt())\n",
" return res*self.mults + self.adds"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"def conv_in(ni, nf, ks=3, stride=2, bn=True, **kwargs):\n",
" layers = [nn.Conv2d(ni, nf, ks, padding=ks//2, stride=stride, bias=True),\n",
" GeneralRelu(**kwargs)]\n",
" if bn: layers.append(InstanceNorm(nf))\n",
" return nn.Sequential(*layers)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"learn,run = get_learn_run(nfs, data, 0.1, conv_in, cbs=cbfs)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train: [nan, tensor(0.0986, device='cuda:0')]\n",
"valid: [nan, tensor(0.0991, device='cuda:0')]\n",
"train: [nan, tensor(0.0986, device='cuda:0')]\n",
"valid: [nan, tensor(0.0991, device='cuda:0')]\n",
"train: [nan, tensor(0.0986, device='cuda:0')]\n",
"valid: [nan, tensor(0.0991, device='cuda:0')]\n",
"CPU times: user 3.78 s, sys: 233 ms, total: 4.02 s\n",
"Wall time: 4.05 s\n"
]
}
],
"source": [
"%time run.fit(3, learn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Question*: why can't this classify anything?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Lost in all those norms? The authors from the [group norm paper](https://arxiv.org/pdf/1803.08494.pdf) have you covered:\n",
"\n",
"![Various norms](images/norms.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Group norm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*From the PyTorch docs:*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`GroupNorm(num_groups, num_channels, eps=1e-5, affine=True)`\n",
"\n",
"The input channels are separated into `num_groups` groups, each containing\n",
"``num_channels / num_groups`` channels. The mean and standard-deviation are calculated\n",
"separately over the each group. $\\gamma$ and $\\beta$ are learnable\n",
"per-channel affine transform parameter vectors of size `num_channels` if\n",
"`affine` is `True`.\n",
"\n",
"This layer uses statistics computed from input data in both training and\n",
"evaluation modes.\n",
"\n",
"Args:\n",
"- `num_groups (int)`: number of groups to separate the channels into\n",
"- `num_channels (int)`: number of channels expected in input\n",
"- `eps`: a value added to the denominator for numerical stability. Default: `1e-5`\n",
"- `affine`: a boolean value that when set to ``True``, this module\n",
" has learnable per-channel affine parameters initialized to ones (for weights)\n",
" and zeros (for biases). Default: ``True``.\n",
"\n",
"Shape:\n",
"- Input: `(N, num_channels, *)`\n",
"- Output: `(N, num_channels, *)` (same shape as input)\n",
"\n",
"Examples::\n",
"\n",
" >>> input = torch.randn(20, 6, 10, 10)\n",
" >>> # Separate 6 channels into 3 groups\n",
" >>> m = nn.GroupNorm(3, 6)\n",
" >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm)\n",
" >>> m = nn.GroupNorm(6, 6)\n",
" >>> # Put all 6 channels into a single group (equivalent with LayerNorm)\n",
" >>> m = nn.GroupNorm(1, 6)\n",
" >>> # Activating the module\n",
" >>> output = m(input)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fix small batch sizes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### What's the problem?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When we compute the statistics (mean and std) for a BatchNorm Layer on a small batch, it is possible that we get a standard deviation very close to 0. because there aren't many samples (the variance of one thing is 0. since it's equal to its mean)."
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"data = DataBunch(*get_dls(train_ds, valid_ds, 2), c)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"def conv_layer(ni, nf, ks=3, stride=2, bn=True, **kwargs):\n",
" layers = [nn.Conv2d(ni, nf, ks, padding=ks//2, stride=stride, bias=not bn),\n",
" GeneralRelu(**kwargs)]\n",
" if bn: layers.append(nn.BatchNorm2d(nf, eps=1e-5, momentum=0.1))\n",
" return nn.Sequential(*layers)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"learn,run = get_learn_run(nfs, data, 0.4, conv_layer, cbs=cbfs)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train: [2.33690859375, tensor(0.1798, device='cuda:0')]\n",
"valid: [21556.2768, tensor(0.2378, device='cuda:0')]\n",
"CPU times: user 1min 23s, sys: 3.5 s, total: 1min 27s\n",
"Wall time: 1min 29s\n"
]
}
],
"source": [
"%time run.fit(1, learn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Running Batch Norm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To solve this problem we introduce a Running BatchNorm that uses smoother running mean and variance for the mean and std."
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"class RunningBatchNorm(nn.Module):\n",
" def __init__(self, nf, mom=0.1, eps=1e-5):\n",
" super().__init__()\n",
" self.mom,self.eps = mom,eps\n",
" self.mults = nn.Parameter(torch.ones (nf,1,1))\n",
" self.adds = nn.Parameter(torch.zeros(nf,1,1))\n",
" self.register_buffer('sums', torch.zeros(1,nf,1,1))\n",
" self.register_buffer('sqrs', torch.zeros(1,nf,1,1))\n",
" self.register_buffer('batch', tensor(0.))\n",
" self.register_buffer('count', tensor(0.))\n",
" self.register_buffer('step', tensor(0.))\n",
" self.register_buffer('dbias', tensor(0.))\n",
"\n",
" def update_stats(self, x):\n",
" bs,nc,*_ = x.shape\n",
" self.sums.detach_()\n",
" self.sqrs.detach_()\n",
" dims = (0,2,3)\n",
" s = x.sum(dims, keepdim=True)\n",
" ss = (x*x).sum(dims, keepdim=True)\n",
" c = self.count.new_tensor(x.numel()/nc)\n",
" mom1 = 1 - (1-self.mom)/math.sqrt(bs-1)\n",
" self.mom1 = self.dbias.new_tensor(mom1)\n",
" self.sums.lerp_(s, self.mom1)\n",
" self.sqrs.lerp_(ss, self.mom1)\n",
" self.count.lerp_(c, self.mom1)\n",
" self.dbias = self.dbias*(1-self.mom1) + self.mom1\n",
" self.batch += bs\n",
" self.step += 1\n",
"\n",
" def forward(self, x):\n",
" if self.training: self.update_stats(x)\n",
" sums = self.sums\n",
" sqrs = self.sqrs\n",
" c = self.count\n",
" if self.step<100:\n",
" sums = sums / self.dbias\n",
" sqrs = sqrs / self.dbias\n",
" c = c / self.dbias\n",
" means = sums/c\n",
" vars = (sqrs/c).sub_(means*means)\n",
" if bool(self.batch < 20): vars.clamp_min_(0.01)\n",
" x = (x-means).div_((vars.add_(self.eps)).sqrt())\n",
" return x.mul_(self.mults).add_(self.adds)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"NB: the calculation of `self.dbias` in the version in the lesson video was incorrect. The correct version is in the cell above. Also, we changed how we calculated `self.mom1` to something that it more mathematically appropriate. These two changes improved the accuracy from 91% (in the video) to 97%+ (shown below)!"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"def conv_rbn_gen(ni, nf, ks=3, stride=2, bn=True, **kwargs):\n",
" layers = [nn.Conv2d(ni, nf, ks, padding=ks//2, stride=stride, bias=not bn),\n",
" GeneralRelu(**kwargs)]\n",
" if bn: layers.append(RunningBatchNorm(nf))\n",
" return nn.Sequential(*layers)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"def conv_rbn_learn(ni, nf, ks=3, stride=2, bn=True, **kwargs):\n",
" layers = [nn.Conv2d(ni, nf, ks, padding=ks//2, stride=stride, bias=not bn),\n",
" LearnedRelu(**kwargs)]\n",
" if bn: layers.append(RunningBatchNorm(nf))\n",
" return nn.Sequential(*layers)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### What can we do in a single epoch?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's see with a decent batch size what result we can get."
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"data = DataBunch(*get_dls(train_ds, valid_ds, 32), c)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"learn,run = get_learn_run(nfs, data, 0.8, conv_rbn_gen, cbs=cbfs\n",
" +[partial(ParamScheduler,'lr', sched_lin(1., 0.2))])"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train: [0.16527302734375, tensor(0.9498, device='cuda:0')]\n",
"valid: [0.07317039794921874, tensor(0.9797, device='cuda:0')]\n",
"CPU times: user 13.6 s, sys: 558 ms, total: 14.2 s\n",
"Wall time: 14.5 s\n"
]
}
],
"source": [
"%time run.fit(1, learn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Try LearnedRelu"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"learn,run = get_learn_run(nfs, data, 0.8, conv_rbn_learn, cbs=cbfs\n",
" +[partial(ParamScheduler,'lr', sched_lin(1., 0.2))])"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train: [0.16582439453125, tensor(0.9488, device='cuda:0')]\n",
"valid: [0.06944829711914062, tensor(0.9790, device='cuda:0')]\n",
"CPU times: user 15.9 s, sys: 858 ms, total: 16.8 s\n",
"Wall time: 17.1 s\n"
]
}
],
"source": [
"%time run.fit(1, learn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Simplified RunningBatchNorm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It turns out we don't actually need to debias - because, for instance, dividing a debiased sum by a debiased count is the same as dividing a *biased* sum by a *biased* count! So we can remove all the debiasing stuff and end up with a simpler class. Also, we should save `eps` as a buffer since it impacts the calculation. (Thanks to Stas Bekman for noticing these.) Also we can slightly change the final calculation in `forward` with one that uses `factor` and `offset` to reduce the amount of broadcasting required. (Thanks to Tom Viehmann for this suggestion.)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class RunningBatchNorm(nn.Module):\n",
" def __init__(self, nf, mom=0.1, eps=1e-5):\n",
" super().__init__()\n",
" self.mom, self.eps = mom, eps\n",
" self.mults = nn.Parameter(torch.ones (nf,1,1))\n",
" self.adds = nn.Parameter(torch.zeros(nf,1,1))\n",
" self.register_buffer('sums', torch.zeros(1,nf,1,1))\n",
" self.register_buffer('sqrs', torch.zeros(1,nf,1,1))\n",
" self.register_buffer('count', tensor(0.))\n",
" self.register_buffer('factor', tensor(0.))\n",
" self.register_buffer('offset', tensor(0.))\n",
" self.batch = 0\n",
" \n",
" def update_stats(self, x):\n",
" bs,nc,*_ = x.shape\n",
" self.sums.detach_()\n",
" self.sqrs.detach_()\n",
" dims = (0,2,3)\n",
" s = x .sum(dims, keepdim=True)\n",
" ss = (x*x).sum(dims, keepdim=True)\n",
" c = s.new_tensor(x.numel()/nc)\n",
" mom1 = s.new_tensor(1 - (1-self.mom)/math.sqrt(bs-1))\n",
" self.sums .lerp_(s , mom1)\n",
" self.sqrs .lerp_(ss, mom1)\n",
" self.count.lerp_(c , mom1)\n",
" self.batch += bs\n",
" means = self.sums/self.count\n",
" varns = (self.sqrs/self.count).sub_(means*means)\n",
" if bool(self.batch < 20): varns.clamp_min_(0.01)\n",
" self.factor = self.mults / (varns+self.eps).sqrt()\n",
" self.offset = self.adds - means*self.factor\n",
" \n",
" def forward(self, x):\n",
" if self.training: self.update_stats(x)\n",
" return x*self.factor + self.offset"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"learn,run = get_learn_run(nfs, data, 0.8, conv_rbn_gen, cbs=cbfs\n",
" +[partial(ParamScheduler,'lr', sched_lin(1., 0.2))])"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train: [0.157169111328125, tensor(0.9513, device='cuda:0')]\n",
"valid: [0.10296422119140625, tensor(0.9797, device='cuda:0')]\n",
"CPU times: user 12.1 s, sys: 585 ms, total: 12.7 s\n",
"Wall time: 12.9 s\n"
]
}
],
"source": [
"%time run.fit(1, learn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Try LearnedRelu"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"learn,run = get_learn_run(nfs, data, 0.8, conv_rbn_learn, cbs=cbfs\n",
" +[partial(ParamScheduler,'lr', sched_lin(1., 0.2))])"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train: [0.160661044921875, tensor(0.9506, device='cuda:0')]\n",
"valid: [0.06689930419921875, tensor(0.9810, device='cuda:0')]\n",
"CPU times: user 14.3 s, sys: 711 ms, total: 15 s\n",
"Wall time: 15.3 s\n"
]
}
],
"source": [
"%time run.fit(1, learn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Export"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nb_auto_export()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment