Skip to content

Instantly share code, notes, and snippets.

@thomasbrandon
Created September 23, 2019 10:40
Show Gist options
  • Save thomasbrandon/67992d2ca8d0421fb765214abcf46bfe to your computer and use it in GitHub Desktop.
Save thomasbrandon/67992d2ca8d0421fb765214abcf46bfe to your computer and use it in GitHub Desktop.
Profiling and experiments on Mish Performance
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from fastai import *\n",
"from fastai.vision import *"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"DATA = untar_data(URLs.IMAGENETTE_160)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ImageDataBunch;\n",
"\n",
"Train: LabelList (3798 items)\n",
"x: ImageList\n",
"Image (3, 160, 160),Image (3, 160, 160),Image (3, 160, 160),Image (3, 160, 160),Image (3, 160, 160)\n",
"y: CategoryList\n",
"n03028079,n03028079,n03028079,n03028079,n03028079\n",
"Path: /home/user/.fastai/data/imagenette-160;\n",
"\n",
"Valid: LabelList (161 items)\n",
"x: ImageList\n",
"Image (3, 160, 160),Image (3, 160, 160),Image (3, 160, 160),Image (3, 160, 160),Image (3, 160, 160)\n",
"y: CategoryList\n",
"n03028079,n03028079,n03028079,n03028079,n03028079\n",
"Path: /home/user/.fastai/data/imagenette-160;\n",
"\n",
"Test: None"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"src = (ImageList.from_folder(DATA).filter_by_rand(0.3, seed=42)\n",
" .split_by_folder(valid='val')\n",
" .label_from_folder()\n",
" .transform(([flip_lr(p=0.5)], []), size=160))\n",
"data = (src.databunch(bs=64, num_workers=6)\n",
" .normalize(imagenet_stats))\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from fastai import layers"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def conv_layer(ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None, bias:bool=None, is_1d:bool=False,\n",
" norm_type:Optional[NormType]=NormType.Batch, use_activ:bool=True, activ_fn:Callable=None, leaky:float=None,\n",
" transpose:bool=False, init:Callable=nn.init.kaiming_normal_, self_attention:bool=False):\n",
" \"Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers.\"\n",
" activ_fn = ifnone(activ_fn, partial(relu, inplace=True, leaky=leaky))\n",
" if padding is None: padding = (ks-1)//2 if not transpose else 0\n",
" bn = norm_type in (NormType.Batch, NormType.BatchZero)\n",
" if bias is None: bias = not bn\n",
" conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d\n",
" conv = init_default(conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding), init)\n",
" if norm_type==NormType.Weight: conv = weight_norm(conv)\n",
" elif norm_type==NormType.Spectral: conv = spectral_norm(conv)\n",
" layers = [conv]\n",
" if use_activ: layers.append(activ_fn())\n",
" if bn: layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))\n",
" if self_attention: layers.append(SelfAttention(nf))\n",
" return nn.Sequential(*layers)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def simple_cnn(data, actns:Collection[int], kernel_szs:Collection[int]=None,\n",
" strides:Collection[int]=None, bn=False, activ_fn=None,\n",
" lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5,\n",
" concat_pool:bool=True, bn_final:bool=False) -> nn.Sequential:\n",
" \"CNN with `conv_layer` defined by `actns`, `kernel_szs` and `strides`, plus batchnorm if `bn`.\"\n",
" nl = len(actns)-1\n",
" kernel_szs = ifnone(kernel_szs, [3]*nl)\n",
" strides = ifnone(strides , [2]*nl)\n",
" layers = [conv_layer(actns[i], actns[i+1], kernel_szs[i], stride=strides[i],\n",
" norm_type=(NormType.Batch if bn and i<(len(strides)-1) else None), activ_fn=activ_fn) for i in range_of(strides)]\n",
" nf_head = actns[-1] * (2 if concat_pool else 1)\n",
" head = create_head(nf_head, data.c, lin_ftrs=lin_ftrs, ps=ps, concat_pool=concat_pool, bn_final=bn_final)\n",
" return nn.Sequential(*layers, head)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"actns = [3,64,64,128,128,256,256,512,512]\n",
"strides = [1,2]*(len(actns)//2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Relu"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Sequential(\n",
" (0): Sequential(\n",
" (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): ReLU(inplace=True)\n",
" )\n",
" (1): Sequential(\n",
" (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
" (1): ReLU(inplace=True)\n",
" )\n",
" (2): Sequential(\n",
" (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): ReLU(inplace=True)\n",
" )\n",
" (3): Sequential(\n",
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
" (1): ReLU(inplace=True)\n",
" )\n",
" (4): Sequential(\n",
" (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): ReLU(inplace=True)\n",
" )\n",
" (5): Sequential(\n",
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
" (1): ReLU(inplace=True)\n",
" )\n",
" (6): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): ReLU(inplace=True)\n",
" )\n",
" (7): Sequential(\n",
" (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
" (1): ReLU(inplace=True)\n",
" )\n",
" (8): Sequential(\n",
" (0): AdaptiveConcatPool2d(\n",
" (ap): AdaptiveAvgPool2d(output_size=1)\n",
" (mp): AdaptiveMaxPool2d(output_size=1)\n",
" )\n",
" (1): Flatten()\n",
" (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Dropout(p=0.25, inplace=False)\n",
" (4): Linear(in_features=1024, out_features=512, bias=True)\n",
" (5): ReLU(inplace=True)\n",
" (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (7): Dropout(p=0.5, inplace=False)\n",
" (8): Linear(in_features=512, out_features=10, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mdl_relu = simple_cnn(data, actns=actns, strides=strides)\n",
"mdl_relu"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"lrn = Learner(data, mdl_relu, metrics=[accuracy,top_k_accuracy])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>top_k_accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>2.142710</td>\n",
" <td>2.980025</td>\n",
" <td>0.223602</td>\n",
" <td>0.739130</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.928795</td>\n",
" <td>1.946811</td>\n",
" <td>0.347826</td>\n",
" <td>0.813665</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.725688</td>\n",
" <td>1.865101</td>\n",
" <td>0.391304</td>\n",
" <td>0.826087</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1.494881</td>\n",
" <td>1.939941</td>\n",
" <td>0.416149</td>\n",
" <td>0.795031</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>1.306941</td>\n",
" <td>1.135048</td>\n",
" <td>0.633540</td>\n",
" <td>0.950311</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"lrn.fit_one_cycle(5, 1e-3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Mish"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"class Mish(nn.Module):\n",
" def forward(self, x):\n",
" return x * torch.tanh(F.softplus(x))"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Sequential(\n",
" (0): Sequential(\n",
" (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): Mish()\n",
" )\n",
" (1): Sequential(\n",
" (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
" (1): Mish()\n",
" )\n",
" (2): Sequential(\n",
" (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): Mish()\n",
" )\n",
" (3): Sequential(\n",
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
" (1): Mish()\n",
" )\n",
" (4): Sequential(\n",
" (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): Mish()\n",
" )\n",
" (5): Sequential(\n",
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
" (1): Mish()\n",
" )\n",
" (6): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): Mish()\n",
" )\n",
" (7): Sequential(\n",
" (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
" (1): Mish()\n",
" )\n",
" (8): Sequential(\n",
" (0): AdaptiveConcatPool2d(\n",
" (ap): AdaptiveAvgPool2d(output_size=1)\n",
" (mp): AdaptiveMaxPool2d(output_size=1)\n",
" )\n",
" (1): Flatten()\n",
" (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Dropout(p=0.25, inplace=False)\n",
" (4): Linear(in_features=1024, out_features=512, bias=True)\n",
" (5): ReLU(inplace=True)\n",
" (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (7): Dropout(p=0.5, inplace=False)\n",
" (8): Linear(in_features=512, out_features=10, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mdl_mish = simple_cnn(data, actns=actns, strides=strides, activ_fn=Mish)\n",
"mdl_mish"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"lrn = Learner(data, mdl_mish, metrics=[accuracy,top_k_accuracy])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>top_k_accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>2.076530</td>\n",
" <td>2.108528</td>\n",
" <td>0.304348</td>\n",
" <td>0.763975</td>\n",
" <td>00:16</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.843536</td>\n",
" <td>2.559659</td>\n",
" <td>0.267081</td>\n",
" <td>0.763975</td>\n",
" <td>00:15</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.580714</td>\n",
" <td>1.727570</td>\n",
" <td>0.360248</td>\n",
" <td>0.875776</td>\n",
" <td>00:15</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1.308276</td>\n",
" <td>0.981897</td>\n",
" <td>0.658385</td>\n",
" <td>0.968944</td>\n",
" <td>00:15</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>1.075476</td>\n",
" <td>0.898451</td>\n",
" <td>0.708075</td>\n",
" <td>0.962733</td>\n",
" <td>00:15</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"lrn.fit_one_cycle(5, 1e-3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Mish JIT"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Not sure this is the right way to create a JIT module"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"class MishJit(torch.jit.ScriptModule):\n",
" \n",
" # Note: No self for forward or you get an error\n",
" @torch.jit.script\n",
" def forward(x):\n",
" return x * torch.tanh(F.softplus(x))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This seems to be the recommended way:"
]
},
{
"cell_type": "code",
"execution_count": 419,
"metadata": {},
"outputs": [],
"source": [
"MishJit = lambda: torch.jit.script(Mish())"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Sequential(\n",
" (0): Sequential(\n",
" (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): MishJit()\n",
" )\n",
" (1): Sequential(\n",
" (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
" (1): MishJit()\n",
" )\n",
" (2): Sequential(\n",
" (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): MishJit()\n",
" )\n",
" (3): Sequential(\n",
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
" (1): MishJit()\n",
" )\n",
" (4): Sequential(\n",
" (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): MishJit()\n",
" )\n",
" (5): Sequential(\n",
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
" (1): MishJit()\n",
" )\n",
" (6): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): MishJit()\n",
" )\n",
" (7): Sequential(\n",
" (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
" (1): MishJit()\n",
" )\n",
" (8): Sequential(\n",
" (0): AdaptiveConcatPool2d(\n",
" (ap): AdaptiveAvgPool2d(output_size=1)\n",
" (mp): AdaptiveMaxPool2d(output_size=1)\n",
" )\n",
" (1): Flatten()\n",
" (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Dropout(p=0.25, inplace=False)\n",
" (4): Linear(in_features=1024, out_features=512, bias=True)\n",
" (5): ReLU(inplace=True)\n",
" (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (7): Dropout(p=0.5, inplace=False)\n",
" (8): Linear(in_features=512, out_features=10, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mdl_mishjit = simple_cnn(data, actns=actns, strides=strides, activ_fn=MishJit)\n",
"mdl_mishjit"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"lrn = Learner(data, mdl_mishjit, metrics=[accuracy,top_k_accuracy])"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>top_k_accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>2.008492</td>\n",
" <td>2.380090</td>\n",
" <td>0.304348</td>\n",
" <td>0.726708</td>\n",
" <td>00:16</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.768086</td>\n",
" <td>1.570313</td>\n",
" <td>0.472050</td>\n",
" <td>0.888199</td>\n",
" <td>00:15</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.452909</td>\n",
" <td>1.287506</td>\n",
" <td>0.534162</td>\n",
" <td>0.944099</td>\n",
" <td>00:15</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1.180452</td>\n",
" <td>0.982612</td>\n",
" <td>0.664596</td>\n",
" <td>0.956522</td>\n",
" <td>00:15</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.939465</td>\n",
" <td>0.796275</td>\n",
" <td>0.732919</td>\n",
" <td>0.968944</td>\n",
" <td>00:16</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"lrn.fit_one_cycle(5, 1e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Doesn't look like any performance gain from JIT.\n",
"\n",
"But you can see some stuff that's happening under the hood (code taken from https://github.com/pytorch/pytorch/blob/master/test/test_jit.py#L233):"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"graph(%x.1 : Tensor):\n",
" %6 : int = prim::Constant[value=20]()\n",
" %5 : int = prim::Constant[value=1]()\n",
" %7 : Tensor = aten::softplus(%x.1, %5, %6) # <ipython-input-14-d5e1efb10fcc>:5:31\n",
" %8 : Tensor = aten::tanh(%7) # <ipython-input-14-d5e1efb10fcc>:5:20\n",
" %9 : Tensor = aten::mul(%x.1, %8) # <ipython-input-14-d5e1efb10fcc>:5:16\n",
" return (%9)"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mj = MishJit()\n",
"mj.graph"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"graph(%0 : Tensor,\n",
" %1 : Tensor,\n",
" %2 : Tensor,\n",
" %3 : Tensor,\n",
" %4 : int[]?,\n",
" %5 : int[]?):\n",
" %6 : int = prim::Constant[value=1]() # <string>:154:39\n",
" %grad_self.1 : Tensor, %grad_other.1 : Tensor = prim::GradOf[name=\"aten::mul\"](%0)\n",
" block0():\n",
" %9 : Tensor = aten::mul(%0, %3) # <string>:11:30\n",
" %grad_self.2 : Tensor = aten::_grad_sum_to_size(%9, %4) # <string>:11:30\n",
" %11 : Tensor = aten::mul(%0, %2) # <string>:12:31\n",
" %grad_other.2 : Tensor = aten::_grad_sum_to_size(%11, %5) # <string>:12:31\n",
" -> (%grad_self.2, %grad_other.2)\n",
" %13 : Tensor = prim::AutogradAdd(%1, %grad_other.1)\n",
" %14 : Tensor = prim::GradOf[name=\"aten::tanh\"](%13)\n",
" block0():\n",
" %15 : Tensor = aten::mul(%3, %3) # <string>:154:43\n",
" %16 : Tensor = aten::neg(%15) # <string>:18:10\n",
" %17 : Tensor = aten::add(%16, %6, %6) # <string>:18:10\n",
" %18 : Tensor = aten::mul(%13, %17) # <string>:154:24\n",
" -> (%18)\n",
" return (%grad_self.1, %14)"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds = mj.forward.get_debug_state()\n",
"fwd_plan = list(ds.execution_plans.values())[0]\n",
"ges = list(fwd_plan.code.grad_executor_states())\n",
"assert len(ges)==1\n",
"bwd_plan = ges[0]\n",
"bwd_plan.graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Mish Profiling"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"# Profiler doesn't like multiple workers\n",
"data_prof = (src.databunch(bs=64, num_workers=0)\n",
" .normalize(imagenet_stats))"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>top_k_accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>2.090991</td>\n",
" <td>2.282425</td>\n",
" <td>0.242236</td>\n",
" <td>0.776398</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.843356</td>\n",
" <td>1.696035</td>\n",
" <td>0.397516</td>\n",
" <td>0.844720</td>\n",
" <td>00:18</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.575372</td>\n",
" <td>1.370669</td>\n",
" <td>0.515528</td>\n",
" <td>0.913043</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"lrn = Learner(data_prof, mdl_mish, metrics=[accuracy,top_k_accuracy])\n",
"with torch.autograd.profiler.profile(use_cuda=True) as prof_mish:\n",
" lrn.fit_one_cycle(3, 1e-3)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- \n",
"Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg CUDA total % CUDA total CUDA time avg Number of Calls \n",
"----------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- \n",
"CudnnConvolutionBackward 0.07% 32.445ms 1.33% 593.029ms 418.806us 17.95% 18.293s 12.919ms 1416 \n",
"cudnn_convolution_backward 1.26% 560.585ms 1.26% 560.585ms 395.893us 17.94% 18.289s 12.916ms 1416 \n",
"conv2d 0.02% 8.777ms 0.52% 230.733ms 155.063us 10.68% 10.883s 7.314ms 1488 \n",
"convolution 0.02% 9.372ms 0.50% 221.956ms 149.164us 10.67% 10.878s 7.311ms 1488 \n",
"_convolution 0.04% 18.914ms 0.48% 212.584ms 142.865us 10.67% 10.874s 7.308ms 1488 \n",
"cudnn_convolution 0.42% 188.122ms 0.42% 188.122ms 126.426us 10.66% 10.865s 7.302ms 1488 \n",
"mul 0.26% 117.852ms 0.26% 117.852ms 25.123us 3.96% 4.035s 860.178us 4691 \n",
"MulBackward0 0.09% 40.653ms 0.29% 130.582ms 92.219us 2.60% 2.653s 1.874ms 1416 \n",
"add 0.12% 54.580ms 0.12% 54.580ms 30.732us 1.32% 1.343s 756.336us 1776 \n",
"TanhBackward 0.06% 24.925ms 0.17% 77.403ms 54.663us 1.30% 1.329s 938.273us 1416 \n",
"SoftplusBackward 0.06% 26.375ms 0.16% 72.244ms 51.020us 1.30% 1.328s 937.749us 1416 \n",
"tanh_backward 0.12% 52.478ms 0.12% 52.478ms 37.060us 1.30% 1.324s 935.328us 1416 \n",
"softplus_backward 0.10% 45.869ms 0.10% 45.869ms 32.393us 1.30% 1.324s 934.842us 1416 \n",
"div_ 2.90% 1.296s 2.90% 1.296s 109.706us 1.26% 1.282s 108.573us 11811 \n",
"softplus 0.06% 26.970ms 0.06% 26.970ms 18.125us 1.03% 1.046s 702.681us 1488 \n",
"stack 2.32% 1.035s 2.32% 1.035s 5.474ms 1.01% 1.030s 5.449ms 189 \n",
"tanh 0.05% 24.128ms 0.05% 24.128ms 16.215us 0.95% 967.134ms 649.956us 1488 \n",
"clone 2.05% 915.029ms 2.05% 915.029ms 76.329us 0.90% 918.066ms 76.582us 11988 \n",
"contiguous 1.48% 660.752ms 1.83% 814.909ms 53.521us 0.80% 818.778ms 53.775us 15226 \n",
"pin_memory 1.07% 478.388ms 1.09% 485.908ms 1.306ms 0.48% 487.357ms 1.310ms 372 \n",
"----------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- \n",
"Self CPU time total: 44.610s\n",
"CUDA time total: 101.941s\n",
"\n"
]
}
],
"source": [
"print(prof_mish.key_averages().table(sort_by=\"cuda_time_total\", row_limit=20))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Relu Profiling"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>top_k_accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>2.144278</td>\n",
" <td>3.227488</td>\n",
" <td>0.198758</td>\n",
" <td>0.503106</td>\n",
" <td>00:16</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.964869</td>\n",
" <td>1.908126</td>\n",
" <td>0.310559</td>\n",
" <td>0.795031</td>\n",
" <td>00:16</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.797392</td>\n",
" <td>1.615858</td>\n",
" <td>0.447205</td>\n",
" <td>0.857143</td>\n",
" <td>00:16</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"lrn = Learner(data_prof, mdl_relu, metrics=[accuracy,top_k_accuracy])\n",
"with torch.autograd.profiler.profile(use_cuda=True) as prof_relu:\n",
" lrn.fit_one_cycle(3, 1e-3)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- \n",
"Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg CUDA total % CUDA total CUDA time avg Number of Calls \n",
"----------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- \n",
"CudnnConvolutionBackward 0.08% 30.636ms 1.59% 590.901ms 417.303us 20.24% 18.338s 12.951ms 1416 \n",
"cudnn_convolution_backward 1.51% 560.265ms 1.51% 560.265ms 395.668us 20.24% 18.334s 12.948ms 1416 \n",
"conv2d 0.02% 8.686ms 0.64% 237.165ms 159.385us 12.05% 10.915s 7.335ms 1488 \n",
"convolution 0.03% 9.691ms 0.62% 228.479ms 153.548us 12.04% 10.910s 7.332ms 1488 \n",
"_convolution 0.06% 21.295ms 0.59% 218.788ms 147.035us 12.04% 10.906s 7.329ms 1488 \n",
"cudnn_convolution 0.52% 193.073ms 0.52% 193.073ms 129.753us 12.03% 10.897s 7.323ms 1488 \n",
"ReluBackward1 0.08% 30.172ms 0.23% 86.354ms 54.208us 1.47% 1.335s 838.307us 1593 \n",
"threshold_backward 0.15% 56.182ms 0.15% 56.182ms 35.268us 1.47% 1.328s 833.722us 1593 \n",
"div_ 3.61% 1.340s 3.61% 1.340s 113.430us 1.46% 1.327s 112.313us 11811 \n",
"stack 2.82% 1.045s 2.82% 1.045s 5.529ms 1.16% 1.046s 5.537ms 189 \n",
"clone 2.56% 949.062ms 2.56% 949.062ms 79.168us 1.05% 954.423ms 79.615us 11988 \n",
"relu_ 0.06% 24.079ms 0.06% 24.079ms 14.384us 1.05% 952.122ms 568.771us 1674 \n",
"contiguous 1.78% 661.230ms 2.26% 838.856ms 55.090us 0.93% 843.215ms 55.376us 15227 \n",
"pin_memory 1.30% 481.157ms 1.31% 487.426ms 1.310ms 0.54% 488.730ms 1.314ms 372 \n",
"to 81.89% 30.382s 81.93% 30.400s 2.374ms 0.41% 371.878ms 29.037us 12807 \n",
"empty_like 0.22% 80.294ms 0.48% 177.626ms 15.665us 0.19% 176.648ms 15.579us 11339 \n",
"add_ 0.55% 204.223ms 0.55% 204.223ms 15.606us 0.16% 142.126ms 10.861us 13086 \n",
"slice 0.34% 124.978ms 0.34% 124.978ms 5.426us 0.13% 119.899ms 5.206us 23032 \n",
"empty 0.32% 117.106ms 0.32% 117.106ms 9.110us 0.13% 115.692ms 9.000us 12854 \n",
"torch::autograd::AccumulateGrad 0.21% 79.715ms 0.53% 197.215ms 46.425us 0.12% 108.346ms 25.505us 4248 \n",
"----------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- \n",
"Self CPU time total: 37.103s\n",
"CUDA time total: 90.599s\n",
"\n"
]
}
],
"source": [
"print(prof_relu.key_averages().table(sort_by=\"cuda_time_total\", row_limit=20))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Profile Differences"
]
},
{
"cell_type": "code",
"execution_count": 410,
"metadata": {},
"outputs": [],
"source": [
"keys = set([ev.key for prof in [prof_mish,prof_relu] for ev in prof.function_events])\n",
"ka_mish,ka_relu = prof_mish.key_averages(), prof_relu.key_averages()\n",
"ka_mish,ka_relu = [{ev.key: ev for ev in prof.key_averages()} for prof in [prof_mish,prof_relu]]\n",
"keys = set(list(ka_mish.keys()) + list(ka_relu.keys()))\n",
"keys -= {'to','contiguous','pin_memory'} # Dataloader stuff"
]
},
{
"cell_type": "code",
"execution_count": 411,
"metadata": {},
"outputs": [],
"source": [
"ev_mish,ev_relu = [],[]\n",
"for key in keys:\n",
" if ( key not in ka_mish or key not in ka_relu or \n",
" ka_mish[key].count != ka_relu[key].count or\n",
" np.abs(ka_mish[key].cuda_time - ka_relu[key].cuda_time) > 100): # cuda_time in us\n",
" if key in ka_mish: ev_mish.append(ka_mish[key])\n",
" if key in ka_relu: ev_relu.append(ka_relu[key])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Mish"
]
},
{
"cell_type": "code",
"execution_count": 412,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- \n",
"Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg CUDA total % CUDA total CUDA time avg Number of Calls \n",
"---------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- \n",
"mul 12.55% 123.558ms 12.55% 123.558ms 26.339us 25.25% 4.034s 860.011us 4691 \n",
"MulBackward0 4.14% 40.766ms 13.87% 136.489ms 96.391us 16.60% 2.653s 1.874ms 1416 \n",
"add 5.54% 54.531ms 5.54% 54.531ms 30.704us 8.41% 1.343s 756.388us 1776 \n",
"TanhBackward 2.60% 25.567ms 8.09% 79.661ms 56.258us 8.33% 1.331s 940.024us 1416 \n",
"SoftplusBackward 2.82% 27.774ms 7.65% 75.281ms 53.164us 8.32% 1.329s 938.424us 1416 \n",
"tanh_backward 5.50% 54.094ms 5.50% 54.094ms 38.202us 8.30% 1.327s 937.036us 1416 \n",
"softplus_backward 4.83% 47.506ms 4.83% 47.506ms 33.550us 8.29% 1.324s 935.374us 1416 \n",
"softplus 2.80% 27.553ms 2.80% 27.553ms 18.517us 6.56% 1.048s 704.625us 1488 \n",
"tanh 2.23% 21.921ms 2.23% 21.921ms 14.732us 6.06% 967.746ms 650.367us 1488 \n",
"empty_like 9.27% 91.280ms 21.58% 212.374ms 18.731us 1.33% 211.789ms 18.680us 11338 \n",
"add_ 20.93% 205.969ms 20.93% 205.969ms 15.711us 0.92% 147.439ms 11.246us 13110 \n",
"empty 13.35% 131.339ms 13.35% 131.339ms 10.219us 0.83% 132.016ms 10.271us 12853 \n",
"slice 12.08% 118.912ms 12.08% 118.912ms 5.163us 0.71% 112.897ms 4.902us 23030 \n",
"ReluBackward1 0.35% 3.435ms 1.01% 9.969ms 56.324us 0.06% 10.011ms 56.562us 177 \n",
"threshold_backward 0.66% 6.534ms 0.66% 6.534ms 36.915us 0.04% 6.778ms 38.293us 177 \n",
"relu_ 0.35% 3.427ms 0.35% 3.427ms 18.426us 0.00% 558.594us 3.003us 186 \n",
"---------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- \n",
"Self CPU time total: 984.168ms\n",
"CUDA time total: 15.979s\n",
"\n"
]
}
],
"source": [
"print(torch.autograd.profiler.EventList(ev_mish).table(sort_by=\"cuda_time_total\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Relu"
]
},
{
"cell_type": "code",
"execution_count": 413,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- \n",
"Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg CUDA total % CUDA total CUDA time avg Number of Calls \n",
"---------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- \n",
"ReluBackward1 5.15% 33.543ms 13.58% 88.343ms 55.457us 31.57% 1.335s 838.335us 1593 \n",
"threshold_backward 8.42% 54.799ms 8.42% 54.799ms 34.400us 31.40% 1.328s 833.837us 1593 \n",
"relu_ 3.73% 24.283ms 3.73% 24.283ms 14.506us 22.52% 952.441ms 568.961us 1674 \n",
"empty_like 9.78% 63.670ms 29.12% 189.489ms 16.714us 4.46% 188.768ms 16.651us 11337 \n",
"add_ 30.49% 198.388ms 30.49% 198.388ms 15.160us 3.34% 141.360ms 10.802us 13086 \n",
"empty 20.89% 135.937ms 20.89% 135.937ms 10.577us 3.23% 136.804ms 10.645us 12852 \n",
"slice 19.79% 128.792ms 19.79% 128.792ms 5.593us 2.92% 123.493ms 5.363us 23028 \n",
"add 1.15% 7.458ms 1.15% 7.458ms 20.717us 0.45% 19.196ms 53.321us 360 \n",
"mul 0.59% 3.861ms 0.59% 3.861ms 10.406us 0.09% 3.815ms 10.284us 371 \n",
"---------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- \n",
"Self CPU time total: 650.733ms\n",
"CUDA time total: 4.230s\n",
"\n"
]
}
],
"source": [
"print(torch.autograd.profiler.EventList(ev_relu).table(sort_by=\"cuda_time_total\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The multiple kernel launches for Mish add up (Softplus, Tanh and some of the muls, from call count they're not all Mish)."
]
},
{
"cell_type": "code",
"execution_count": 394,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Autograd Function"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can't do anything inplace as this will cause errors in gradient calculation:"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"class MishInplace(nn.Module):\n",
" def forward(self, x):\n",
" return x.mul_(torch.tanh_(F.softplus(x)))"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "a leaf Variable that requires grad has been used in an in-place operation.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-40-1667fcc83abe>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0minp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrand\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrequires_grad\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mmdl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMishInplace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmdl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/.conda/envs/fastai/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 545\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 546\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 547\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 548\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 549\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-39-5adf038e9e50>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mMishInplace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmul_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtanh_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msoftplus\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m: a leaf Variable that requires grad has been used in an in-place operation."
]
}
],
"source": [
"inp = torch.rand(5, requires_grad=True)\n",
"mdl = MishInplace()\n",
"out = torch.sum(mdl(inp))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Autograd functions allow such inplace operations through saving stuff in forward to be used in backwards. But you have to do the gradient calculations then.\n",
"Being able to manually compute gradients would also be needed for a CUDA implementation as you can't really re-use the existing stuff for Softplus/Tanh (you'd have to have separate kernel launches and be back to where the straight Python is).\n",
"\n",
"This seems to be the idea:"
]
},
{
"cell_type": "code",
"execution_count": 311,
"metadata": {},
"outputs": [],
"source": [
"class MishFunc(torch.autograd.Function):\n",
" @staticmethod\n",
" def forward(ctx, inp):\n",
" ctx.mark_dirty(inp)\n",
" tsp = torch.tanh_(F.softplus(inp))\n",
" ctx.save_for_backward(inp, tsp)\n",
" return x.mul_(tsp)\n",
" \n",
" @staticmethod\n",
" def backward(ctx, grad_out):\n",
" inp,tsp = ctx.saved_tensors\n",
" grad_tsp = torch.autograd.grad(grad_out, tsp)\n",
" grad_inp = ...\n",
" return grad_inp"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Playing around with some of the autograd stuff:"
]
},
{
"cell_type": "code",
"execution_count": 314,
"metadata": {},
"outputs": [],
"source": [
"grad = lambda o,i: torch.autograd.grad(o, i, retain_graph=True) # Need to use retain graph or you can only call once"
]
},
{
"cell_type": "code",
"execution_count": 312,
"metadata": {},
"outputs": [],
"source": [
"x = torch.rand(3)\n",
"inp = x.clone().requires_grad_(True)\n",
"tsp = torch.tanh(F.softplus(inp))\n",
"out = x.mul(tsp)\n",
"l = torch.sum(out)"
]
},
{
"cell_type": "code",
"execution_count": 316,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([0.0430, 0.0735, 0.1787]),)"
]
},
"execution_count": 316,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"grad(l, inp) # Gradient of loss w.r.t input"
]
},
{
"cell_type": "code",
"execution_count": 317,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([0.1375, 0.2415, 0.8298]),)"
]
},
"execution_count": 317,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"grad(l, tsp) # Gradient of loss w.r.t intermediate"
]
},
{
"cell_type": "code",
"execution_count": 248,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([1., 1., 1.]),)"
]
},
"execution_count": 248,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"grad_out = torch.autograd.grad(l, out)\n",
"grad_out"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### CI Graph"
]
},
{
"cell_type": "code",
"execution_count": 302,
"metadata": {},
"outputs": [],
"source": [
"txt = \"\"\"Mish,87.48%,0.3967,-\n",
"Swish-1,87.32%,0.414,-0.3975 to 0.0844\n",
"E-Swish (?=1.75),87.49%,0.411,-0.2261 to 0.2539\n",
"GELU,87.37%,0.472,-0.3682 to 0.1499\n",
"ReLU,86.66%,0.584,-1.1179 to -0.5247\n",
"ELU(?=1.0),86.41%,0.3371,-1.2931 to -0.8556\n",
"Leaky ReLU(?=0.3),86.85%,0.4569,-0.8860 to -0.3774\n",
"RReLU,86.87%,0.4478,-0.8623 to -0.3595\n",
"SELU,83.91%,0.5995,-3.8713 to -3.2670\n",
"SoftPlus(? = 1),83.00%,1.4015,-4.7778 to -4.1735\n",
"HardShrink(? = 0.5),75.03%,0.98345,-12.8948 to -12.0035\n",
"Hardtanh,82.78%,0.4491,-4.9522 to -4.4486\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 304,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Name</th>\n",
" <th>Acc</th>\n",
" <th>SD</th>\n",
" <th>CI</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>Mish</td>\n",
" <td>87.48%</td>\n",
" <td>0.3967</td>\n",
" <td>-</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>Swish-1</td>\n",
" <td>87.32%</td>\n",
" <td>0.414</td>\n",
" <td>-0.3975 to 0.0844</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>E-Swish (?=1.75)</td>\n",
" <td>87.49%</td>\n",
" <td>0.411</td>\n",
" <td>-0.2261 to 0.2539</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>GELU</td>\n",
" <td>87.37%</td>\n",
" <td>0.472</td>\n",
" <td>-0.3682 to 0.1499</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>ReLU</td>\n",
" <td>86.66%</td>\n",
" <td>0.584</td>\n",
" <td>-1.1179 to -0.5247</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>ELU(?=1.0)</td>\n",
" <td>86.41%</td>\n",
" <td>0.3371</td>\n",
" <td>-1.2931 to -0.8556</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>Leaky ReLU(?=0.3)</td>\n",
" <td>86.85%</td>\n",
" <td>0.4569</td>\n",
" <td>-0.8860 to -0.3774</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>RReLU</td>\n",
" <td>86.87%</td>\n",
" <td>0.4478</td>\n",
" <td>-0.8623 to -0.3595</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>SELU</td>\n",
" <td>83.91%</td>\n",
" <td>0.5995</td>\n",
" <td>-3.8713 to -3.2670</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>SoftPlus(? = 1)</td>\n",
" <td>83.00%</td>\n",
" <td>1.4015</td>\n",
" <td>-4.7778 to -4.1735</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>HardShrink(? = 0.5)</td>\n",
" <td>75.03%</td>\n",
" <td>0.98345</td>\n",
" <td>-12.8948 to -12.0035</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>Hardtanh</td>\n",
" <td>82.78%</td>\n",
" <td>0.4491</td>\n",
" <td>-4.9522 to -4.4486</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Name Acc SD CI\n",
"0 Mish 87.48% 0.3967 -\n",
"1 Swish-1 87.32% 0.414 -0.3975 to 0.0844\n",
"2 E-Swish (?=1.75) 87.49% 0.411 -0.2261 to 0.2539\n",
"3 GELU 87.37% 0.472 -0.3682 to 0.1499\n",
"4 ReLU 86.66% 0.584 -1.1179 to -0.5247\n",
"5 ELU(?=1.0) 86.41% 0.3371 -1.2931 to -0.8556\n",
"6 Leaky ReLU(?=0.3) 86.85% 0.4569 -0.8860 to -0.3774\n",
"7 RReLU 86.87% 0.4478 -0.8623 to -0.3595\n",
"8 SELU 83.91% 0.5995 -3.8713 to -3.2670\n",
"9 SoftPlus(? = 1) 83.00% 1.4015 -4.7778 to -4.1735\n",
"10 HardShrink(? = 0.5) 75.03% 0.98345 -12.8948 to -12.0035\n",
"11 Hardtanh 82.78% 0.4491 -4.9522 to -4.4486"
]
},
"execution_count": 304,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sp = [l.split(',') for l in txt.split('\\n')]\n",
"d = {n:v for n,v in zip(['Name','Acc','SD','CI'], zip(*sp))}\n",
"df = pd.DataFrame.from_dict(d)\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [],
"source": [
"df.Name = df.Name.apply(lambda s: s.split('(')[0].strip())\n",
"df.Acc = df.Acc.str.slice(stop=5).astype(np.float)\n",
"df['ci_lo'] = df[1:].CI.apply(lambda s: s.split(' to ')[0]).astype(np.float) + df.iloc[0,1]\n",
"df['ci_hi'] = df[1:].CI.apply(lambda s: s.split(' to ')[1]).astype(np.float) + df.iloc[0,1]\n",
"df.ci_lo[0] = 87.3085\n",
"df.ci_hi[0] = 87.6515"
]
},
{
"cell_type": "code",
"execution_count": 174,
"metadata": {},
"outputs": [],
"source": [
"df = df.drop(index=[10]) # Outlier"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Name</th>\n",
" <th>Acc</th>\n",
" <th>ci_lo</th>\n",
" <th>ci_hi</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>Mish</td>\n",
" <td>87.48</td>\n",
" <td>87.3085</td>\n",
" <td>87.6515</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>Swish-1</td>\n",
" <td>87.32</td>\n",
" <td>87.0825</td>\n",
" <td>87.5644</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>E-Swish (?=1.75)</td>\n",
" <td>87.49</td>\n",
" <td>87.2539</td>\n",
" <td>87.7339</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>GELU</td>\n",
" <td>87.37</td>\n",
" <td>87.1118</td>\n",
" <td>87.6299</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>ReLU</td>\n",
" <td>86.66</td>\n",
" <td>86.3621</td>\n",
" <td>86.9553</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>ELU(?=1.0)</td>\n",
" <td>86.41</td>\n",
" <td>86.1869</td>\n",
" <td>86.6244</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>Leaky ReLU(?=0.3)</td>\n",
" <td>86.85</td>\n",
" <td>86.5940</td>\n",
" <td>87.1026</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>RReLU</td>\n",
" <td>86.87</td>\n",
" <td>86.6177</td>\n",
" <td>87.1205</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>SELU</td>\n",
" <td>83.91</td>\n",
" <td>83.6087</td>\n",
" <td>84.2130</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>SoftPlus(? = 1)</td>\n",
" <td>83.00</td>\n",
" <td>82.7022</td>\n",
" <td>83.3065</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>HardShrink(? = 0.5)</td>\n",
" <td>75.03</td>\n",
" <td>74.5852</td>\n",
" <td>75.4765</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>Hardtanh</td>\n",
" <td>82.78</td>\n",
" <td>82.5278</td>\n",
" <td>83.0314</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Name Acc ci_lo ci_hi\n",
"0 Mish 87.48 87.3085 87.6515\n",
"1 Swish-1 87.32 87.0825 87.5644\n",
"2 E-Swish (?=1.75) 87.49 87.2539 87.7339\n",
"3 GELU 87.37 87.1118 87.6299\n",
"4 ReLU 86.66 86.3621 86.9553\n",
"5 ELU(?=1.0) 86.41 86.1869 86.6244\n",
"6 Leaky ReLU(?=0.3) 86.85 86.5940 87.1026\n",
"7 RReLU 86.87 86.6177 87.1205\n",
"8 SELU 83.91 83.6087 84.2130\n",
"9 SoftPlus(? = 1) 83.00 82.7022 83.3065\n",
"10 HardShrink(? = 0.5) 75.03 74.5852 75.4765\n",
"11 Hardtanh 82.78 82.5278 83.0314"
]
},
"execution_count": 95,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df[['Name','Acc','ci_lo','ci_hi']]"
]
},
{
"cell_type": "code",
"execution_count": 175,
"metadata": {},
"outputs": [],
"source": [
"errs = df[['ci_lo','ci_hi']].to_numpy() - df['Acc'].to_numpy()[:,None]"
]
},
{
"cell_type": "code",
"execution_count": 176,
"metadata": {},
"outputs": [],
"source": [
"errs = np.abs(errs).transpose()"
]
},
{
"cell_type": "code",
"execution_count": 177,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 504x216 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plt.figure(figsize=(7,3))\n",
"plt.errorbar(df.Acc, range(len(df)), xerr=errs, ls='', marker='o', ms='3', capsize=1, capthick=0.5);\n",
"plt.yticks(range(len(df)), df.Name);\n",
"plt.gca().invert_yaxis()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:.conda-fastai]",
"language": "python",
"name": "conda-env-.conda-fastai-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment