Skip to content

Instantly share code, notes, and snippets.

@thomasbrandon
Created September 26, 2019 09:53
Show Gist options
  • Save thomasbrandon/16915d8a01bcdd9d4b74abbc7cf6638b to your computer and use it in GitHub Desktop.
Save thomasbrandon/16915d8a01bcdd9d4b74abbc7cf6638b to your computer and use it in GitHub Desktop.
Semi-successful training of Mish
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"from fastai.vision import *"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"text/plain": [
"ImageDataBunch;\n",
"\n",
"Train: LabelList (12396 items)\n",
"x: ImageList\n",
"Image (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28)\n",
"y: CategoryList\n",
"3,3,3,3,3\n",
"Path: /home/user/.fastai/data/mnist_sample;\n",
"\n",
"Valid: LabelList (2038 items)\n",
"x: ImageList\n",
"Image (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28)\n",
"y: CategoryList\n",
"3,3,3,3,3\n",
"Path: /home/user/.fastai/data/mnist_sample;\n",
"\n",
"Test: None"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"DATA = untar_data(URLs.MNIST_SAMPLE)\n",
"src = (ImageList.from_folder(DATA, convert_mode='L')\n",
" .split_by_folder(valid='valid')\n",
" .label_from_folder())\n",
"data = (src.databunch(bs=8, num_workers=0)\n",
" .normalize((tensor(0.128), tensor(0.305))))\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"def get_mdl(actn:Callable):\n",
" layers = [conv2d(1, 16, stride=2), actn(), conv2d(16, 32, stride=2), actn(), AdaptiveConcatPool2d(1), Flatten(), nn.Linear(64, data.c)]\n",
" return nn.Sequential(*layers)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"text/plain": [
"Sequential(\n",
" (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (1): MishCuda()\n",
" (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (3): MishCuda()\n",
" (4): AdaptiveConcatPool2d(\n",
" (ap): AdaptiveAvgPool2d(output_size=1)\n",
" (mp): AdaptiveMaxPool2d(output_size=1)\n",
" )\n",
" (5): Flatten()\n",
" (6): Linear(in_features=64, out_features=2, bias=True)\n",
")"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from mish_cuda import *\n",
"mdl_mish = get_mdl(MishCuda)\n",
"mdl_mish"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"lrn_mish = Learner(data, mdl_mish, metrics=[accuracy])\n",
"cbs = []"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.093107</td>\n",
" <td>0.102739</td>\n",
" <td>0.959764</td>\n",
" <td>00:06</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.043055</td>\n",
" <td>0.049509</td>\n",
" <td>0.983808</td>\n",
" <td>00:06</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>nan</td>\n",
" <td>nan</td>\n",
" <td>0.495584</td>\n",
" <td>00:06</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"lrn_mish.fit_one_cycle(3, 1e-3)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:.conda-fastai]",
"language": "python",
"name": "conda-env-.conda-fastai-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment