Skip to content

Instantly share code, notes, and snippets.

@pouannes
Last active December 19, 2018 17:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pouannes/a937dbb3247e8582d184d9b7db585723 to your computer and use it in GitHub Desktop.
Save pouannes/a937dbb3247e8582d184d9b7db585723 to your computer and use it in GitHub Desktop.
docs_src/callbacks.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "# List of callbacks"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-12-19T17:01:49.315444Z",
"end_time": "2018-12-19T17:01:49.320535Z"
},
"hide_input": true,
"trusted": true
},
"cell_type": "code",
"source": "from fastai.gen_doc.nbdoc import *\nfrom fastai.vision import *\nfrom fastai.text import *\nfrom fastai.callbacks import * \nfrom fastai.basic_train import * \nfrom fastai.train import * \nfrom fastai import callbacks",
"execution_count": 5,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "fastai's training loop is highly extensible, with a rich *callback* system. See the [`callback`](/callback.html#callback) docs if you're interested in writing your own callback. See below for a list of callbacks that are provided with fastai, grouped by the module they're defined in.\n\nEvery callback that is passed to [`Learner`](/basic_train.html#Learner) with the `callback_fns` parameter will be automatically stored as an attribute. The attribute name is snake-cased, so for instance [`ActivationStats`](/callbacks.hooks.html#ActivationStats) will appear as `learn.activation_stats` (assuming your object is named `learn`)."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## [`Callback`](/callback.html#Callback)\n\nThis sub-package contains more sophisticated callbacks that each are in their own module. They are (click the link for more details):"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### [`LRFinder`](/callbacks.lr_finder.html#LRFinder)\n\nUse Leslie Smith's [learning rate finder](https://www.jeremyjordan.me/nn-learning-rate/) to find a good learning rate for training your model. Let's see an example of use on the MNIST dataset with a simple CNN."
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-12-19T17:01:50.595809Z",
"end_time": "2018-12-19T17:01:50.731571Z"
},
"trusted": true
},
"cell_type": "code",
"source": "path = untar_data(URLs.MNIST_SAMPLE)\ndata = ImageDataBunch.from_folder(path)\ndef simple_learner(): return Learner(data, simple_cnn((3,16,16,2)), metrics=[accuracy])\nlearn = simple_learner()",
"execution_count": 6,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "The fastai librairy already has a Learner method called `lr_find` that uses `LRFinder` to plot the loss as a function of the learning rate"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-12-19T17:01:51.507907Z",
"end_time": "2018-12-19T17:01:52.921581Z"
},
"trusted": true
},
"cell_type": "code",
"source": "learn.lr_find()",
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n",
"name": "stdout"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-12-19T17:01:53.033769Z",
"end_time": "2018-12-19T17:01:53.362186Z"
},
"trusted": true
},
"cell_type": "code",
"source": "learn.recorder.plot()",
"execution_count": 8,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "In this example, a learning rate around 2e-2 seems like the right fit"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-12-19T17:01:59.124739Z",
"end_time": "2018-12-19T17:01:59.127542Z"
},
"trusted": true
},
"cell_type": "code",
"source": "lr = 2e-2",
"execution_count": 9,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler)\n\nTrain with Leslie Smith's [1cycle annealing](https://sgugger.github.io/the-1cycle-policy.html) method. Let's train our simple learner using the one cycle policy."
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-12-19T17:02:00.292469Z",
"end_time": "2018-12-19T17:02:08.573328Z"
},
"trusted": true
},
"cell_type": "code",
"source": "learn.fit_one_cycle(3, lr)",
"execution_count": 10,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<IPython.core.display.HTML object>",
"text/html": "Total time: 00:08 <p><table style='width:300px; margin-bottom:10px'>\n <tr>\n <th>epoch</th>\n <th>train_loss</th>\n <th>valid_loss</th>\n <th>accuracy</th>\n </tr>\n <tr>\n <th>1</th>\n <th>0.111205</th>\n <th>0.056460</th>\n <th>0.979882</th>\n </tr>\n <tr>\n <th>2</th>\n <th>0.040632</th>\n <th>0.023650</th>\n <th>0.987733</th>\n </tr>\n <tr>\n <th>3</th>\n <th>0.021217</th>\n <th>0.020044</th>\n <th>0.991659</th>\n </tr>\n</table>\n"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "The learning rate and the momentum were changed during the epochs as follows (more info on the [dedicated documentation page](https://docs.fast.ai/callbacks.one_cycle.html))."
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-12-19T17:02:10.549162Z",
"end_time": "2018-12-19T17:02:10.733212Z"
},
"trusted": true
},
"cell_type": "code",
"source": "learn.recorder.plot_lr(show_moms=True)",
"execution_count": 11,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 864x288 with 2 Axes>",
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### [`MixUpCallback`](/callbacks.mixup.html#MixUpCallback)\n\nData augmentation using the method from [mixup: Beyond Empirical Risk Minimization](https://arxiv.org/abs/1710.09412). It is very simple to add mixup in fastai :"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-12-19T17:02:13.653761Z",
"end_time": "2018-12-19T17:02:13.660217Z"
},
"trusted": true
},
"cell_type": "code",
"source": "learn = Learner(data, simple_cnn((3, 16, 16, 2)), metrics=[accuracy]).mixup()",
"execution_count": 12,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-12-19T17:02:14.277579Z",
"end_time": "2018-12-19T17:02:23.533832Z"
},
"trusted": true
},
"cell_type": "code",
"source": "learn.fit_one_cycle(3, lr)",
"execution_count": 13,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<IPython.core.display.HTML object>",
"text/html": "Total time: 00:09 <p><table style='width:300px; margin-bottom:10px'>\n <tr>\n <th>epoch</th>\n <th>train_loss</th>\n <th>valid_loss</th>\n <th>accuracy</th>\n </tr>\n <tr>\n <th>1</th>\n <th>0.359284</th>\n <th>0.148368</th>\n <th>0.965162</th>\n </tr>\n <tr>\n <th>2</th>\n <th>0.315967</th>\n <th>0.090522</th>\n <th>0.991168</th>\n </tr>\n <tr>\n <th>3</th>\n <th>0.305872</th>\n <th>0.084860</th>\n <th>0.990186</th>\n </tr>\n</table>\n"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### [`CSVLogger`](/callbacks.csv_logger.html#CSVLogger)\n\nLog the results of training in a csv file. Simply pass the CSVLogger callback to the Learner."
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-12-19T17:02:28.760214Z",
"end_time": "2018-12-19T17:02:28.766141Z"
},
"trusted": true
},
"cell_type": "code",
"source": "learn = Learner(data, simple_cnn((3, 16, 16, 2)), metrics=[accuracy, error_rate], callback_fns=[CSVLogger])",
"execution_count": 14,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-12-19T17:02:29.256292Z",
"end_time": "2018-12-19T17:02:37.887856Z"
},
"scrolled": true,
"trusted": true
},
"cell_type": "code",
"source": "learn.fit(3)",
"execution_count": 15,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<IPython.core.display.HTML object>",
"text/html": "Total time: 00:08 <p><table style='width:375px; margin-bottom:10px'>\n <tr>\n <th>epoch</th>\n <th>train_loss</th>\n <th>valid_loss</th>\n <th>accuracy</th>\n <th>error_rate</th>\n </tr>\n <tr>\n <th>1</th>\n <th>0.125326</th>\n <th>0.103473</th>\n <th>0.963690</th>\n <th>0.036310</th>\n </tr>\n <tr>\n <th>2</th>\n <th>0.077392</th>\n <th>0.059223</th>\n <th>0.977920</th>\n <th>0.022080</th>\n </tr>\n <tr>\n <th>3</th>\n <th>0.065756</th>\n <th>0.081031</th>\n <th>0.969578</th>\n <th>0.030422</th>\n </tr>\n</table>\n"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "You can then read the csv."
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-12-19T17:02:38.014240Z",
"end_time": "2018-12-19T17:02:38.027150Z"
},
"trusted": true
},
"cell_type": "code",
"source": "learn.csv_logger.read_logged_file()",
"execution_count": 16,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 16,
"data": {
"text/plain": " epoch train_loss valid_loss accuracy error_rate\n0 1 0.125326 0.103473 0.963690 0.036310\n1 2 0.077392 0.059223 0.977920 0.022080\n2 3 0.065756 0.081031 0.969578 0.030422",
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>epoch</th>\n <th>train_loss</th>\n <th>valid_loss</th>\n <th>accuracy</th>\n <th>error_rate</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>1</td>\n <td>0.125326</td>\n <td>0.103473</td>\n <td>0.963690</td>\n <td>0.036310</td>\n </tr>\n <tr>\n <th>1</th>\n <td>2</td>\n <td>0.077392</td>\n <td>0.059223</td>\n <td>0.977920</td>\n <td>0.022080</td>\n </tr>\n <tr>\n <th>2</th>\n <td>3</td>\n <td>0.065756</td>\n <td>0.081031</td>\n <td>0.969578</td>\n <td>0.030422</td>\n </tr>\n </tbody>\n</table>\n</div>"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### [`GeneralScheduler`](/callbacks.general_sched.html#GeneralScheduler)\n\nCreate your own multi-stage annealing schemes with a convenient API. To illustrate, let's implement a 2 phase schedule."
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-12-19T17:02:39.690478Z",
"end_time": "2018-12-19T17:02:39.696033Z"
},
"trusted": true
},
"cell_type": "code",
"source": "def fit_odd_shedule(learn, lr, mom):\n n = len(learn.data.train_dl)\n phases = [TrainingPhase(n, lr, mom, lr_anneal=annealing_cos), TrainingPhase(n*2, lr, mom, lr_anneal=annealing_poly(2))]\n sched = GeneralScheduler(learn, phases)\n learn.callbacks.append(sched)\n total_epochs = 3\n learn.fit(total_epochs)",
"execution_count": 17,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-12-19T17:02:40.329973Z",
"end_time": "2018-12-19T17:02:49.158050Z"
},
"trusted": true
},
"cell_type": "code",
"source": "learn = Learner(data, simple_cnn((3,16,16,2)), metrics=accuracy)\nfit_odd_shedule(learn, 1e-3, 0.9)",
"execution_count": 18,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<IPython.core.display.HTML object>",
"text/html": "Total time: 00:08 <p><table style='width:300px; margin-bottom:10px'>\n <tr>\n <th>epoch</th>\n <th>train_loss</th>\n <th>valid_loss</th>\n <th>accuracy</th>\n </tr>\n <tr>\n <th>1</th>\n <th>0.178648</th>\n <th>0.161728</th>\n <th>0.944553</th>\n </tr>\n <tr>\n <th>2</th>\n <th>0.142739</th>\n <th>0.132620</th>\n <th>0.957802</th>\n </tr>\n <tr>\n <th>3</th>\n <th>0.135239</th>\n <th>0.129183</th>\n <th>0.960255</th>\n </tr>\n</table>\n"
},
"metadata": {}
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2018-12-19T17:02:50.794571Z",
"end_time": "2018-12-19T17:02:50.894112Z"
},
"trusted": true
},
"cell_type": "code",
"source": "learn.recorder.plot_lr()",
"execution_count": 19,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### [`MixedPrecision`](/callbacks.fp16.html#MixedPrecision)\n\nUse fp16 to [take advantage of tensor cores](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html) on recent NVIDIA GPUs for a 200% or more speedup."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### [`HookCallback`](/callbacks.hooks.html#HookCallback)\n\nConvenient wrapper for registering and automatically deregistering [PyTorch hooks](https://pytorch.org/tutorials/beginner/former_torchies/nn_tutorial.html#forward-and-backward-function-hooks). Also contains pre-defined hook callback: [`ActivationStats`](/callbacks.hooks.html#ActivationStats)."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### [`RNNTrainer`](/callbacks.rnn.html#RNNTrainer)\n\nCallback taking care of all the tweaks to train an RNN."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### [`TerminateOnNaNCallback`](/callbacks.tracker.html#TerminateOnNaNCallback)\n\nStop training if the loss reaches NaN."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### [`EarlyStoppingCallback`](/callbacks.tracker.html#EarlyStoppingCallback)\n\nStop training if a given metric/validation loss doesn't improve."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### [`SaveModelCallback`](/callbacks.tracker.html#SaveModelCallback)\n\nSave the model at every epoch, or the best model for a given metric/validation loss."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### [`ReduceLROnPlateauCallback`](/callbacks.tracker.html#ReduceLROnPlateauCallback)\n\nReduce the learning rate each time a given metric/validation loss doesn't improve by a certain factor."
},
{
"metadata": {
"cell_style": "center"
},
"cell_type": "markdown",
"source": "## [`train`](/train.html#train) and [`basic_train`](/basic_train.html#basic_train)"
},
{
"metadata": {
"cell_style": "center"
},
"cell_type": "markdown",
"source": "### [`Recorder`](/basic_train.html#Recorder)\n\nTrack per-batch and per-epoch smoothed losses and metrics."
},
{
"metadata": {
"cell_style": "center"
},
"cell_type": "markdown",
"source": "### [`ShowGraph`](/train.html#ShowGraph)\n\nDynamically display a learning chart during training."
},
{
"metadata": {
"cell_style": "center"
},
"cell_type": "markdown",
"source": "### [`BnFreeze`](/train.html#BnFreeze)\n\nFreeze batchnorm layer moving average statistics for non-trainable layers."
},
{
"metadata": {
"cell_style": "center"
},
"cell_type": "markdown",
"source": "### [`GradientClipping`](/train.html#GradientClipping)\n\nClips gradient during training."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"gist": {
"id": "a937dbb3247e8582d184d9b7db585723",
"data": {
"description": "docs_src/callbacks.ipynb",
"public": true
}
},
"jekyll": {
"keywords": "fastai",
"summary": "Callbacks implemented in the fastai library",
"title": "callbacks"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.7.1",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"toc": {
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"base_numbering": 1,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
},
"varInspector": {
"window_display": false,
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"library": "var_list.py",
"delete_cmd_prefix": "del ",
"delete_cmd_postfix": "",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"library": "var_list.r",
"delete_cmd_prefix": "rm(",
"delete_cmd_postfix": ") ",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
]
},
"_draft": {
"nbviewer_url": "https://gist.github.com/a937dbb3247e8582d184d9b7db585723"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment