Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save warner-benjamin/d0f0836924d27c205a2e9a0ef98e3665 to your computer and use it in GitHub Desktop.
Save warner-benjamin/d0f0836924d27c205a2e9a0ef98e3665 to your computer and use it in GitHub Desktop.
Problem with channels_last and fastai custom tensors
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Problem with channels_last and fastai custom tensors.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# Problem with channels_last and fastai custom tensors\n",
"\n",
"This notebook contains expiraments showing that passing fastai's custom tensor classes (TensorBase & TensorImage) results in a degredation of channels last performance. The only change needed to get the expected channels last performance increase is to cast a TensorBase derived input to torch.tensor and then channels last trains faster as expected."
],
"metadata": {
"id": "DgDxBVq9TTpe"
}
},
{
"cell_type": "code",
"source": [
"! nvidia-smi"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WvJkCesSEx2V",
"outputId": "12ea5700-85f2-4ee6-8acf-e9c183618d8b"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Wed Jun 8 16:08:37 2022 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 62C P8 11W / 70W | 0MiB / 15109MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u-D43aOA-N_z"
},
"outputs": [],
"source": [
"! pip install fastcore fastai fastxtend[vision] -U"
]
},
{
"cell_type": "code",
"source": [
"! pip uninstall pillow -y\n",
"! CC=\"cc -mavx2\" pip install -U --force-reinstall pillow-simd"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "m0FrzR7EHzBZ",
"outputId": "6aee40bb-8ed0-4af1-c919-df2dd629ed12"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Found existing installation: Pillow 7.1.2\n",
"Uninstalling Pillow-7.1.2:\n",
" Successfully uninstalled Pillow-7.1.2\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting pillow-simd\n",
" Downloading Pillow-SIMD-9.0.0.post1.tar.gz (849 kB)\n",
"\u001b[K |████████████████████████████████| 849 kB 7.6 MB/s \n",
"\u001b[?25hBuilding wheels for collected packages: pillow-simd\n",
" Building wheel for pillow-simd (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for pillow-simd: filename=Pillow_SIMD-9.0.0.post1-cp37-cp37m-linux_x86_64.whl size=1255565 sha256=5da71e6b46623d6920cc321db69258300c8543cdab25e5fc5711d397036233dc\n",
" Stored in directory: /root/.cache/pip/wheels/9b/3f/fd/ca7133b4f7f509eb1de652bb8c128529a0c04b25ef0a6c535a\n",
"Successfully built pillow-simd\n",
"Installing collected packages: pillow-simd\n",
"Successfully installed pillow-simd-9.0.0.post1\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from __future__ import annotations\n",
"from fastai.vision.all import *\n",
"from fastxtend.vision.all import *\n",
"from fastxtend.callback import simpleprofiler"
],
"metadata": {
"id": "55BRyAzH-S-f"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# DataLoaders"
],
"metadata": {
"id": "J1tiDmSY-sxr"
}
},
{
"cell_type": "code",
"source": [
"imagewoof_stats = ([0.496,0.461,0.399],[0.257,0.249,0.258])\n",
"imagenette_stats = ([0.465,0.458,0.429],[0.285,0.28,0.301])\n",
"\n",
"def get_dls(size:int, woof:bool, bs:int, sh:float=0., augs:list=None, workers:int=None, stats:bool=True) -> DataLoaders:\n",
" if size<=224: path = URLs.IMAGEWOOF_320 if woof else URLs.IMAGENETTE_320\n",
" else : path = URLs.IMAGEWOOF if woof else URLs.IMAGENETTE\n",
" source = untar_data(path)\n",
" if workers is None: workers = min(8, num_cpus())\n",
" batch_tfms = []\n",
" if stats:\n",
" if woof: \n",
" batch_tfms += [Normalize.from_stats(*imagewoof_stats)]\n",
" else:\n",
" batch_tfms += [Normalize.from_stats(*imagenette_stats)]\n",
" if augs: batch_tfms += augs\n",
" if sh: batch_tfms.append(RandomErasing(p=0.3, max_count=3, sh=sh))\n",
" dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),\n",
" splitter=GrandparentSplitter(valid_name='val'),\n",
" get_items=get_image_files, get_y=parent_label,\n",
" item_tfms=[Resize(size)],\n",
" batch_tfms=batch_tfms)\n",
" return dblock.dataloaders(source, path=source, bs=bs, num_workers=workers)"
],
"metadata": {
"id": "pNgVl8Gs-dxf"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Train fastai default mixed precision"
],
"metadata": {
"id": "beQssA6n-wam"
}
},
{
"cell_type": "markdown",
"source": [
"## Warmup"
],
"metadata": {
"id": "bOYaHgr--2yJ"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" opt_func=Adam, metrics=[Accuracy()]).to_fp16()\n",
"\n",
"learn.fit_one_cycle(1, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 81
},
"id": "gXqWSo7R-v28",
"outputId": "a2fbfba6-15bc-48b1-e5eb-295f7cab29be"
},
"execution_count": 3,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.912659</td>\n",
" <td>1.717486</td>\n",
" <td>0.414777</td>\n",
" <td>01:32</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"## Test"
],
"metadata": {
"id": "h8luVtBs-4F7"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" opt_func=Adam, metrics=[Accuracy()]).to_fp16().profile()\n",
"\n",
"learn.fit_one_cycle(2, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 632
},
"id": "lkQ7iTGJ--dd",
"outputId": "738d77bf-6cbc-4faa-cd21-79c1ac57d7ef"
},
"execution_count": 4,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.976799</td>\n",
" <td>1.873074</td>\n",
" <td>0.344968</td>\n",
" <td>01:26</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.385735</td>\n",
" <td>1.268956</td>\n",
" <td>0.610701</td>\n",
" <td>01:25</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<pandas.io.formats.style.Styler at 0x7f763ecf4a90>"
],
"text/html": [
"<style type=\"text/css\">\n",
"</style>\n",
"<table id=\"T_e689c_\" class=\"dataframe\">\n",
" <caption>Simple Profiler Results</caption>\n",
" <thead>\n",
" <tr>\n",
" <th class=\"col_heading level0 col0\" >Phase</th>\n",
" <th class=\"col_heading level0 col1\" >Action</th>\n",
" <th class=\"col_heading level0 col2\" >Mean Duration</th>\n",
" <th class=\"col_heading level0 col3\" >Duration Std Dev</th>\n",
" <th class=\"col_heading level0 col4\" >Number of Calls</th>\n",
" <th class=\"col_heading level0 col5\" >Total Time</th>\n",
" <th class=\"col_heading level0 col6\" >Percent of Total</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td id=\"T_e689c_row0_col0\" class=\"data row0 col0\" >fit</td>\n",
" <td id=\"T_e689c_row0_col1\" class=\"data row0 col1\" >fit</td>\n",
" <td id=\"T_e689c_row0_col2\" class=\"data row0 col2\" >-</td>\n",
" <td id=\"T_e689c_row0_col3\" class=\"data row0 col3\" >-</td>\n",
" <td id=\"T_e689c_row0_col4\" class=\"data row0 col4\" >1</td>\n",
" <td id=\"T_e689c_row0_col5\" class=\"data row0 col5\" >171.6 s</td>\n",
" <td id=\"T_e689c_row0_col6\" class=\"data row0 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row1_col0\" class=\"data row1 col0\" ></td>\n",
" <td id=\"T_e689c_row1_col1\" class=\"data row1 col1\" >epoch</td>\n",
" <td id=\"T_e689c_row1_col2\" class=\"data row1 col2\" >85.81 s</td>\n",
" <td id=\"T_e689c_row1_col3\" class=\"data row1 col3\" >422.4ms</td>\n",
" <td id=\"T_e689c_row1_col4\" class=\"data row1 col4\" >2</td>\n",
" <td id=\"T_e689c_row1_col5\" class=\"data row1 col5\" >171.6 s</td>\n",
" <td id=\"T_e689c_row1_col6\" class=\"data row1 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row2_col0\" class=\"data row2 col0\" ></td>\n",
" <td id=\"T_e689c_row2_col1\" class=\"data row2 col1\" >train</td>\n",
" <td id=\"T_e689c_row2_col2\" class=\"data row2 col2\" >66.26 s</td>\n",
" <td id=\"T_e689c_row2_col3\" class=\"data row2 col3\" >353.6ms</td>\n",
" <td id=\"T_e689c_row2_col4\" class=\"data row2 col4\" >2</td>\n",
" <td id=\"T_e689c_row2_col5\" class=\"data row2 col5\" >132.5 s</td>\n",
" <td id=\"T_e689c_row2_col6\" class=\"data row2 col6\" >77%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row3_col0\" class=\"data row3 col0\" ></td>\n",
" <td id=\"T_e689c_row3_col1\" class=\"data row3 col1\" >validate</td>\n",
" <td id=\"T_e689c_row3_col2\" class=\"data row3 col2\" >19.54 s</td>\n",
" <td id=\"T_e689c_row3_col3\" class=\"data row3 col3\" >65.33ms</td>\n",
" <td id=\"T_e689c_row3_col4\" class=\"data row3 col4\" >2</td>\n",
" <td id=\"T_e689c_row3_col5\" class=\"data row3 col5\" >39.08 s</td>\n",
" <td id=\"T_e689c_row3_col6\" class=\"data row3 col6\" >23%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row4_col0\" class=\"data row4 col0\" >train</td>\n",
" <td id=\"T_e689c_row4_col1\" class=\"data row4 col1\" >batch</td>\n",
" <td id=\"T_e689c_row4_col2\" class=\"data row4 col2\" >443.2ms</td>\n",
" <td id=\"T_e689c_row4_col3\" class=\"data row4 col3\" >67.11ms</td>\n",
" <td id=\"T_e689c_row4_col4\" class=\"data row4 col4\" >294</td>\n",
" <td id=\"T_e689c_row4_col5\" class=\"data row4 col5\" >130.3 s</td>\n",
" <td id=\"T_e689c_row4_col6\" class=\"data row4 col6\" >76%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row5_col0\" class=\"data row5 col0\" ></td>\n",
" <td id=\"T_e689c_row5_col1\" class=\"data row5 col1\" >step</td>\n",
" <td id=\"T_e689c_row5_col2\" class=\"data row5 col2\" >280.0ms</td>\n",
" <td id=\"T_e689c_row5_col3\" class=\"data row5 col3\" >23.06ms</td>\n",
" <td id=\"T_e689c_row5_col4\" class=\"data row5 col4\" >294</td>\n",
" <td id=\"T_e689c_row5_col5\" class=\"data row5 col5\" >82.32 s</td>\n",
" <td id=\"T_e689c_row5_col6\" class=\"data row5 col6\" >48%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row6_col0\" class=\"data row6 col0\" ></td>\n",
" <td id=\"T_e689c_row6_col1\" class=\"data row6 col1\" >backward</td>\n",
" <td id=\"T_e689c_row6_col2\" class=\"data row6 col2\" >73.15ms</td>\n",
" <td id=\"T_e689c_row6_col3\" class=\"data row6 col3\" >13.76ms</td>\n",
" <td id=\"T_e689c_row6_col4\" class=\"data row6 col4\" >294</td>\n",
" <td id=\"T_e689c_row6_col5\" class=\"data row6 col5\" >21.51 s</td>\n",
" <td id=\"T_e689c_row6_col6\" class=\"data row6 col6\" >13%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row7_col0\" class=\"data row7 col0\" ></td>\n",
" <td id=\"T_e689c_row7_col1\" class=\"data row7 col1\" >pred</td>\n",
" <td id=\"T_e689c_row7_col2\" class=\"data row7 col2\" >61.11ms</td>\n",
" <td id=\"T_e689c_row7_col3\" class=\"data row7 col3\" >19.80ms</td>\n",
" <td id=\"T_e689c_row7_col4\" class=\"data row7 col4\" >294</td>\n",
" <td id=\"T_e689c_row7_col5\" class=\"data row7 col5\" >17.97 s</td>\n",
" <td id=\"T_e689c_row7_col6\" class=\"data row7 col6\" >10%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row8_col0\" class=\"data row8 col0\" ></td>\n",
" <td id=\"T_e689c_row8_col1\" class=\"data row8 col1\" >draw</td>\n",
" <td id=\"T_e689c_row8_col2\" class=\"data row8 col2\" >23.78ms</td>\n",
" <td id=\"T_e689c_row8_col3\" class=\"data row8 col3\" >62.00ms</td>\n",
" <td id=\"T_e689c_row8_col4\" class=\"data row8 col4\" >294</td>\n",
" <td id=\"T_e689c_row8_col5\" class=\"data row8 col5\" >6.991 s</td>\n",
" <td id=\"T_e689c_row8_col6\" class=\"data row8 col6\" >4%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row9_col0\" class=\"data row9 col0\" ></td>\n",
" <td id=\"T_e689c_row9_col1\" class=\"data row9 col1\" >zero_grad</td>\n",
" <td id=\"T_e689c_row9_col2\" class=\"data row9 col2\" >2.996ms</td>\n",
" <td id=\"T_e689c_row9_col3\" class=\"data row9 col3\" >1.757ms</td>\n",
" <td id=\"T_e689c_row9_col4\" class=\"data row9 col4\" >294</td>\n",
" <td id=\"T_e689c_row9_col5\" class=\"data row9 col5\" >880.8ms</td>\n",
" <td id=\"T_e689c_row9_col6\" class=\"data row9 col6\" >1%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row10_col0\" class=\"data row10 col0\" ></td>\n",
" <td id=\"T_e689c_row10_col1\" class=\"data row10 col1\" >loss</td>\n",
" <td id=\"T_e689c_row10_col2\" class=\"data row10 col2\" >1.882ms</td>\n",
" <td id=\"T_e689c_row10_col3\" class=\"data row10 col3\" >1.883ms</td>\n",
" <td id=\"T_e689c_row10_col4\" class=\"data row10 col4\" >294</td>\n",
" <td id=\"T_e689c_row10_col5\" class=\"data row10 col5\" >553.3ms</td>\n",
" <td id=\"T_e689c_row10_col6\" class=\"data row10 col6\" >0%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row11_col0\" class=\"data row11 col0\" >valid</td>\n",
" <td id=\"T_e689c_row11_col1\" class=\"data row11 col1\" >batch</td>\n",
" <td id=\"T_e689c_row11_col2\" class=\"data row11 col2\" >277.3ms</td>\n",
" <td id=\"T_e689c_row11_col3\" class=\"data row11 col3\" >155.1ms</td>\n",
" <td id=\"T_e689c_row11_col4\" class=\"data row11 col4\" >124</td>\n",
" <td id=\"T_e689c_row11_col5\" class=\"data row11 col5\" >34.38 s</td>\n",
" <td id=\"T_e689c_row11_col6\" class=\"data row11 col6\" >20%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row12_col0\" class=\"data row12 col0\" ></td>\n",
" <td id=\"T_e689c_row12_col1\" class=\"data row12 col1\" >draw</td>\n",
" <td id=\"T_e689c_row12_col2\" class=\"data row12 col2\" >210.5ms</td>\n",
" <td id=\"T_e689c_row12_col3\" class=\"data row12 col3\" >150.8ms</td>\n",
" <td id=\"T_e689c_row12_col4\" class=\"data row12 col4\" >124</td>\n",
" <td id=\"T_e689c_row12_col5\" class=\"data row12 col5\" >26.10 s</td>\n",
" <td id=\"T_e689c_row12_col6\" class=\"data row12 col6\" >15%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row13_col0\" class=\"data row13 col0\" ></td>\n",
" <td id=\"T_e689c_row13_col1\" class=\"data row13 col1\" >pred</td>\n",
" <td id=\"T_e689c_row13_col2\" class=\"data row13 col2\" >64.46ms</td>\n",
" <td id=\"T_e689c_row13_col3\" class=\"data row13 col3\" >12.82ms</td>\n",
" <td id=\"T_e689c_row13_col4\" class=\"data row13 col4\" >124</td>\n",
" <td id=\"T_e689c_row13_col5\" class=\"data row13 col5\" >7.993 s</td>\n",
" <td id=\"T_e689c_row13_col6\" class=\"data row13 col6\" >5%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row14_col0\" class=\"data row14 col0\" ></td>\n",
" <td id=\"T_e689c_row14_col1\" class=\"data row14 col1\" >loss</td>\n",
" <td id=\"T_e689c_row14_col2\" class=\"data row14 col2\" >1.952ms</td>\n",
" <td id=\"T_e689c_row14_col3\" class=\"data row14 col3\" >2.213ms</td>\n",
" <td id=\"T_e689c_row14_col4\" class=\"data row14 col4\" >124</td>\n",
" <td id=\"T_e689c_row14_col5\" class=\"data row14 col5\" >242.1ms</td>\n",
" <td id=\"T_e689c_row14_col6\" class=\"data row14 col6\" >0%</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"# Train fastai channels last"
],
"metadata": {
"id": "VmZkjfymAVSS"
}
},
{
"cell_type": "markdown",
"source": [
"## Warmup"
],
"metadata": {
"id": "ybh8djjhAVST"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" opt_func=Adam, metrics=[Accuracy()]).to_channelslast()\n",
"\n",
"learn.fit_one_cycle(1, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 81
},
"outputId": "93f31c24-725b-4467-8c22-f739ee09e829",
"id": "6axWVksaAVST"
},
"execution_count": 5,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.800101</td>\n",
" <td>1.618706</td>\n",
" <td>0.447643</td>\n",
" <td>01:46</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"## Test"
],
"metadata": {
"id": "Wr2wDyEPAVSU"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" opt_func=Adam, metrics=[Accuracy()]).to_channelslast().profile()\n",
"\n",
"learn.fit_one_cycle(2, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 632
},
"outputId": "aa76f239-d805-451c-a0d9-d54964ceabe2",
"id": "cIqWhmxhAVSU"
},
"execution_count": 6,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.945158</td>\n",
" <td>1.860097</td>\n",
" <td>0.369172</td>\n",
" <td>01:33</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.290537</td>\n",
" <td>1.169882</td>\n",
" <td>0.627261</td>\n",
" <td>01:33</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<pandas.io.formats.style.Styler at 0x7f7647fb75d0>"
],
"text/html": [
"<style type=\"text/css\">\n",
"</style>\n",
"<table id=\"T_c2056_\" class=\"dataframe\">\n",
" <caption>Simple Profiler Results</caption>\n",
" <thead>\n",
" <tr>\n",
" <th class=\"col_heading level0 col0\" >Phase</th>\n",
" <th class=\"col_heading level0 col1\" >Action</th>\n",
" <th class=\"col_heading level0 col2\" >Mean Duration</th>\n",
" <th class=\"col_heading level0 col3\" >Duration Std Dev</th>\n",
" <th class=\"col_heading level0 col4\" >Number of Calls</th>\n",
" <th class=\"col_heading level0 col5\" >Total Time</th>\n",
" <th class=\"col_heading level0 col6\" >Percent of Total</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td id=\"T_c2056_row0_col0\" class=\"data row0 col0\" >fit</td>\n",
" <td id=\"T_c2056_row0_col1\" class=\"data row0 col1\" >fit</td>\n",
" <td id=\"T_c2056_row0_col2\" class=\"data row0 col2\" >-</td>\n",
" <td id=\"T_c2056_row0_col3\" class=\"data row0 col3\" >-</td>\n",
" <td id=\"T_c2056_row0_col4\" class=\"data row0 col4\" >1</td>\n",
" <td id=\"T_c2056_row0_col5\" class=\"data row0 col5\" >186.9 s</td>\n",
" <td id=\"T_c2056_row0_col6\" class=\"data row0 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row1_col0\" class=\"data row1 col0\" ></td>\n",
" <td id=\"T_c2056_row1_col1\" class=\"data row1 col1\" >epoch</td>\n",
" <td id=\"T_c2056_row1_col2\" class=\"data row1 col2\" >93.47 s</td>\n",
" <td id=\"T_c2056_row1_col3\" class=\"data row1 col3\" >170.4ms</td>\n",
" <td id=\"T_c2056_row1_col4\" class=\"data row1 col4\" >2</td>\n",
" <td id=\"T_c2056_row1_col5\" class=\"data row1 col5\" >186.9 s</td>\n",
" <td id=\"T_c2056_row1_col6\" class=\"data row1 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row2_col0\" class=\"data row2 col0\" ></td>\n",
" <td id=\"T_c2056_row2_col1\" class=\"data row2 col1\" >train</td>\n",
" <td id=\"T_c2056_row2_col2\" class=\"data row2 col2\" >74.08 s</td>\n",
" <td id=\"T_c2056_row2_col3\" class=\"data row2 col3\" >188.7ms</td>\n",
" <td id=\"T_c2056_row2_col4\" class=\"data row2 col4\" >2</td>\n",
" <td id=\"T_c2056_row2_col5\" class=\"data row2 col5\" >148.2 s</td>\n",
" <td id=\"T_c2056_row2_col6\" class=\"data row2 col6\" >79%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row3_col0\" class=\"data row3 col0\" ></td>\n",
" <td id=\"T_c2056_row3_col1\" class=\"data row3 col1\" >validate</td>\n",
" <td id=\"T_c2056_row3_col2\" class=\"data row3 col2\" >19.38 s</td>\n",
" <td id=\"T_c2056_row3_col3\" class=\"data row3 col3\" >21.99ms</td>\n",
" <td id=\"T_c2056_row3_col4\" class=\"data row3 col4\" >2</td>\n",
" <td id=\"T_c2056_row3_col5\" class=\"data row3 col5\" >38.77 s</td>\n",
" <td id=\"T_c2056_row3_col6\" class=\"data row3 col6\" >21%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row4_col0\" class=\"data row4 col0\" >train</td>\n",
" <td id=\"T_c2056_row4_col1\" class=\"data row4 col1\" >batch</td>\n",
" <td id=\"T_c2056_row4_col2\" class=\"data row4 col2\" >498.0ms</td>\n",
" <td id=\"T_c2056_row4_col3\" class=\"data row4 col3\" >55.88ms</td>\n",
" <td id=\"T_c2056_row4_col4\" class=\"data row4 col4\" >294</td>\n",
" <td id=\"T_c2056_row4_col5\" class=\"data row4 col5\" >146.4 s</td>\n",
" <td id=\"T_c2056_row4_col6\" class=\"data row4 col6\" >78%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row5_col0\" class=\"data row5 col0\" ></td>\n",
" <td id=\"T_c2056_row5_col1\" class=\"data row5 col1\" >step</td>\n",
" <td id=\"T_c2056_row5_col2\" class=\"data row5 col2\" >365.1ms</td>\n",
" <td id=\"T_c2056_row5_col3\" class=\"data row5 col3\" >26.65ms</td>\n",
" <td id=\"T_c2056_row5_col4\" class=\"data row5 col4\" >294</td>\n",
" <td id=\"T_c2056_row5_col5\" class=\"data row5 col5\" >107.3 s</td>\n",
" <td id=\"T_c2056_row5_col6\" class=\"data row5 col6\" >57%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row6_col0\" class=\"data row6 col0\" ></td>\n",
" <td id=\"T_c2056_row6_col1\" class=\"data row6 col1\" >backward</td>\n",
" <td id=\"T_c2056_row6_col2\" class=\"data row6 col2\" >53.55ms</td>\n",
" <td id=\"T_c2056_row6_col3\" class=\"data row6 col3\" >15.32ms</td>\n",
" <td id=\"T_c2056_row6_col4\" class=\"data row6 col4\" >294</td>\n",
" <td id=\"T_c2056_row6_col5\" class=\"data row6 col5\" >15.74 s</td>\n",
" <td id=\"T_c2056_row6_col6\" class=\"data row6 col6\" >8%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row7_col0\" class=\"data row7 col0\" ></td>\n",
" <td id=\"T_c2056_row7_col1\" class=\"data row7 col1\" >pred</td>\n",
" <td id=\"T_c2056_row7_col2\" class=\"data row7 col2\" >52.80ms</td>\n",
" <td id=\"T_c2056_row7_col3\" class=\"data row7 col3\" >16.42ms</td>\n",
" <td id=\"T_c2056_row7_col4\" class=\"data row7 col4\" >294</td>\n",
" <td id=\"T_c2056_row7_col5\" class=\"data row7 col5\" >15.52 s</td>\n",
" <td id=\"T_c2056_row7_col6\" class=\"data row7 col6\" >8%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row8_col0\" class=\"data row8 col0\" ></td>\n",
" <td id=\"T_c2056_row8_col1\" class=\"data row8 col1\" >draw</td>\n",
" <td id=\"T_c2056_row8_col2\" class=\"data row8 col2\" >22.18ms</td>\n",
" <td id=\"T_c2056_row8_col3\" class=\"data row8 col3\" >54.90ms</td>\n",
" <td id=\"T_c2056_row8_col4\" class=\"data row8 col4\" >294</td>\n",
" <td id=\"T_c2056_row8_col5\" class=\"data row8 col5\" >6.521 s</td>\n",
" <td id=\"T_c2056_row8_col6\" class=\"data row8 col6\" >3%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row9_col0\" class=\"data row9 col0\" ></td>\n",
" <td id=\"T_c2056_row9_col1\" class=\"data row9 col1\" >zero_grad</td>\n",
" <td id=\"T_c2056_row9_col2\" class=\"data row9 col2\" >2.565ms</td>\n",
" <td id=\"T_c2056_row9_col3\" class=\"data row9 col3\" >1.181ms</td>\n",
" <td id=\"T_c2056_row9_col4\" class=\"data row9 col4\" >294</td>\n",
" <td id=\"T_c2056_row9_col5\" class=\"data row9 col5\" >754.1ms</td>\n",
" <td id=\"T_c2056_row9_col6\" class=\"data row9 col6\" >0%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row10_col0\" class=\"data row10 col0\" ></td>\n",
" <td id=\"T_c2056_row10_col1\" class=\"data row10 col1\" >loss</td>\n",
" <td id=\"T_c2056_row10_col2\" class=\"data row10 col2\" >1.518ms</td>\n",
" <td id=\"T_c2056_row10_col3\" class=\"data row10 col3\" >1.213ms</td>\n",
" <td id=\"T_c2056_row10_col4\" class=\"data row10 col4\" >294</td>\n",
" <td id=\"T_c2056_row10_col5\" class=\"data row10 col5\" >446.2ms</td>\n",
" <td id=\"T_c2056_row10_col6\" class=\"data row10 col6\" >0%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row11_col0\" class=\"data row11 col0\" >valid</td>\n",
" <td id=\"T_c2056_row11_col1\" class=\"data row11 col1\" >batch</td>\n",
" <td id=\"T_c2056_row11_col2\" class=\"data row11 col2\" >286.2ms</td>\n",
" <td id=\"T_c2056_row11_col3\" class=\"data row11 col3\" >197.6ms</td>\n",
" <td id=\"T_c2056_row11_col4\" class=\"data row11 col4\" >124</td>\n",
" <td id=\"T_c2056_row11_col5\" class=\"data row11 col5\" >35.49 s</td>\n",
" <td id=\"T_c2056_row11_col6\" class=\"data row11 col6\" >19%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row12_col0\" class=\"data row12 col0\" ></td>\n",
" <td id=\"T_c2056_row12_col1\" class=\"data row12 col1\" >draw</td>\n",
" <td id=\"T_c2056_row12_col2\" class=\"data row12 col2\" >220.4ms</td>\n",
" <td id=\"T_c2056_row12_col3\" class=\"data row12 col3\" >193.4ms</td>\n",
" <td id=\"T_c2056_row12_col4\" class=\"data row12 col4\" >124</td>\n",
" <td id=\"T_c2056_row12_col5\" class=\"data row12 col5\" >27.33 s</td>\n",
" <td id=\"T_c2056_row12_col6\" class=\"data row12 col6\" >15%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row13_col0\" class=\"data row13 col0\" ></td>\n",
" <td id=\"T_c2056_row13_col1\" class=\"data row13 col1\" >pred</td>\n",
" <td id=\"T_c2056_row13_col2\" class=\"data row13 col2\" >63.65ms</td>\n",
" <td id=\"T_c2056_row13_col3\" class=\"data row13 col3\" >13.46ms</td>\n",
" <td id=\"T_c2056_row13_col4\" class=\"data row13 col4\" >124</td>\n",
" <td id=\"T_c2056_row13_col5\" class=\"data row13 col5\" >7.893 s</td>\n",
" <td id=\"T_c2056_row13_col6\" class=\"data row13 col6\" >4%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row14_col0\" class=\"data row14 col0\" ></td>\n",
" <td id=\"T_c2056_row14_col1\" class=\"data row14 col1\" >loss</td>\n",
" <td id=\"T_c2056_row14_col2\" class=\"data row14 col2\" >1.820ms</td>\n",
" <td id=\"T_c2056_row14_col3\" class=\"data row14 col3\" >2.196ms</td>\n",
" <td id=\"T_c2056_row14_col4\" class=\"data row14 col4\" >124</td>\n",
" <td id=\"T_c2056_row14_col5\" class=\"data row14 col5\" >225.7ms</td>\n",
" <td id=\"T_c2056_row14_col6\" class=\"data row14 col6\" >0%</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"# Remove All Possible fastai Bits That Might be Interfering"
],
"metadata": {
"id": "-dm_BtSuELmM"
}
},
{
"cell_type": "markdown",
"source": [
"First, create a fastai transform that doesn't retain the tensor type"
],
"metadata": {
"id": "yPdi-I1hEfCY"
}
},
{
"cell_type": "code",
"source": [
"from fastcore.transform import DisplayedTransform, _is_tuple, retain_type\n",
"\n",
"class LoseTypeTransform(DisplayedTransform):\n",
" \"\"\n",
" def __init__(self, **kwargs): super().__init__(**kwargs)\n",
"\n",
" def _call(self, fn, x, split_idx=None, **kwargs):\n",
" if split_idx!=self.split_idx and self.split_idx is not None: return x\n",
" return self._do_call(getattr(self, fn), x, **kwargs)\n",
"\n",
" def _do_call(self, f, x, **kwargs):\n",
" if not _is_tuple(x):\n",
" if f is None: return x\n",
" return f(x, **kwargs)\n",
" return tuple(self._do_call(f, x_, **kwargs) for x_ in x)"
],
"metadata": {
"id": "LQIN7mOpECun"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Inherit from it for the ChannelsLastTfm and pass a torch.tensor to the model instead of TensorImage"
],
"metadata": {
"id": "v4XIj3esElF-"
}
},
{
"cell_type": "code",
"source": [
"from fastxtend.callback.channelslast import ChannelsLastTfm"
],
"metadata": {
"id": "OjV5m8lSFGVj"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class ChannelsLastTensor(LoseTypeTransform):\n",
" \"Sets image-like inputs to `channels_last` format. For use in ChannelsLastCallback\"\n",
" order = 110 # run after all other transforms if added to batch_tfms\n",
" def encodes(self, x:TensorImageBase|TensorMask):\n",
" return torch.tensor(x).to(memory_format=torch.channels_last)"
],
"metadata": {
"id": "uWe_J5fFE50f"
},
"execution_count": 9,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Otherwise fastxtend's channelslast callback looks the same, other then a switch for inputing torch.tensor verses tensorimage"
],
"metadata": {
"id": "QmGndPRiRymP"
}
},
{
"cell_type": "code",
"source": [
"class ChannelsLastCallback(Callback):\n",
" \"Channels last training using PyTorch's Channels Last Memory Format (beta)\"\n",
" order = MixedPrecision.order+1\n",
" def __init__(self, astensor=False):\n",
" if astensor: self._channels_last = Pipeline([ChannelsLastTensor()])\n",
" else: self._channels_last = Pipeline([ChannelsLastTfm()])\n",
"\n",
" def before_fit(self):\n",
" self.learn.model.to(memory_format=torch.channels_last)\n",
"\n",
" def before_batch(self):\n",
" self.learn.xb = self._channels_last(self.xb)"
],
"metadata": {
"id": "a9WecdxUEX_F"
},
"execution_count": 10,
"outputs": []
},
{
"cell_type": "code",
"source": [
"@patch\n",
"def to_channelslast(self:Learner, to_fp16=True, astensor=True, **kwargs):\n",
" \"Set `Learner` and inputs to `channels_last` format and Mixed Precision by default\"\n",
" if to_fp16 and not hasattr(self, 'mixed_precision') and not hasattr(self, 'channels_last'): \n",
" return self.add_cbs([ChannelsLastCallback(astensor=astensor), MixedPrecision(**kwargs)])\n",
" elif not hasattr(self, 'channels_last'):\n",
" return self.add_cb(ChannelsLastCallback(astensor=astensor))"
],
"metadata": {
"id": "piSnsjp7EZRF"
},
"execution_count": 11,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Warmup"
],
"metadata": {
"id": "SjhRhmwRE1Mm"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=nn.CrossEntropyLoss(), \n",
" opt_func=partial(OptimWrapper, opt=torch.optim.AdamW),\n",
" metrics=[Accuracy()]).to_channelslast()\n",
"\n",
"learn.fit_one_cycle(1, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 135
},
"id": "5-MnrPkYE2EG",
"outputId": "f092f0dc-7cae-49da-cbab-056abcc5e5b0"
},
"execution_count": 12,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.818512</td>\n",
" <td>1.680830</td>\n",
" <td>0.428535</td>\n",
" <td>01:22</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" \"\"\"\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Test"
],
"metadata": {
"id": "OlfL5IFKE2s9"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=nn.CrossEntropyLoss(), \n",
" opt_func=partial(OptimWrapper, opt=torch.optim.AdamW),\n",
" metrics=[Accuracy()]).to_channelslast().profile()\n",
"\n",
"learn.fit_one_cycle(2, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 687
},
"id": "F4hx0CEXFPut",
"outputId": "5c4772d1-365e-43d8-c2ce-a8a3bb730297"
},
"execution_count": 13,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.863661</td>\n",
" <td>1.996808</td>\n",
" <td>0.355669</td>\n",
" <td>01:17</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.320199</td>\n",
" <td>1.176207</td>\n",
" <td>0.615032</td>\n",
" <td>01:16</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" \"\"\"\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<pandas.io.formats.style.Styler at 0x7f763eeca4d0>"
],
"text/html": [
"<style type=\"text/css\">\n",
"</style>\n",
"<table id=\"T_8e94a_\" class=\"dataframe\">\n",
" <caption>Simple Profiler Results</caption>\n",
" <thead>\n",
" <tr>\n",
" <th class=\"col_heading level0 col0\" >Phase</th>\n",
" <th class=\"col_heading level0 col1\" >Action</th>\n",
" <th class=\"col_heading level0 col2\" >Mean Duration</th>\n",
" <th class=\"col_heading level0 col3\" >Duration Std Dev</th>\n",
" <th class=\"col_heading level0 col4\" >Number of Calls</th>\n",
" <th class=\"col_heading level0 col5\" >Total Time</th>\n",
" <th class=\"col_heading level0 col6\" >Percent of Total</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td id=\"T_8e94a_row0_col0\" class=\"data row0 col0\" >fit</td>\n",
" <td id=\"T_8e94a_row0_col1\" class=\"data row0 col1\" >fit</td>\n",
" <td id=\"T_8e94a_row0_col2\" class=\"data row0 col2\" >-</td>\n",
" <td id=\"T_8e94a_row0_col3\" class=\"data row0 col3\" >-</td>\n",
" <td id=\"T_8e94a_row0_col4\" class=\"data row0 col4\" >1</td>\n",
" <td id=\"T_8e94a_row0_col5\" class=\"data row0 col5\" >154.6 s</td>\n",
" <td id=\"T_8e94a_row0_col6\" class=\"data row0 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row1_col0\" class=\"data row1 col0\" ></td>\n",
" <td id=\"T_8e94a_row1_col1\" class=\"data row1 col1\" >epoch</td>\n",
" <td id=\"T_8e94a_row1_col2\" class=\"data row1 col2\" >77.31 s</td>\n",
" <td id=\"T_8e94a_row1_col3\" class=\"data row1 col3\" >537.9ms</td>\n",
" <td id=\"T_8e94a_row1_col4\" class=\"data row1 col4\" >2</td>\n",
" <td id=\"T_8e94a_row1_col5\" class=\"data row1 col5\" >154.6 s</td>\n",
" <td id=\"T_8e94a_row1_col6\" class=\"data row1 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row2_col0\" class=\"data row2 col0\" ></td>\n",
" <td id=\"T_8e94a_row2_col1\" class=\"data row2 col1\" >train</td>\n",
" <td id=\"T_8e94a_row2_col2\" class=\"data row2 col2\" >57.79 s</td>\n",
" <td id=\"T_8e94a_row2_col3\" class=\"data row2 col3\" >402.5ms</td>\n",
" <td id=\"T_8e94a_row2_col4\" class=\"data row2 col4\" >2</td>\n",
" <td id=\"T_8e94a_row2_col5\" class=\"data row2 col5\" >115.6 s</td>\n",
" <td id=\"T_8e94a_row2_col6\" class=\"data row2 col6\" >75%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row3_col0\" class=\"data row3 col0\" ></td>\n",
" <td id=\"T_8e94a_row3_col1\" class=\"data row3 col1\" >validate</td>\n",
" <td id=\"T_8e94a_row3_col2\" class=\"data row3 col2\" >19.51 s</td>\n",
" <td id=\"T_8e94a_row3_col3\" class=\"data row3 col3\" >132.2ms</td>\n",
" <td id=\"T_8e94a_row3_col4\" class=\"data row3 col4\" >2</td>\n",
" <td id=\"T_8e94a_row3_col5\" class=\"data row3 col5\" >39.01 s</td>\n",
" <td id=\"T_8e94a_row3_col6\" class=\"data row3 col6\" >25%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row4_col0\" class=\"data row4 col0\" >train</td>\n",
" <td id=\"T_8e94a_row4_col1\" class=\"data row4 col1\" >batch</td>\n",
" <td id=\"T_8e94a_row4_col2\" class=\"data row4 col2\" >368.2ms</td>\n",
" <td id=\"T_8e94a_row4_col3\" class=\"data row4 col3\" >97.57ms</td>\n",
" <td id=\"T_8e94a_row4_col4\" class=\"data row4 col4\" >294</td>\n",
" <td id=\"T_8e94a_row4_col5\" class=\"data row4 col5\" >108.3 s</td>\n",
" <td id=\"T_8e94a_row4_col6\" class=\"data row4 col6\" >70%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row5_col0\" class=\"data row5 col0\" ></td>\n",
" <td id=\"T_8e94a_row5_col1\" class=\"data row5 col1\" >step</td>\n",
" <td id=\"T_8e94a_row5_col2\" class=\"data row5 col2\" >172.8ms</td>\n",
" <td id=\"T_8e94a_row5_col3\" class=\"data row5 col3\" >22.10ms</td>\n",
" <td id=\"T_8e94a_row5_col4\" class=\"data row5 col4\" >294</td>\n",
" <td id=\"T_8e94a_row5_col5\" class=\"data row5 col5\" >50.81 s</td>\n",
" <td id=\"T_8e94a_row5_col6\" class=\"data row5 col6\" >33%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row6_col0\" class=\"data row6 col0\" ></td>\n",
" <td id=\"T_8e94a_row6_col1\" class=\"data row6 col1\" >draw</td>\n",
" <td id=\"T_8e94a_row6_col2\" class=\"data row6 col2\" >71.45ms</td>\n",
" <td id=\"T_8e94a_row6_col3\" class=\"data row6 col3\" >95.57ms</td>\n",
" <td id=\"T_8e94a_row6_col4\" class=\"data row6 col4\" >294</td>\n",
" <td id=\"T_8e94a_row6_col5\" class=\"data row6 col5\" >21.00 s</td>\n",
" <td id=\"T_8e94a_row6_col6\" class=\"data row6 col6\" >14%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row7_col0\" class=\"data row7 col0\" ></td>\n",
" <td id=\"T_8e94a_row7_col1\" class=\"data row7 col1\" >pred</td>\n",
" <td id=\"T_8e94a_row7_col2\" class=\"data row7 col2\" >61.65ms</td>\n",
" <td id=\"T_8e94a_row7_col3\" class=\"data row7 col3\" >15.55ms</td>\n",
" <td id=\"T_8e94a_row7_col4\" class=\"data row7 col4\" >294</td>\n",
" <td id=\"T_8e94a_row7_col5\" class=\"data row7 col5\" >18.13 s</td>\n",
" <td id=\"T_8e94a_row7_col6\" class=\"data row7 col6\" >12%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row8_col0\" class=\"data row8 col0\" ></td>\n",
" <td id=\"T_8e94a_row8_col1\" class=\"data row8 col1\" >backward</td>\n",
" <td id=\"T_8e94a_row8_col2\" class=\"data row8 col2\" >56.72ms</td>\n",
" <td id=\"T_8e94a_row8_col3\" class=\"data row8 col3\" >12.65ms</td>\n",
" <td id=\"T_8e94a_row8_col4\" class=\"data row8 col4\" >294</td>\n",
" <td id=\"T_8e94a_row8_col5\" class=\"data row8 col5\" >16.68 s</td>\n",
" <td id=\"T_8e94a_row8_col6\" class=\"data row8 col6\" >11%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row9_col0\" class=\"data row9 col0\" ></td>\n",
" <td id=\"T_8e94a_row9_col1\" class=\"data row9 col1\" >zero_grad</td>\n",
" <td id=\"T_8e94a_row9_col2\" class=\"data row9 col2\" >3.895ms</td>\n",
" <td id=\"T_8e94a_row9_col3\" class=\"data row9 col3\" >2.942ms</td>\n",
" <td id=\"T_8e94a_row9_col4\" class=\"data row9 col4\" >294</td>\n",
" <td id=\"T_8e94a_row9_col5\" class=\"data row9 col5\" >1.145 s</td>\n",
" <td id=\"T_8e94a_row9_col6\" class=\"data row9 col6\" >1%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row10_col0\" class=\"data row10 col0\" ></td>\n",
" <td id=\"T_8e94a_row10_col1\" class=\"data row10 col1\" >loss</td>\n",
" <td id=\"T_8e94a_row10_col2\" class=\"data row10 col2\" >1.345ms</td>\n",
" <td id=\"T_8e94a_row10_col3\" class=\"data row10 col3\" >1.974ms</td>\n",
" <td id=\"T_8e94a_row10_col4\" class=\"data row10 col4\" >294</td>\n",
" <td id=\"T_8e94a_row10_col5\" class=\"data row10 col5\" >395.3ms</td>\n",
" <td id=\"T_8e94a_row10_col6\" class=\"data row10 col6\" >0%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row11_col0\" class=\"data row11 col0\" >valid</td>\n",
" <td id=\"T_8e94a_row11_col1\" class=\"data row11 col1\" >batch</td>\n",
" <td id=\"T_8e94a_row11_col2\" class=\"data row11 col2\" >264.9ms</td>\n",
" <td id=\"T_8e94a_row11_col3\" class=\"data row11 col3\" >172.1ms</td>\n",
" <td id=\"T_8e94a_row11_col4\" class=\"data row11 col4\" >124</td>\n",
" <td id=\"T_8e94a_row11_col5\" class=\"data row11 col5\" >32.85 s</td>\n",
" <td id=\"T_8e94a_row11_col6\" class=\"data row11 col6\" >21%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row12_col0\" class=\"data row12 col0\" ></td>\n",
" <td id=\"T_8e94a_row12_col1\" class=\"data row12 col1\" >draw</td>\n",
" <td id=\"T_8e94a_row12_col2\" class=\"data row12 col2\" >223.5ms</td>\n",
" <td id=\"T_8e94a_row12_col3\" class=\"data row12 col3\" >171.4ms</td>\n",
" <td id=\"T_8e94a_row12_col4\" class=\"data row12 col4\" >124</td>\n",
" <td id=\"T_8e94a_row12_col5\" class=\"data row12 col5\" >27.72 s</td>\n",
" <td id=\"T_8e94a_row12_col6\" class=\"data row12 col6\" >18%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row13_col0\" class=\"data row13 col0\" ></td>\n",
" <td id=\"T_8e94a_row13_col1\" class=\"data row13 col1\" >pred</td>\n",
" <td id=\"T_8e94a_row13_col2\" class=\"data row13 col2\" >40.13ms</td>\n",
" <td id=\"T_8e94a_row13_col3\" class=\"data row13 col3\" >9.556ms</td>\n",
" <td id=\"T_8e94a_row13_col4\" class=\"data row13 col4\" >124</td>\n",
" <td id=\"T_8e94a_row13_col5\" class=\"data row13 col5\" >4.976 s</td>\n",
" <td id=\"T_8e94a_row13_col6\" class=\"data row13 col6\" >3%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row14_col0\" class=\"data row14 col0\" ></td>\n",
" <td id=\"T_8e94a_row14_col1\" class=\"data row14 col1\" >loss</td>\n",
" <td id=\"T_8e94a_row14_col2\" class=\"data row14 col2\" >945.4µs</td>\n",
" <td id=\"T_8e94a_row14_col3\" class=\"data row14 col3\" >1.121ms</td>\n",
" <td id=\"T_8e94a_row14_col4\" class=\"data row14 col4\" >124</td>\n",
" <td id=\"T_8e94a_row14_col5\" class=\"data row14 col5\" >117.2ms</td>\n",
" <td id=\"T_8e94a_row14_col6\" class=\"data row14 col6\" >0%</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"# Use a fastai optimizer, but tensor input, pytorch loss"
],
"metadata": {
"id": "DaezZT9eK3g-"
}
},
{
"cell_type": "markdown",
"source": [
"## Warmup"
],
"metadata": {
"id": "8P-SVF8aLAbP"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=nn.CrossEntropyLoss(), \n",
" opt_func=Adam,\n",
" metrics=[Accuracy()]).to_channelslast()\n",
"\n",
"learn.fit_one_cycle(1, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 135
},
"outputId": "d6785999-1946-4356-b38a-848e78594be6",
"id": "pU5AcaIILAbP"
},
"execution_count": 14,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.979175</td>\n",
" <td>1.794662</td>\n",
" <td>0.380382</td>\n",
" <td>01:15</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" \"\"\"\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Test"
],
"metadata": {
"id": "LWxeP4AULAbQ"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=nn.CrossEntropyLoss(), \n",
" opt_func=Adam,\n",
" metrics=[Accuracy()]).to_channelslast().profile()\n",
"\n",
"learn.fit_one_cycle(2, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 687
},
"id": "UbhUwy9bLAbQ",
"outputId": "86da8255-0c61-4b7b-b4ef-f1c5b0b4171b"
},
"execution_count": 15,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.852839</td>\n",
" <td>2.054055</td>\n",
" <td>0.332739</td>\n",
" <td>01:15</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.300486</td>\n",
" <td>1.145100</td>\n",
" <td>0.634650</td>\n",
" <td>01:18</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" \"\"\"\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<pandas.io.formats.style.Styler at 0x7f763eff1750>"
],
"text/html": [
"<style type=\"text/css\">\n",
"</style>\n",
"<table id=\"T_e15fa_\" class=\"dataframe\">\n",
" <caption>Simple Profiler Results</caption>\n",
" <thead>\n",
" <tr>\n",
" <th class=\"col_heading level0 col0\" >Phase</th>\n",
" <th class=\"col_heading level0 col1\" >Action</th>\n",
" <th class=\"col_heading level0 col2\" >Mean Duration</th>\n",
" <th class=\"col_heading level0 col3\" >Duration Std Dev</th>\n",
" <th class=\"col_heading level0 col4\" >Number of Calls</th>\n",
" <th class=\"col_heading level0 col5\" >Total Time</th>\n",
" <th class=\"col_heading level0 col6\" >Percent of Total</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td id=\"T_e15fa_row0_col0\" class=\"data row0 col0\" >fit</td>\n",
" <td id=\"T_e15fa_row0_col1\" class=\"data row0 col1\" >fit</td>\n",
" <td id=\"T_e15fa_row0_col2\" class=\"data row0 col2\" >-</td>\n",
" <td id=\"T_e15fa_row0_col3\" class=\"data row0 col3\" >-</td>\n",
" <td id=\"T_e15fa_row0_col4\" class=\"data row0 col4\" >1</td>\n",
" <td id=\"T_e15fa_row0_col5\" class=\"data row0 col5\" >154.5 s</td>\n",
" <td id=\"T_e15fa_row0_col6\" class=\"data row0 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row1_col0\" class=\"data row1 col0\" ></td>\n",
" <td id=\"T_e15fa_row1_col1\" class=\"data row1 col1\" >epoch</td>\n",
" <td id=\"T_e15fa_row1_col2\" class=\"data row1 col2\" >77.27 s</td>\n",
" <td id=\"T_e15fa_row1_col3\" class=\"data row1 col3\" >1.341 s</td>\n",
" <td id=\"T_e15fa_row1_col4\" class=\"data row1 col4\" >2</td>\n",
" <td id=\"T_e15fa_row1_col5\" class=\"data row1 col5\" >154.5 s</td>\n",
" <td id=\"T_e15fa_row1_col6\" class=\"data row1 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row2_col0\" class=\"data row2 col0\" ></td>\n",
" <td id=\"T_e15fa_row2_col1\" class=\"data row2 col1\" >train</td>\n",
" <td id=\"T_e15fa_row2_col2\" class=\"data row2 col2\" >57.12 s</td>\n",
" <td id=\"T_e15fa_row2_col3\" class=\"data row2 col3\" >1.114 s</td>\n",
" <td id=\"T_e15fa_row2_col4\" class=\"data row2 col4\" >2</td>\n",
" <td id=\"T_e15fa_row2_col5\" class=\"data row2 col5\" >114.2 s</td>\n",
" <td id=\"T_e15fa_row2_col6\" class=\"data row2 col6\" >74%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row3_col0\" class=\"data row3 col0\" ></td>\n",
" <td id=\"T_e15fa_row3_col1\" class=\"data row3 col1\" >validate</td>\n",
" <td id=\"T_e15fa_row3_col2\" class=\"data row3 col2\" >20.14 s</td>\n",
" <td id=\"T_e15fa_row3_col3\" class=\"data row3 col3\" >230.3ms</td>\n",
" <td id=\"T_e15fa_row3_col4\" class=\"data row3 col4\" >2</td>\n",
" <td id=\"T_e15fa_row3_col5\" class=\"data row3 col5\" >40.29 s</td>\n",
" <td id=\"T_e15fa_row3_col6\" class=\"data row3 col6\" >26%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row4_col0\" class=\"data row4 col0\" >train</td>\n",
" <td id=\"T_e15fa_row4_col1\" class=\"data row4 col1\" >batch</td>\n",
" <td id=\"T_e15fa_row4_col2\" class=\"data row4 col2\" >379.6ms</td>\n",
" <td id=\"T_e15fa_row4_col3\" class=\"data row4 col3\" >113.8ms</td>\n",
" <td id=\"T_e15fa_row4_col4\" class=\"data row4 col4\" >294</td>\n",
" <td id=\"T_e15fa_row4_col5\" class=\"data row4 col5\" >111.6 s</td>\n",
" <td id=\"T_e15fa_row4_col6\" class=\"data row4 col6\" >72%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row5_col0\" class=\"data row5 col0\" ></td>\n",
" <td id=\"T_e15fa_row5_col1\" class=\"data row5 col1\" >step</td>\n",
" <td id=\"T_e15fa_row5_col2\" class=\"data row5 col2\" >182.0ms</td>\n",
" <td id=\"T_e15fa_row5_col3\" class=\"data row5 col3\" >24.35ms</td>\n",
" <td id=\"T_e15fa_row5_col4\" class=\"data row5 col4\" >294</td>\n",
" <td id=\"T_e15fa_row5_col5\" class=\"data row5 col5\" >53.50 s</td>\n",
" <td id=\"T_e15fa_row5_col6\" class=\"data row5 col6\" >35%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row6_col0\" class=\"data row6 col0\" ></td>\n",
" <td id=\"T_e15fa_row6_col1\" class=\"data row6 col1\" >draw</td>\n",
" <td id=\"T_e15fa_row6_col2\" class=\"data row6 col2\" >87.42ms</td>\n",
" <td id=\"T_e15fa_row6_col3\" class=\"data row6 col3\" >113.3ms</td>\n",
" <td id=\"T_e15fa_row6_col4\" class=\"data row6 col4\" >294</td>\n",
" <td id=\"T_e15fa_row6_col5\" class=\"data row6 col5\" >25.70 s</td>\n",
" <td id=\"T_e15fa_row6_col6\" class=\"data row6 col6\" >17%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row7_col0\" class=\"data row7 col0\" ></td>\n",
" <td id=\"T_e15fa_row7_col1\" class=\"data row7 col1\" >backward</td>\n",
" <td id=\"T_e15fa_row7_col2\" class=\"data row7 col2\" >56.09ms</td>\n",
" <td id=\"T_e15fa_row7_col3\" class=\"data row7 col3\" >13.87ms</td>\n",
" <td id=\"T_e15fa_row7_col4\" class=\"data row7 col4\" >294</td>\n",
" <td id=\"T_e15fa_row7_col5\" class=\"data row7 col5\" >16.49 s</td>\n",
" <td id=\"T_e15fa_row7_col6\" class=\"data row7 col6\" >11%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row8_col0\" class=\"data row8 col0\" ></td>\n",
" <td id=\"T_e15fa_row8_col1\" class=\"data row8 col1\" >pred</td>\n",
" <td id=\"T_e15fa_row8_col2\" class=\"data row8 col2\" >48.32ms</td>\n",
" <td id=\"T_e15fa_row8_col3\" class=\"data row8 col3\" >13.84ms</td>\n",
" <td id=\"T_e15fa_row8_col4\" class=\"data row8 col4\" >294</td>\n",
" <td id=\"T_e15fa_row8_col5\" class=\"data row8 col5\" >14.21 s</td>\n",
" <td id=\"T_e15fa_row8_col6\" class=\"data row8 col6\" >9%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row9_col0\" class=\"data row9 col0\" ></td>\n",
" <td id=\"T_e15fa_row9_col1\" class=\"data row9 col1\" >zero_grad</td>\n",
" <td id=\"T_e15fa_row9_col2\" class=\"data row9 col2\" >4.106ms</td>\n",
" <td id=\"T_e15fa_row9_col3\" class=\"data row9 col3\" >3.017ms</td>\n",
" <td id=\"T_e15fa_row9_col4\" class=\"data row9 col4\" >294</td>\n",
" <td id=\"T_e15fa_row9_col5\" class=\"data row9 col5\" >1.207 s</td>\n",
" <td id=\"T_e15fa_row9_col6\" class=\"data row9 col6\" >1%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row10_col0\" class=\"data row10 col0\" ></td>\n",
" <td id=\"T_e15fa_row10_col1\" class=\"data row10 col1\" >loss</td>\n",
" <td id=\"T_e15fa_row10_col2\" class=\"data row10 col2\" >1.425ms</td>\n",
" <td id=\"T_e15fa_row10_col3\" class=\"data row10 col3\" >2.092ms</td>\n",
" <td id=\"T_e15fa_row10_col4\" class=\"data row10 col4\" >294</td>\n",
" <td id=\"T_e15fa_row10_col5\" class=\"data row10 col5\" >418.9ms</td>\n",
" <td id=\"T_e15fa_row10_col6\" class=\"data row10 col6\" >0%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row11_col0\" class=\"data row11 col0\" >valid</td>\n",
" <td id=\"T_e15fa_row11_col1\" class=\"data row11 col1\" >batch</td>\n",
" <td id=\"T_e15fa_row11_col2\" class=\"data row11 col2\" >276.2ms</td>\n",
" <td id=\"T_e15fa_row11_col3\" class=\"data row11 col3\" >186.7ms</td>\n",
" <td id=\"T_e15fa_row11_col4\" class=\"data row11 col4\" >124</td>\n",
" <td id=\"T_e15fa_row11_col5\" class=\"data row11 col5\" >34.25 s</td>\n",
" <td id=\"T_e15fa_row11_col6\" class=\"data row11 col6\" >22%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row12_col0\" class=\"data row12 col0\" ></td>\n",
" <td id=\"T_e15fa_row12_col1\" class=\"data row12 col1\" >draw</td>\n",
" <td id=\"T_e15fa_row12_col2\" class=\"data row12 col2\" >233.0ms</td>\n",
" <td id=\"T_e15fa_row12_col3\" class=\"data row12 col3\" >183.6ms</td>\n",
" <td id=\"T_e15fa_row12_col4\" class=\"data row12 col4\" >124</td>\n",
" <td id=\"T_e15fa_row12_col5\" class=\"data row12 col5\" >28.90 s</td>\n",
" <td id=\"T_e15fa_row12_col6\" class=\"data row12 col6\" >19%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row13_col0\" class=\"data row13 col0\" ></td>\n",
" <td id=\"T_e15fa_row13_col1\" class=\"data row13 col1\" >pred</td>\n",
" <td id=\"T_e15fa_row13_col2\" class=\"data row13 col2\" >41.59ms</td>\n",
" <td id=\"T_e15fa_row13_col3\" class=\"data row13 col3\" >12.00ms</td>\n",
" <td id=\"T_e15fa_row13_col4\" class=\"data row13 col4\" >124</td>\n",
" <td id=\"T_e15fa_row13_col5\" class=\"data row13 col5\" >5.158 s</td>\n",
" <td id=\"T_e15fa_row13_col6\" class=\"data row13 col6\" >3%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row14_col0\" class=\"data row14 col0\" ></td>\n",
" <td id=\"T_e15fa_row14_col1\" class=\"data row14 col1\" >loss</td>\n",
" <td id=\"T_e15fa_row14_col2\" class=\"data row14 col2\" >1.344ms</td>\n",
" <td id=\"T_e15fa_row14_col3\" class=\"data row14 col3\" >2.000ms</td>\n",
" <td id=\"T_e15fa_row14_col4\" class=\"data row14 col4\" >124</td>\n",
" <td id=\"T_e15fa_row14_col5\" class=\"data row14 col5\" >166.7ms</td>\n",
" <td id=\"T_e15fa_row14_col6\" class=\"data row14 col6\" >0%</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"# Use a fastai optimizer & fastai loss, but tensor input"
],
"metadata": {
"id": "J_V3brgbMPPF"
}
},
{
"cell_type": "markdown",
"source": [
"## Warmup"
],
"metadata": {
"id": "29C-NpRYMPPX"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" opt_func=Adam,\n",
" metrics=[Accuracy()]).to_channelslast()\n",
"\n",
"learn.fit_one_cycle(1, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 135
},
"outputId": "42af2a10-2f51-4318-dbe3-edb581a69599",
"id": "CcmbGsBaMPPX"
},
"execution_count": 16,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.891318</td>\n",
" <td>1.795799</td>\n",
" <td>0.423694</td>\n",
" <td>01:15</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" \"\"\"\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Test"
],
"metadata": {
"id": "RHCP86D_MPPY"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" opt_func=Adam,\n",
" metrics=[Accuracy()]).to_channelslast().profile()\n",
"\n",
"learn.fit_one_cycle(2, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 687
},
"outputId": "21124eb9-593b-4994-9fc9-914e888bf166",
"id": "OIpimLlxMPPY"
},
"execution_count": 17,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.887546</td>\n",
" <td>1.747342</td>\n",
" <td>0.409427</td>\n",
" <td>01:15</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.378880</td>\n",
" <td>1.274395</td>\n",
" <td>0.592611</td>\n",
" <td>01:18</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" \"\"\"\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<pandas.io.formats.style.Styler at 0x7f763efedc10>"
],
"text/html": [
"<style type=\"text/css\">\n",
"</style>\n",
"<table id=\"T_435ce_\" class=\"dataframe\">\n",
" <caption>Simple Profiler Results</caption>\n",
" <thead>\n",
" <tr>\n",
" <th class=\"col_heading level0 col0\" >Phase</th>\n",
" <th class=\"col_heading level0 col1\" >Action</th>\n",
" <th class=\"col_heading level0 col2\" >Mean Duration</th>\n",
" <th class=\"col_heading level0 col3\" >Duration Std Dev</th>\n",
" <th class=\"col_heading level0 col4\" >Number of Calls</th>\n",
" <th class=\"col_heading level0 col5\" >Total Time</th>\n",
" <th class=\"col_heading level0 col6\" >Percent of Total</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td id=\"T_435ce_row0_col0\" class=\"data row0 col0\" >fit</td>\n",
" <td id=\"T_435ce_row0_col1\" class=\"data row0 col1\" >fit</td>\n",
" <td id=\"T_435ce_row0_col2\" class=\"data row0 col2\" >-</td>\n",
" <td id=\"T_435ce_row0_col3\" class=\"data row0 col3\" >-</td>\n",
" <td id=\"T_435ce_row0_col4\" class=\"data row0 col4\" >1</td>\n",
" <td id=\"T_435ce_row0_col5\" class=\"data row0 col5\" >153.7 s</td>\n",
" <td id=\"T_435ce_row0_col6\" class=\"data row0 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row1_col0\" class=\"data row1 col0\" ></td>\n",
" <td id=\"T_435ce_row1_col1\" class=\"data row1 col1\" >epoch</td>\n",
" <td id=\"T_435ce_row1_col2\" class=\"data row1 col2\" >76.85 s</td>\n",
" <td id=\"T_435ce_row1_col3\" class=\"data row1 col3\" >1.211 s</td>\n",
" <td id=\"T_435ce_row1_col4\" class=\"data row1 col4\" >2</td>\n",
" <td id=\"T_435ce_row1_col5\" class=\"data row1 col5\" >153.7 s</td>\n",
" <td id=\"T_435ce_row1_col6\" class=\"data row1 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row2_col0\" class=\"data row2 col0\" ></td>\n",
" <td id=\"T_435ce_row2_col1\" class=\"data row2 col1\" >train</td>\n",
" <td id=\"T_435ce_row2_col2\" class=\"data row2 col2\" >56.61 s</td>\n",
" <td id=\"T_435ce_row2_col3\" class=\"data row2 col3\" >356.4ms</td>\n",
" <td id=\"T_435ce_row2_col4\" class=\"data row2 col4\" >2</td>\n",
" <td id=\"T_435ce_row2_col5\" class=\"data row2 col5\" >113.2 s</td>\n",
" <td id=\"T_435ce_row2_col6\" class=\"data row2 col6\" >74%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row3_col0\" class=\"data row3 col0\" ></td>\n",
" <td id=\"T_435ce_row3_col1\" class=\"data row3 col1\" >validate</td>\n",
" <td id=\"T_435ce_row3_col2\" class=\"data row3 col2\" >20.23 s</td>\n",
" <td id=\"T_435ce_row3_col3\" class=\"data row3 col3\" >857.5ms</td>\n",
" <td id=\"T_435ce_row3_col4\" class=\"data row3 col4\" >2</td>\n",
" <td id=\"T_435ce_row3_col5\" class=\"data row3 col5\" >40.45 s</td>\n",
" <td id=\"T_435ce_row3_col6\" class=\"data row3 col6\" >26%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row4_col0\" class=\"data row4 col0\" >train</td>\n",
" <td id=\"T_435ce_row4_col1\" class=\"data row4 col1\" >batch</td>\n",
" <td id=\"T_435ce_row4_col2\" class=\"data row4 col2\" >375.5ms</td>\n",
" <td id=\"T_435ce_row4_col3\" class=\"data row4 col3\" >114.0ms</td>\n",
" <td id=\"T_435ce_row4_col4\" class=\"data row4 col4\" >294</td>\n",
" <td id=\"T_435ce_row4_col5\" class=\"data row4 col5\" >110.4 s</td>\n",
" <td id=\"T_435ce_row4_col6\" class=\"data row4 col6\" >72%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row5_col0\" class=\"data row5 col0\" ></td>\n",
" <td id=\"T_435ce_row5_col1\" class=\"data row5 col1\" >step</td>\n",
" <td id=\"T_435ce_row5_col2\" class=\"data row5 col2\" >177.2ms</td>\n",
" <td id=\"T_435ce_row5_col3\" class=\"data row5 col3\" >24.03ms</td>\n",
" <td id=\"T_435ce_row5_col4\" class=\"data row5 col4\" >294</td>\n",
" <td id=\"T_435ce_row5_col5\" class=\"data row5 col5\" >52.11 s</td>\n",
" <td id=\"T_435ce_row5_col6\" class=\"data row5 col6\" >34%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row6_col0\" class=\"data row6 col0\" ></td>\n",
" <td id=\"T_435ce_row6_col1\" class=\"data row6 col1\" >draw</td>\n",
" <td id=\"T_435ce_row6_col2\" class=\"data row6 col2\" >83.93ms</td>\n",
" <td id=\"T_435ce_row6_col3\" class=\"data row6 col3\" >113.6ms</td>\n",
" <td id=\"T_435ce_row6_col4\" class=\"data row6 col4\" >294</td>\n",
" <td id=\"T_435ce_row6_col5\" class=\"data row6 col5\" >24.68 s</td>\n",
" <td id=\"T_435ce_row6_col6\" class=\"data row6 col6\" >16%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row7_col0\" class=\"data row7 col0\" ></td>\n",
" <td id=\"T_435ce_row7_col1\" class=\"data row7 col1\" >backward</td>\n",
" <td id=\"T_435ce_row7_col2\" class=\"data row7 col2\" >57.99ms</td>\n",
" <td id=\"T_435ce_row7_col3\" class=\"data row7 col3\" >14.16ms</td>\n",
" <td id=\"T_435ce_row7_col4\" class=\"data row7 col4\" >294</td>\n",
" <td id=\"T_435ce_row7_col5\" class=\"data row7 col5\" >17.05 s</td>\n",
" <td id=\"T_435ce_row7_col6\" class=\"data row7 col6\" >11%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row8_col0\" class=\"data row8 col0\" ></td>\n",
" <td id=\"T_435ce_row8_col1\" class=\"data row8 col1\" >pred</td>\n",
" <td id=\"T_435ce_row8_col2\" class=\"data row8 col2\" >49.67ms</td>\n",
" <td id=\"T_435ce_row8_col3\" class=\"data row8 col3\" >13.58ms</td>\n",
" <td id=\"T_435ce_row8_col4\" class=\"data row8 col4\" >294</td>\n",
" <td id=\"T_435ce_row8_col5\" class=\"data row8 col5\" >14.60 s</td>\n",
" <td id=\"T_435ce_row8_col6\" class=\"data row8 col6\" >10%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row9_col0\" class=\"data row9 col0\" ></td>\n",
" <td id=\"T_435ce_row9_col1\" class=\"data row9 col1\" >zero_grad</td>\n",
" <td id=\"T_435ce_row9_col2\" class=\"data row9 col2\" >4.312ms</td>\n",
" <td id=\"T_435ce_row9_col3\" class=\"data row9 col3\" >2.997ms</td>\n",
" <td id=\"T_435ce_row9_col4\" class=\"data row9 col4\" >294</td>\n",
" <td id=\"T_435ce_row9_col5\" class=\"data row9 col5\" >1.268 s</td>\n",
" <td id=\"T_435ce_row9_col6\" class=\"data row9 col6\" >1%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row10_col0\" class=\"data row10 col0\" ></td>\n",
" <td id=\"T_435ce_row10_col1\" class=\"data row10 col1\" >loss</td>\n",
" <td id=\"T_435ce_row10_col2\" class=\"data row10 col2\" >2.058ms</td>\n",
" <td id=\"T_435ce_row10_col3\" class=\"data row10 col3\" >2.189ms</td>\n",
" <td id=\"T_435ce_row10_col4\" class=\"data row10 col4\" >294</td>\n",
" <td id=\"T_435ce_row10_col5\" class=\"data row10 col5\" >605.1ms</td>\n",
" <td id=\"T_435ce_row10_col6\" class=\"data row10 col6\" >0%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row11_col0\" class=\"data row11 col0\" >valid</td>\n",
" <td id=\"T_435ce_row11_col1\" class=\"data row11 col1\" >batch</td>\n",
" <td id=\"T_435ce_row11_col2\" class=\"data row11 col2\" >279.1ms</td>\n",
" <td id=\"T_435ce_row11_col3\" class=\"data row11 col3\" >192.9ms</td>\n",
" <td id=\"T_435ce_row11_col4\" class=\"data row11 col4\" >124</td>\n",
" <td id=\"T_435ce_row11_col5\" class=\"data row11 col5\" >34.61 s</td>\n",
" <td id=\"T_435ce_row11_col6\" class=\"data row11 col6\" >23%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row12_col0\" class=\"data row12 col0\" ></td>\n",
" <td id=\"T_435ce_row12_col1\" class=\"data row12 col1\" >draw</td>\n",
" <td id=\"T_435ce_row12_col2\" class=\"data row12 col2\" >234.2ms</td>\n",
" <td id=\"T_435ce_row12_col3\" class=\"data row12 col3\" >191.4ms</td>\n",
" <td id=\"T_435ce_row12_col4\" class=\"data row12 col4\" >124</td>\n",
" <td id=\"T_435ce_row12_col5\" class=\"data row12 col5\" >29.04 s</td>\n",
" <td id=\"T_435ce_row12_col6\" class=\"data row12 col6\" >19%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row13_col0\" class=\"data row13 col0\" ></td>\n",
" <td id=\"T_435ce_row13_col1\" class=\"data row13 col1\" >pred</td>\n",
" <td id=\"T_435ce_row13_col2\" class=\"data row13 col2\" >42.37ms</td>\n",
" <td id=\"T_435ce_row13_col3\" class=\"data row13 col3\" >11.61ms</td>\n",
" <td id=\"T_435ce_row13_col4\" class=\"data row13 col4\" >124</td>\n",
" <td id=\"T_435ce_row13_col5\" class=\"data row13 col5\" >5.254 s</td>\n",
" <td id=\"T_435ce_row13_col6\" class=\"data row13 col6\" >3%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row14_col0\" class=\"data row14 col0\" ></td>\n",
" <td id=\"T_435ce_row14_col1\" class=\"data row14 col1\" >loss</td>\n",
" <td id=\"T_435ce_row14_col2\" class=\"data row14 col2\" >2.227ms</td>\n",
" <td id=\"T_435ce_row14_col3\" class=\"data row14 col3\" >2.896ms</td>\n",
" <td id=\"T_435ce_row14_col4\" class=\"data row14 col4\" >124</td>\n",
" <td id=\"T_435ce_row14_col5\" class=\"data row14 col5\" >276.1ms</td>\n",
" <td id=\"T_435ce_row14_col6\" class=\"data row14 col6\" >0%</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"# Cast to TensorBase instead of torch.tensor or TensorImage"
],
"metadata": {
"id": "nhSzdwgANYP2"
}
},
{
"cell_type": "code",
"source": [
"class ChannelsLastTensor(LoseTypeTransform):\n",
" \"Sets image-like inputs to `channels_last` format. For use in ChannelsLastCallback\"\n",
" order = 110 # run after all other transforms if added to batch_tfms\n",
" def encodes(self, x:TensorImageBase|TensorMask):\n",
" return TensorBase(x).to(memory_format=torch.channels_last)"
],
"metadata": {
"id": "l92GVy9zNWh0"
},
"execution_count": 18,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class ChannelsLastCallback(Callback):\n",
" \"Channels last training using PyTorch's Channels Last Memory Format (beta)\"\n",
" order = MixedPrecision.order+1\n",
" def __init__(self, astensor=False):\n",
" if astensor: self._channels_last = Pipeline([ChannelsLastTensor()])\n",
" else: self._channels_last = Pipeline([ChannelsLastTfm()])\n",
"\n",
" def before_fit(self):\n",
" self.learn.model.to(memory_format=torch.channels_last)\n",
"\n",
" def before_batch(self):\n",
" self.learn.xb = self._channels_last(self.xb)"
],
"metadata": {
"id": "ve-xXsTBNWh9"
},
"execution_count": 19,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Warmup"
],
"metadata": {
"id": "mNhlK0t5Ni3d"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" opt_func=Adam,\n",
" metrics=[Accuracy()]).to_channelslast()\n",
"\n",
"learn.fit_one_cycle(1, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 81
},
"outputId": "fd7925eb-6023-4561-8c84-0588c66742d0",
"id": "1K0ZxQcGNi3e"
},
"execution_count": 20,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.829382</td>\n",
" <td>1.654897</td>\n",
" <td>0.438726</td>\n",
" <td>01:35</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"## Test"
],
"metadata": {
"id": "B-mI8ifUNi3g"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" opt_func=Adam,\n",
" metrics=[Accuracy()]).to_channelslast().profile()\n",
"\n",
"learn.fit_one_cycle(2, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 632
},
"outputId": "33c960de-0f87-452d-ba5f-59a015f54b5f",
"id": "EXgwBMH4Ni3g"
},
"execution_count": 21,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.935964</td>\n",
" <td>2.968615</td>\n",
" <td>0.291720</td>\n",
" <td>01:33</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.425070</td>\n",
" <td>1.298448</td>\n",
" <td>0.583439</td>\n",
" <td>01:33</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<pandas.io.formats.style.Styler at 0x7f763ee0f390>"
],
"text/html": [
"<style type=\"text/css\">\n",
"</style>\n",
"<table id=\"T_b3c64_\" class=\"dataframe\">\n",
" <caption>Simple Profiler Results</caption>\n",
" <thead>\n",
" <tr>\n",
" <th class=\"col_heading level0 col0\" >Phase</th>\n",
" <th class=\"col_heading level0 col1\" >Action</th>\n",
" <th class=\"col_heading level0 col2\" >Mean Duration</th>\n",
" <th class=\"col_heading level0 col3\" >Duration Std Dev</th>\n",
" <th class=\"col_heading level0 col4\" >Number of Calls</th>\n",
" <th class=\"col_heading level0 col5\" >Total Time</th>\n",
" <th class=\"col_heading level0 col6\" >Percent of Total</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td id=\"T_b3c64_row0_col0\" class=\"data row0 col0\" >fit</td>\n",
" <td id=\"T_b3c64_row0_col1\" class=\"data row0 col1\" >fit</td>\n",
" <td id=\"T_b3c64_row0_col2\" class=\"data row0 col2\" >-</td>\n",
" <td id=\"T_b3c64_row0_col3\" class=\"data row0 col3\" >-</td>\n",
" <td id=\"T_b3c64_row0_col4\" class=\"data row0 col4\" >1</td>\n",
" <td id=\"T_b3c64_row0_col5\" class=\"data row0 col5\" >187.3 s</td>\n",
" <td id=\"T_b3c64_row0_col6\" class=\"data row0 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b3c64_row1_col0\" class=\"data row1 col0\" ></td>\n",
" <td id=\"T_b3c64_row1_col1\" class=\"data row1 col1\" >epoch</td>\n",
" <td id=\"T_b3c64_row1_col2\" class=\"data row1 col2\" >93.65 s</td>\n",
" <td id=\"T_b3c64_row1_col3\" class=\"data row1 col3\" >27.24ms</td>\n",
" <td id=\"T_b3c64_row1_col4\" class=\"data row1 col4\" >2</td>\n",
" <td id=\"T_b3c64_row1_col5\" class=\"data row1 col5\" >187.3 s</td>\n",
" <td id=\"T_b3c64_row1_col6\" class=\"data row1 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b3c64_row2_col0\" class=\"data row2 col0\" ></td>\n",
" <td id=\"T_b3c64_row2_col1\" class=\"data row2 col1\" >train</td>\n",
" <td id=\"T_b3c64_row2_col2\" class=\"data row2 col2\" >74.25 s</td>\n",
" <td id=\"T_b3c64_row2_col3\" class=\"data row2 col3\" >48.47ms</td>\n",
" <td id=\"T_b3c64_row2_col4\" class=\"data row2 col4\" >2</td>\n",
" <td id=\"T_b3c64_row2_col5\" class=\"data row2 col5\" >148.5 s</td>\n",
" <td id=\"T_b3c64_row2_col6\" class=\"data row2 col6\" >79%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b3c64_row3_col0\" class=\"data row3 col0\" ></td>\n",
" <td id=\"T_b3c64_row3_col1\" class=\"data row3 col1\" >validate</td>\n",
" <td id=\"T_b3c64_row3_col2\" class=\"data row3 col2\" >19.40 s</td>\n",
" <td id=\"T_b3c64_row3_col3\" class=\"data row3 col3\" >25.36ms</td>\n",
" <td id=\"T_b3c64_row3_col4\" class=\"data row3 col4\" >2</td>\n",
" <td id=\"T_b3c64_row3_col5\" class=\"data row3 col5\" >38.79 s</td>\n",
" <td id=\"T_b3c64_row3_col6\" class=\"data row3 col6\" >21%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b3c64_row4_col0\" class=\"data row4 col0\" >train</td>\n",
" <td id=\"T_b3c64_row4_col1\" class=\"data row4 col1\" >batch</td>\n",
" <td id=\"T_b3c64_row4_col2\" class=\"data row4 col2\" >498.3ms</td>\n",
" <td id=\"T_b3c64_row4_col3\" class=\"data row4 col3\" >70.69ms</td>\n",
" <td id=\"T_b3c64_row4_col4\" class=\"data row4 col4\" >294</td>\n",
" <td id=\"T_b3c64_row4_col5\" class=\"data row4 col5\" >146.5 s</td>\n",
" <td id=\"T_b3c64_row4_col6\" class=\"data row4 col6\" >78%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b3c64_row5_col0\" class=\"data row5 col0\" ></td>\n",
" <td id=\"T_b3c64_row5_col1\" class=\"data row5 col1\" >step</td>\n",
" <td id=\"T_b3c64_row5_col2\" class=\"data row5 col2\" >362.4ms</td>\n",
" <td id=\"T_b3c64_row5_col3\" class=\"data row5 col3\" >25.78ms</td>\n",
" <td id=\"T_b3c64_row5_col4\" class=\"data row5 col4\" >294</td>\n",
" <td id=\"T_b3c64_row5_col5\" class=\"data row5 col5\" >106.5 s</td>\n",
" <td id=\"T_b3c64_row5_col6\" class=\"data row5 col6\" >57%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b3c64_row6_col0\" class=\"data row6 col0\" ></td>\n",
" <td id=\"T_b3c64_row6_col1\" class=\"data row6 col1\" >backward</td>\n",
" <td id=\"T_b3c64_row6_col2\" class=\"data row6 col2\" >56.09ms</td>\n",
" <td id=\"T_b3c64_row6_col3\" class=\"data row6 col3\" >16.79ms</td>\n",
" <td id=\"T_b3c64_row6_col4\" class=\"data row6 col4\" >294</td>\n",
" <td id=\"T_b3c64_row6_col5\" class=\"data row6 col5\" >16.49 s</td>\n",
" <td id=\"T_b3c64_row6_col6\" class=\"data row6 col6\" >9%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b3c64_row7_col0\" class=\"data row7 col0\" ></td>\n",
" <td id=\"T_b3c64_row7_col1\" class=\"data row7 col1\" >pred</td>\n",
" <td id=\"T_b3c64_row7_col2\" class=\"data row7 col2\" >53.14ms</td>\n",
" <td id=\"T_b3c64_row7_col3\" class=\"data row7 col3\" >17.82ms</td>\n",
" <td id=\"T_b3c64_row7_col4\" class=\"data row7 col4\" >294</td>\n",
" <td id=\"T_b3c64_row7_col5\" class=\"data row7 col5\" >15.62 s</td>\n",
" <td id=\"T_b3c64_row7_col6\" class=\"data row7 col6\" >8%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b3c64_row8_col0\" class=\"data row8 col0\" ></td>\n",
" <td id=\"T_b3c64_row8_col1\" class=\"data row8 col1\" >draw</td>\n",
" <td id=\"T_b3c64_row8_col2\" class=\"data row8 col2\" >22.42ms</td>\n",
" <td id=\"T_b3c64_row8_col3\" class=\"data row8 col3\" >67.00ms</td>\n",
" <td id=\"T_b3c64_row8_col4\" class=\"data row8 col4\" >294</td>\n",
" <td id=\"T_b3c64_row8_col5\" class=\"data row8 col5\" >6.591 s</td>\n",
" <td id=\"T_b3c64_row8_col6\" class=\"data row8 col6\" >4%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b3c64_row9_col0\" class=\"data row9 col0\" ></td>\n",
" <td id=\"T_b3c64_row9_col1\" class=\"data row9 col1\" >zero_grad</td>\n",
" <td id=\"T_b3c64_row9_col2\" class=\"data row9 col2\" >2.542ms</td>\n",
" <td id=\"T_b3c64_row9_col3\" class=\"data row9 col3\" >879.8µs</td>\n",
" <td id=\"T_b3c64_row9_col4\" class=\"data row9 col4\" >294</td>\n",
" <td id=\"T_b3c64_row9_col5\" class=\"data row9 col5\" >747.5ms</td>\n",
" <td id=\"T_b3c64_row9_col6\" class=\"data row9 col6\" >0%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b3c64_row10_col0\" class=\"data row10 col0\" ></td>\n",
" <td id=\"T_b3c64_row10_col1\" class=\"data row10 col1\" >loss</td>\n",
" <td id=\"T_b3c64_row10_col2\" class=\"data row10 col2\" >1.562ms</td>\n",
" <td id=\"T_b3c64_row10_col3\" class=\"data row10 col3\" >1.418ms</td>\n",
" <td id=\"T_b3c64_row10_col4\" class=\"data row10 col4\" >294</td>\n",
" <td id=\"T_b3c64_row10_col5\" class=\"data row10 col5\" >459.3ms</td>\n",
" <td id=\"T_b3c64_row10_col6\" class=\"data row10 col6\" >0%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b3c64_row11_col0\" class=\"data row11 col0\" >valid</td>\n",
" <td id=\"T_b3c64_row11_col1\" class=\"data row11 col1\" >batch</td>\n",
" <td id=\"T_b3c64_row11_col2\" class=\"data row11 col2\" >287.5ms</td>\n",
" <td id=\"T_b3c64_row11_col3\" class=\"data row11 col3\" >201.1ms</td>\n",
" <td id=\"T_b3c64_row11_col4\" class=\"data row11 col4\" >124</td>\n",
" <td id=\"T_b3c64_row11_col5\" class=\"data row11 col5\" >35.65 s</td>\n",
" <td id=\"T_b3c64_row11_col6\" class=\"data row11 col6\" >19%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b3c64_row12_col0\" class=\"data row12 col0\" ></td>\n",
" <td id=\"T_b3c64_row12_col1\" class=\"data row12 col1\" >draw</td>\n",
" <td id=\"T_b3c64_row12_col2\" class=\"data row12 col2\" >221.2ms</td>\n",
" <td id=\"T_b3c64_row12_col3\" class=\"data row12 col3\" >197.2ms</td>\n",
" <td id=\"T_b3c64_row12_col4\" class=\"data row12 col4\" >124</td>\n",
" <td id=\"T_b3c64_row12_col5\" class=\"data row12 col5\" >27.43 s</td>\n",
" <td id=\"T_b3c64_row12_col6\" class=\"data row12 col6\" >15%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b3c64_row13_col0\" class=\"data row13 col0\" ></td>\n",
" <td id=\"T_b3c64_row13_col1\" class=\"data row13 col1\" >pred</td>\n",
" <td id=\"T_b3c64_row13_col2\" class=\"data row13 col2\" >63.72ms</td>\n",
" <td id=\"T_b3c64_row13_col3\" class=\"data row13 col3\" >13.80ms</td>\n",
" <td id=\"T_b3c64_row13_col4\" class=\"data row13 col4\" >124</td>\n",
" <td id=\"T_b3c64_row13_col5\" class=\"data row13 col5\" >7.901 s</td>\n",
" <td id=\"T_b3c64_row13_col6\" class=\"data row13 col6\" >4%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b3c64_row14_col0\" class=\"data row14 col0\" ></td>\n",
" <td id=\"T_b3c64_row14_col1\" class=\"data row14 col1\" >loss</td>\n",
" <td id=\"T_b3c64_row14_col2\" class=\"data row14 col2\" >2.158ms</td>\n",
" <td id=\"T_b3c64_row14_col3\" class=\"data row14 col3\" >2.513ms</td>\n",
" <td id=\"T_b3c64_row14_col4\" class=\"data row14 col4\" >124</td>\n",
" <td id=\"T_b3c64_row14_col5\" class=\"data row14 col5\" >267.6ms</td>\n",
" <td id=\"T_b3c64_row14_col6\" class=\"data row14 col6\" >0%</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"# PyTorch 1.10\n",
"Restarted the notebook here. This won't have the custom channelslast callback and thus passes TensorImage to the model."
],
"metadata": {
"id": "U2vOG2nsOdnp"
}
},
{
"cell_type": "code",
"source": [
"! pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/torch_stable.html -qq"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HmQV0NcANlfE",
"outputId": "0c0b084d-2a91-4243-f7c1-cca160456ba1"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[K |████████████▌ | 834.1 MB 1.5 MB/s eta 0:14:27tcmalloc: large alloc 1147494400 bytes == 0x392f2000 @ 0x7f054b6a6615 0x592b76 0x4df71e 0x59afff 0x515655 0x549576 0x593fce 0x548ae9 0x51566f 0x549576 0x593fce 0x548ae9 0x5127f1 0x598e3b 0x511f68 0x598e3b 0x511f68 0x598e3b 0x511f68 0x4bc98a 0x532e76 0x594b72 0x515600 0x549576 0x593fce 0x548ae9 0x5127f1 0x549576 0x593fce 0x5118f8 0x593dd7\n",
"\u001b[K |███████████████▉ | 1055.7 MB 1.2 MB/s eta 0:14:28tcmalloc: large alloc 1434370048 bytes == 0x7d948000 @ 0x7f054b6a6615 0x592b76 0x4df71e 0x59afff 0x515655 0x549576 0x593fce 0x548ae9 0x51566f 0x549576 0x593fce 0x548ae9 0x5127f1 0x598e3b 0x511f68 0x598e3b 0x511f68 0x598e3b 0x511f68 0x4bc98a 0x532e76 0x594b72 0x515600 0x549576 0x593fce 0x548ae9 0x5127f1 0x549576 0x593fce 0x5118f8 0x593dd7\n",
"\u001b[K |████████████████████ | 1336.2 MB 1.2 MB/s eta 0:10:45tcmalloc: large alloc 1792966656 bytes == 0x277a000 @ 0x7f054b6a6615 0x592b76 0x4df71e 0x59afff 0x515655 0x549576 0x593fce 0x548ae9 0x51566f 0x549576 0x593fce 0x548ae9 0x5127f1 0x598e3b 0x511f68 0x598e3b 0x511f68 0x598e3b 0x511f68 0x4bc98a 0x532e76 0x594b72 0x515600 0x549576 0x593fce 0x548ae9 0x5127f1 0x549576 0x593fce 0x5118f8 0x593dd7\n",
"\u001b[K |█████████████████████████▎ | 1691.1 MB 1.3 MB/s eta 0:05:36tcmalloc: large alloc 2241208320 bytes == 0x6d562000 @ 0x7f054b6a6615 0x592b76 0x4df71e 0x59afff 0x515655 0x549576 0x593fce 0x548ae9 0x51566f 0x549576 0x593fce 0x548ae9 0x5127f1 0x598e3b 0x511f68 0x598e3b 0x511f68 0x598e3b 0x511f68 0x4bc98a 0x532e76 0x594b72 0x515600 0x549576 0x593fce 0x548ae9 0x5127f1 0x549576 0x593fce 0x5118f8 0x593dd7\n",
"\u001b[K |████████████████████████████████| 2137.7 MB 1.2 MB/s eta 0:00:01tcmalloc: large alloc 2137653248 bytes == 0xf2ec4000 @ 0x7f054b6a51e7 0x4a3940 0x4a39cc 0x592b76 0x4df71e 0x59afff 0x515655 0x549576 0x593fce 0x511e2c 0x549576 0x593fce 0x511e2c 0x549576 0x593fce 0x511e2c 0x549576 0x593fce 0x511e2c 0x549576 0x593fce 0x511e2c 0x593dd7 0x511e2c 0x549576 0x593fce 0x548ae9 0x5127f1 0x549576 0x593fce 0x548ae9\n",
"tcmalloc: large alloc 2672066560 bytes == 0x172564000 @ 0x7f054b6a6615 0x592b76 0x4df71e 0x59afff 0x515655 0x549576 0x593fce 0x511e2c 0x549576 0x593fce 0x511e2c 0x549576 0x593fce 0x511e2c 0x549576 0x593fce 0x511e2c 0x549576 0x593fce 0x511e2c 0x593dd7 0x511e2c 0x549576 0x593fce 0x548ae9 0x5127f1 0x549576 0x593fce 0x548ae9 0x5127f1 0x549576\n",
"\u001b[K |████████████████████████████████| 2137.7 MB 392 bytes/s \n",
"\u001b[K |████████████████████████████████| 24.5 MB 85.6 MB/s \n",
"\u001b[K |████████████████████████████████| 2.7 MB 43.2 MB/s \n",
"\u001b[K |████████████████████████████████| 3.1 MB 8.8 MB/s \n",
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"torchtext 0.12.0 requires torch==1.11.0, but you have torch 1.10.1+cu111 which is incompatible.\n",
"albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.\u001b[0m\n",
"\u001b[?25h"
]
}
]
},
{
"cell_type": "code",
"source": [
"from __future__ import annotations\n",
"from fastai.vision.all import *\n",
"from fastxtend.vision.all import *\n",
"from fastxtend.callback import simpleprofiler"
],
"metadata": {
"id": "OYsdrzMMQ84l"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# DataLoaders"
],
"metadata": {
"id": "I7pjgtNwQ84m"
}
},
{
"cell_type": "code",
"source": [
"imagewoof_stats = ([0.496,0.461,0.399],[0.257,0.249,0.258])\n",
"imagenette_stats = ([0.465,0.458,0.429],[0.285,0.28,0.301])\n",
"\n",
"def get_dls(size:int, woof:bool, bs:int, sh:float=0., augs:list=None, workers:int=None, stats:bool=True) -> DataLoaders:\n",
" if size<=224: path = URLs.IMAGEWOOF_320 if woof else URLs.IMAGENETTE_320\n",
" else : path = URLs.IMAGEWOOF if woof else URLs.IMAGENETTE\n",
" source = untar_data(path)\n",
" if workers is None: workers = min(8, num_cpus())\n",
" batch_tfms = []\n",
" if stats:\n",
" if woof: \n",
" batch_tfms += [Normalize.from_stats(*imagewoof_stats)]\n",
" else:\n",
" batch_tfms += [Normalize.from_stats(*imagenette_stats)]\n",
" if augs: batch_tfms += augs\n",
" if sh: batch_tfms.append(RandomErasing(p=0.3, max_count=3, sh=sh))\n",
" dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),\n",
" splitter=GrandparentSplitter(valid_name='val'),\n",
" get_items=get_image_files, get_y=parent_label,\n",
" item_tfms=[Resize(size)],\n",
" batch_tfms=batch_tfms)\n",
" return dblock.dataloaders(source, path=source, bs=bs, num_workers=workers)"
],
"metadata": {
"id": "HzFZ6Pa7Q84m"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Train fastai default mixed precision"
],
"metadata": {
"id": "c2y2zAkPQ84n"
}
},
{
"cell_type": "markdown",
"source": [
"## Warmup"
],
"metadata": {
"id": "jsDKAQPgQ84n"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" opt_func=Adam, metrics=[Accuracy()]).to_fp16()\n",
"\n",
"learn.fit_one_cycle(1, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 81
},
"outputId": "736159ba-7ab9-4524-800d-916989e475ae",
"id": "eT2TcoouQ84n"
},
"execution_count": 3,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.883641</td>\n",
" <td>1.696063</td>\n",
" <td>0.421146</td>\n",
" <td>01:54</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"## Test"
],
"metadata": {
"id": "mPwHXKDPQ84o"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" opt_func=Adam, metrics=[Accuracy()]).to_fp16().profile()\n",
"\n",
"learn.fit_one_cycle(2, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 632
},
"outputId": "d54c9f59-2746-42e7-dda8-cb74ba1fe281",
"id": "s0knkmVfQ84o"
},
"execution_count": 4,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.958247</td>\n",
" <td>3.767949</td>\n",
" <td>0.315924</td>\n",
" <td>01:37</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.415995</td>\n",
" <td>1.279303</td>\n",
" <td>0.586752</td>\n",
" <td>01:37</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<pandas.io.formats.style.Styler at 0x7f8317309810>"
],
"text/html": [
"<style type=\"text/css\">\n",
"</style>\n",
"<table id=\"T_eef9d_\" class=\"dataframe\">\n",
" <caption>Simple Profiler Results</caption>\n",
" <thead>\n",
" <tr>\n",
" <th class=\"col_heading level0 col0\" >Phase</th>\n",
" <th class=\"col_heading level0 col1\" >Action</th>\n",
" <th class=\"col_heading level0 col2\" >Mean Duration</th>\n",
" <th class=\"col_heading level0 col3\" >Duration Std Dev</th>\n",
" <th class=\"col_heading level0 col4\" >Number of Calls</th>\n",
" <th class=\"col_heading level0 col5\" >Total Time</th>\n",
" <th class=\"col_heading level0 col6\" >Percent of Total</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td id=\"T_eef9d_row0_col0\" class=\"data row0 col0\" >fit</td>\n",
" <td id=\"T_eef9d_row0_col1\" class=\"data row0 col1\" >fit</td>\n",
" <td id=\"T_eef9d_row0_col2\" class=\"data row0 col2\" >-</td>\n",
" <td id=\"T_eef9d_row0_col3\" class=\"data row0 col3\" >-</td>\n",
" <td id=\"T_eef9d_row0_col4\" class=\"data row0 col4\" >1</td>\n",
" <td id=\"T_eef9d_row0_col5\" class=\"data row0 col5\" >194.8 s</td>\n",
" <td id=\"T_eef9d_row0_col6\" class=\"data row0 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_eef9d_row1_col0\" class=\"data row1 col0\" ></td>\n",
" <td id=\"T_eef9d_row1_col1\" class=\"data row1 col1\" >epoch</td>\n",
" <td id=\"T_eef9d_row1_col2\" class=\"data row1 col2\" >97.41 s</td>\n",
" <td id=\"T_eef9d_row1_col3\" class=\"data row1 col3\" >264.9ms</td>\n",
" <td id=\"T_eef9d_row1_col4\" class=\"data row1 col4\" >2</td>\n",
" <td id=\"T_eef9d_row1_col5\" class=\"data row1 col5\" >194.8 s</td>\n",
" <td id=\"T_eef9d_row1_col6\" class=\"data row1 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_eef9d_row2_col0\" class=\"data row2 col0\" ></td>\n",
" <td id=\"T_eef9d_row2_col1\" class=\"data row2 col1\" >train</td>\n",
" <td id=\"T_eef9d_row2_col2\" class=\"data row2 col2\" >72.51 s</td>\n",
" <td id=\"T_eef9d_row2_col3\" class=\"data row2 col3\" >345.1ms</td>\n",
" <td id=\"T_eef9d_row2_col4\" class=\"data row2 col4\" >2</td>\n",
" <td id=\"T_eef9d_row2_col5\" class=\"data row2 col5\" >145.0 s</td>\n",
" <td id=\"T_eef9d_row2_col6\" class=\"data row2 col6\" >74%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_eef9d_row3_col0\" class=\"data row3 col0\" ></td>\n",
" <td id=\"T_eef9d_row3_col1\" class=\"data row3 col1\" >validate</td>\n",
" <td id=\"T_eef9d_row3_col2\" class=\"data row3 col2\" >24.90 s</td>\n",
" <td id=\"T_eef9d_row3_col3\" class=\"data row3 col3\" >606.2ms</td>\n",
" <td id=\"T_eef9d_row3_col4\" class=\"data row3 col4\" >2</td>\n",
" <td id=\"T_eef9d_row3_col5\" class=\"data row3 col5\" >49.79 s</td>\n",
" <td id=\"T_eef9d_row3_col6\" class=\"data row3 col6\" >26%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_eef9d_row4_col0\" class=\"data row4 col0\" >train</td>\n",
" <td id=\"T_eef9d_row4_col1\" class=\"data row4 col1\" >batch</td>\n",
" <td id=\"T_eef9d_row4_col2\" class=\"data row4 col2\" >483.8ms</td>\n",
" <td id=\"T_eef9d_row4_col3\" class=\"data row4 col3\" >113.0ms</td>\n",
" <td id=\"T_eef9d_row4_col4\" class=\"data row4 col4\" >294</td>\n",
" <td id=\"T_eef9d_row4_col5\" class=\"data row4 col5\" >142.2 s</td>\n",
" <td id=\"T_eef9d_row4_col6\" class=\"data row4 col6\" >73%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_eef9d_row5_col0\" class=\"data row5 col0\" ></td>\n",
" <td id=\"T_eef9d_row5_col1\" class=\"data row5 col1\" >step</td>\n",
" <td id=\"T_eef9d_row5_col2\" class=\"data row5 col2\" >302.8ms</td>\n",
" <td id=\"T_eef9d_row5_col3\" class=\"data row5 col3\" >18.21ms</td>\n",
" <td id=\"T_eef9d_row5_col4\" class=\"data row5 col4\" >294</td>\n",
" <td id=\"T_eef9d_row5_col5\" class=\"data row5 col5\" >89.02 s</td>\n",
" <td id=\"T_eef9d_row5_col6\" class=\"data row5 col6\" >46%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_eef9d_row6_col0\" class=\"data row6 col0\" ></td>\n",
" <td id=\"T_eef9d_row6_col1\" class=\"data row6 col1\" >pred</td>\n",
" <td id=\"T_eef9d_row6_col2\" class=\"data row6 col2\" >58.76ms</td>\n",
" <td id=\"T_eef9d_row6_col3\" class=\"data row6 col3\" >16.20ms</td>\n",
" <td id=\"T_eef9d_row6_col4\" class=\"data row6 col4\" >294</td>\n",
" <td id=\"T_eef9d_row6_col5\" class=\"data row6 col5\" >17.27 s</td>\n",
" <td id=\"T_eef9d_row6_col6\" class=\"data row6 col6\" >9%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_eef9d_row7_col0\" class=\"data row7 col0\" ></td>\n",
" <td id=\"T_eef9d_row7_col1\" class=\"data row7 col1\" >backward</td>\n",
" <td id=\"T_eef9d_row7_col2\" class=\"data row7 col2\" >58.61ms</td>\n",
" <td id=\"T_eef9d_row7_col3\" class=\"data row7 col3\" >11.02ms</td>\n",
" <td id=\"T_eef9d_row7_col4\" class=\"data row7 col4\" >294</td>\n",
" <td id=\"T_eef9d_row7_col5\" class=\"data row7 col5\" >17.23 s</td>\n",
" <td id=\"T_eef9d_row7_col6\" class=\"data row7 col6\" >9%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_eef9d_row8_col0\" class=\"data row8 col0\" ></td>\n",
" <td id=\"T_eef9d_row8_col1\" class=\"data row8 col1\" >draw</td>\n",
" <td id=\"T_eef9d_row8_col2\" class=\"data row8 col2\" >57.44ms</td>\n",
" <td id=\"T_eef9d_row8_col3\" class=\"data row8 col3\" >114.3ms</td>\n",
" <td id=\"T_eef9d_row8_col4\" class=\"data row8 col4\" >294</td>\n",
" <td id=\"T_eef9d_row8_col5\" class=\"data row8 col5\" >16.89 s</td>\n",
" <td id=\"T_eef9d_row8_col6\" class=\"data row8 col6\" >9%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_eef9d_row9_col0\" class=\"data row9 col0\" ></td>\n",
" <td id=\"T_eef9d_row9_col1\" class=\"data row9 col1\" >zero_grad</td>\n",
" <td id=\"T_eef9d_row9_col2\" class=\"data row9 col2\" >3.818ms</td>\n",
" <td id=\"T_eef9d_row9_col3\" class=\"data row9 col3\" >3.007ms</td>\n",
" <td id=\"T_eef9d_row9_col4\" class=\"data row9 col4\" >294</td>\n",
" <td id=\"T_eef9d_row9_col5\" class=\"data row9 col5\" >1.122 s</td>\n",
" <td id=\"T_eef9d_row9_col6\" class=\"data row9 col6\" >1%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_eef9d_row10_col0\" class=\"data row10 col0\" ></td>\n",
" <td id=\"T_eef9d_row10_col1\" class=\"data row10 col1\" >loss</td>\n",
" <td id=\"T_eef9d_row10_col2\" class=\"data row10 col2\" >2.055ms</td>\n",
" <td id=\"T_eef9d_row10_col3\" class=\"data row10 col3\" >2.240ms</td>\n",
" <td id=\"T_eef9d_row10_col4\" class=\"data row10 col4\" >294</td>\n",
" <td id=\"T_eef9d_row10_col5\" class=\"data row10 col5\" >604.0ms</td>\n",
" <td id=\"T_eef9d_row10_col6\" class=\"data row10 col6\" >0%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_eef9d_row11_col0\" class=\"data row11 col0\" >valid</td>\n",
" <td id=\"T_eef9d_row11_col1\" class=\"data row11 col1\" >batch</td>\n",
" <td id=\"T_eef9d_row11_col2\" class=\"data row11 col2\" >352.1ms</td>\n",
" <td id=\"T_eef9d_row11_col3\" class=\"data row11 col3\" >285.7ms</td>\n",
" <td id=\"T_eef9d_row11_col4\" class=\"data row11 col4\" >124</td>\n",
" <td id=\"T_eef9d_row11_col5\" class=\"data row11 col5\" >43.66 s</td>\n",
" <td id=\"T_eef9d_row11_col6\" class=\"data row11 col6\" >22%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_eef9d_row12_col0\" class=\"data row12 col0\" ></td>\n",
" <td id=\"T_eef9d_row12_col1\" class=\"data row12 col1\" >draw</td>\n",
" <td id=\"T_eef9d_row12_col2\" class=\"data row12 col2\" >295.1ms</td>\n",
" <td id=\"T_eef9d_row12_col3\" class=\"data row12 col3\" >283.6ms</td>\n",
" <td id=\"T_eef9d_row12_col4\" class=\"data row12 col4\" >124</td>\n",
" <td id=\"T_eef9d_row12_col5\" class=\"data row12 col5\" >36.60 s</td>\n",
" <td id=\"T_eef9d_row12_col6\" class=\"data row12 col6\" >19%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_eef9d_row13_col0\" class=\"data row13 col0\" ></td>\n",
" <td id=\"T_eef9d_row13_col1\" class=\"data row13 col1\" >pred</td>\n",
" <td id=\"T_eef9d_row13_col2\" class=\"data row13 col2\" >54.49ms</td>\n",
" <td id=\"T_eef9d_row13_col3\" class=\"data row13 col3\" >14.05ms</td>\n",
" <td id=\"T_eef9d_row13_col4\" class=\"data row13 col4\" >124</td>\n",
" <td id=\"T_eef9d_row13_col5\" class=\"data row13 col5\" >6.756 s</td>\n",
" <td id=\"T_eef9d_row13_col6\" class=\"data row13 col6\" >3%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_eef9d_row14_col0\" class=\"data row14 col0\" ></td>\n",
" <td id=\"T_eef9d_row14_col1\" class=\"data row14 col1\" >loss</td>\n",
" <td id=\"T_eef9d_row14_col2\" class=\"data row14 col2\" >2.116ms</td>\n",
" <td id=\"T_eef9d_row14_col3\" class=\"data row14 col3\" >2.646ms</td>\n",
" <td id=\"T_eef9d_row14_col4\" class=\"data row14 col4\" >124</td>\n",
" <td id=\"T_eef9d_row14_col5\" class=\"data row14 col5\" >262.4ms</td>\n",
" <td id=\"T_eef9d_row14_col6\" class=\"data row14 col6\" >0%</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"# Train fastai channels last"
],
"metadata": {
"id": "m4ffdf65Q84p"
}
},
{
"cell_type": "markdown",
"source": [
"## Warmup"
],
"metadata": {
"id": "7QDyEJkfQ84p"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" opt_func=Adam, metrics=[Accuracy()]).to_channelslast()\n",
"\n",
"learn.fit_one_cycle(1, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 81
},
"outputId": "77ad87f6-20a7-4d1a-f3e7-2eac1e0c8881",
"id": "FzR97cR9Q84p"
},
"execution_count": 5,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.866315</td>\n",
" <td>1.722754</td>\n",
" <td>0.450446</td>\n",
" <td>01:52</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"## Test"
],
"metadata": {
"id": "Lok8snIKQ84p"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" opt_func=Adam, metrics=[Accuracy()]).to_channelslast().profile()\n",
"\n",
"learn.fit_one_cycle(2, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 632
},
"outputId": "b53a6349-871f-4d70-9601-b1c342077cf3",
"id": "_9SgqLhgQ84q"
},
"execution_count": 6,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.911500</td>\n",
" <td>1.742548</td>\n",
" <td>0.390318</td>\n",
" <td>01:39</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.379123</td>\n",
" <td>1.248970</td>\n",
" <td>0.603057</td>\n",
" <td>01:39</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<pandas.io.formats.style.Styler at 0x7f83783cf310>"
],
"text/html": [
"<style type=\"text/css\">\n",
"</style>\n",
"<table id=\"T_165f6_\" class=\"dataframe\">\n",
" <caption>Simple Profiler Results</caption>\n",
" <thead>\n",
" <tr>\n",
" <th class=\"col_heading level0 col0\" >Phase</th>\n",
" <th class=\"col_heading level0 col1\" >Action</th>\n",
" <th class=\"col_heading level0 col2\" >Mean Duration</th>\n",
" <th class=\"col_heading level0 col3\" >Duration Std Dev</th>\n",
" <th class=\"col_heading level0 col4\" >Number of Calls</th>\n",
" <th class=\"col_heading level0 col5\" >Total Time</th>\n",
" <th class=\"col_heading level0 col6\" >Percent of Total</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td id=\"T_165f6_row0_col0\" class=\"data row0 col0\" >fit</td>\n",
" <td id=\"T_165f6_row0_col1\" class=\"data row0 col1\" >fit</td>\n",
" <td id=\"T_165f6_row0_col2\" class=\"data row0 col2\" >-</td>\n",
" <td id=\"T_165f6_row0_col3\" class=\"data row0 col3\" >-</td>\n",
" <td id=\"T_165f6_row0_col4\" class=\"data row0 col4\" >1</td>\n",
" <td id=\"T_165f6_row0_col5\" class=\"data row0 col5\" >199.5 s</td>\n",
" <td id=\"T_165f6_row0_col6\" class=\"data row0 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_165f6_row1_col0\" class=\"data row1 col0\" ></td>\n",
" <td id=\"T_165f6_row1_col1\" class=\"data row1 col1\" >epoch</td>\n",
" <td id=\"T_165f6_row1_col2\" class=\"data row1 col2\" >99.74 s</td>\n",
" <td id=\"T_165f6_row1_col3\" class=\"data row1 col3\" >145.6ms</td>\n",
" <td id=\"T_165f6_row1_col4\" class=\"data row1 col4\" >2</td>\n",
" <td id=\"T_165f6_row1_col5\" class=\"data row1 col5\" >199.5 s</td>\n",
" <td id=\"T_165f6_row1_col6\" class=\"data row1 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_165f6_row2_col0\" class=\"data row2 col0\" ></td>\n",
" <td id=\"T_165f6_row2_col1\" class=\"data row2 col1\" >train</td>\n",
" <td id=\"T_165f6_row2_col2\" class=\"data row2 col2\" >76.77 s</td>\n",
" <td id=\"T_165f6_row2_col3\" class=\"data row2 col3\" >232.2ms</td>\n",
" <td id=\"T_165f6_row2_col4\" class=\"data row2 col4\" >2</td>\n",
" <td id=\"T_165f6_row2_col5\" class=\"data row2 col5\" >153.5 s</td>\n",
" <td id=\"T_165f6_row2_col6\" class=\"data row2 col6\" >77%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_165f6_row3_col0\" class=\"data row3 col0\" ></td>\n",
" <td id=\"T_165f6_row3_col1\" class=\"data row3 col1\" >validate</td>\n",
" <td id=\"T_165f6_row3_col2\" class=\"data row3 col2\" >22.96 s</td>\n",
" <td id=\"T_165f6_row3_col3\" class=\"data row3 col3\" >82.68ms</td>\n",
" <td id=\"T_165f6_row3_col4\" class=\"data row3 col4\" >2</td>\n",
" <td id=\"T_165f6_row3_col5\" class=\"data row3 col5\" >45.91 s</td>\n",
" <td id=\"T_165f6_row3_col6\" class=\"data row3 col6\" >23%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_165f6_row4_col0\" class=\"data row4 col0\" >train</td>\n",
" <td id=\"T_165f6_row4_col1\" class=\"data row4 col1\" >batch</td>\n",
" <td id=\"T_165f6_row4_col2\" class=\"data row4 col2\" >514.0ms</td>\n",
" <td id=\"T_165f6_row4_col3\" class=\"data row4 col3\" >66.73ms</td>\n",
" <td id=\"T_165f6_row4_col4\" class=\"data row4 col4\" >294</td>\n",
" <td id=\"T_165f6_row4_col5\" class=\"data row4 col5\" >151.1 s</td>\n",
" <td id=\"T_165f6_row4_col6\" class=\"data row4 col6\" >76%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_165f6_row5_col0\" class=\"data row5 col0\" ></td>\n",
" <td id=\"T_165f6_row5_col1\" class=\"data row5 col1\" >step</td>\n",
" <td id=\"T_165f6_row5_col2\" class=\"data row5 col2\" >377.4ms</td>\n",
" <td id=\"T_165f6_row5_col3\" class=\"data row5 col3\" >23.88ms</td>\n",
" <td id=\"T_165f6_row5_col4\" class=\"data row5 col4\" >294</td>\n",
" <td id=\"T_165f6_row5_col5\" class=\"data row5 col5\" >111.0 s</td>\n",
" <td id=\"T_165f6_row5_col6\" class=\"data row5 col6\" >56%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_165f6_row6_col0\" class=\"data row6 col0\" ></td>\n",
" <td id=\"T_165f6_row6_col1\" class=\"data row6 col1\" >pred</td>\n",
" <td id=\"T_165f6_row6_col2\" class=\"data row6 col2\" >56.88ms</td>\n",
" <td id=\"T_165f6_row6_col3\" class=\"data row6 col3\" >17.09ms</td>\n",
" <td id=\"T_165f6_row6_col4\" class=\"data row6 col4\" >294</td>\n",
" <td id=\"T_165f6_row6_col5\" class=\"data row6 col5\" >16.72 s</td>\n",
" <td id=\"T_165f6_row6_col6\" class=\"data row6 col6\" >8%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_165f6_row7_col0\" class=\"data row7 col0\" ></td>\n",
" <td id=\"T_165f6_row7_col1\" class=\"data row7 col1\" >backward</td>\n",
" <td id=\"T_165f6_row7_col2\" class=\"data row7 col2\" >49.06ms</td>\n",
" <td id=\"T_165f6_row7_col3\" class=\"data row7 col3\" >9.625ms</td>\n",
" <td id=\"T_165f6_row7_col4\" class=\"data row7 col4\" >294</td>\n",
" <td id=\"T_165f6_row7_col5\" class=\"data row7 col5\" >14.42 s</td>\n",
" <td id=\"T_165f6_row7_col6\" class=\"data row7 col6\" >7%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_165f6_row8_col0\" class=\"data row8 col0\" ></td>\n",
" <td id=\"T_165f6_row8_col1\" class=\"data row8 col1\" >draw</td>\n",
" <td id=\"T_165f6_row8_col2\" class=\"data row8 col2\" >24.88ms</td>\n",
" <td id=\"T_165f6_row8_col3\" class=\"data row8 col3\" >64.92ms</td>\n",
" <td id=\"T_165f6_row8_col4\" class=\"data row8 col4\" >294</td>\n",
" <td id=\"T_165f6_row8_col5\" class=\"data row8 col5\" >7.315 s</td>\n",
" <td id=\"T_165f6_row8_col6\" class=\"data row8 col6\" >4%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_165f6_row9_col0\" class=\"data row9 col0\" ></td>\n",
" <td id=\"T_165f6_row9_col1\" class=\"data row9 col1\" >zero_grad</td>\n",
" <td id=\"T_165f6_row9_col2\" class=\"data row9 col2\" >3.390ms</td>\n",
" <td id=\"T_165f6_row9_col3\" class=\"data row9 col3\" >2.429ms</td>\n",
" <td id=\"T_165f6_row9_col4\" class=\"data row9 col4\" >294</td>\n",
" <td id=\"T_165f6_row9_col5\" class=\"data row9 col5\" >996.7ms</td>\n",
" <td id=\"T_165f6_row9_col6\" class=\"data row9 col6\" >0%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_165f6_row10_col0\" class=\"data row10 col0\" ></td>\n",
" <td id=\"T_165f6_row10_col1\" class=\"data row10 col1\" >loss</td>\n",
" <td id=\"T_165f6_row10_col2\" class=\"data row10 col2\" >1.979ms</td>\n",
" <td id=\"T_165f6_row10_col3\" class=\"data row10 col3\" >1.894ms</td>\n",
" <td id=\"T_165f6_row10_col4\" class=\"data row10 col4\" >294</td>\n",
" <td id=\"T_165f6_row10_col5\" class=\"data row10 col5\" >581.7ms</td>\n",
" <td id=\"T_165f6_row10_col6\" class=\"data row10 col6\" >0%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_165f6_row11_col0\" class=\"data row11 col0\" >valid</td>\n",
" <td id=\"T_165f6_row11_col1\" class=\"data row11 col1\" >batch</td>\n",
" <td id=\"T_165f6_row11_col2\" class=\"data row11 col2\" >333.9ms</td>\n",
" <td id=\"T_165f6_row11_col3\" class=\"data row11 col3\" >207.6ms</td>\n",
" <td id=\"T_165f6_row11_col4\" class=\"data row11 col4\" >124</td>\n",
" <td id=\"T_165f6_row11_col5\" class=\"data row11 col5\" >41.40 s</td>\n",
" <td id=\"T_165f6_row11_col6\" class=\"data row11 col6\" >21%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_165f6_row12_col0\" class=\"data row12 col0\" ></td>\n",
" <td id=\"T_165f6_row12_col1\" class=\"data row12 col1\" >draw</td>\n",
" <td id=\"T_165f6_row12_col2\" class=\"data row12 col2\" >278.4ms</td>\n",
" <td id=\"T_165f6_row12_col3\" class=\"data row12 col3\" >203.4ms</td>\n",
" <td id=\"T_165f6_row12_col4\" class=\"data row12 col4\" >124</td>\n",
" <td id=\"T_165f6_row12_col5\" class=\"data row12 col5\" >34.52 s</td>\n",
" <td id=\"T_165f6_row12_col6\" class=\"data row12 col6\" >17%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_165f6_row13_col0\" class=\"data row13 col0\" ></td>\n",
" <td id=\"T_165f6_row13_col1\" class=\"data row13 col1\" >pred</td>\n",
" <td id=\"T_165f6_row13_col2\" class=\"data row13 col2\" >53.18ms</td>\n",
" <td id=\"T_165f6_row13_col3\" class=\"data row13 col3\" >11.34ms</td>\n",
" <td id=\"T_165f6_row13_col4\" class=\"data row13 col4\" >124</td>\n",
" <td id=\"T_165f6_row13_col5\" class=\"data row13 col5\" >6.594 s</td>\n",
" <td id=\"T_165f6_row13_col6\" class=\"data row13 col6\" >3%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_165f6_row14_col0\" class=\"data row14 col0\" ></td>\n",
" <td id=\"T_165f6_row14_col1\" class=\"data row14 col1\" >loss</td>\n",
" <td id=\"T_165f6_row14_col2\" class=\"data row14 col2\" >1.799ms</td>\n",
" <td id=\"T_165f6_row14_col3\" class=\"data row14 col3\" >1.944ms</td>\n",
" <td id=\"T_165f6_row14_col4\" class=\"data row14 col4\" >124</td>\n",
" <td id=\"T_165f6_row14_col5\" class=\"data row14 col5\" >223.1ms</td>\n",
" <td id=\"T_165f6_row14_col6\" class=\"data row14 col6\" >0%</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
]
},
"metadata": {}
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment