Skip to content

Instantly share code, notes, and snippets.

@jcreinhold
Created October 25, 2018 19:55
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save jcreinhold/fbd28adf4e0c423b6cb93cb49f4232d0 to your computer and use it in GitHub Desktop.
Save jcreinhold/fbd28adf4e0c423b6cb93cb49f4232d0 to your computer and use it in GitHub Desktop.
test 3d network with Learner class in fastai
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Test 3D Conv with Learner"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup notebook"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import torch\n",
"from torch import nn\n",
"import torchvision.transforms as torch_tfms"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
" Support in-notebook plotting"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Report versions"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"numpy version: 1.15.3\n",
"matplotlib version: 3.0.0\n"
]
}
],
"source": [
"print('numpy version: {}'.format(np.__version__))\n",
"from matplotlib import __version__ as mplver\n",
"print('matplotlib version: {}'.format(mplver))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"python version: 3.7.0\n"
]
}
],
"source": [
"pv = sys.version_info\n",
"print('python version: {}.{}.{}'.format(pv.major, pv.minor, pv.micro))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Reload packages where content for package development"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define test images"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"train_dir = '/Users/jcreinhold/Research/data/nn_test/real/train/'\n",
"val_dir = '/Users/jcreinhold/Research/data/nn_test/real/test/'"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from niftidataset import *"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"class AddChannel:\n",
" \"\"\" Add channel dimension to sample \"\"\"\n",
" def __call__(self, sample: Tuple[np.ndarray, np.ndarray]):\n",
" src, tgt = sample\n",
" assert src.shape == tgt.shape\n",
" return (src.unsqueeze(0), tgt.unsqueeze(0))"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"patch_sz = 32\n",
"tfms = torch_tfms.Compose([RandomCrop3D(patch_sz), ToTensor(), AddChannel()])\n",
"tds = NiftiDataset(train_dir+'t1', train_dir+'flair', tfms)\n",
"vds = NiftiDataset(val_dir+'t1', val_dir+'flair', tfms)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAADoBJREFUeJzt3VlrFF0XhuHtPCTGDGriFMVEED0RRPD/nwkeiJ6EEOegJkbN6Dy8f6CfW9yf1Of7rvs67EV1V1f3oqAe1t77fv782ST99+3/f5+ApGHY7FIRNrtUhM0uFWGzS0XY7FIRB4f8sPv378ec78uXL/G4VNvc3IzHrKysxNqDBw9i7e3bt7F2+PDhka8fOnQoHnPgwIFYO3LkSKwdP3481vbt2xdr6VptbGzEY+7duxdrT58+jbVr167F2uLi4sjX9/b24jFfv36NtXPnzsVa+l1aa21sbGzk67dv347H3Lx5M9bot6YY+8ePH7H2/fv333r9V591586dkX8Q7+xSETa7VITNLhVhs0tF2OxSETa7VMSg0dvOzk6s9cYMCUUdFJ/MzMzE2tGjR0e+/vHjx3gMxUkUGU1NTcUaxZTp806cOBGPWVhYiDWKDs+ePRtrp0+fHvk6fS/6f+zfn+9LdI7b29sjX6dI8eLFi7F24cKFWKPf+vPnz7H24cOHka/Td56YmIi1xDu7VITNLhVhs0tF2OxSETa7VITNLhUxaPRGkdHa2lqsvXr1auTrBw/m00/TTq1x/PP+/ftYS5Hd+Ph4PIYiEorDaDosRTWt5bimJ65rrbXJyclYS1FkazlW/NPTfK3x9UixHH3n9H9rjSNi+l/RfyTVKNL99OlTrCXe2aUibHapCJtdKsJml4qw2aUiBn0af+zYsVjrGfygp/FnzpyJNXrC/Pr161hL5z83NxePoWEXqtH6ektLS7GWhknoSTclF5QYnD9/PtbSU3d6ck7nSP8dekKerjEN8dAwFD0Fp3X+KLlIiQ0NbG1tbcVa4p1dKsJml4qw2aUibHapCJtdKsJml4oYNHqjNbUo4kkxDkVGNDhBn0VbGqW18GjtNIpqaG09Om53dzfW0tp7dO3pWlE82HONKfZM68W1xpEdvWfaNorW3Zufn+86D4pL6VpRdJhQlJd4Z5eKsNmlImx2qQibXSrCZpeKsNmlIgaN3npjnBRDUTxF8RpNV9Hab3RcQmudUdT0+PHjWKOoL20z9OLFi673W1xcjDWaREu/DR0zOzsbaz1ruLWWJxLpPNbX12ONIq8U87XG/4Nv376NfJ22tUrHEO/sUhE2u1SEzS4VYbNLRdjsUhE2u1TEX7P9E20LlCa2aHucni2B6LNay+dPESBFLqurq7FGC19SZPf8+fORrz958iQeQwtmkjRh11qeSKRJRYrQrly5EmsUh/3piTJa5JQ+i2rp96TzcOpNUmSzS0XY7FIRNrtUhM0uFWGzS0UMGr1RVEPxVYrDaAqNFnOkiSc6xxSjUZS3trYWazSJRhEgefTo0cjXX758GY+5evVqrNG+Z/S90/Wn6Gp6ejrWeqcY03VM+6u1xnusUdxL0TJJn9fbL4l3dqkIm10qwmaXirDZpSJsdqmIQZ/G09NbGlzpWc+MnuzSU1Nagy49iaWhlaWlpVijIRl6Gr+xsRFr6Sk4bWlET9zpPOj3TINN9LvQE3JCAzRpYISeuNNTcBpAoVSABmF6koueAR/v7FIRNrtUhM0uFWGzS0XY7FIRNrtUxKDRW9qaqDUeXOmJT969exdrNERAcV7y6tWrWHv//n2sUVSzubkZazRck+I8ipPosygOo2uVfs8jR450fRZtHUb/nYS2T6Jr1RuvUYyW3pPOo4d3dqkIm10qwmaXirDZpSJsdqkIm10qYtDojSaoaJooRWwUn1BUQ1NSNBGXtumhiOTkyZOxRpHR7u5urNF3S9eK3o8iUfrN6Pqn4yYnJ+MxFL3RhCBdxzSZRzEZXQ+69jQFSJFj+v/QJGhP3OidXSrCZpeKsNmlImx2qQibXSrCZpeKGDR6oy18KPJKsQtNXdH7bW1txVqK11rLcc2pU6fiMRS5UFRDkdfY2Fis3bhxY+TrFJP96W2LWsvfu2d7LXq//6WW0O9C59gr/TYUAdLvmXhnl4qw2aUibHapCJtdKsJml4qw2aUi/pq93ihqSij6SXuNtcbR1dTUVKylyK53EUU6bm5uLtYovkrXmCJF2t+uN95M15jiV7oeNG1Gx6X/Ve9+bvTfoUk6WuQ0fTe6VjQRl3hnl4qw2aUibHapCJtdKsJml4qw2aUi/projeKTtE9WT+TyK7QnV/q83jiG4h+Kw0iK+uha0WKUVKNFD1OcRFEeXY/eBSLT70mxFu3nRpEXTcTR904TbHQ9KNJNvLNLRdjsUhE2u1SEzS4VYbNLRQz6NJ6e7PYMtezs7MRj6Gk8fRY9AU3vSWvQ0VPk3u2r6D1nZ2dHvj4zMxOP2dzcjDUa4KCBkXSOHz9+jMfQb0Y1SlBSAkRDPHR9KYGg4StKotLTf3q6T//TxDu7VITNLhVhs0tF2OxSETa7VITNLhUxaPRGEQkNM6QaRR29W01RLcU/FNXQUAVdj+np6a7jUo0GOCgWopiSpOtPAxx0jhQB0jVO8SZdQ4r5KHqja0Uxa/r/0O9CtcQ7u1SEzS4VYbNLRdjsUhE2u1SEzS4VMWj0RtFET+RF6P1oYogimXQcRW+0Ph2hqIbinxRD0RRdT3TVWt92XhSh9WyR9KvzSN+td/KRImI6/54tuwh9VuKdXSrCZpeKsNmlImx2qQibXSrCZpeKGDR665ViHNpShxZDpNiCFkTsiXEouqJIkeI8igfTOdL70YQgbRtF3y1df3o/uvYUedEkXfptaDHHni3AWmtta2sr1uj6p/d0wUlJXWx2qQibXSrCZpeKsNmlImx2qYhBozeKyijGSTEDxWtkb28v1mgCLEUkdO69C05SnEdTbz2LF5LeOCzFRj0Te63xNaYoNZ0HfVbvgqT03ej6p9+MolmKABPv7FIRNrtUhM0uFWGzS0XY7FIRgz6Np0EBerKb1h/b3t6Ox9DTT0oF6Mnuzs7Ob79f75ZXhNZjS0MttBYePUWmxKNnAIgGQnqvIz21Tv+DnkGS1vq3w6Kn/6lG/2H6zRLv7FIRNrtUhM0uFWGzS0XY7FIRNrtUxKDRG60V9uHDh1jrGWYgNNxBQzIp/qEojCKe3qEKGgrZ3d397c+ia0/xT89WWfR+4+PjsUYRIJ1Huh70u1BM2TtAQ/FgOhd6P/oPxHP47SMk/SvZ7FIRNrtUhM0uFWGzS0XY7FIRg0ZvFHfQmlopCqEYh+I1mlyi2CVFIRSRUGS0ubkZa2/evIk1msxLk4BPnjyJxywvL8caXavLly/H2szMzMjXT58+HY9ZWFiItenp6Vijc0yRF/13erfsohpFqSm6pfejic/EO7tUhM0uFWGzS0XY7FIRNrtUhM0uFTFo9EYTZZOTk7GWYgtavDAtvEjv1xqfY4pCKPpJE3uttfb69etYe/jwYaxR1JeuybNnz+Ixa2trsdYTiRI6d5ooo7hxdnY21lLERr8Z1Xq3hqL/d5rao+9Mk5aJd3apCJtdKsJml4qw2aUibHapCJtdKmLQ6I0mwGgKKcVJtNAgxXK9e5uluIMWbHz37l2sEdr3jOK8dE3SFFprfI40LZf2vmuttevXr498na4VTXLR9aCJOIpSk+PHj8caRZE98VpreUKzN6ZMvLNLRdjsUhE2u1SEzS4VYbNLRfw1T+Pp6Xlau47ej9YRo/Xp6LieNejoye6pU6dijQZ5aHBldXV15Ov0pJvWwqOn8XStXrx4MfJ1+l7nzp2LNVqTb25uLtampqZiLaG13+j3pP8j/Udo8CahRCbxzi4VYbNLRdjsUhE2u1SEzS4VYbNLRQwavVG8RmtqpYhnd3c3HtM7OEEx1Pr6+sjXKcqjqIkiHroeFNmliIeu1aVLl2KNhmRoECah7Z9ouIMGSej3TO9JURj9d3ojXYrXUl/Q9aCBnMQ7u1SEzS4VYbNLRdjsUhE2u1SEzS4VMWj0RvHPyZMnYy1NE1EMQjWaGKJII0Ve9L02NjZibWVlJdYoxulZz4xivoWFhVij86drfOXKlZGvT0xMxGPo2l+4cCHWKNJNEVuapGyNY086R6rRfy5FbD2/M/HOLhVhs0tF2OxSETa7VITNLhVhs0tFDBq90ZQUxR1pWx2KXKhGWxDRNlRpCyWKte7fvx9rT58+jbXe7Z/Sd6NzTNN8reUIrTWOw1IsRws2Xrt2LdbI8vJyrKXvRluALS4uxhqdP/3n6H9FtcTtnyRFNrtUhM0uFWGzS0XY7FIRNrtUxKDR26dPn2KNpnjSIoW0mOPW1lasUXRF55jiJNrji/Yho33UaKFHin/S9CBdD4oiaZKLpt7m5+dHvn7+/Pl4DMVhS0tLsUZThyl6o9/szJkzsfb58+dYo0UsKfpM+9HRIpX0WfEcfvsISf9KNrtUhM0uFWGzS0XY7FIRNrtUxKDRG+1RRjHO9vb2yNcpzqBJIorX6D3TpBFFJLdu3Yq1sbGxWHv06FGspevRWj5HWuiRFjakyI7in7RnHkWsFL3RHnwUh6VrTPEaXSs6R5pEo/9jek/6rJ5JOe/sUhE2u1SEzS4VYbNLRdjsUhGDPo0fHx+PNXrqm4Yx0ppwrfFwRM9T5NbyGmM0kENDNzSMcePGjVijtfxevnw58vW3b9/GY+hpPCUN9PQ5XUd6ct77WfRkOqU8tF4cnQc9qaffmv5z6VwooaLfLPHOLhVhs0tF2OxSETa7VITNLhVhs0tFDBq9UaSRtnhqLccMb968icf0DsnQcEqKT2gbp7t378Yarf128eLFWKNzTLEcDf/QmnbPnz+PtRTztdba5cuXf/uzVldXY41+TxpqSbFo7yAMnUda/681jvpSTLy3txeP6eGdXSrCZpeKsNmlImx2qQibXSrCZpeK2EfrXEn67/DOLhVhs0tF2OxSETa7VITNLhVhs0tF2OxSETa7VITNLhVhs0tF2OxSETa7VITNLhVhs0tF2OxSETa7VITNLhVhs0tF2OxSETa7VITNLhVhs0tF2OxSEf8AsxST+PEYSnkAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"src,tgt = tds[0]\n",
"plt.imshow(np.rot90(src[0,:,16,:]), cmap='gist_gray')\n",
"plt.axis('off');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test fastai"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"import fastai as fai\n",
"import fastai.vision as faiv\n",
"import torchvision\n",
"from torch.utils.data import DataLoader"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"fastai version: 1.0.12\n",
"pytorch version: 1.0.0.dev20181014\n",
"torchvision version: 0.2.1\n"
]
}
],
"source": [
"print(f'fastai version: {fai.__version__}')\n",
"print(f'pytorch version: {torch.__version__}')\n",
"print(f'torchvision version: {torchvision.__version__}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define a 3D Unet model\n",
"\n",
"The Unet model is described in this package: https://github.com/jcreinhold/synthit. I'll probably break this off at some point."
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"from synthit.models import Unet"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Setup a 3D fastai learner"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"idb = faiv.ImageDataBunch.create(tds, vds, bs=2, num_workers=1)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"loss = nn.MSELoss()\n",
"loss.__name__ = 'MSE'"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"model = Unet(1, channel_base_power=1)\n",
"learner = fai.Learner(idb, model, loss_func=loss, metrics=[loss])"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(617962.5000, grad_fn=<MseLossBackward>)\n",
"tensor(736843.1250, grad_fn=<MseLossBackward>)\n"
]
}
],
"source": [
"# test that the loss function and dataloaders are working as expected\n",
"\n",
"for x,y in idb.train_dl:\n",
" print(loss(learner.model(x),y))\n",
" break\n",
" \n",
"for x,y in idb.valid_dl:\n",
" print(loss(learner.model(x),y))\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "097faba9fb694adca3a0deb78a7d3c02",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(HBox(children=(IntProgress(value=0, max=34), HTML(value='0.00% [0/34 00:00<00:00]'))), HTML(val…"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learner.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learner.recorder.plot()"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(HBox(children=(IntProgress(value=0, max=1), HTML(value='0.00% [0/1 00:00<00:00]'))), HTML(value…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total time: 00:03\n",
"epoch train loss valid loss MSE \n",
"1 994864.687500 736764.625000 736764.625000 (00:03)\n",
"\n"
]
}
],
"source": [
"learner.fit(1)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:neuropp]",
"language": "python",
"name": "conda-env-neuropp-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment