Created
June 8, 2022 17:26
Problem with channels_last and fastai custom tensors
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Problem with channels_last and fastai custom tensors.ipynb", | |
"provenance": [], | |
"collapsed_sections": [] | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Problem with channels_last and fastai custom tensors\n", | |
"\n", | |
"This notebook contains expiraments showing that passing fastai's custom tensor classes (TensorBase & TensorImage) results in a degredation of channels last performance. The only change needed to get the expected channels last performance increase is to cast a TensorBase derived input to torch.tensor and then channels last trains faster as expected." | |
], | |
"metadata": { | |
"id": "DgDxBVq9TTpe" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"! nvidia-smi" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "WvJkCesSEx2V", | |
"outputId": "12ea5700-85f2-4ee6-8acf-e9c183618d8b" | |
}, | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Wed Jun 8 16:08:37 2022 \n", | |
"+-----------------------------------------------------------------------------+\n", | |
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n", | |
"|-------------------------------+----------------------+----------------------+\n", | |
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", | |
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", | |
"| | | MIG M. |\n", | |
"|===============================+======================+======================|\n", | |
"| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", | |
"| N/A 62C P8 11W / 70W | 0MiB / 15109MiB | 0% Default |\n", | |
"| | | N/A |\n", | |
"+-------------------------------+----------------------+----------------------+\n", | |
" \n", | |
"+-----------------------------------------------------------------------------+\n", | |
"| Processes: |\n", | |
"| GPU GI CI PID Type Process name GPU Memory |\n", | |
"| ID ID Usage |\n", | |
"|=============================================================================|\n", | |
"| No running processes found |\n", | |
"+-----------------------------------------------------------------------------+\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "u-D43aOA-N_z" | |
}, | |
"outputs": [], | |
"source": [ | |
"! pip install fastcore fastai fastxtend[vision] -U" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"! pip uninstall pillow -y\n", | |
"! CC=\"cc -mavx2\" pip install -U --force-reinstall pillow-simd" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "m0FrzR7EHzBZ", | |
"outputId": "6aee40bb-8ed0-4af1-c919-df2dd629ed12" | |
}, | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Found existing installation: Pillow 7.1.2\n", | |
"Uninstalling Pillow-7.1.2:\n", | |
" Successfully uninstalled Pillow-7.1.2\n", | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Collecting pillow-simd\n", | |
" Downloading Pillow-SIMD-9.0.0.post1.tar.gz (849 kB)\n", | |
"\u001b[K |████████████████████████████████| 849 kB 7.6 MB/s \n", | |
"\u001b[?25hBuilding wheels for collected packages: pillow-simd\n", | |
" Building wheel for pillow-simd (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for pillow-simd: filename=Pillow_SIMD-9.0.0.post1-cp37-cp37m-linux_x86_64.whl size=1255565 sha256=5da71e6b46623d6920cc321db69258300c8543cdab25e5fc5711d397036233dc\n", | |
" Stored in directory: /root/.cache/pip/wheels/9b/3f/fd/ca7133b4f7f509eb1de652bb8c128529a0c04b25ef0a6c535a\n", | |
"Successfully built pillow-simd\n", | |
"Installing collected packages: pillow-simd\n", | |
"Successfully installed pillow-simd-9.0.0.post1\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from __future__ import annotations\n", | |
"from fastai.vision.all import *\n", | |
"from fastxtend.vision.all import *\n", | |
"from fastxtend.callback import simpleprofiler" | |
], | |
"metadata": { | |
"id": "55BRyAzH-S-f" | |
}, | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# DataLoaders" | |
], | |
"metadata": { | |
"id": "J1tiDmSY-sxr" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"imagewoof_stats = ([0.496,0.461,0.399],[0.257,0.249,0.258])\n", | |
"imagenette_stats = ([0.465,0.458,0.429],[0.285,0.28,0.301])\n", | |
"\n", | |
"def get_dls(size:int, woof:bool, bs:int, sh:float=0., augs:list=None, workers:int=None, stats:bool=True) -> DataLoaders:\n", | |
" if size<=224: path = URLs.IMAGEWOOF_320 if woof else URLs.IMAGENETTE_320\n", | |
" else : path = URLs.IMAGEWOOF if woof else URLs.IMAGENETTE\n", | |
" source = untar_data(path)\n", | |
" if workers is None: workers = min(8, num_cpus())\n", | |
" batch_tfms = []\n", | |
" if stats:\n", | |
" if woof: \n", | |
" batch_tfms += [Normalize.from_stats(*imagewoof_stats)]\n", | |
" else:\n", | |
" batch_tfms += [Normalize.from_stats(*imagenette_stats)]\n", | |
" if augs: batch_tfms += augs\n", | |
" if sh: batch_tfms.append(RandomErasing(p=0.3, max_count=3, sh=sh))\n", | |
" dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),\n", | |
" splitter=GrandparentSplitter(valid_name='val'),\n", | |
" get_items=get_image_files, get_y=parent_label,\n", | |
" item_tfms=[Resize(size)],\n", | |
" batch_tfms=batch_tfms)\n", | |
" return dblock.dataloaders(source, path=source, bs=bs, num_workers=workers)" | |
], | |
"metadata": { | |
"id": "pNgVl8Gs-dxf" | |
}, | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Train fastai default mixed precision" | |
], | |
"metadata": { | |
"id": "beQssA6n-wam" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Warmup" | |
], | |
"metadata": { | |
"id": "bOYaHgr--2yJ" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dls = get_dls(256, False, 64)\n", | |
"learn = Learner(dls, resnet50(num_classes=dls.c), \n", | |
" loss_func=CrossEntropyLossFlat(), \n", | |
" opt_func=Adam, metrics=[Accuracy()]).to_fp16()\n", | |
"\n", | |
"learn.fit_one_cycle(1, 3e-3)\n", | |
"\n", | |
"free_gpu_memory(learn, dls)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 81 | |
}, | |
"id": "gXqWSo7R-v28", | |
"outputId": "a2fbfba6-15bc-48b1-e5eb-295f7cab29be" | |
}, | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
"<style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
"</style>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>accuracy</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>1.912659</td>\n", | |
" <td>1.717486</td>\n", | |
" <td>0.414777</td>\n", | |
" <td>01:32</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
] | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Test" | |
], | |
"metadata": { | |
"id": "h8luVtBs-4F7" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dls = get_dls(256, False, 64)\n", | |
"learn = Learner(dls, resnet50(num_classes=dls.c), \n", | |
" loss_func=CrossEntropyLossFlat(), \n", | |
" opt_func=Adam, metrics=[Accuracy()]).to_fp16().profile()\n", | |
"\n", | |
"learn.fit_one_cycle(2, 3e-3)\n", | |
"\n", | |
"free_gpu_memory(learn, dls)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 632 | |
}, | |
"id": "lkQ7iTGJ--dd", | |
"outputId": "738d77bf-6cbc-4faa-cd21-79c1ac57d7ef" | |
}, | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
"<style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
"</style>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>accuracy</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>1.976799</td>\n", | |
" <td>1.873074</td>\n", | |
" <td>0.344968</td>\n", | |
" <td>01:26</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>1</td>\n", | |
" <td>1.385735</td>\n", | |
" <td>1.268956</td>\n", | |
" <td>0.610701</td>\n", | |
" <td>01:25</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<pandas.io.formats.style.Styler at 0x7f763ecf4a90>" | |
], | |
"text/html": [ | |
"<style type=\"text/css\">\n", | |
"</style>\n", | |
"<table id=\"T_e689c_\" class=\"dataframe\">\n", | |
" <caption>Simple Profiler Results</caption>\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <th class=\"col_heading level0 col0\" >Phase</th>\n", | |
" <th class=\"col_heading level0 col1\" >Action</th>\n", | |
" <th class=\"col_heading level0 col2\" >Mean Duration</th>\n", | |
" <th class=\"col_heading level0 col3\" >Duration Std Dev</th>\n", | |
" <th class=\"col_heading level0 col4\" >Number of Calls</th>\n", | |
" <th class=\"col_heading level0 col5\" >Total Time</th>\n", | |
" <th class=\"col_heading level0 col6\" >Percent of Total</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td id=\"T_e689c_row0_col0\" class=\"data row0 col0\" >fit</td>\n", | |
" <td id=\"T_e689c_row0_col1\" class=\"data row0 col1\" >fit</td>\n", | |
" <td id=\"T_e689c_row0_col2\" class=\"data row0 col2\" >-</td>\n", | |
" <td id=\"T_e689c_row0_col3\" class=\"data row0 col3\" >-</td>\n", | |
" <td id=\"T_e689c_row0_col4\" class=\"data row0 col4\" >1</td>\n", | |
" <td id=\"T_e689c_row0_col5\" class=\"data row0 col5\" >171.6 s</td>\n", | |
" <td id=\"T_e689c_row0_col6\" class=\"data row0 col6\" >100%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e689c_row1_col0\" class=\"data row1 col0\" ></td>\n", | |
" <td id=\"T_e689c_row1_col1\" class=\"data row1 col1\" >epoch</td>\n", | |
" <td id=\"T_e689c_row1_col2\" class=\"data row1 col2\" >85.81 s</td>\n", | |
" <td id=\"T_e689c_row1_col3\" class=\"data row1 col3\" >422.4ms</td>\n", | |
" <td id=\"T_e689c_row1_col4\" class=\"data row1 col4\" >2</td>\n", | |
" <td id=\"T_e689c_row1_col5\" class=\"data row1 col5\" >171.6 s</td>\n", | |
" <td id=\"T_e689c_row1_col6\" class=\"data row1 col6\" >100%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e689c_row2_col0\" class=\"data row2 col0\" ></td>\n", | |
" <td id=\"T_e689c_row2_col1\" class=\"data row2 col1\" >train</td>\n", | |
" <td id=\"T_e689c_row2_col2\" class=\"data row2 col2\" >66.26 s</td>\n", | |
" <td id=\"T_e689c_row2_col3\" class=\"data row2 col3\" >353.6ms</td>\n", | |
" <td id=\"T_e689c_row2_col4\" class=\"data row2 col4\" >2</td>\n", | |
" <td id=\"T_e689c_row2_col5\" class=\"data row2 col5\" >132.5 s</td>\n", | |
" <td id=\"T_e689c_row2_col6\" class=\"data row2 col6\" >77%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e689c_row3_col0\" class=\"data row3 col0\" ></td>\n", | |
" <td id=\"T_e689c_row3_col1\" class=\"data row3 col1\" >validate</td>\n", | |
" <td id=\"T_e689c_row3_col2\" class=\"data row3 col2\" >19.54 s</td>\n", | |
" <td id=\"T_e689c_row3_col3\" class=\"data row3 col3\" >65.33ms</td>\n", | |
" <td id=\"T_e689c_row3_col4\" class=\"data row3 col4\" >2</td>\n", | |
" <td id=\"T_e689c_row3_col5\" class=\"data row3 col5\" >39.08 s</td>\n", | |
" <td id=\"T_e689c_row3_col6\" class=\"data row3 col6\" >23%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e689c_row4_col0\" class=\"data row4 col0\" >train</td>\n", | |
" <td id=\"T_e689c_row4_col1\" class=\"data row4 col1\" >batch</td>\n", | |
" <td id=\"T_e689c_row4_col2\" class=\"data row4 col2\" >443.2ms</td>\n", | |
" <td id=\"T_e689c_row4_col3\" class=\"data row4 col3\" >67.11ms</td>\n", | |
" <td id=\"T_e689c_row4_col4\" class=\"data row4 col4\" >294</td>\n", | |
" <td id=\"T_e689c_row4_col5\" class=\"data row4 col5\" >130.3 s</td>\n", | |
" <td id=\"T_e689c_row4_col6\" class=\"data row4 col6\" >76%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e689c_row5_col0\" class=\"data row5 col0\" ></td>\n", | |
" <td id=\"T_e689c_row5_col1\" class=\"data row5 col1\" >step</td>\n", | |
" <td id=\"T_e689c_row5_col2\" class=\"data row5 col2\" >280.0ms</td>\n", | |
" <td id=\"T_e689c_row5_col3\" class=\"data row5 col3\" >23.06ms</td>\n", | |
" <td id=\"T_e689c_row5_col4\" class=\"data row5 col4\" >294</td>\n", | |
" <td id=\"T_e689c_row5_col5\" class=\"data row5 col5\" >82.32 s</td>\n", | |
" <td id=\"T_e689c_row5_col6\" class=\"data row5 col6\" >48%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e689c_row6_col0\" class=\"data row6 col0\" ></td>\n", | |
" <td id=\"T_e689c_row6_col1\" class=\"data row6 col1\" >backward</td>\n", | |
" <td id=\"T_e689c_row6_col2\" class=\"data row6 col2\" >73.15ms</td>\n", | |
" <td id=\"T_e689c_row6_col3\" class=\"data row6 col3\" >13.76ms</td>\n", | |
" <td id=\"T_e689c_row6_col4\" class=\"data row6 col4\" >294</td>\n", | |
" <td id=\"T_e689c_row6_col5\" class=\"data row6 col5\" >21.51 s</td>\n", | |
" <td id=\"T_e689c_row6_col6\" class=\"data row6 col6\" >13%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e689c_row7_col0\" class=\"data row7 col0\" ></td>\n", | |
" <td id=\"T_e689c_row7_col1\" class=\"data row7 col1\" >pred</td>\n", | |
" <td id=\"T_e689c_row7_col2\" class=\"data row7 col2\" >61.11ms</td>\n", | |
" <td id=\"T_e689c_row7_col3\" class=\"data row7 col3\" >19.80ms</td>\n", | |
" <td id=\"T_e689c_row7_col4\" class=\"data row7 col4\" >294</td>\n", | |
" <td id=\"T_e689c_row7_col5\" class=\"data row7 col5\" >17.97 s</td>\n", | |
" <td id=\"T_e689c_row7_col6\" class=\"data row7 col6\" >10%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e689c_row8_col0\" class=\"data row8 col0\" ></td>\n", | |
" <td id=\"T_e689c_row8_col1\" class=\"data row8 col1\" >draw</td>\n", | |
" <td id=\"T_e689c_row8_col2\" class=\"data row8 col2\" >23.78ms</td>\n", | |
" <td id=\"T_e689c_row8_col3\" class=\"data row8 col3\" >62.00ms</td>\n", | |
" <td id=\"T_e689c_row8_col4\" class=\"data row8 col4\" >294</td>\n", | |
" <td id=\"T_e689c_row8_col5\" class=\"data row8 col5\" >6.991 s</td>\n", | |
" <td id=\"T_e689c_row8_col6\" class=\"data row8 col6\" >4%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e689c_row9_col0\" class=\"data row9 col0\" ></td>\n", | |
" <td id=\"T_e689c_row9_col1\" class=\"data row9 col1\" >zero_grad</td>\n", | |
" <td id=\"T_e689c_row9_col2\" class=\"data row9 col2\" >2.996ms</td>\n", | |
" <td id=\"T_e689c_row9_col3\" class=\"data row9 col3\" >1.757ms</td>\n", | |
" <td id=\"T_e689c_row9_col4\" class=\"data row9 col4\" >294</td>\n", | |
" <td id=\"T_e689c_row9_col5\" class=\"data row9 col5\" >880.8ms</td>\n", | |
" <td id=\"T_e689c_row9_col6\" class=\"data row9 col6\" >1%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e689c_row10_col0\" class=\"data row10 col0\" ></td>\n", | |
" <td id=\"T_e689c_row10_col1\" class=\"data row10 col1\" >loss</td>\n", | |
" <td id=\"T_e689c_row10_col2\" class=\"data row10 col2\" >1.882ms</td>\n", | |
" <td id=\"T_e689c_row10_col3\" class=\"data row10 col3\" >1.883ms</td>\n", | |
" <td id=\"T_e689c_row10_col4\" class=\"data row10 col4\" >294</td>\n", | |
" <td id=\"T_e689c_row10_col5\" class=\"data row10 col5\" >553.3ms</td>\n", | |
" <td id=\"T_e689c_row10_col6\" class=\"data row10 col6\" >0%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e689c_row11_col0\" class=\"data row11 col0\" >valid</td>\n", | |
" <td id=\"T_e689c_row11_col1\" class=\"data row11 col1\" >batch</td>\n", | |
" <td id=\"T_e689c_row11_col2\" class=\"data row11 col2\" >277.3ms</td>\n", | |
" <td id=\"T_e689c_row11_col3\" class=\"data row11 col3\" >155.1ms</td>\n", | |
" <td id=\"T_e689c_row11_col4\" class=\"data row11 col4\" >124</td>\n", | |
" <td id=\"T_e689c_row11_col5\" class=\"data row11 col5\" >34.38 s</td>\n", | |
" <td id=\"T_e689c_row11_col6\" class=\"data row11 col6\" >20%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e689c_row12_col0\" class=\"data row12 col0\" ></td>\n", | |
" <td id=\"T_e689c_row12_col1\" class=\"data row12 col1\" >draw</td>\n", | |
" <td id=\"T_e689c_row12_col2\" class=\"data row12 col2\" >210.5ms</td>\n", | |
" <td id=\"T_e689c_row12_col3\" class=\"data row12 col3\" >150.8ms</td>\n", | |
" <td id=\"T_e689c_row12_col4\" class=\"data row12 col4\" >124</td>\n", | |
" <td id=\"T_e689c_row12_col5\" class=\"data row12 col5\" >26.10 s</td>\n", | |
" <td id=\"T_e689c_row12_col6\" class=\"data row12 col6\" >15%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e689c_row13_col0\" class=\"data row13 col0\" ></td>\n", | |
" <td id=\"T_e689c_row13_col1\" class=\"data row13 col1\" >pred</td>\n", | |
" <td id=\"T_e689c_row13_col2\" class=\"data row13 col2\" >64.46ms</td>\n", | |
" <td id=\"T_e689c_row13_col3\" class=\"data row13 col3\" >12.82ms</td>\n", | |
" <td id=\"T_e689c_row13_col4\" class=\"data row13 col4\" >124</td>\n", | |
" <td id=\"T_e689c_row13_col5\" class=\"data row13 col5\" >7.993 s</td>\n", | |
" <td id=\"T_e689c_row13_col6\" class=\"data row13 col6\" >5%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e689c_row14_col0\" class=\"data row14 col0\" ></td>\n", | |
" <td id=\"T_e689c_row14_col1\" class=\"data row14 col1\" >loss</td>\n", | |
" <td id=\"T_e689c_row14_col2\" class=\"data row14 col2\" >1.952ms</td>\n", | |
" <td id=\"T_e689c_row14_col3\" class=\"data row14 col3\" >2.213ms</td>\n", | |
" <td id=\"T_e689c_row14_col4\" class=\"data row14 col4\" >124</td>\n", | |
" <td id=\"T_e689c_row14_col5\" class=\"data row14 col5\" >242.1ms</td>\n", | |
" <td id=\"T_e689c_row14_col6\" class=\"data row14 col6\" >0%</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n" | |
] | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Train fastai channels last" | |
], | |
"metadata": { | |
"id": "VmZkjfymAVSS" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Warmup" | |
], | |
"metadata": { | |
"id": "ybh8djjhAVST" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dls = get_dls(256, False, 64)\n", | |
"learn = Learner(dls, resnet50(num_classes=dls.c), \n", | |
" loss_func=CrossEntropyLossFlat(), \n", | |
" opt_func=Adam, metrics=[Accuracy()]).to_channelslast()\n", | |
"\n", | |
"learn.fit_one_cycle(1, 3e-3)\n", | |
"\n", | |
"free_gpu_memory(learn, dls)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 81 | |
}, | |
"outputId": "93f31c24-725b-4467-8c22-f739ee09e829", | |
"id": "6axWVksaAVST" | |
}, | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
"<style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
"</style>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>accuracy</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>1.800101</td>\n", | |
" <td>1.618706</td>\n", | |
" <td>0.447643</td>\n", | |
" <td>01:46</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
] | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Test" | |
], | |
"metadata": { | |
"id": "Wr2wDyEPAVSU" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dls = get_dls(256, False, 64)\n", | |
"learn = Learner(dls, resnet50(num_classes=dls.c), \n", | |
" loss_func=CrossEntropyLossFlat(), \n", | |
" opt_func=Adam, metrics=[Accuracy()]).to_channelslast().profile()\n", | |
"\n", | |
"learn.fit_one_cycle(2, 3e-3)\n", | |
"\n", | |
"free_gpu_memory(learn, dls)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 632 | |
}, | |
"outputId": "aa76f239-d805-451c-a0d9-d54964ceabe2", | |
"id": "cIqWhmxhAVSU" | |
}, | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
"<style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
"</style>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>accuracy</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>1.945158</td>\n", | |
" <td>1.860097</td>\n", | |
" <td>0.369172</td>\n", | |
" <td>01:33</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>1</td>\n", | |
" <td>1.290537</td>\n", | |
" <td>1.169882</td>\n", | |
" <td>0.627261</td>\n", | |
" <td>01:33</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<pandas.io.formats.style.Styler at 0x7f7647fb75d0>" | |
], | |
"text/html": [ | |
"<style type=\"text/css\">\n", | |
"</style>\n", | |
"<table id=\"T_c2056_\" class=\"dataframe\">\n", | |
" <caption>Simple Profiler Results</caption>\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <th class=\"col_heading level0 col0\" >Phase</th>\n", | |
" <th class=\"col_heading level0 col1\" >Action</th>\n", | |
" <th class=\"col_heading level0 col2\" >Mean Duration</th>\n", | |
" <th class=\"col_heading level0 col3\" >Duration Std Dev</th>\n", | |
" <th class=\"col_heading level0 col4\" >Number of Calls</th>\n", | |
" <th class=\"col_heading level0 col5\" >Total Time</th>\n", | |
" <th class=\"col_heading level0 col6\" >Percent of Total</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td id=\"T_c2056_row0_col0\" class=\"data row0 col0\" >fit</td>\n", | |
" <td id=\"T_c2056_row0_col1\" class=\"data row0 col1\" >fit</td>\n", | |
" <td id=\"T_c2056_row0_col2\" class=\"data row0 col2\" >-</td>\n", | |
" <td id=\"T_c2056_row0_col3\" class=\"data row0 col3\" >-</td>\n", | |
" <td id=\"T_c2056_row0_col4\" class=\"data row0 col4\" >1</td>\n", | |
" <td id=\"T_c2056_row0_col5\" class=\"data row0 col5\" >186.9 s</td>\n", | |
" <td id=\"T_c2056_row0_col6\" class=\"data row0 col6\" >100%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_c2056_row1_col0\" class=\"data row1 col0\" ></td>\n", | |
" <td id=\"T_c2056_row1_col1\" class=\"data row1 col1\" >epoch</td>\n", | |
" <td id=\"T_c2056_row1_col2\" class=\"data row1 col2\" >93.47 s</td>\n", | |
" <td id=\"T_c2056_row1_col3\" class=\"data row1 col3\" >170.4ms</td>\n", | |
" <td id=\"T_c2056_row1_col4\" class=\"data row1 col4\" >2</td>\n", | |
" <td id=\"T_c2056_row1_col5\" class=\"data row1 col5\" >186.9 s</td>\n", | |
" <td id=\"T_c2056_row1_col6\" class=\"data row1 col6\" >100%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_c2056_row2_col0\" class=\"data row2 col0\" ></td>\n", | |
" <td id=\"T_c2056_row2_col1\" class=\"data row2 col1\" >train</td>\n", | |
" <td id=\"T_c2056_row2_col2\" class=\"data row2 col2\" >74.08 s</td>\n", | |
" <td id=\"T_c2056_row2_col3\" class=\"data row2 col3\" >188.7ms</td>\n", | |
" <td id=\"T_c2056_row2_col4\" class=\"data row2 col4\" >2</td>\n", | |
" <td id=\"T_c2056_row2_col5\" class=\"data row2 col5\" >148.2 s</td>\n", | |
" <td id=\"T_c2056_row2_col6\" class=\"data row2 col6\" >79%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_c2056_row3_col0\" class=\"data row3 col0\" ></td>\n", | |
" <td id=\"T_c2056_row3_col1\" class=\"data row3 col1\" >validate</td>\n", | |
" <td id=\"T_c2056_row3_col2\" class=\"data row3 col2\" >19.38 s</td>\n", | |
" <td id=\"T_c2056_row3_col3\" class=\"data row3 col3\" >21.99ms</td>\n", | |
" <td id=\"T_c2056_row3_col4\" class=\"data row3 col4\" >2</td>\n", | |
" <td id=\"T_c2056_row3_col5\" class=\"data row3 col5\" >38.77 s</td>\n", | |
" <td id=\"T_c2056_row3_col6\" class=\"data row3 col6\" >21%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_c2056_row4_col0\" class=\"data row4 col0\" >train</td>\n", | |
" <td id=\"T_c2056_row4_col1\" class=\"data row4 col1\" >batch</td>\n", | |
" <td id=\"T_c2056_row4_col2\" class=\"data row4 col2\" >498.0ms</td>\n", | |
" <td id=\"T_c2056_row4_col3\" class=\"data row4 col3\" >55.88ms</td>\n", | |
" <td id=\"T_c2056_row4_col4\" class=\"data row4 col4\" >294</td>\n", | |
" <td id=\"T_c2056_row4_col5\" class=\"data row4 col5\" >146.4 s</td>\n", | |
" <td id=\"T_c2056_row4_col6\" class=\"data row4 col6\" >78%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_c2056_row5_col0\" class=\"data row5 col0\" ></td>\n", | |
" <td id=\"T_c2056_row5_col1\" class=\"data row5 col1\" >step</td>\n", | |
" <td id=\"T_c2056_row5_col2\" class=\"data row5 col2\" >365.1ms</td>\n", | |
" <td id=\"T_c2056_row5_col3\" class=\"data row5 col3\" >26.65ms</td>\n", | |
" <td id=\"T_c2056_row5_col4\" class=\"data row5 col4\" >294</td>\n", | |
" <td id=\"T_c2056_row5_col5\" class=\"data row5 col5\" >107.3 s</td>\n", | |
" <td id=\"T_c2056_row5_col6\" class=\"data row5 col6\" >57%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_c2056_row6_col0\" class=\"data row6 col0\" ></td>\n", | |
" <td id=\"T_c2056_row6_col1\" class=\"data row6 col1\" >backward</td>\n", | |
" <td id=\"T_c2056_row6_col2\" class=\"data row6 col2\" >53.55ms</td>\n", | |
" <td id=\"T_c2056_row6_col3\" class=\"data row6 col3\" >15.32ms</td>\n", | |
" <td id=\"T_c2056_row6_col4\" class=\"data row6 col4\" >294</td>\n", | |
" <td id=\"T_c2056_row6_col5\" class=\"data row6 col5\" >15.74 s</td>\n", | |
" <td id=\"T_c2056_row6_col6\" class=\"data row6 col6\" >8%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_c2056_row7_col0\" class=\"data row7 col0\" ></td>\n", | |
" <td id=\"T_c2056_row7_col1\" class=\"data row7 col1\" >pred</td>\n", | |
" <td id=\"T_c2056_row7_col2\" class=\"data row7 col2\" >52.80ms</td>\n", | |
" <td id=\"T_c2056_row7_col3\" class=\"data row7 col3\" >16.42ms</td>\n", | |
" <td id=\"T_c2056_row7_col4\" class=\"data row7 col4\" >294</td>\n", | |
" <td id=\"T_c2056_row7_col5\" class=\"data row7 col5\" >15.52 s</td>\n", | |
" <td id=\"T_c2056_row7_col6\" class=\"data row7 col6\" >8%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_c2056_row8_col0\" class=\"data row8 col0\" ></td>\n", | |
" <td id=\"T_c2056_row8_col1\" class=\"data row8 col1\" >draw</td>\n", | |
" <td id=\"T_c2056_row8_col2\" class=\"data row8 col2\" >22.18ms</td>\n", | |
" <td id=\"T_c2056_row8_col3\" class=\"data row8 col3\" >54.90ms</td>\n", | |
" <td id=\"T_c2056_row8_col4\" class=\"data row8 col4\" >294</td>\n", | |
" <td id=\"T_c2056_row8_col5\" class=\"data row8 col5\" >6.521 s</td>\n", | |
" <td id=\"T_c2056_row8_col6\" class=\"data row8 col6\" >3%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_c2056_row9_col0\" class=\"data row9 col0\" ></td>\n", | |
" <td id=\"T_c2056_row9_col1\" class=\"data row9 col1\" >zero_grad</td>\n", | |
" <td id=\"T_c2056_row9_col2\" class=\"data row9 col2\" >2.565ms</td>\n", | |
" <td id=\"T_c2056_row9_col3\" class=\"data row9 col3\" >1.181ms</td>\n", | |
" <td id=\"T_c2056_row9_col4\" class=\"data row9 col4\" >294</td>\n", | |
" <td id=\"T_c2056_row9_col5\" class=\"data row9 col5\" >754.1ms</td>\n", | |
" <td id=\"T_c2056_row9_col6\" class=\"data row9 col6\" >0%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_c2056_row10_col0\" class=\"data row10 col0\" ></td>\n", | |
" <td id=\"T_c2056_row10_col1\" class=\"data row10 col1\" >loss</td>\n", | |
" <td id=\"T_c2056_row10_col2\" class=\"data row10 col2\" >1.518ms</td>\n", | |
" <td id=\"T_c2056_row10_col3\" class=\"data row10 col3\" >1.213ms</td>\n", | |
" <td id=\"T_c2056_row10_col4\" class=\"data row10 col4\" >294</td>\n", | |
" <td id=\"T_c2056_row10_col5\" class=\"data row10 col5\" >446.2ms</td>\n", | |
" <td id=\"T_c2056_row10_col6\" class=\"data row10 col6\" >0%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_c2056_row11_col0\" class=\"data row11 col0\" >valid</td>\n", | |
" <td id=\"T_c2056_row11_col1\" class=\"data row11 col1\" >batch</td>\n", | |
" <td id=\"T_c2056_row11_col2\" class=\"data row11 col2\" >286.2ms</td>\n", | |
" <td id=\"T_c2056_row11_col3\" class=\"data row11 col3\" >197.6ms</td>\n", | |
" <td id=\"T_c2056_row11_col4\" class=\"data row11 col4\" >124</td>\n", | |
" <td id=\"T_c2056_row11_col5\" class=\"data row11 col5\" >35.49 s</td>\n", | |
" <td id=\"T_c2056_row11_col6\" class=\"data row11 col6\" >19%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_c2056_row12_col0\" class=\"data row12 col0\" ></td>\n", | |
" <td id=\"T_c2056_row12_col1\" class=\"data row12 col1\" >draw</td>\n", | |
" <td id=\"T_c2056_row12_col2\" class=\"data row12 col2\" >220.4ms</td>\n", | |
" <td id=\"T_c2056_row12_col3\" class=\"data row12 col3\" >193.4ms</td>\n", | |
" <td id=\"T_c2056_row12_col4\" class=\"data row12 col4\" >124</td>\n", | |
" <td id=\"T_c2056_row12_col5\" class=\"data row12 col5\" >27.33 s</td>\n", | |
" <td id=\"T_c2056_row12_col6\" class=\"data row12 col6\" >15%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_c2056_row13_col0\" class=\"data row13 col0\" ></td>\n", | |
" <td id=\"T_c2056_row13_col1\" class=\"data row13 col1\" >pred</td>\n", | |
" <td id=\"T_c2056_row13_col2\" class=\"data row13 col2\" >63.65ms</td>\n", | |
" <td id=\"T_c2056_row13_col3\" class=\"data row13 col3\" >13.46ms</td>\n", | |
" <td id=\"T_c2056_row13_col4\" class=\"data row13 col4\" >124</td>\n", | |
" <td id=\"T_c2056_row13_col5\" class=\"data row13 col5\" >7.893 s</td>\n", | |
" <td id=\"T_c2056_row13_col6\" class=\"data row13 col6\" >4%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_c2056_row14_col0\" class=\"data row14 col0\" ></td>\n", | |
" <td id=\"T_c2056_row14_col1\" class=\"data row14 col1\" >loss</td>\n", | |
" <td id=\"T_c2056_row14_col2\" class=\"data row14 col2\" >1.820ms</td>\n", | |
" <td id=\"T_c2056_row14_col3\" class=\"data row14 col3\" >2.196ms</td>\n", | |
" <td id=\"T_c2056_row14_col4\" class=\"data row14 col4\" >124</td>\n", | |
" <td id=\"T_c2056_row14_col5\" class=\"data row14 col5\" >225.7ms</td>\n", | |
" <td id=\"T_c2056_row14_col6\" class=\"data row14 col6\" >0%</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n" | |
] | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Remove All Possible fastai Bits That Might be Interfering" | |
], | |
"metadata": { | |
"id": "-dm_BtSuELmM" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"First, create a fastai transform that doesn't retain the tensor type" | |
], | |
"metadata": { | |
"id": "yPdi-I1hEfCY" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from fastcore.transform import DisplayedTransform, _is_tuple, retain_type\n", | |
"\n", | |
"class LoseTypeTransform(DisplayedTransform):\n", | |
" \"\"\n", | |
" def __init__(self, **kwargs): super().__init__(**kwargs)\n", | |
"\n", | |
" def _call(self, fn, x, split_idx=None, **kwargs):\n", | |
" if split_idx!=self.split_idx and self.split_idx is not None: return x\n", | |
" return self._do_call(getattr(self, fn), x, **kwargs)\n", | |
"\n", | |
" def _do_call(self, f, x, **kwargs):\n", | |
" if not _is_tuple(x):\n", | |
" if f is None: return x\n", | |
" return f(x, **kwargs)\n", | |
" return tuple(self._do_call(f, x_, **kwargs) for x_ in x)" | |
], | |
"metadata": { | |
"id": "LQIN7mOpECun" | |
}, | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Inherit from it for the ChannelsLastTfm and pass a torch.tensor to the model instead of TensorImage" | |
], | |
"metadata": { | |
"id": "v4XIj3esElF-" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from fastxtend.callback.channelslast import ChannelsLastTfm" | |
], | |
"metadata": { | |
"id": "OjV5m8lSFGVj" | |
}, | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class ChannelsLastTensor(LoseTypeTransform):\n", | |
" \"Sets image-like inputs to `channels_last` format. For use in ChannelsLastCallback\"\n", | |
" order = 110 # run after all other transforms if added to batch_tfms\n", | |
" def encodes(self, x:TensorImageBase|TensorMask):\n", | |
" return torch.tensor(x).to(memory_format=torch.channels_last)" | |
], | |
"metadata": { | |
"id": "uWe_J5fFE50f" | |
}, | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Otherwise fastxtend's channelslast callback looks the same, other then a switch for inputing torch.tensor verses tensorimage" | |
], | |
"metadata": { | |
"id": "QmGndPRiRymP" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class ChannelsLastCallback(Callback):\n", | |
" \"Channels last training using PyTorch's Channels Last Memory Format (beta)\"\n", | |
" order = MixedPrecision.order+1\n", | |
" def __init__(self, astensor=False):\n", | |
" if astensor: self._channels_last = Pipeline([ChannelsLastTensor()])\n", | |
" else: self._channels_last = Pipeline([ChannelsLastTfm()])\n", | |
"\n", | |
" def before_fit(self):\n", | |
" self.learn.model.to(memory_format=torch.channels_last)\n", | |
"\n", | |
" def before_batch(self):\n", | |
" self.learn.xb = self._channels_last(self.xb)" | |
], | |
"metadata": { | |
"id": "a9WecdxUEX_F" | |
}, | |
"execution_count": 10, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"@patch\n", | |
"def to_channelslast(self:Learner, to_fp16=True, astensor=True, **kwargs):\n", | |
" \"Set `Learner` and inputs to `channels_last` format and Mixed Precision by default\"\n", | |
" if to_fp16 and not hasattr(self, 'mixed_precision') and not hasattr(self, 'channels_last'): \n", | |
" return self.add_cbs([ChannelsLastCallback(astensor=astensor), MixedPrecision(**kwargs)])\n", | |
" elif not hasattr(self, 'channels_last'):\n", | |
" return self.add_cb(ChannelsLastCallback(astensor=astensor))" | |
], | |
"metadata": { | |
"id": "piSnsjp7EZRF" | |
}, | |
"execution_count": 11, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Warmup" | |
], | |
"metadata": { | |
"id": "SjhRhmwRE1Mm" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dls = get_dls(256, False, 64)\n", | |
"learn = Learner(dls, resnet50(num_classes=dls.c), \n", | |
" loss_func=nn.CrossEntropyLoss(), \n", | |
" opt_func=partial(OptimWrapper, opt=torch.optim.AdamW),\n", | |
" metrics=[Accuracy()]).to_channelslast()\n", | |
"\n", | |
"learn.fit_one_cycle(1, 3e-3)\n", | |
"\n", | |
"free_gpu_memory(learn, dls)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 135 | |
}, | |
"id": "5-MnrPkYE2EG", | |
"outputId": "f092f0dc-7cae-49da-cbab-056abcc5e5b0" | |
}, | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
"<style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
"</style>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>accuracy</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>1.818512</td>\n", | |
" <td>1.680830</td>\n", | |
" <td>0.428535</td>\n", | |
" <td>01:22</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", | |
" \"\"\"\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Test" | |
], | |
"metadata": { | |
"id": "OlfL5IFKE2s9" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dls = get_dls(256, False, 64)\n", | |
"learn = Learner(dls, resnet50(num_classes=dls.c), \n", | |
" loss_func=nn.CrossEntropyLoss(), \n", | |
" opt_func=partial(OptimWrapper, opt=torch.optim.AdamW),\n", | |
" metrics=[Accuracy()]).to_channelslast().profile()\n", | |
"\n", | |
"learn.fit_one_cycle(2, 3e-3)\n", | |
"\n", | |
"free_gpu_memory(learn, dls)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 687 | |
}, | |
"id": "F4hx0CEXFPut", | |
"outputId": "5c4772d1-365e-43d8-c2ce-a8a3bb730297" | |
}, | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
"<style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
"</style>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>accuracy</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>1.863661</td>\n", | |
" <td>1.996808</td>\n", | |
" <td>0.355669</td>\n", | |
" <td>01:17</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>1</td>\n", | |
" <td>1.320199</td>\n", | |
" <td>1.176207</td>\n", | |
" <td>0.615032</td>\n", | |
" <td>01:16</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", | |
" \"\"\"\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<pandas.io.formats.style.Styler at 0x7f763eeca4d0>" | |
], | |
"text/html": [ | |
"<style type=\"text/css\">\n", | |
"</style>\n", | |
"<table id=\"T_8e94a_\" class=\"dataframe\">\n", | |
" <caption>Simple Profiler Results</caption>\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <th class=\"col_heading level0 col0\" >Phase</th>\n", | |
" <th class=\"col_heading level0 col1\" >Action</th>\n", | |
" <th class=\"col_heading level0 col2\" >Mean Duration</th>\n", | |
" <th class=\"col_heading level0 col3\" >Duration Std Dev</th>\n", | |
" <th class=\"col_heading level0 col4\" >Number of Calls</th>\n", | |
" <th class=\"col_heading level0 col5\" >Total Time</th>\n", | |
" <th class=\"col_heading level0 col6\" >Percent of Total</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td id=\"T_8e94a_row0_col0\" class=\"data row0 col0\" >fit</td>\n", | |
" <td id=\"T_8e94a_row0_col1\" class=\"data row0 col1\" >fit</td>\n", | |
" <td id=\"T_8e94a_row0_col2\" class=\"data row0 col2\" >-</td>\n", | |
" <td id=\"T_8e94a_row0_col3\" class=\"data row0 col3\" >-</td>\n", | |
" <td id=\"T_8e94a_row0_col4\" class=\"data row0 col4\" >1</td>\n", | |
" <td id=\"T_8e94a_row0_col5\" class=\"data row0 col5\" >154.6 s</td>\n", | |
" <td id=\"T_8e94a_row0_col6\" class=\"data row0 col6\" >100%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_8e94a_row1_col0\" class=\"data row1 col0\" ></td>\n", | |
" <td id=\"T_8e94a_row1_col1\" class=\"data row1 col1\" >epoch</td>\n", | |
" <td id=\"T_8e94a_row1_col2\" class=\"data row1 col2\" >77.31 s</td>\n", | |
" <td id=\"T_8e94a_row1_col3\" class=\"data row1 col3\" >537.9ms</td>\n", | |
" <td id=\"T_8e94a_row1_col4\" class=\"data row1 col4\" >2</td>\n", | |
" <td id=\"T_8e94a_row1_col5\" class=\"data row1 col5\" >154.6 s</td>\n", | |
" <td id=\"T_8e94a_row1_col6\" class=\"data row1 col6\" >100%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_8e94a_row2_col0\" class=\"data row2 col0\" ></td>\n", | |
" <td id=\"T_8e94a_row2_col1\" class=\"data row2 col1\" >train</td>\n", | |
" <td id=\"T_8e94a_row2_col2\" class=\"data row2 col2\" >57.79 s</td>\n", | |
" <td id=\"T_8e94a_row2_col3\" class=\"data row2 col3\" >402.5ms</td>\n", | |
" <td id=\"T_8e94a_row2_col4\" class=\"data row2 col4\" >2</td>\n", | |
" <td id=\"T_8e94a_row2_col5\" class=\"data row2 col5\" >115.6 s</td>\n", | |
" <td id=\"T_8e94a_row2_col6\" class=\"data row2 col6\" >75%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_8e94a_row3_col0\" class=\"data row3 col0\" ></td>\n", | |
" <td id=\"T_8e94a_row3_col1\" class=\"data row3 col1\" >validate</td>\n", | |
" <td id=\"T_8e94a_row3_col2\" class=\"data row3 col2\" >19.51 s</td>\n", | |
" <td id=\"T_8e94a_row3_col3\" class=\"data row3 col3\" >132.2ms</td>\n", | |
" <td id=\"T_8e94a_row3_col4\" class=\"data row3 col4\" >2</td>\n", | |
" <td id=\"T_8e94a_row3_col5\" class=\"data row3 col5\" >39.01 s</td>\n", | |
" <td id=\"T_8e94a_row3_col6\" class=\"data row3 col6\" >25%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_8e94a_row4_col0\" class=\"data row4 col0\" >train</td>\n", | |
" <td id=\"T_8e94a_row4_col1\" class=\"data row4 col1\" >batch</td>\n", | |
" <td id=\"T_8e94a_row4_col2\" class=\"data row4 col2\" >368.2ms</td>\n", | |
" <td id=\"T_8e94a_row4_col3\" class=\"data row4 col3\" >97.57ms</td>\n", | |
" <td id=\"T_8e94a_row4_col4\" class=\"data row4 col4\" >294</td>\n", | |
" <td id=\"T_8e94a_row4_col5\" class=\"data row4 col5\" >108.3 s</td>\n", | |
" <td id=\"T_8e94a_row4_col6\" class=\"data row4 col6\" >70%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_8e94a_row5_col0\" class=\"data row5 col0\" ></td>\n", | |
" <td id=\"T_8e94a_row5_col1\" class=\"data row5 col1\" >step</td>\n", | |
" <td id=\"T_8e94a_row5_col2\" class=\"data row5 col2\" >172.8ms</td>\n", | |
" <td id=\"T_8e94a_row5_col3\" class=\"data row5 col3\" >22.10ms</td>\n", | |
" <td id=\"T_8e94a_row5_col4\" class=\"data row5 col4\" >294</td>\n", | |
" <td id=\"T_8e94a_row5_col5\" class=\"data row5 col5\" >50.81 s</td>\n", | |
" <td id=\"T_8e94a_row5_col6\" class=\"data row5 col6\" >33%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_8e94a_row6_col0\" class=\"data row6 col0\" ></td>\n", | |
" <td id=\"T_8e94a_row6_col1\" class=\"data row6 col1\" >draw</td>\n", | |
" <td id=\"T_8e94a_row6_col2\" class=\"data row6 col2\" >71.45ms</td>\n", | |
" <td id=\"T_8e94a_row6_col3\" class=\"data row6 col3\" >95.57ms</td>\n", | |
" <td id=\"T_8e94a_row6_col4\" class=\"data row6 col4\" >294</td>\n", | |
" <td id=\"T_8e94a_row6_col5\" class=\"data row6 col5\" >21.00 s</td>\n", | |
" <td id=\"T_8e94a_row6_col6\" class=\"data row6 col6\" >14%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_8e94a_row7_col0\" class=\"data row7 col0\" ></td>\n", | |
" <td id=\"T_8e94a_row7_col1\" class=\"data row7 col1\" >pred</td>\n", | |
" <td id=\"T_8e94a_row7_col2\" class=\"data row7 col2\" >61.65ms</td>\n", | |
" <td id=\"T_8e94a_row7_col3\" class=\"data row7 col3\" >15.55ms</td>\n", | |
" <td id=\"T_8e94a_row7_col4\" class=\"data row7 col4\" >294</td>\n", | |
" <td id=\"T_8e94a_row7_col5\" class=\"data row7 col5\" >18.13 s</td>\n", | |
" <td id=\"T_8e94a_row7_col6\" class=\"data row7 col6\" >12%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_8e94a_row8_col0\" class=\"data row8 col0\" ></td>\n", | |
" <td id=\"T_8e94a_row8_col1\" class=\"data row8 col1\" >backward</td>\n", | |
" <td id=\"T_8e94a_row8_col2\" class=\"data row8 col2\" >56.72ms</td>\n", | |
" <td id=\"T_8e94a_row8_col3\" class=\"data row8 col3\" >12.65ms</td>\n", | |
" <td id=\"T_8e94a_row8_col4\" class=\"data row8 col4\" >294</td>\n", | |
" <td id=\"T_8e94a_row8_col5\" class=\"data row8 col5\" >16.68 s</td>\n", | |
" <td id=\"T_8e94a_row8_col6\" class=\"data row8 col6\" >11%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_8e94a_row9_col0\" class=\"data row9 col0\" ></td>\n", | |
" <td id=\"T_8e94a_row9_col1\" class=\"data row9 col1\" >zero_grad</td>\n", | |
" <td id=\"T_8e94a_row9_col2\" class=\"data row9 col2\" >3.895ms</td>\n", | |
" <td id=\"T_8e94a_row9_col3\" class=\"data row9 col3\" >2.942ms</td>\n", | |
" <td id=\"T_8e94a_row9_col4\" class=\"data row9 col4\" >294</td>\n", | |
" <td id=\"T_8e94a_row9_col5\" class=\"data row9 col5\" >1.145 s</td>\n", | |
" <td id=\"T_8e94a_row9_col6\" class=\"data row9 col6\" >1%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_8e94a_row10_col0\" class=\"data row10 col0\" ></td>\n", | |
" <td id=\"T_8e94a_row10_col1\" class=\"data row10 col1\" >loss</td>\n", | |
" <td id=\"T_8e94a_row10_col2\" class=\"data row10 col2\" >1.345ms</td>\n", | |
" <td id=\"T_8e94a_row10_col3\" class=\"data row10 col3\" >1.974ms</td>\n", | |
" <td id=\"T_8e94a_row10_col4\" class=\"data row10 col4\" >294</td>\n", | |
" <td id=\"T_8e94a_row10_col5\" class=\"data row10 col5\" >395.3ms</td>\n", | |
" <td id=\"T_8e94a_row10_col6\" class=\"data row10 col6\" >0%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_8e94a_row11_col0\" class=\"data row11 col0\" >valid</td>\n", | |
" <td id=\"T_8e94a_row11_col1\" class=\"data row11 col1\" >batch</td>\n", | |
" <td id=\"T_8e94a_row11_col2\" class=\"data row11 col2\" >264.9ms</td>\n", | |
" <td id=\"T_8e94a_row11_col3\" class=\"data row11 col3\" >172.1ms</td>\n", | |
" <td id=\"T_8e94a_row11_col4\" class=\"data row11 col4\" >124</td>\n", | |
" <td id=\"T_8e94a_row11_col5\" class=\"data row11 col5\" >32.85 s</td>\n", | |
" <td id=\"T_8e94a_row11_col6\" class=\"data row11 col6\" >21%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_8e94a_row12_col0\" class=\"data row12 col0\" ></td>\n", | |
" <td id=\"T_8e94a_row12_col1\" class=\"data row12 col1\" >draw</td>\n", | |
" <td id=\"T_8e94a_row12_col2\" class=\"data row12 col2\" >223.5ms</td>\n", | |
" <td id=\"T_8e94a_row12_col3\" class=\"data row12 col3\" >171.4ms</td>\n", | |
" <td id=\"T_8e94a_row12_col4\" class=\"data row12 col4\" >124</td>\n", | |
" <td id=\"T_8e94a_row12_col5\" class=\"data row12 col5\" >27.72 s</td>\n", | |
" <td id=\"T_8e94a_row12_col6\" class=\"data row12 col6\" >18%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_8e94a_row13_col0\" class=\"data row13 col0\" ></td>\n", | |
" <td id=\"T_8e94a_row13_col1\" class=\"data row13 col1\" >pred</td>\n", | |
" <td id=\"T_8e94a_row13_col2\" class=\"data row13 col2\" >40.13ms</td>\n", | |
" <td id=\"T_8e94a_row13_col3\" class=\"data row13 col3\" >9.556ms</td>\n", | |
" <td id=\"T_8e94a_row13_col4\" class=\"data row13 col4\" >124</td>\n", | |
" <td id=\"T_8e94a_row13_col5\" class=\"data row13 col5\" >4.976 s</td>\n", | |
" <td id=\"T_8e94a_row13_col6\" class=\"data row13 col6\" >3%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_8e94a_row14_col0\" class=\"data row14 col0\" ></td>\n", | |
" <td id=\"T_8e94a_row14_col1\" class=\"data row14 col1\" >loss</td>\n", | |
" <td id=\"T_8e94a_row14_col2\" class=\"data row14 col2\" >945.4µs</td>\n", | |
" <td id=\"T_8e94a_row14_col3\" class=\"data row14 col3\" >1.121ms</td>\n", | |
" <td id=\"T_8e94a_row14_col4\" class=\"data row14 col4\" >124</td>\n", | |
" <td id=\"T_8e94a_row14_col5\" class=\"data row14 col5\" >117.2ms</td>\n", | |
" <td id=\"T_8e94a_row14_col6\" class=\"data row14 col6\" >0%</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n" | |
] | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Use a fastai optimizer, but tensor input, pytorch loss" | |
], | |
"metadata": { | |
"id": "DaezZT9eK3g-" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Warmup" | |
], | |
"metadata": { | |
"id": "8P-SVF8aLAbP" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dls = get_dls(256, False, 64)\n", | |
"learn = Learner(dls, resnet50(num_classes=dls.c), \n", | |
" loss_func=nn.CrossEntropyLoss(), \n", | |
" opt_func=Adam,\n", | |
" metrics=[Accuracy()]).to_channelslast()\n", | |
"\n", | |
"learn.fit_one_cycle(1, 3e-3)\n", | |
"\n", | |
"free_gpu_memory(learn, dls)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 135 | |
}, | |
"outputId": "d6785999-1946-4356-b38a-848e78594be6", | |
"id": "pU5AcaIILAbP" | |
}, | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
"<style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
"</style>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>accuracy</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>1.979175</td>\n", | |
" <td>1.794662</td>\n", | |
" <td>0.380382</td>\n", | |
" <td>01:15</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", | |
" \"\"\"\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Test" | |
], | |
"metadata": { | |
"id": "LWxeP4AULAbQ" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dls = get_dls(256, False, 64)\n", | |
"learn = Learner(dls, resnet50(num_classes=dls.c), \n", | |
" loss_func=nn.CrossEntropyLoss(), \n", | |
" opt_func=Adam,\n", | |
" metrics=[Accuracy()]).to_channelslast().profile()\n", | |
"\n", | |
"learn.fit_one_cycle(2, 3e-3)\n", | |
"\n", | |
"free_gpu_memory(learn, dls)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 687 | |
}, | |
"id": "UbhUwy9bLAbQ", | |
"outputId": "86da8255-0c61-4b7b-b4ef-f1c5b0b4171b" | |
}, | |
"execution_count": 15, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
"<style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
"</style>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>accuracy</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>1.852839</td>\n", | |
" <td>2.054055</td>\n", | |
" <td>0.332739</td>\n", | |
" <td>01:15</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>1</td>\n", | |
" <td>1.300486</td>\n", | |
" <td>1.145100</td>\n", | |
" <td>0.634650</td>\n", | |
" <td>01:18</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", | |
" \"\"\"\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<pandas.io.formats.style.Styler at 0x7f763eff1750>" | |
], | |
"text/html": [ | |
"<style type=\"text/css\">\n", | |
"</style>\n", | |
"<table id=\"T_e15fa_\" class=\"dataframe\">\n", | |
" <caption>Simple Profiler Results</caption>\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <th class=\"col_heading level0 col0\" >Phase</th>\n", | |
" <th class=\"col_heading level0 col1\" >Action</th>\n", | |
" <th class=\"col_heading level0 col2\" >Mean Duration</th>\n", | |
" <th class=\"col_heading level0 col3\" >Duration Std Dev</th>\n", | |
" <th class=\"col_heading level0 col4\" >Number of Calls</th>\n", | |
" <th class=\"col_heading level0 col5\" >Total Time</th>\n", | |
" <th class=\"col_heading level0 col6\" >Percent of Total</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td id=\"T_e15fa_row0_col0\" class=\"data row0 col0\" >fit</td>\n", | |
" <td id=\"T_e15fa_row0_col1\" class=\"data row0 col1\" >fit</td>\n", | |
" <td id=\"T_e15fa_row0_col2\" class=\"data row0 col2\" >-</td>\n", | |
" <td id=\"T_e15fa_row0_col3\" class=\"data row0 col3\" >-</td>\n", | |
" <td id=\"T_e15fa_row0_col4\" class=\"data row0 col4\" >1</td>\n", | |
" <td id=\"T_e15fa_row0_col5\" class=\"data row0 col5\" >154.5 s</td>\n", | |
" <td id=\"T_e15fa_row0_col6\" class=\"data row0 col6\" >100%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e15fa_row1_col0\" class=\"data row1 col0\" ></td>\n", | |
" <td id=\"T_e15fa_row1_col1\" class=\"data row1 col1\" >epoch</td>\n", | |
" <td id=\"T_e15fa_row1_col2\" class=\"data row1 col2\" >77.27 s</td>\n", | |
" <td id=\"T_e15fa_row1_col3\" class=\"data row1 col3\" >1.341 s</td>\n", | |
" <td id=\"T_e15fa_row1_col4\" class=\"data row1 col4\" >2</td>\n", | |
" <td id=\"T_e15fa_row1_col5\" class=\"data row1 col5\" >154.5 s</td>\n", | |
" <td id=\"T_e15fa_row1_col6\" class=\"data row1 col6\" >100%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e15fa_row2_col0\" class=\"data row2 col0\" ></td>\n", | |
" <td id=\"T_e15fa_row2_col1\" class=\"data row2 col1\" >train</td>\n", | |
" <td id=\"T_e15fa_row2_col2\" class=\"data row2 col2\" >57.12 s</td>\n", | |
" <td id=\"T_e15fa_row2_col3\" class=\"data row2 col3\" >1.114 s</td>\n", | |
" <td id=\"T_e15fa_row2_col4\" class=\"data row2 col4\" >2</td>\n", | |
" <td id=\"T_e15fa_row2_col5\" class=\"data row2 col5\" >114.2 s</td>\n", | |
" <td id=\"T_e15fa_row2_col6\" class=\"data row2 col6\" >74%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e15fa_row3_col0\" class=\"data row3 col0\" ></td>\n", | |
" <td id=\"T_e15fa_row3_col1\" class=\"data row3 col1\" >validate</td>\n", | |
" <td id=\"T_e15fa_row3_col2\" class=\"data row3 col2\" >20.14 s</td>\n", | |
" <td id=\"T_e15fa_row3_col3\" class=\"data row3 col3\" >230.3ms</td>\n", | |
" <td id=\"T_e15fa_row3_col4\" class=\"data row3 col4\" >2</td>\n", | |
" <td id=\"T_e15fa_row3_col5\" class=\"data row3 col5\" >40.29 s</td>\n", | |
" <td id=\"T_e15fa_row3_col6\" class=\"data row3 col6\" >26%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e15fa_row4_col0\" class=\"data row4 col0\" >train</td>\n", | |
" <td id=\"T_e15fa_row4_col1\" class=\"data row4 col1\" >batch</td>\n", | |
" <td id=\"T_e15fa_row4_col2\" class=\"data row4 col2\" >379.6ms</td>\n", | |
" <td id=\"T_e15fa_row4_col3\" class=\"data row4 col3\" >113.8ms</td>\n", | |
" <td id=\"T_e15fa_row4_col4\" class=\"data row4 col4\" >294</td>\n", | |
" <td id=\"T_e15fa_row4_col5\" class=\"data row4 col5\" >111.6 s</td>\n", | |
" <td id=\"T_e15fa_row4_col6\" class=\"data row4 col6\" >72%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e15fa_row5_col0\" class=\"data row5 col0\" ></td>\n", | |
" <td id=\"T_e15fa_row5_col1\" class=\"data row5 col1\" >step</td>\n", | |
" <td id=\"T_e15fa_row5_col2\" class=\"data row5 col2\" >182.0ms</td>\n", | |
" <td id=\"T_e15fa_row5_col3\" class=\"data row5 col3\" >24.35ms</td>\n", | |
" <td id=\"T_e15fa_row5_col4\" class=\"data row5 col4\" >294</td>\n", | |
" <td id=\"T_e15fa_row5_col5\" class=\"data row5 col5\" >53.50 s</td>\n", | |
" <td id=\"T_e15fa_row5_col6\" class=\"data row5 col6\" >35%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e15fa_row6_col0\" class=\"data row6 col0\" ></td>\n", | |
" <td id=\"T_e15fa_row6_col1\" class=\"data row6 col1\" >draw</td>\n", | |
" <td id=\"T_e15fa_row6_col2\" class=\"data row6 col2\" >87.42ms</td>\n", | |
" <td id=\"T_e15fa_row6_col3\" class=\"data row6 col3\" >113.3ms</td>\n", | |
" <td id=\"T_e15fa_row6_col4\" class=\"data row6 col4\" >294</td>\n", | |
" <td id=\"T_e15fa_row6_col5\" class=\"data row6 col5\" >25.70 s</td>\n", | |
" <td id=\"T_e15fa_row6_col6\" class=\"data row6 col6\" >17%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e15fa_row7_col0\" class=\"data row7 col0\" ></td>\n", | |
" <td id=\"T_e15fa_row7_col1\" class=\"data row7 col1\" >backward</td>\n", | |
" <td id=\"T_e15fa_row7_col2\" class=\"data row7 col2\" >56.09ms</td>\n", | |
" <td id=\"T_e15fa_row7_col3\" class=\"data row7 col3\" >13.87ms</td>\n", | |
" <td id=\"T_e15fa_row7_col4\" class=\"data row7 col4\" >294</td>\n", | |
" <td id=\"T_e15fa_row7_col5\" class=\"data row7 col5\" >16.49 s</td>\n", | |
" <td id=\"T_e15fa_row7_col6\" class=\"data row7 col6\" >11%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e15fa_row8_col0\" class=\"data row8 col0\" ></td>\n", | |
" <td id=\"T_e15fa_row8_col1\" class=\"data row8 col1\" >pred</td>\n", | |
" <td id=\"T_e15fa_row8_col2\" class=\"data row8 col2\" >48.32ms</td>\n", | |
" <td id=\"T_e15fa_row8_col3\" class=\"data row8 col3\" >13.84ms</td>\n", | |
" <td id=\"T_e15fa_row8_col4\" class=\"data row8 col4\" >294</td>\n", | |
" <td id=\"T_e15fa_row8_col5\" class=\"data row8 col5\" >14.21 s</td>\n", | |
" <td id=\"T_e15fa_row8_col6\" class=\"data row8 col6\" >9%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e15fa_row9_col0\" class=\"data row9 col0\" ></td>\n", | |
" <td id=\"T_e15fa_row9_col1\" class=\"data row9 col1\" >zero_grad</td>\n", | |
" <td id=\"T_e15fa_row9_col2\" class=\"data row9 col2\" >4.106ms</td>\n", | |
" <td id=\"T_e15fa_row9_col3\" class=\"data row9 col3\" >3.017ms</td>\n", | |
" <td id=\"T_e15fa_row9_col4\" class=\"data row9 col4\" >294</td>\n", | |
" <td id=\"T_e15fa_row9_col5\" class=\"data row9 col5\" >1.207 s</td>\n", | |
" <td id=\"T_e15fa_row9_col6\" class=\"data row9 col6\" >1%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e15fa_row10_col0\" class=\"data row10 col0\" ></td>\n", | |
" <td id=\"T_e15fa_row10_col1\" class=\"data row10 col1\" >loss</td>\n", | |
" <td id=\"T_e15fa_row10_col2\" class=\"data row10 col2\" >1.425ms</td>\n", | |
" <td id=\"T_e15fa_row10_col3\" class=\"data row10 col3\" >2.092ms</td>\n", | |
" <td id=\"T_e15fa_row10_col4\" class=\"data row10 col4\" >294</td>\n", | |
" <td id=\"T_e15fa_row10_col5\" class=\"data row10 col5\" >418.9ms</td>\n", | |
" <td id=\"T_e15fa_row10_col6\" class=\"data row10 col6\" >0%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e15fa_row11_col0\" class=\"data row11 col0\" >valid</td>\n", | |
" <td id=\"T_e15fa_row11_col1\" class=\"data row11 col1\" >batch</td>\n", | |
" <td id=\"T_e15fa_row11_col2\" class=\"data row11 col2\" >276.2ms</td>\n", | |
" <td id=\"T_e15fa_row11_col3\" class=\"data row11 col3\" >186.7ms</td>\n", | |
" <td id=\"T_e15fa_row11_col4\" class=\"data row11 col4\" >124</td>\n", | |
" <td id=\"T_e15fa_row11_col5\" class=\"data row11 col5\" >34.25 s</td>\n", | |
" <td id=\"T_e15fa_row11_col6\" class=\"data row11 col6\" >22%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e15fa_row12_col0\" class=\"data row12 col0\" ></td>\n", | |
" <td id=\"T_e15fa_row12_col1\" class=\"data row12 col1\" >draw</td>\n", | |
" <td id=\"T_e15fa_row12_col2\" class=\"data row12 col2\" >233.0ms</td>\n", | |
" <td id=\"T_e15fa_row12_col3\" class=\"data row12 col3\" >183.6ms</td>\n", | |
" <td id=\"T_e15fa_row12_col4\" class=\"data row12 col4\" >124</td>\n", | |
" <td id=\"T_e15fa_row12_col5\" class=\"data row12 col5\" >28.90 s</td>\n", | |
" <td id=\"T_e15fa_row12_col6\" class=\"data row12 col6\" >19%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e15fa_row13_col0\" class=\"data row13 col0\" ></td>\n", | |
" <td id=\"T_e15fa_row13_col1\" class=\"data row13 col1\" >pred</td>\n", | |
" <td id=\"T_e15fa_row13_col2\" class=\"data row13 col2\" >41.59ms</td>\n", | |
" <td id=\"T_e15fa_row13_col3\" class=\"data row13 col3\" >12.00ms</td>\n", | |
" <td id=\"T_e15fa_row13_col4\" class=\"data row13 col4\" >124</td>\n", | |
" <td id=\"T_e15fa_row13_col5\" class=\"data row13 col5\" >5.158 s</td>\n", | |
" <td id=\"T_e15fa_row13_col6\" class=\"data row13 col6\" >3%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_e15fa_row14_col0\" class=\"data row14 col0\" ></td>\n", | |
" <td id=\"T_e15fa_row14_col1\" class=\"data row14 col1\" >loss</td>\n", | |
" <td id=\"T_e15fa_row14_col2\" class=\"data row14 col2\" >1.344ms</td>\n", | |
" <td id=\"T_e15fa_row14_col3\" class=\"data row14 col3\" >2.000ms</td>\n", | |
" <td id=\"T_e15fa_row14_col4\" class=\"data row14 col4\" >124</td>\n", | |
" <td id=\"T_e15fa_row14_col5\" class=\"data row14 col5\" >166.7ms</td>\n", | |
" <td id=\"T_e15fa_row14_col6\" class=\"data row14 col6\" >0%</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n" | |
] | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Use a fastai optimizer & fastai loss, but tensor input" | |
], | |
"metadata": { | |
"id": "J_V3brgbMPPF" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Warmup" | |
], | |
"metadata": { | |
"id": "29C-NpRYMPPX" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dls = get_dls(256, False, 64)\n", | |
"learn = Learner(dls, resnet50(num_classes=dls.c), \n", | |
" loss_func=CrossEntropyLossFlat(), \n", | |
" opt_func=Adam,\n", | |
" metrics=[Accuracy()]).to_channelslast()\n", | |
"\n", | |
"learn.fit_one_cycle(1, 3e-3)\n", | |
"\n", | |
"free_gpu_memory(learn, dls)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 135 | |
}, | |
"outputId": "42af2a10-2f51-4318-dbe3-edb581a69599", | |
"id": "CcmbGsBaMPPX" | |
}, | |
"execution_count": 16, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
"<style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
"</style>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>accuracy</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>1.891318</td>\n", | |
" <td>1.795799</td>\n", | |
" <td>0.423694</td>\n", | |
" <td>01:15</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", | |
" \"\"\"\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Test" | |
], | |
"metadata": { | |
"id": "RHCP86D_MPPY" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dls = get_dls(256, False, 64)\n", | |
"learn = Learner(dls, resnet50(num_classes=dls.c), \n", | |
" loss_func=CrossEntropyLossFlat(), \n", | |
" opt_func=Adam,\n", | |
" metrics=[Accuracy()]).to_channelslast().profile()\n", | |
"\n", | |
"learn.fit_one_cycle(2, 3e-3)\n", | |
"\n", | |
"free_gpu_memory(learn, dls)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 687 | |
}, | |
"outputId": "21124eb9-593b-4994-9fc9-914e888bf166", | |
"id": "OIpimLlxMPPY" | |
}, | |
"execution_count": 17, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
"<style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
"</style>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>accuracy</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>1.887546</td>\n", | |
" <td>1.747342</td>\n", | |
" <td>0.409427</td>\n", | |
" <td>01:15</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>1</td>\n", | |
" <td>1.378880</td>\n", | |
" <td>1.274395</td>\n", | |
" <td>0.592611</td>\n", | |
" <td>01:18</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", | |
" \"\"\"\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<pandas.io.formats.style.Styler at 0x7f763efedc10>" | |
], | |
"text/html": [ | |
"<style type=\"text/css\">\n", | |
"</style>\n", | |
"<table id=\"T_435ce_\" class=\"dataframe\">\n", | |
" <caption>Simple Profiler Results</caption>\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <th class=\"col_heading level0 col0\" >Phase</th>\n", | |
" <th class=\"col_heading level0 col1\" >Action</th>\n", | |
" <th class=\"col_heading level0 col2\" >Mean Duration</th>\n", | |
" <th class=\"col_heading level0 col3\" >Duration Std Dev</th>\n", | |
" <th class=\"col_heading level0 col4\" >Number of Calls</th>\n", | |
" <th class=\"col_heading level0 col5\" >Total Time</th>\n", | |
" <th class=\"col_heading level0 col6\" >Percent of Total</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td id=\"T_435ce_row0_col0\" class=\"data row0 col0\" >fit</td>\n", | |
" <td id=\"T_435ce_row0_col1\" class=\"data row0 col1\" >fit</td>\n", | |
" <td id=\"T_435ce_row0_col2\" class=\"data row0 col2\" >-</td>\n", | |
" <td id=\"T_435ce_row0_col3\" class=\"data row0 col3\" >-</td>\n", | |
" <td id=\"T_435ce_row0_col4\" class=\"data row0 col4\" >1</td>\n", | |
" <td id=\"T_435ce_row0_col5\" class=\"data row0 col5\" >153.7 s</td>\n", | |
" <td id=\"T_435ce_row0_col6\" class=\"data row0 col6\" >100%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_435ce_row1_col0\" class=\"data row1 col0\" ></td>\n", | |
" <td id=\"T_435ce_row1_col1\" class=\"data row1 col1\" >epoch</td>\n", | |
" <td id=\"T_435ce_row1_col2\" class=\"data row1 col2\" >76.85 s</td>\n", | |
" <td id=\"T_435ce_row1_col3\" class=\"data row1 col3\" >1.211 s</td>\n", | |
" <td id=\"T_435ce_row1_col4\" class=\"data row1 col4\" >2</td>\n", | |
" <td id=\"T_435ce_row1_col5\" class=\"data row1 col5\" >153.7 s</td>\n", | |
" <td id=\"T_435ce_row1_col6\" class=\"data row1 col6\" >100%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_435ce_row2_col0\" class=\"data row2 col0\" ></td>\n", | |
" <td id=\"T_435ce_row2_col1\" class=\"data row2 col1\" >train</td>\n", | |
" <td id=\"T_435ce_row2_col2\" class=\"data row2 col2\" >56.61 s</td>\n", | |
" <td id=\"T_435ce_row2_col3\" class=\"data row2 col3\" >356.4ms</td>\n", | |
" <td id=\"T_435ce_row2_col4\" class=\"data row2 col4\" >2</td>\n", | |
" <td id=\"T_435ce_row2_col5\" class=\"data row2 col5\" >113.2 s</td>\n", | |
" <td id=\"T_435ce_row2_col6\" class=\"data row2 col6\" >74%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_435ce_row3_col0\" class=\"data row3 col0\" ></td>\n", | |
" <td id=\"T_435ce_row3_col1\" class=\"data row3 col1\" >validate</td>\n", | |
" <td id=\"T_435ce_row3_col2\" class=\"data row3 col2\" >20.23 s</td>\n", | |
" <td id=\"T_435ce_row3_col3\" class=\"data row3 col3\" >857.5ms</td>\n", | |
" <td id=\"T_435ce_row3_col4\" class=\"data row3 col4\" >2</td>\n", | |
" <td id=\"T_435ce_row3_col5\" class=\"data row3 col5\" >40.45 s</td>\n", | |
" <td id=\"T_435ce_row3_col6\" class=\"data row3 col6\" >26%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_435ce_row4_col0\" class=\"data row4 col0\" >train</td>\n", | |
" <td id=\"T_435ce_row4_col1\" class=\"data row4 col1\" >batch</td>\n", | |
" <td id=\"T_435ce_row4_col2\" class=\"data row4 col2\" >375.5ms</td>\n", | |
" <td id=\"T_435ce_row4_col3\" class=\"data row4 col3\" >114.0ms</td>\n", | |
" <td id=\"T_435ce_row4_col4\" class=\"data row4 col4\" >294</td>\n", | |
" <td id=\"T_435ce_row4_col5\" class=\"data row4 col5\" >110.4 s</td>\n", | |
" <td id=\"T_435ce_row4_col6\" class=\"data row4 col6\" >72%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_435ce_row5_col0\" class=\"data row5 col0\" ></td>\n", | |
" <td id=\"T_435ce_row5_col1\" class=\"data row5 col1\" >step</td>\n", | |
" <td id=\"T_435ce_row5_col2\" class=\"data row5 col2\" >177.2ms</td>\n", | |
" <td id=\"T_435ce_row5_col3\" class=\"data row5 col3\" >24.03ms</td>\n", | |
" <td id=\"T_435ce_row5_col4\" class=\"data row5 col4\" >294</td>\n", | |
" <td id=\"T_435ce_row5_col5\" class=\"data row5 col5\" >52.11 s</td>\n", | |
" <td id=\"T_435ce_row5_col6\" class=\"data row5 col6\" >34%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_435ce_row6_col0\" class=\"data row6 col0\" ></td>\n", | |
" <td id=\"T_435ce_row6_col1\" class=\"data row6 col1\" >draw</td>\n", | |
" <td id=\"T_435ce_row6_col2\" class=\"data row6 col2\" >83.93ms</td>\n", | |
" <td id=\"T_435ce_row6_col3\" class=\"data row6 col3\" >113.6ms</td>\n", | |
" <td id=\"T_435ce_row6_col4\" class=\"data row6 col4\" >294</td>\n", | |
" <td id=\"T_435ce_row6_col5\" class=\"data row6 col5\" >24.68 s</td>\n", | |
" <td id=\"T_435ce_row6_col6\" class=\"data row6 col6\" >16%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_435ce_row7_col0\" class=\"data row7 col0\" ></td>\n", | |
" <td id=\"T_435ce_row7_col1\" class=\"data row7 col1\" >backward</td>\n", | |
" <td id=\"T_435ce_row7_col2\" class=\"data row7 col2\" >57.99ms</td>\n", | |
" <td id=\"T_435ce_row7_col3\" class=\"data row7 col3\" >14.16ms</td>\n", | |
" <td id=\"T_435ce_row7_col4\" class=\"data row7 col4\" >294</td>\n", | |
" <td id=\"T_435ce_row7_col5\" class=\"data row7 col5\" >17.05 s</td>\n", | |
" <td id=\"T_435ce_row7_col6\" class=\"data row7 col6\" >11%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_435ce_row8_col0\" class=\"data row8 col0\" ></td>\n", | |
" <td id=\"T_435ce_row8_col1\" class=\"data row8 col1\" >pred</td>\n", | |
" <td id=\"T_435ce_row8_col2\" class=\"data row8 col2\" >49.67ms</td>\n", | |
" <td id=\"T_435ce_row8_col3\" class=\"data row8 col3\" >13.58ms</td>\n", | |
" <td id=\"T_435ce_row8_col4\" class=\"data row8 col4\" >294</td>\n", | |
" <td id=\"T_435ce_row8_col5\" class=\"data row8 col5\" >14.60 s</td>\n", | |
" <td id=\"T_435ce_row8_col6\" class=\"data row8 col6\" >10%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_435ce_row9_col0\" class=\"data row9 col0\" ></td>\n", | |
" <td id=\"T_435ce_row9_col1\" class=\"data row9 col1\" >zero_grad</td>\n", | |
" <td id=\"T_435ce_row9_col2\" class=\"data row9 col2\" >4.312ms</td>\n", | |
" <td id=\"T_435ce_row9_col3\" class=\"data row9 col3\" >2.997ms</td>\n", | |
" <td id=\"T_435ce_row9_col4\" class=\"data row9 col4\" >294</td>\n", | |
" <td id=\"T_435ce_row9_col5\" class=\"data row9 col5\" >1.268 s</td>\n", | |
" <td id=\"T_435ce_row9_col6\" class=\"data row9 col6\" >1%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_435ce_row10_col0\" class=\"data row10 col0\" ></td>\n", | |
" <td id=\"T_435ce_row10_col1\" class=\"data row10 col1\" >loss</td>\n", | |
" <td id=\"T_435ce_row10_col2\" class=\"data row10 col2\" >2.058ms</td>\n", | |
" <td id=\"T_435ce_row10_col3\" class=\"data row10 col3\" >2.189ms</td>\n", | |
" <td id=\"T_435ce_row10_col4\" class=\"data row10 col4\" >294</td>\n", | |
" <td id=\"T_435ce_row10_col5\" class=\"data row10 col5\" >605.1ms</td>\n", | |
" <td id=\"T_435ce_row10_col6\" class=\"data row10 col6\" >0%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_435ce_row11_col0\" class=\"data row11 col0\" >valid</td>\n", | |
" <td id=\"T_435ce_row11_col1\" class=\"data row11 col1\" >batch</td>\n", | |
" <td id=\"T_435ce_row11_col2\" class=\"data row11 col2\" >279.1ms</td>\n", | |
" <td id=\"T_435ce_row11_col3\" class=\"data row11 col3\" >192.9ms</td>\n", | |
" <td id=\"T_435ce_row11_col4\" class=\"data row11 col4\" >124</td>\n", | |
" <td id=\"T_435ce_row11_col5\" class=\"data row11 col5\" >34.61 s</td>\n", | |
" <td id=\"T_435ce_row11_col6\" class=\"data row11 col6\" >23%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_435ce_row12_col0\" class=\"data row12 col0\" ></td>\n", | |
" <td id=\"T_435ce_row12_col1\" class=\"data row12 col1\" >draw</td>\n", | |
" <td id=\"T_435ce_row12_col2\" class=\"data row12 col2\" >234.2ms</td>\n", | |
" <td id=\"T_435ce_row12_col3\" class=\"data row12 col3\" >191.4ms</td>\n", | |
" <td id=\"T_435ce_row12_col4\" class=\"data row12 col4\" >124</td>\n", | |
" <td id=\"T_435ce_row12_col5\" class=\"data row12 col5\" >29.04 s</td>\n", | |
" <td id=\"T_435ce_row12_col6\" class=\"data row12 col6\" >19%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_435ce_row13_col0\" class=\"data row13 col0\" ></td>\n", | |
" <td id=\"T_435ce_row13_col1\" class=\"data row13 col1\" >pred</td>\n", | |
" <td id=\"T_435ce_row13_col2\" class=\"data row13 col2\" >42.37ms</td>\n", | |
" <td id=\"T_435ce_row13_col3\" class=\"data row13 col3\" >11.61ms</td>\n", | |
" <td id=\"T_435ce_row13_col4\" class=\"data row13 col4\" >124</td>\n", | |
" <td id=\"T_435ce_row13_col5\" class=\"data row13 col5\" >5.254 s</td>\n", | |
" <td id=\"T_435ce_row13_col6\" class=\"data row13 col6\" >3%</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_435ce_row14_col0\" class=\"data row14 col0\" ></td>\n", | |
" <td id=\"T_435ce_row14_col1\" class=\"data row14 col1\" >loss</td>\n", | |
" <td id=\"T_435ce_row14_col2\" class=\"data row14 col2\" >2.227ms</td>\n", | |
" <td id=\"T_435ce_row14_col3\" class=\"data row14 col3\" >2.896ms</td>\n", | |
" <td id=\"T_435ce_row14_col4\" class=\"data row14 col4\" >124</td>\n", | |
" <td id=\"T_435ce_row14_col5\" class=\"data row14 col5\" >276.1ms</td>\n", | |
" <td id=\"T_435ce_row14_col6\" class=\"data row14 col6\" >0%</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n" | |
] | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Cast to TensorBase instead of torch.tensor or TensorImage" | |
], | |
"metadata": { | |
"id": "nhSzdwgANYP2" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class ChannelsLastTensor(LoseTypeTransform):\n", | |
" \"Sets image-like inputs to `channels_last` format. For use in ChannelsLastCallback\"\n", | |
" order = 110 # run after all other transforms if added to batch_tfms\n", | |
" def encodes(self, x:TensorImageBase|TensorMask):\n", | |
" return TensorBase(x).to(memory_format=torch.channels_last)" | |
], | |
"metadata": { | |
"id": "l92GVy9zNWh0" | |
}, | |
"execution_count": 18, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class ChannelsLastCallback(Callback):\n", | |
" \"Channels last training using PyTorch's Channels Last Memory Format (beta)\"\n", | |
" order = MixedPrecision.order+1\n", | |
" def __init__(self, astensor=False):\n", | |
" if astensor: self._channels_last = Pipeline([ChannelsLastTensor()])\n", | |
" else: self._channels_last = Pipeline([ChannelsLastTfm()])\n", | |
"\n", | |
" def before_fit(self):\n", | |
" self.learn.model.to(memory_format=torch.channels_last)\n", | |
"\n", | |
" def before_batch(self):\n", | |
" self.learn.xb = self._channels_last(self.xb)" | |
], | |
"metadata": { | |
"id": "ve-xXsTBNWh9" | |
}, | |
"execution_count": 19, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Warmup" | |
], | |
"metadata": { | |
"id": "mNhlK0t5Ni3d" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dls = get_dls(256, False, 64)\n", | |
"learn = Learner(dls, resnet50(num_classes=dls.c), \n", | |
" loss_func=CrossEntropyLossFlat(), \n", | |
" opt_func=Adam,\n", | |
" metrics=[Accuracy()]).to_channelslast()\n", | |
"\n", | |
"learn.fit_one_cycle(1, 3e-3)\n", | |
"\n", | |
"free_gpu_memory(learn, dls)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 81 | |
}, | |
"outputId": "fd7925eb-6023-4561-8c84-0588c66742d0", | |
"id": "1K0ZxQcGNi3e" | |
}, | |
"execution_count": 20, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
"<style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
"</style>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>accuracy</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>1.829382</td>\n", | |
" <td>1.654897</td>\n", | |
" <td>0.438726</td>\n", | |
" <td>01:35</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
] | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Test" | |
], | |
"metadata": { | |
"id": "B-mI8ifUNi3g" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dls = get_dls(256, False, 64)\n", | |
"learn = Learner(dls, resnet50(num_classes=dls.c), \n", | |
" loss_func=CrossEntropyLossFlat(), \n", | |
" opt_func=Adam,\n", | |
" metrics=[Accuracy()]).to_channelslast().profile()\n", | |
"\n", | |
"learn.fit_one_cycle(2, 3e-3)\n", | |
"\n", | |
"free_gpu_memory(learn, dls)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 632 | |
}, | |
"outputId": "33c960de-0f87-452d-ba5f-59a015f54b5f", | |
"id": "EXgwBMH4Ni3g" | |
}, | |
"execution_count": 21, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
"<style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
"</style>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>accuracy</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>1.935964</td>\n", | |
" <td>2.968615</td>\n", | |
" <td>0.291720</td>\n", | |
" <td>01:33</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>1</td>\n", | |
" <td>1.425070</td>\n", | |
" <td>1.298448</td>\n", | |
" <td>0.583439</td>\n", | |
" <td>01:33</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data& |