Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save warner-benjamin/d0f0836924d27c205a2e9a0ef98e3665 to your computer and use it in GitHub Desktop.
Save warner-benjamin/d0f0836924d27c205a2e9a0ef98e3665 to your computer and use it in GitHub Desktop.
Problem with channels_last and fastai custom tensors
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Problem with channels_last and fastai custom tensors.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# Problem with channels_last and fastai custom tensors\n",
"\n",
"This notebook contains expiraments showing that passing fastai's custom tensor classes (TensorBase & TensorImage) results in a degredation of channels last performance. The only change needed to get the expected channels last performance increase is to cast a TensorBase derived input to torch.tensor and then channels last trains faster as expected."
],
"metadata": {
"id": "DgDxBVq9TTpe"
}
},
{
"cell_type": "code",
"source": [
"! nvidia-smi"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WvJkCesSEx2V",
"outputId": "12ea5700-85f2-4ee6-8acf-e9c183618d8b"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Wed Jun 8 16:08:37 2022 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 62C P8 11W / 70W | 0MiB / 15109MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u-D43aOA-N_z"
},
"outputs": [],
"source": [
"! pip install fastcore fastai fastxtend[vision] -U"
]
},
{
"cell_type": "code",
"source": [
"! pip uninstall pillow -y\n",
"! CC=\"cc -mavx2\" pip install -U --force-reinstall pillow-simd"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "m0FrzR7EHzBZ",
"outputId": "6aee40bb-8ed0-4af1-c919-df2dd629ed12"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Found existing installation: Pillow 7.1.2\n",
"Uninstalling Pillow-7.1.2:\n",
" Successfully uninstalled Pillow-7.1.2\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting pillow-simd\n",
" Downloading Pillow-SIMD-9.0.0.post1.tar.gz (849 kB)\n",
"\u001b[K |████████████████████████████████| 849 kB 7.6 MB/s \n",
"\u001b[?25hBuilding wheels for collected packages: pillow-simd\n",
" Building wheel for pillow-simd (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for pillow-simd: filename=Pillow_SIMD-9.0.0.post1-cp37-cp37m-linux_x86_64.whl size=1255565 sha256=5da71e6b46623d6920cc321db69258300c8543cdab25e5fc5711d397036233dc\n",
" Stored in directory: /root/.cache/pip/wheels/9b/3f/fd/ca7133b4f7f509eb1de652bb8c128529a0c04b25ef0a6c535a\n",
"Successfully built pillow-simd\n",
"Installing collected packages: pillow-simd\n",
"Successfully installed pillow-simd-9.0.0.post1\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from __future__ import annotations\n",
"from fastai.vision.all import *\n",
"from fastxtend.vision.all import *\n",
"from fastxtend.callback import simpleprofiler"
],
"metadata": {
"id": "55BRyAzH-S-f"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# DataLoaders"
],
"metadata": {
"id": "J1tiDmSY-sxr"
}
},
{
"cell_type": "code",
"source": [
"imagewoof_stats = ([0.496,0.461,0.399],[0.257,0.249,0.258])\n",
"imagenette_stats = ([0.465,0.458,0.429],[0.285,0.28,0.301])\n",
"\n",
"def get_dls(size:int, woof:bool, bs:int, sh:float=0., augs:list=None, workers:int=None, stats:bool=True) -> DataLoaders:\n",
" if size<=224: path = URLs.IMAGEWOOF_320 if woof else URLs.IMAGENETTE_320\n",
" else : path = URLs.IMAGEWOOF if woof else URLs.IMAGENETTE\n",
" source = untar_data(path)\n",
" if workers is None: workers = min(8, num_cpus())\n",
" batch_tfms = []\n",
" if stats:\n",
" if woof: \n",
" batch_tfms += [Normalize.from_stats(*imagewoof_stats)]\n",
" else:\n",
" batch_tfms += [Normalize.from_stats(*imagenette_stats)]\n",
" if augs: batch_tfms += augs\n",
" if sh: batch_tfms.append(RandomErasing(p=0.3, max_count=3, sh=sh))\n",
" dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),\n",
" splitter=GrandparentSplitter(valid_name='val'),\n",
" get_items=get_image_files, get_y=parent_label,\n",
" item_tfms=[Resize(size)],\n",
" batch_tfms=batch_tfms)\n",
" return dblock.dataloaders(source, path=source, bs=bs, num_workers=workers)"
],
"metadata": {
"id": "pNgVl8Gs-dxf"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Train fastai default mixed precision"
],
"metadata": {
"id": "beQssA6n-wam"
}
},
{
"cell_type": "markdown",
"source": [
"## Warmup"
],
"metadata": {
"id": "bOYaHgr--2yJ"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" opt_func=Adam, metrics=[Accuracy()]).to_fp16()\n",
"\n",
"learn.fit_one_cycle(1, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 81
},
"id": "gXqWSo7R-v28",
"outputId": "a2fbfba6-15bc-48b1-e5eb-295f7cab29be"
},
"execution_count": 3,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.912659</td>\n",
" <td>1.717486</td>\n",
" <td>0.414777</td>\n",
" <td>01:32</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"## Test"
],
"metadata": {
"id": "h8luVtBs-4F7"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" opt_func=Adam, metrics=[Accuracy()]).to_fp16().profile()\n",
"\n",
"learn.fit_one_cycle(2, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 632
},
"id": "lkQ7iTGJ--dd",
"outputId": "738d77bf-6cbc-4faa-cd21-79c1ac57d7ef"
},
"execution_count": 4,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.976799</td>\n",
" <td>1.873074</td>\n",
" <td>0.344968</td>\n",
" <td>01:26</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.385735</td>\n",
" <td>1.268956</td>\n",
" <td>0.610701</td>\n",
" <td>01:25</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<pandas.io.formats.style.Styler at 0x7f763ecf4a90>"
],
"text/html": [
"<style type=\"text/css\">\n",
"</style>\n",
"<table id=\"T_e689c_\" class=\"dataframe\">\n",
" <caption>Simple Profiler Results</caption>\n",
" <thead>\n",
" <tr>\n",
" <th class=\"col_heading level0 col0\" >Phase</th>\n",
" <th class=\"col_heading level0 col1\" >Action</th>\n",
" <th class=\"col_heading level0 col2\" >Mean Duration</th>\n",
" <th class=\"col_heading level0 col3\" >Duration Std Dev</th>\n",
" <th class=\"col_heading level0 col4\" >Number of Calls</th>\n",
" <th class=\"col_heading level0 col5\" >Total Time</th>\n",
" <th class=\"col_heading level0 col6\" >Percent of Total</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td id=\"T_e689c_row0_col0\" class=\"data row0 col0\" >fit</td>\n",
" <td id=\"T_e689c_row0_col1\" class=\"data row0 col1\" >fit</td>\n",
" <td id=\"T_e689c_row0_col2\" class=\"data row0 col2\" >-</td>\n",
" <td id=\"T_e689c_row0_col3\" class=\"data row0 col3\" >-</td>\n",
" <td id=\"T_e689c_row0_col4\" class=\"data row0 col4\" >1</td>\n",
" <td id=\"T_e689c_row0_col5\" class=\"data row0 col5\" >171.6 s</td>\n",
" <td id=\"T_e689c_row0_col6\" class=\"data row0 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row1_col0\" class=\"data row1 col0\" ></td>\n",
" <td id=\"T_e689c_row1_col1\" class=\"data row1 col1\" >epoch</td>\n",
" <td id=\"T_e689c_row1_col2\" class=\"data row1 col2\" >85.81 s</td>\n",
" <td id=\"T_e689c_row1_col3\" class=\"data row1 col3\" >422.4ms</td>\n",
" <td id=\"T_e689c_row1_col4\" class=\"data row1 col4\" >2</td>\n",
" <td id=\"T_e689c_row1_col5\" class=\"data row1 col5\" >171.6 s</td>\n",
" <td id=\"T_e689c_row1_col6\" class=\"data row1 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row2_col0\" class=\"data row2 col0\" ></td>\n",
" <td id=\"T_e689c_row2_col1\" class=\"data row2 col1\" >train</td>\n",
" <td id=\"T_e689c_row2_col2\" class=\"data row2 col2\" >66.26 s</td>\n",
" <td id=\"T_e689c_row2_col3\" class=\"data row2 col3\" >353.6ms</td>\n",
" <td id=\"T_e689c_row2_col4\" class=\"data row2 col4\" >2</td>\n",
" <td id=\"T_e689c_row2_col5\" class=\"data row2 col5\" >132.5 s</td>\n",
" <td id=\"T_e689c_row2_col6\" class=\"data row2 col6\" >77%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row3_col0\" class=\"data row3 col0\" ></td>\n",
" <td id=\"T_e689c_row3_col1\" class=\"data row3 col1\" >validate</td>\n",
" <td id=\"T_e689c_row3_col2\" class=\"data row3 col2\" >19.54 s</td>\n",
" <td id=\"T_e689c_row3_col3\" class=\"data row3 col3\" >65.33ms</td>\n",
" <td id=\"T_e689c_row3_col4\" class=\"data row3 col4\" >2</td>\n",
" <td id=\"T_e689c_row3_col5\" class=\"data row3 col5\" >39.08 s</td>\n",
" <td id=\"T_e689c_row3_col6\" class=\"data row3 col6\" >23%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row4_col0\" class=\"data row4 col0\" >train</td>\n",
" <td id=\"T_e689c_row4_col1\" class=\"data row4 col1\" >batch</td>\n",
" <td id=\"T_e689c_row4_col2\" class=\"data row4 col2\" >443.2ms</td>\n",
" <td id=\"T_e689c_row4_col3\" class=\"data row4 col3\" >67.11ms</td>\n",
" <td id=\"T_e689c_row4_col4\" class=\"data row4 col4\" >294</td>\n",
" <td id=\"T_e689c_row4_col5\" class=\"data row4 col5\" >130.3 s</td>\n",
" <td id=\"T_e689c_row4_col6\" class=\"data row4 col6\" >76%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row5_col0\" class=\"data row5 col0\" ></td>\n",
" <td id=\"T_e689c_row5_col1\" class=\"data row5 col1\" >step</td>\n",
" <td id=\"T_e689c_row5_col2\" class=\"data row5 col2\" >280.0ms</td>\n",
" <td id=\"T_e689c_row5_col3\" class=\"data row5 col3\" >23.06ms</td>\n",
" <td id=\"T_e689c_row5_col4\" class=\"data row5 col4\" >294</td>\n",
" <td id=\"T_e689c_row5_col5\" class=\"data row5 col5\" >82.32 s</td>\n",
" <td id=\"T_e689c_row5_col6\" class=\"data row5 col6\" >48%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row6_col0\" class=\"data row6 col0\" ></td>\n",
" <td id=\"T_e689c_row6_col1\" class=\"data row6 col1\" >backward</td>\n",
" <td id=\"T_e689c_row6_col2\" class=\"data row6 col2\" >73.15ms</td>\n",
" <td id=\"T_e689c_row6_col3\" class=\"data row6 col3\" >13.76ms</td>\n",
" <td id=\"T_e689c_row6_col4\" class=\"data row6 col4\" >294</td>\n",
" <td id=\"T_e689c_row6_col5\" class=\"data row6 col5\" >21.51 s</td>\n",
" <td id=\"T_e689c_row6_col6\" class=\"data row6 col6\" >13%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row7_col0\" class=\"data row7 col0\" ></td>\n",
" <td id=\"T_e689c_row7_col1\" class=\"data row7 col1\" >pred</td>\n",
" <td id=\"T_e689c_row7_col2\" class=\"data row7 col2\" >61.11ms</td>\n",
" <td id=\"T_e689c_row7_col3\" class=\"data row7 col3\" >19.80ms</td>\n",
" <td id=\"T_e689c_row7_col4\" class=\"data row7 col4\" >294</td>\n",
" <td id=\"T_e689c_row7_col5\" class=\"data row7 col5\" >17.97 s</td>\n",
" <td id=\"T_e689c_row7_col6\" class=\"data row7 col6\" >10%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row8_col0\" class=\"data row8 col0\" ></td>\n",
" <td id=\"T_e689c_row8_col1\" class=\"data row8 col1\" >draw</td>\n",
" <td id=\"T_e689c_row8_col2\" class=\"data row8 col2\" >23.78ms</td>\n",
" <td id=\"T_e689c_row8_col3\" class=\"data row8 col3\" >62.00ms</td>\n",
" <td id=\"T_e689c_row8_col4\" class=\"data row8 col4\" >294</td>\n",
" <td id=\"T_e689c_row8_col5\" class=\"data row8 col5\" >6.991 s</td>\n",
" <td id=\"T_e689c_row8_col6\" class=\"data row8 col6\" >4%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row9_col0\" class=\"data row9 col0\" ></td>\n",
" <td id=\"T_e689c_row9_col1\" class=\"data row9 col1\" >zero_grad</td>\n",
" <td id=\"T_e689c_row9_col2\" class=\"data row9 col2\" >2.996ms</td>\n",
" <td id=\"T_e689c_row9_col3\" class=\"data row9 col3\" >1.757ms</td>\n",
" <td id=\"T_e689c_row9_col4\" class=\"data row9 col4\" >294</td>\n",
" <td id=\"T_e689c_row9_col5\" class=\"data row9 col5\" >880.8ms</td>\n",
" <td id=\"T_e689c_row9_col6\" class=\"data row9 col6\" >1%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row10_col0\" class=\"data row10 col0\" ></td>\n",
" <td id=\"T_e689c_row10_col1\" class=\"data row10 col1\" >loss</td>\n",
" <td id=\"T_e689c_row10_col2\" class=\"data row10 col2\" >1.882ms</td>\n",
" <td id=\"T_e689c_row10_col3\" class=\"data row10 col3\" >1.883ms</td>\n",
" <td id=\"T_e689c_row10_col4\" class=\"data row10 col4\" >294</td>\n",
" <td id=\"T_e689c_row10_col5\" class=\"data row10 col5\" >553.3ms</td>\n",
" <td id=\"T_e689c_row10_col6\" class=\"data row10 col6\" >0%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row11_col0\" class=\"data row11 col0\" >valid</td>\n",
" <td id=\"T_e689c_row11_col1\" class=\"data row11 col1\" >batch</td>\n",
" <td id=\"T_e689c_row11_col2\" class=\"data row11 col2\" >277.3ms</td>\n",
" <td id=\"T_e689c_row11_col3\" class=\"data row11 col3\" >155.1ms</td>\n",
" <td id=\"T_e689c_row11_col4\" class=\"data row11 col4\" >124</td>\n",
" <td id=\"T_e689c_row11_col5\" class=\"data row11 col5\" >34.38 s</td>\n",
" <td id=\"T_e689c_row11_col6\" class=\"data row11 col6\" >20%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row12_col0\" class=\"data row12 col0\" ></td>\n",
" <td id=\"T_e689c_row12_col1\" class=\"data row12 col1\" >draw</td>\n",
" <td id=\"T_e689c_row12_col2\" class=\"data row12 col2\" >210.5ms</td>\n",
" <td id=\"T_e689c_row12_col3\" class=\"data row12 col3\" >150.8ms</td>\n",
" <td id=\"T_e689c_row12_col4\" class=\"data row12 col4\" >124</td>\n",
" <td id=\"T_e689c_row12_col5\" class=\"data row12 col5\" >26.10 s</td>\n",
" <td id=\"T_e689c_row12_col6\" class=\"data row12 col6\" >15%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row13_col0\" class=\"data row13 col0\" ></td>\n",
" <td id=\"T_e689c_row13_col1\" class=\"data row13 col1\" >pred</td>\n",
" <td id=\"T_e689c_row13_col2\" class=\"data row13 col2\" >64.46ms</td>\n",
" <td id=\"T_e689c_row13_col3\" class=\"data row13 col3\" >12.82ms</td>\n",
" <td id=\"T_e689c_row13_col4\" class=\"data row13 col4\" >124</td>\n",
" <td id=\"T_e689c_row13_col5\" class=\"data row13 col5\" >7.993 s</td>\n",
" <td id=\"T_e689c_row13_col6\" class=\"data row13 col6\" >5%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e689c_row14_col0\" class=\"data row14 col0\" ></td>\n",
" <td id=\"T_e689c_row14_col1\" class=\"data row14 col1\" >loss</td>\n",
" <td id=\"T_e689c_row14_col2\" class=\"data row14 col2\" >1.952ms</td>\n",
" <td id=\"T_e689c_row14_col3\" class=\"data row14 col3\" >2.213ms</td>\n",
" <td id=\"T_e689c_row14_col4\" class=\"data row14 col4\" >124</td>\n",
" <td id=\"T_e689c_row14_col5\" class=\"data row14 col5\" >242.1ms</td>\n",
" <td id=\"T_e689c_row14_col6\" class=\"data row14 col6\" >0%</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"# Train fastai channels last"
],
"metadata": {
"id": "VmZkjfymAVSS"
}
},
{
"cell_type": "markdown",
"source": [
"## Warmup"
],
"metadata": {
"id": "ybh8djjhAVST"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" opt_func=Adam, metrics=[Accuracy()]).to_channelslast()\n",
"\n",
"learn.fit_one_cycle(1, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 81
},
"outputId": "93f31c24-725b-4467-8c22-f739ee09e829",
"id": "6axWVksaAVST"
},
"execution_count": 5,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.800101</td>\n",
" <td>1.618706</td>\n",
" <td>0.447643</td>\n",
" <td>01:46</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"## Test"
],
"metadata": {
"id": "Wr2wDyEPAVSU"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" opt_func=Adam, metrics=[Accuracy()]).to_channelslast().profile()\n",
"\n",
"learn.fit_one_cycle(2, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 632
},
"outputId": "aa76f239-d805-451c-a0d9-d54964ceabe2",
"id": "cIqWhmxhAVSU"
},
"execution_count": 6,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.945158</td>\n",
" <td>1.860097</td>\n",
" <td>0.369172</td>\n",
" <td>01:33</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.290537</td>\n",
" <td>1.169882</td>\n",
" <td>0.627261</td>\n",
" <td>01:33</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<pandas.io.formats.style.Styler at 0x7f7647fb75d0>"
],
"text/html": [
"<style type=\"text/css\">\n",
"</style>\n",
"<table id=\"T_c2056_\" class=\"dataframe\">\n",
" <caption>Simple Profiler Results</caption>\n",
" <thead>\n",
" <tr>\n",
" <th class=\"col_heading level0 col0\" >Phase</th>\n",
" <th class=\"col_heading level0 col1\" >Action</th>\n",
" <th class=\"col_heading level0 col2\" >Mean Duration</th>\n",
" <th class=\"col_heading level0 col3\" >Duration Std Dev</th>\n",
" <th class=\"col_heading level0 col4\" >Number of Calls</th>\n",
" <th class=\"col_heading level0 col5\" >Total Time</th>\n",
" <th class=\"col_heading level0 col6\" >Percent of Total</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td id=\"T_c2056_row0_col0\" class=\"data row0 col0\" >fit</td>\n",
" <td id=\"T_c2056_row0_col1\" class=\"data row0 col1\" >fit</td>\n",
" <td id=\"T_c2056_row0_col2\" class=\"data row0 col2\" >-</td>\n",
" <td id=\"T_c2056_row0_col3\" class=\"data row0 col3\" >-</td>\n",
" <td id=\"T_c2056_row0_col4\" class=\"data row0 col4\" >1</td>\n",
" <td id=\"T_c2056_row0_col5\" class=\"data row0 col5\" >186.9 s</td>\n",
" <td id=\"T_c2056_row0_col6\" class=\"data row0 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row1_col0\" class=\"data row1 col0\" ></td>\n",
" <td id=\"T_c2056_row1_col1\" class=\"data row1 col1\" >epoch</td>\n",
" <td id=\"T_c2056_row1_col2\" class=\"data row1 col2\" >93.47 s</td>\n",
" <td id=\"T_c2056_row1_col3\" class=\"data row1 col3\" >170.4ms</td>\n",
" <td id=\"T_c2056_row1_col4\" class=\"data row1 col4\" >2</td>\n",
" <td id=\"T_c2056_row1_col5\" class=\"data row1 col5\" >186.9 s</td>\n",
" <td id=\"T_c2056_row1_col6\" class=\"data row1 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row2_col0\" class=\"data row2 col0\" ></td>\n",
" <td id=\"T_c2056_row2_col1\" class=\"data row2 col1\" >train</td>\n",
" <td id=\"T_c2056_row2_col2\" class=\"data row2 col2\" >74.08 s</td>\n",
" <td id=\"T_c2056_row2_col3\" class=\"data row2 col3\" >188.7ms</td>\n",
" <td id=\"T_c2056_row2_col4\" class=\"data row2 col4\" >2</td>\n",
" <td id=\"T_c2056_row2_col5\" class=\"data row2 col5\" >148.2 s</td>\n",
" <td id=\"T_c2056_row2_col6\" class=\"data row2 col6\" >79%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row3_col0\" class=\"data row3 col0\" ></td>\n",
" <td id=\"T_c2056_row3_col1\" class=\"data row3 col1\" >validate</td>\n",
" <td id=\"T_c2056_row3_col2\" class=\"data row3 col2\" >19.38 s</td>\n",
" <td id=\"T_c2056_row3_col3\" class=\"data row3 col3\" >21.99ms</td>\n",
" <td id=\"T_c2056_row3_col4\" class=\"data row3 col4\" >2</td>\n",
" <td id=\"T_c2056_row3_col5\" class=\"data row3 col5\" >38.77 s</td>\n",
" <td id=\"T_c2056_row3_col6\" class=\"data row3 col6\" >21%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row4_col0\" class=\"data row4 col0\" >train</td>\n",
" <td id=\"T_c2056_row4_col1\" class=\"data row4 col1\" >batch</td>\n",
" <td id=\"T_c2056_row4_col2\" class=\"data row4 col2\" >498.0ms</td>\n",
" <td id=\"T_c2056_row4_col3\" class=\"data row4 col3\" >55.88ms</td>\n",
" <td id=\"T_c2056_row4_col4\" class=\"data row4 col4\" >294</td>\n",
" <td id=\"T_c2056_row4_col5\" class=\"data row4 col5\" >146.4 s</td>\n",
" <td id=\"T_c2056_row4_col6\" class=\"data row4 col6\" >78%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row5_col0\" class=\"data row5 col0\" ></td>\n",
" <td id=\"T_c2056_row5_col1\" class=\"data row5 col1\" >step</td>\n",
" <td id=\"T_c2056_row5_col2\" class=\"data row5 col2\" >365.1ms</td>\n",
" <td id=\"T_c2056_row5_col3\" class=\"data row5 col3\" >26.65ms</td>\n",
" <td id=\"T_c2056_row5_col4\" class=\"data row5 col4\" >294</td>\n",
" <td id=\"T_c2056_row5_col5\" class=\"data row5 col5\" >107.3 s</td>\n",
" <td id=\"T_c2056_row5_col6\" class=\"data row5 col6\" >57%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row6_col0\" class=\"data row6 col0\" ></td>\n",
" <td id=\"T_c2056_row6_col1\" class=\"data row6 col1\" >backward</td>\n",
" <td id=\"T_c2056_row6_col2\" class=\"data row6 col2\" >53.55ms</td>\n",
" <td id=\"T_c2056_row6_col3\" class=\"data row6 col3\" >15.32ms</td>\n",
" <td id=\"T_c2056_row6_col4\" class=\"data row6 col4\" >294</td>\n",
" <td id=\"T_c2056_row6_col5\" class=\"data row6 col5\" >15.74 s</td>\n",
" <td id=\"T_c2056_row6_col6\" class=\"data row6 col6\" >8%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row7_col0\" class=\"data row7 col0\" ></td>\n",
" <td id=\"T_c2056_row7_col1\" class=\"data row7 col1\" >pred</td>\n",
" <td id=\"T_c2056_row7_col2\" class=\"data row7 col2\" >52.80ms</td>\n",
" <td id=\"T_c2056_row7_col3\" class=\"data row7 col3\" >16.42ms</td>\n",
" <td id=\"T_c2056_row7_col4\" class=\"data row7 col4\" >294</td>\n",
" <td id=\"T_c2056_row7_col5\" class=\"data row7 col5\" >15.52 s</td>\n",
" <td id=\"T_c2056_row7_col6\" class=\"data row7 col6\" >8%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row8_col0\" class=\"data row8 col0\" ></td>\n",
" <td id=\"T_c2056_row8_col1\" class=\"data row8 col1\" >draw</td>\n",
" <td id=\"T_c2056_row8_col2\" class=\"data row8 col2\" >22.18ms</td>\n",
" <td id=\"T_c2056_row8_col3\" class=\"data row8 col3\" >54.90ms</td>\n",
" <td id=\"T_c2056_row8_col4\" class=\"data row8 col4\" >294</td>\n",
" <td id=\"T_c2056_row8_col5\" class=\"data row8 col5\" >6.521 s</td>\n",
" <td id=\"T_c2056_row8_col6\" class=\"data row8 col6\" >3%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row9_col0\" class=\"data row9 col0\" ></td>\n",
" <td id=\"T_c2056_row9_col1\" class=\"data row9 col1\" >zero_grad</td>\n",
" <td id=\"T_c2056_row9_col2\" class=\"data row9 col2\" >2.565ms</td>\n",
" <td id=\"T_c2056_row9_col3\" class=\"data row9 col3\" >1.181ms</td>\n",
" <td id=\"T_c2056_row9_col4\" class=\"data row9 col4\" >294</td>\n",
" <td id=\"T_c2056_row9_col5\" class=\"data row9 col5\" >754.1ms</td>\n",
" <td id=\"T_c2056_row9_col6\" class=\"data row9 col6\" >0%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row10_col0\" class=\"data row10 col0\" ></td>\n",
" <td id=\"T_c2056_row10_col1\" class=\"data row10 col1\" >loss</td>\n",
" <td id=\"T_c2056_row10_col2\" class=\"data row10 col2\" >1.518ms</td>\n",
" <td id=\"T_c2056_row10_col3\" class=\"data row10 col3\" >1.213ms</td>\n",
" <td id=\"T_c2056_row10_col4\" class=\"data row10 col4\" >294</td>\n",
" <td id=\"T_c2056_row10_col5\" class=\"data row10 col5\" >446.2ms</td>\n",
" <td id=\"T_c2056_row10_col6\" class=\"data row10 col6\" >0%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row11_col0\" class=\"data row11 col0\" >valid</td>\n",
" <td id=\"T_c2056_row11_col1\" class=\"data row11 col1\" >batch</td>\n",
" <td id=\"T_c2056_row11_col2\" class=\"data row11 col2\" >286.2ms</td>\n",
" <td id=\"T_c2056_row11_col3\" class=\"data row11 col3\" >197.6ms</td>\n",
" <td id=\"T_c2056_row11_col4\" class=\"data row11 col4\" >124</td>\n",
" <td id=\"T_c2056_row11_col5\" class=\"data row11 col5\" >35.49 s</td>\n",
" <td id=\"T_c2056_row11_col6\" class=\"data row11 col6\" >19%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row12_col0\" class=\"data row12 col0\" ></td>\n",
" <td id=\"T_c2056_row12_col1\" class=\"data row12 col1\" >draw</td>\n",
" <td id=\"T_c2056_row12_col2\" class=\"data row12 col2\" >220.4ms</td>\n",
" <td id=\"T_c2056_row12_col3\" class=\"data row12 col3\" >193.4ms</td>\n",
" <td id=\"T_c2056_row12_col4\" class=\"data row12 col4\" >124</td>\n",
" <td id=\"T_c2056_row12_col5\" class=\"data row12 col5\" >27.33 s</td>\n",
" <td id=\"T_c2056_row12_col6\" class=\"data row12 col6\" >15%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row13_col0\" class=\"data row13 col0\" ></td>\n",
" <td id=\"T_c2056_row13_col1\" class=\"data row13 col1\" >pred</td>\n",
" <td id=\"T_c2056_row13_col2\" class=\"data row13 col2\" >63.65ms</td>\n",
" <td id=\"T_c2056_row13_col3\" class=\"data row13 col3\" >13.46ms</td>\n",
" <td id=\"T_c2056_row13_col4\" class=\"data row13 col4\" >124</td>\n",
" <td id=\"T_c2056_row13_col5\" class=\"data row13 col5\" >7.893 s</td>\n",
" <td id=\"T_c2056_row13_col6\" class=\"data row13 col6\" >4%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_c2056_row14_col0\" class=\"data row14 col0\" ></td>\n",
" <td id=\"T_c2056_row14_col1\" class=\"data row14 col1\" >loss</td>\n",
" <td id=\"T_c2056_row14_col2\" class=\"data row14 col2\" >1.820ms</td>\n",
" <td id=\"T_c2056_row14_col3\" class=\"data row14 col3\" >2.196ms</td>\n",
" <td id=\"T_c2056_row14_col4\" class=\"data row14 col4\" >124</td>\n",
" <td id=\"T_c2056_row14_col5\" class=\"data row14 col5\" >225.7ms</td>\n",
" <td id=\"T_c2056_row14_col6\" class=\"data row14 col6\" >0%</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"# Remove All Possible fastai Bits That Might be Interfering"
],
"metadata": {
"id": "-dm_BtSuELmM"
}
},
{
"cell_type": "markdown",
"source": [
"First, create a fastai transform that doesn't retain the tensor type"
],
"metadata": {
"id": "yPdi-I1hEfCY"
}
},
{
"cell_type": "code",
"source": [
"from fastcore.transform import DisplayedTransform, _is_tuple, retain_type\n",
"\n",
"class LoseTypeTransform(DisplayedTransform):\n",
" \"\"\n",
" def __init__(self, **kwargs): super().__init__(**kwargs)\n",
"\n",
" def _call(self, fn, x, split_idx=None, **kwargs):\n",
" if split_idx!=self.split_idx and self.split_idx is not None: return x\n",
" return self._do_call(getattr(self, fn), x, **kwargs)\n",
"\n",
" def _do_call(self, f, x, **kwargs):\n",
" if not _is_tuple(x):\n",
" if f is None: return x\n",
" return f(x, **kwargs)\n",
" return tuple(self._do_call(f, x_, **kwargs) for x_ in x)"
],
"metadata": {
"id": "LQIN7mOpECun"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Inherit from it for the ChannelsLastTfm and pass a torch.tensor to the model instead of TensorImage"
],
"metadata": {
"id": "v4XIj3esElF-"
}
},
{
"cell_type": "code",
"source": [
"from fastxtend.callback.channelslast import ChannelsLastTfm"
],
"metadata": {
"id": "OjV5m8lSFGVj"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class ChannelsLastTensor(LoseTypeTransform):\n",
" \"Sets image-like inputs to `channels_last` format. For use in ChannelsLastCallback\"\n",
" order = 110 # run after all other transforms if added to batch_tfms\n",
" def encodes(self, x:TensorImageBase|TensorMask):\n",
" return torch.tensor(x).to(memory_format=torch.channels_last)"
],
"metadata": {
"id": "uWe_J5fFE50f"
},
"execution_count": 9,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Otherwise fastxtend's channelslast callback looks the same, other then a switch for inputing torch.tensor verses tensorimage"
],
"metadata": {
"id": "QmGndPRiRymP"
}
},
{
"cell_type": "code",
"source": [
"class ChannelsLastCallback(Callback):\n",
" \"Channels last training using PyTorch's Channels Last Memory Format (beta)\"\n",
" order = MixedPrecision.order+1\n",
" def __init__(self, astensor=False):\n",
" if astensor: self._channels_last = Pipeline([ChannelsLastTensor()])\n",
" else: self._channels_last = Pipeline([ChannelsLastTfm()])\n",
"\n",
" def before_fit(self):\n",
" self.learn.model.to(memory_format=torch.channels_last)\n",
"\n",
" def before_batch(self):\n",
" self.learn.xb = self._channels_last(self.xb)"
],
"metadata": {
"id": "a9WecdxUEX_F"
},
"execution_count": 10,
"outputs": []
},
{
"cell_type": "code",
"source": [
"@patch\n",
"def to_channelslast(self:Learner, to_fp16=True, astensor=True, **kwargs):\n",
" \"Set `Learner` and inputs to `channels_last` format and Mixed Precision by default\"\n",
" if to_fp16 and not hasattr(self, 'mixed_precision') and not hasattr(self, 'channels_last'): \n",
" return self.add_cbs([ChannelsLastCallback(astensor=astensor), MixedPrecision(**kwargs)])\n",
" elif not hasattr(self, 'channels_last'):\n",
" return self.add_cb(ChannelsLastCallback(astensor=astensor))"
],
"metadata": {
"id": "piSnsjp7EZRF"
},
"execution_count": 11,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Warmup"
],
"metadata": {
"id": "SjhRhmwRE1Mm"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=nn.CrossEntropyLoss(), \n",
" opt_func=partial(OptimWrapper, opt=torch.optim.AdamW),\n",
" metrics=[Accuracy()]).to_channelslast()\n",
"\n",
"learn.fit_one_cycle(1, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 135
},
"id": "5-MnrPkYE2EG",
"outputId": "f092f0dc-7cae-49da-cbab-056abcc5e5b0"
},
"execution_count": 12,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.818512</td>\n",
" <td>1.680830</td>\n",
" <td>0.428535</td>\n",
" <td>01:22</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" \"\"\"\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Test"
],
"metadata": {
"id": "OlfL5IFKE2s9"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=nn.CrossEntropyLoss(), \n",
" opt_func=partial(OptimWrapper, opt=torch.optim.AdamW),\n",
" metrics=[Accuracy()]).to_channelslast().profile()\n",
"\n",
"learn.fit_one_cycle(2, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 687
},
"id": "F4hx0CEXFPut",
"outputId": "5c4772d1-365e-43d8-c2ce-a8a3bb730297"
},
"execution_count": 13,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.863661</td>\n",
" <td>1.996808</td>\n",
" <td>0.355669</td>\n",
" <td>01:17</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.320199</td>\n",
" <td>1.176207</td>\n",
" <td>0.615032</td>\n",
" <td>01:16</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" \"\"\"\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<pandas.io.formats.style.Styler at 0x7f763eeca4d0>"
],
"text/html": [
"<style type=\"text/css\">\n",
"</style>\n",
"<table id=\"T_8e94a_\" class=\"dataframe\">\n",
" <caption>Simple Profiler Results</caption>\n",
" <thead>\n",
" <tr>\n",
" <th class=\"col_heading level0 col0\" >Phase</th>\n",
" <th class=\"col_heading level0 col1\" >Action</th>\n",
" <th class=\"col_heading level0 col2\" >Mean Duration</th>\n",
" <th class=\"col_heading level0 col3\" >Duration Std Dev</th>\n",
" <th class=\"col_heading level0 col4\" >Number of Calls</th>\n",
" <th class=\"col_heading level0 col5\" >Total Time</th>\n",
" <th class=\"col_heading level0 col6\" >Percent of Total</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td id=\"T_8e94a_row0_col0\" class=\"data row0 col0\" >fit</td>\n",
" <td id=\"T_8e94a_row0_col1\" class=\"data row0 col1\" >fit</td>\n",
" <td id=\"T_8e94a_row0_col2\" class=\"data row0 col2\" >-</td>\n",
" <td id=\"T_8e94a_row0_col3\" class=\"data row0 col3\" >-</td>\n",
" <td id=\"T_8e94a_row0_col4\" class=\"data row0 col4\" >1</td>\n",
" <td id=\"T_8e94a_row0_col5\" class=\"data row0 col5\" >154.6 s</td>\n",
" <td id=\"T_8e94a_row0_col6\" class=\"data row0 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row1_col0\" class=\"data row1 col0\" ></td>\n",
" <td id=\"T_8e94a_row1_col1\" class=\"data row1 col1\" >epoch</td>\n",
" <td id=\"T_8e94a_row1_col2\" class=\"data row1 col2\" >77.31 s</td>\n",
" <td id=\"T_8e94a_row1_col3\" class=\"data row1 col3\" >537.9ms</td>\n",
" <td id=\"T_8e94a_row1_col4\" class=\"data row1 col4\" >2</td>\n",
" <td id=\"T_8e94a_row1_col5\" class=\"data row1 col5\" >154.6 s</td>\n",
" <td id=\"T_8e94a_row1_col6\" class=\"data row1 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row2_col0\" class=\"data row2 col0\" ></td>\n",
" <td id=\"T_8e94a_row2_col1\" class=\"data row2 col1\" >train</td>\n",
" <td id=\"T_8e94a_row2_col2\" class=\"data row2 col2\" >57.79 s</td>\n",
" <td id=\"T_8e94a_row2_col3\" class=\"data row2 col3\" >402.5ms</td>\n",
" <td id=\"T_8e94a_row2_col4\" class=\"data row2 col4\" >2</td>\n",
" <td id=\"T_8e94a_row2_col5\" class=\"data row2 col5\" >115.6 s</td>\n",
" <td id=\"T_8e94a_row2_col6\" class=\"data row2 col6\" >75%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row3_col0\" class=\"data row3 col0\" ></td>\n",
" <td id=\"T_8e94a_row3_col1\" class=\"data row3 col1\" >validate</td>\n",
" <td id=\"T_8e94a_row3_col2\" class=\"data row3 col2\" >19.51 s</td>\n",
" <td id=\"T_8e94a_row3_col3\" class=\"data row3 col3\" >132.2ms</td>\n",
" <td id=\"T_8e94a_row3_col4\" class=\"data row3 col4\" >2</td>\n",
" <td id=\"T_8e94a_row3_col5\" class=\"data row3 col5\" >39.01 s</td>\n",
" <td id=\"T_8e94a_row3_col6\" class=\"data row3 col6\" >25%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row4_col0\" class=\"data row4 col0\" >train</td>\n",
" <td id=\"T_8e94a_row4_col1\" class=\"data row4 col1\" >batch</td>\n",
" <td id=\"T_8e94a_row4_col2\" class=\"data row4 col2\" >368.2ms</td>\n",
" <td id=\"T_8e94a_row4_col3\" class=\"data row4 col3\" >97.57ms</td>\n",
" <td id=\"T_8e94a_row4_col4\" class=\"data row4 col4\" >294</td>\n",
" <td id=\"T_8e94a_row4_col5\" class=\"data row4 col5\" >108.3 s</td>\n",
" <td id=\"T_8e94a_row4_col6\" class=\"data row4 col6\" >70%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row5_col0\" class=\"data row5 col0\" ></td>\n",
" <td id=\"T_8e94a_row5_col1\" class=\"data row5 col1\" >step</td>\n",
" <td id=\"T_8e94a_row5_col2\" class=\"data row5 col2\" >172.8ms</td>\n",
" <td id=\"T_8e94a_row5_col3\" class=\"data row5 col3\" >22.10ms</td>\n",
" <td id=\"T_8e94a_row5_col4\" class=\"data row5 col4\" >294</td>\n",
" <td id=\"T_8e94a_row5_col5\" class=\"data row5 col5\" >50.81 s</td>\n",
" <td id=\"T_8e94a_row5_col6\" class=\"data row5 col6\" >33%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row6_col0\" class=\"data row6 col0\" ></td>\n",
" <td id=\"T_8e94a_row6_col1\" class=\"data row6 col1\" >draw</td>\n",
" <td id=\"T_8e94a_row6_col2\" class=\"data row6 col2\" >71.45ms</td>\n",
" <td id=\"T_8e94a_row6_col3\" class=\"data row6 col3\" >95.57ms</td>\n",
" <td id=\"T_8e94a_row6_col4\" class=\"data row6 col4\" >294</td>\n",
" <td id=\"T_8e94a_row6_col5\" class=\"data row6 col5\" >21.00 s</td>\n",
" <td id=\"T_8e94a_row6_col6\" class=\"data row6 col6\" >14%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row7_col0\" class=\"data row7 col0\" ></td>\n",
" <td id=\"T_8e94a_row7_col1\" class=\"data row7 col1\" >pred</td>\n",
" <td id=\"T_8e94a_row7_col2\" class=\"data row7 col2\" >61.65ms</td>\n",
" <td id=\"T_8e94a_row7_col3\" class=\"data row7 col3\" >15.55ms</td>\n",
" <td id=\"T_8e94a_row7_col4\" class=\"data row7 col4\" >294</td>\n",
" <td id=\"T_8e94a_row7_col5\" class=\"data row7 col5\" >18.13 s</td>\n",
" <td id=\"T_8e94a_row7_col6\" class=\"data row7 col6\" >12%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row8_col0\" class=\"data row8 col0\" ></td>\n",
" <td id=\"T_8e94a_row8_col1\" class=\"data row8 col1\" >backward</td>\n",
" <td id=\"T_8e94a_row8_col2\" class=\"data row8 col2\" >56.72ms</td>\n",
" <td id=\"T_8e94a_row8_col3\" class=\"data row8 col3\" >12.65ms</td>\n",
" <td id=\"T_8e94a_row8_col4\" class=\"data row8 col4\" >294</td>\n",
" <td id=\"T_8e94a_row8_col5\" class=\"data row8 col5\" >16.68 s</td>\n",
" <td id=\"T_8e94a_row8_col6\" class=\"data row8 col6\" >11%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row9_col0\" class=\"data row9 col0\" ></td>\n",
" <td id=\"T_8e94a_row9_col1\" class=\"data row9 col1\" >zero_grad</td>\n",
" <td id=\"T_8e94a_row9_col2\" class=\"data row9 col2\" >3.895ms</td>\n",
" <td id=\"T_8e94a_row9_col3\" class=\"data row9 col3\" >2.942ms</td>\n",
" <td id=\"T_8e94a_row9_col4\" class=\"data row9 col4\" >294</td>\n",
" <td id=\"T_8e94a_row9_col5\" class=\"data row9 col5\" >1.145 s</td>\n",
" <td id=\"T_8e94a_row9_col6\" class=\"data row9 col6\" >1%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row10_col0\" class=\"data row10 col0\" ></td>\n",
" <td id=\"T_8e94a_row10_col1\" class=\"data row10 col1\" >loss</td>\n",
" <td id=\"T_8e94a_row10_col2\" class=\"data row10 col2\" >1.345ms</td>\n",
" <td id=\"T_8e94a_row10_col3\" class=\"data row10 col3\" >1.974ms</td>\n",
" <td id=\"T_8e94a_row10_col4\" class=\"data row10 col4\" >294</td>\n",
" <td id=\"T_8e94a_row10_col5\" class=\"data row10 col5\" >395.3ms</td>\n",
" <td id=\"T_8e94a_row10_col6\" class=\"data row10 col6\" >0%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row11_col0\" class=\"data row11 col0\" >valid</td>\n",
" <td id=\"T_8e94a_row11_col1\" class=\"data row11 col1\" >batch</td>\n",
" <td id=\"T_8e94a_row11_col2\" class=\"data row11 col2\" >264.9ms</td>\n",
" <td id=\"T_8e94a_row11_col3\" class=\"data row11 col3\" >172.1ms</td>\n",
" <td id=\"T_8e94a_row11_col4\" class=\"data row11 col4\" >124</td>\n",
" <td id=\"T_8e94a_row11_col5\" class=\"data row11 col5\" >32.85 s</td>\n",
" <td id=\"T_8e94a_row11_col6\" class=\"data row11 col6\" >21%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row12_col0\" class=\"data row12 col0\" ></td>\n",
" <td id=\"T_8e94a_row12_col1\" class=\"data row12 col1\" >draw</td>\n",
" <td id=\"T_8e94a_row12_col2\" class=\"data row12 col2\" >223.5ms</td>\n",
" <td id=\"T_8e94a_row12_col3\" class=\"data row12 col3\" >171.4ms</td>\n",
" <td id=\"T_8e94a_row12_col4\" class=\"data row12 col4\" >124</td>\n",
" <td id=\"T_8e94a_row12_col5\" class=\"data row12 col5\" >27.72 s</td>\n",
" <td id=\"T_8e94a_row12_col6\" class=\"data row12 col6\" >18%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row13_col0\" class=\"data row13 col0\" ></td>\n",
" <td id=\"T_8e94a_row13_col1\" class=\"data row13 col1\" >pred</td>\n",
" <td id=\"T_8e94a_row13_col2\" class=\"data row13 col2\" >40.13ms</td>\n",
" <td id=\"T_8e94a_row13_col3\" class=\"data row13 col3\" >9.556ms</td>\n",
" <td id=\"T_8e94a_row13_col4\" class=\"data row13 col4\" >124</td>\n",
" <td id=\"T_8e94a_row13_col5\" class=\"data row13 col5\" >4.976 s</td>\n",
" <td id=\"T_8e94a_row13_col6\" class=\"data row13 col6\" >3%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_8e94a_row14_col0\" class=\"data row14 col0\" ></td>\n",
" <td id=\"T_8e94a_row14_col1\" class=\"data row14 col1\" >loss</td>\n",
" <td id=\"T_8e94a_row14_col2\" class=\"data row14 col2\" >945.4µs</td>\n",
" <td id=\"T_8e94a_row14_col3\" class=\"data row14 col3\" >1.121ms</td>\n",
" <td id=\"T_8e94a_row14_col4\" class=\"data row14 col4\" >124</td>\n",
" <td id=\"T_8e94a_row14_col5\" class=\"data row14 col5\" >117.2ms</td>\n",
" <td id=\"T_8e94a_row14_col6\" class=\"data row14 col6\" >0%</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"# Use a fastai optimizer, but tensor input, pytorch loss"
],
"metadata": {
"id": "DaezZT9eK3g-"
}
},
{
"cell_type": "markdown",
"source": [
"## Warmup"
],
"metadata": {
"id": "8P-SVF8aLAbP"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=nn.CrossEntropyLoss(), \n",
" opt_func=Adam,\n",
" metrics=[Accuracy()]).to_channelslast()\n",
"\n",
"learn.fit_one_cycle(1, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 135
},
"outputId": "d6785999-1946-4356-b38a-848e78594be6",
"id": "pU5AcaIILAbP"
},
"execution_count": 14,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.979175</td>\n",
" <td>1.794662</td>\n",
" <td>0.380382</td>\n",
" <td>01:15</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" \"\"\"\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Test"
],
"metadata": {
"id": "LWxeP4AULAbQ"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=nn.CrossEntropyLoss(), \n",
" opt_func=Adam,\n",
" metrics=[Accuracy()]).to_channelslast().profile()\n",
"\n",
"learn.fit_one_cycle(2, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 687
},
"id": "UbhUwy9bLAbQ",
"outputId": "86da8255-0c61-4b7b-b4ef-f1c5b0b4171b"
},
"execution_count": 15,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.852839</td>\n",
" <td>2.054055</td>\n",
" <td>0.332739</td>\n",
" <td>01:15</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.300486</td>\n",
" <td>1.145100</td>\n",
" <td>0.634650</td>\n",
" <td>01:18</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" \"\"\"\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<pandas.io.formats.style.Styler at 0x7f763eff1750>"
],
"text/html": [
"<style type=\"text/css\">\n",
"</style>\n",
"<table id=\"T_e15fa_\" class=\"dataframe\">\n",
" <caption>Simple Profiler Results</caption>\n",
" <thead>\n",
" <tr>\n",
" <th class=\"col_heading level0 col0\" >Phase</th>\n",
" <th class=\"col_heading level0 col1\" >Action</th>\n",
" <th class=\"col_heading level0 col2\" >Mean Duration</th>\n",
" <th class=\"col_heading level0 col3\" >Duration Std Dev</th>\n",
" <th class=\"col_heading level0 col4\" >Number of Calls</th>\n",
" <th class=\"col_heading level0 col5\" >Total Time</th>\n",
" <th class=\"col_heading level0 col6\" >Percent of Total</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td id=\"T_e15fa_row0_col0\" class=\"data row0 col0\" >fit</td>\n",
" <td id=\"T_e15fa_row0_col1\" class=\"data row0 col1\" >fit</td>\n",
" <td id=\"T_e15fa_row0_col2\" class=\"data row0 col2\" >-</td>\n",
" <td id=\"T_e15fa_row0_col3\" class=\"data row0 col3\" >-</td>\n",
" <td id=\"T_e15fa_row0_col4\" class=\"data row0 col4\" >1</td>\n",
" <td id=\"T_e15fa_row0_col5\" class=\"data row0 col5\" >154.5 s</td>\n",
" <td id=\"T_e15fa_row0_col6\" class=\"data row0 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row1_col0\" class=\"data row1 col0\" ></td>\n",
" <td id=\"T_e15fa_row1_col1\" class=\"data row1 col1\" >epoch</td>\n",
" <td id=\"T_e15fa_row1_col2\" class=\"data row1 col2\" >77.27 s</td>\n",
" <td id=\"T_e15fa_row1_col3\" class=\"data row1 col3\" >1.341 s</td>\n",
" <td id=\"T_e15fa_row1_col4\" class=\"data row1 col4\" >2</td>\n",
" <td id=\"T_e15fa_row1_col5\" class=\"data row1 col5\" >154.5 s</td>\n",
" <td id=\"T_e15fa_row1_col6\" class=\"data row1 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row2_col0\" class=\"data row2 col0\" ></td>\n",
" <td id=\"T_e15fa_row2_col1\" class=\"data row2 col1\" >train</td>\n",
" <td id=\"T_e15fa_row2_col2\" class=\"data row2 col2\" >57.12 s</td>\n",
" <td id=\"T_e15fa_row2_col3\" class=\"data row2 col3\" >1.114 s</td>\n",
" <td id=\"T_e15fa_row2_col4\" class=\"data row2 col4\" >2</td>\n",
" <td id=\"T_e15fa_row2_col5\" class=\"data row2 col5\" >114.2 s</td>\n",
" <td id=\"T_e15fa_row2_col6\" class=\"data row2 col6\" >74%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row3_col0\" class=\"data row3 col0\" ></td>\n",
" <td id=\"T_e15fa_row3_col1\" class=\"data row3 col1\" >validate</td>\n",
" <td id=\"T_e15fa_row3_col2\" class=\"data row3 col2\" >20.14 s</td>\n",
" <td id=\"T_e15fa_row3_col3\" class=\"data row3 col3\" >230.3ms</td>\n",
" <td id=\"T_e15fa_row3_col4\" class=\"data row3 col4\" >2</td>\n",
" <td id=\"T_e15fa_row3_col5\" class=\"data row3 col5\" >40.29 s</td>\n",
" <td id=\"T_e15fa_row3_col6\" class=\"data row3 col6\" >26%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row4_col0\" class=\"data row4 col0\" >train</td>\n",
" <td id=\"T_e15fa_row4_col1\" class=\"data row4 col1\" >batch</td>\n",
" <td id=\"T_e15fa_row4_col2\" class=\"data row4 col2\" >379.6ms</td>\n",
" <td id=\"T_e15fa_row4_col3\" class=\"data row4 col3\" >113.8ms</td>\n",
" <td id=\"T_e15fa_row4_col4\" class=\"data row4 col4\" >294</td>\n",
" <td id=\"T_e15fa_row4_col5\" class=\"data row4 col5\" >111.6 s</td>\n",
" <td id=\"T_e15fa_row4_col6\" class=\"data row4 col6\" >72%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row5_col0\" class=\"data row5 col0\" ></td>\n",
" <td id=\"T_e15fa_row5_col1\" class=\"data row5 col1\" >step</td>\n",
" <td id=\"T_e15fa_row5_col2\" class=\"data row5 col2\" >182.0ms</td>\n",
" <td id=\"T_e15fa_row5_col3\" class=\"data row5 col3\" >24.35ms</td>\n",
" <td id=\"T_e15fa_row5_col4\" class=\"data row5 col4\" >294</td>\n",
" <td id=\"T_e15fa_row5_col5\" class=\"data row5 col5\" >53.50 s</td>\n",
" <td id=\"T_e15fa_row5_col6\" class=\"data row5 col6\" >35%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row6_col0\" class=\"data row6 col0\" ></td>\n",
" <td id=\"T_e15fa_row6_col1\" class=\"data row6 col1\" >draw</td>\n",
" <td id=\"T_e15fa_row6_col2\" class=\"data row6 col2\" >87.42ms</td>\n",
" <td id=\"T_e15fa_row6_col3\" class=\"data row6 col3\" >113.3ms</td>\n",
" <td id=\"T_e15fa_row6_col4\" class=\"data row6 col4\" >294</td>\n",
" <td id=\"T_e15fa_row6_col5\" class=\"data row6 col5\" >25.70 s</td>\n",
" <td id=\"T_e15fa_row6_col6\" class=\"data row6 col6\" >17%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row7_col0\" class=\"data row7 col0\" ></td>\n",
" <td id=\"T_e15fa_row7_col1\" class=\"data row7 col1\" >backward</td>\n",
" <td id=\"T_e15fa_row7_col2\" class=\"data row7 col2\" >56.09ms</td>\n",
" <td id=\"T_e15fa_row7_col3\" class=\"data row7 col3\" >13.87ms</td>\n",
" <td id=\"T_e15fa_row7_col4\" class=\"data row7 col4\" >294</td>\n",
" <td id=\"T_e15fa_row7_col5\" class=\"data row7 col5\" >16.49 s</td>\n",
" <td id=\"T_e15fa_row7_col6\" class=\"data row7 col6\" >11%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row8_col0\" class=\"data row8 col0\" ></td>\n",
" <td id=\"T_e15fa_row8_col1\" class=\"data row8 col1\" >pred</td>\n",
" <td id=\"T_e15fa_row8_col2\" class=\"data row8 col2\" >48.32ms</td>\n",
" <td id=\"T_e15fa_row8_col3\" class=\"data row8 col3\" >13.84ms</td>\n",
" <td id=\"T_e15fa_row8_col4\" class=\"data row8 col4\" >294</td>\n",
" <td id=\"T_e15fa_row8_col5\" class=\"data row8 col5\" >14.21 s</td>\n",
" <td id=\"T_e15fa_row8_col6\" class=\"data row8 col6\" >9%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row9_col0\" class=\"data row9 col0\" ></td>\n",
" <td id=\"T_e15fa_row9_col1\" class=\"data row9 col1\" >zero_grad</td>\n",
" <td id=\"T_e15fa_row9_col2\" class=\"data row9 col2\" >4.106ms</td>\n",
" <td id=\"T_e15fa_row9_col3\" class=\"data row9 col3\" >3.017ms</td>\n",
" <td id=\"T_e15fa_row9_col4\" class=\"data row9 col4\" >294</td>\n",
" <td id=\"T_e15fa_row9_col5\" class=\"data row9 col5\" >1.207 s</td>\n",
" <td id=\"T_e15fa_row9_col6\" class=\"data row9 col6\" >1%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row10_col0\" class=\"data row10 col0\" ></td>\n",
" <td id=\"T_e15fa_row10_col1\" class=\"data row10 col1\" >loss</td>\n",
" <td id=\"T_e15fa_row10_col2\" class=\"data row10 col2\" >1.425ms</td>\n",
" <td id=\"T_e15fa_row10_col3\" class=\"data row10 col3\" >2.092ms</td>\n",
" <td id=\"T_e15fa_row10_col4\" class=\"data row10 col4\" >294</td>\n",
" <td id=\"T_e15fa_row10_col5\" class=\"data row10 col5\" >418.9ms</td>\n",
" <td id=\"T_e15fa_row10_col6\" class=\"data row10 col6\" >0%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row11_col0\" class=\"data row11 col0\" >valid</td>\n",
" <td id=\"T_e15fa_row11_col1\" class=\"data row11 col1\" >batch</td>\n",
" <td id=\"T_e15fa_row11_col2\" class=\"data row11 col2\" >276.2ms</td>\n",
" <td id=\"T_e15fa_row11_col3\" class=\"data row11 col3\" >186.7ms</td>\n",
" <td id=\"T_e15fa_row11_col4\" class=\"data row11 col4\" >124</td>\n",
" <td id=\"T_e15fa_row11_col5\" class=\"data row11 col5\" >34.25 s</td>\n",
" <td id=\"T_e15fa_row11_col6\" class=\"data row11 col6\" >22%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row12_col0\" class=\"data row12 col0\" ></td>\n",
" <td id=\"T_e15fa_row12_col1\" class=\"data row12 col1\" >draw</td>\n",
" <td id=\"T_e15fa_row12_col2\" class=\"data row12 col2\" >233.0ms</td>\n",
" <td id=\"T_e15fa_row12_col3\" class=\"data row12 col3\" >183.6ms</td>\n",
" <td id=\"T_e15fa_row12_col4\" class=\"data row12 col4\" >124</td>\n",
" <td id=\"T_e15fa_row12_col5\" class=\"data row12 col5\" >28.90 s</td>\n",
" <td id=\"T_e15fa_row12_col6\" class=\"data row12 col6\" >19%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row13_col0\" class=\"data row13 col0\" ></td>\n",
" <td id=\"T_e15fa_row13_col1\" class=\"data row13 col1\" >pred</td>\n",
" <td id=\"T_e15fa_row13_col2\" class=\"data row13 col2\" >41.59ms</td>\n",
" <td id=\"T_e15fa_row13_col3\" class=\"data row13 col3\" >12.00ms</td>\n",
" <td id=\"T_e15fa_row13_col4\" class=\"data row13 col4\" >124</td>\n",
" <td id=\"T_e15fa_row13_col5\" class=\"data row13 col5\" >5.158 s</td>\n",
" <td id=\"T_e15fa_row13_col6\" class=\"data row13 col6\" >3%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_e15fa_row14_col0\" class=\"data row14 col0\" ></td>\n",
" <td id=\"T_e15fa_row14_col1\" class=\"data row14 col1\" >loss</td>\n",
" <td id=\"T_e15fa_row14_col2\" class=\"data row14 col2\" >1.344ms</td>\n",
" <td id=\"T_e15fa_row14_col3\" class=\"data row14 col3\" >2.000ms</td>\n",
" <td id=\"T_e15fa_row14_col4\" class=\"data row14 col4\" >124</td>\n",
" <td id=\"T_e15fa_row14_col5\" class=\"data row14 col5\" >166.7ms</td>\n",
" <td id=\"T_e15fa_row14_col6\" class=\"data row14 col6\" >0%</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"# Use a fastai optimizer & fastai loss, but tensor input"
],
"metadata": {
"id": "J_V3brgbMPPF"
}
},
{
"cell_type": "markdown",
"source": [
"## Warmup"
],
"metadata": {
"id": "29C-NpRYMPPX"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" opt_func=Adam,\n",
" metrics=[Accuracy()]).to_channelslast()\n",
"\n",
"learn.fit_one_cycle(1, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 135
},
"outputId": "42af2a10-2f51-4318-dbe3-edb581a69599",
"id": "CcmbGsBaMPPX"
},
"execution_count": 16,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.891318</td>\n",
" <td>1.795799</td>\n",
" <td>0.423694</td>\n",
" <td>01:15</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" \"\"\"\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Test"
],
"metadata": {
"id": "RHCP86D_MPPY"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" opt_func=Adam,\n",
" metrics=[Accuracy()]).to_channelslast().profile()\n",
"\n",
"learn.fit_one_cycle(2, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 687
},
"outputId": "21124eb9-593b-4994-9fc9-914e888bf166",
"id": "OIpimLlxMPPY"
},
"execution_count": 17,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.887546</td>\n",
" <td>1.747342</td>\n",
" <td>0.409427</td>\n",
" <td>01:15</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.378880</td>\n",
" <td>1.274395</td>\n",
" <td>0.592611</td>\n",
" <td>01:18</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" \"\"\"\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<pandas.io.formats.style.Styler at 0x7f763efedc10>"
],
"text/html": [
"<style type=\"text/css\">\n",
"</style>\n",
"<table id=\"T_435ce_\" class=\"dataframe\">\n",
" <caption>Simple Profiler Results</caption>\n",
" <thead>\n",
" <tr>\n",
" <th class=\"col_heading level0 col0\" >Phase</th>\n",
" <th class=\"col_heading level0 col1\" >Action</th>\n",
" <th class=\"col_heading level0 col2\" >Mean Duration</th>\n",
" <th class=\"col_heading level0 col3\" >Duration Std Dev</th>\n",
" <th class=\"col_heading level0 col4\" >Number of Calls</th>\n",
" <th class=\"col_heading level0 col5\" >Total Time</th>\n",
" <th class=\"col_heading level0 col6\" >Percent of Total</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td id=\"T_435ce_row0_col0\" class=\"data row0 col0\" >fit</td>\n",
" <td id=\"T_435ce_row0_col1\" class=\"data row0 col1\" >fit</td>\n",
" <td id=\"T_435ce_row0_col2\" class=\"data row0 col2\" >-</td>\n",
" <td id=\"T_435ce_row0_col3\" class=\"data row0 col3\" >-</td>\n",
" <td id=\"T_435ce_row0_col4\" class=\"data row0 col4\" >1</td>\n",
" <td id=\"T_435ce_row0_col5\" class=\"data row0 col5\" >153.7 s</td>\n",
" <td id=\"T_435ce_row0_col6\" class=\"data row0 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row1_col0\" class=\"data row1 col0\" ></td>\n",
" <td id=\"T_435ce_row1_col1\" class=\"data row1 col1\" >epoch</td>\n",
" <td id=\"T_435ce_row1_col2\" class=\"data row1 col2\" >76.85 s</td>\n",
" <td id=\"T_435ce_row1_col3\" class=\"data row1 col3\" >1.211 s</td>\n",
" <td id=\"T_435ce_row1_col4\" class=\"data row1 col4\" >2</td>\n",
" <td id=\"T_435ce_row1_col5\" class=\"data row1 col5\" >153.7 s</td>\n",
" <td id=\"T_435ce_row1_col6\" class=\"data row1 col6\" >100%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row2_col0\" class=\"data row2 col0\" ></td>\n",
" <td id=\"T_435ce_row2_col1\" class=\"data row2 col1\" >train</td>\n",
" <td id=\"T_435ce_row2_col2\" class=\"data row2 col2\" >56.61 s</td>\n",
" <td id=\"T_435ce_row2_col3\" class=\"data row2 col3\" >356.4ms</td>\n",
" <td id=\"T_435ce_row2_col4\" class=\"data row2 col4\" >2</td>\n",
" <td id=\"T_435ce_row2_col5\" class=\"data row2 col5\" >113.2 s</td>\n",
" <td id=\"T_435ce_row2_col6\" class=\"data row2 col6\" >74%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row3_col0\" class=\"data row3 col0\" ></td>\n",
" <td id=\"T_435ce_row3_col1\" class=\"data row3 col1\" >validate</td>\n",
" <td id=\"T_435ce_row3_col2\" class=\"data row3 col2\" >20.23 s</td>\n",
" <td id=\"T_435ce_row3_col3\" class=\"data row3 col3\" >857.5ms</td>\n",
" <td id=\"T_435ce_row3_col4\" class=\"data row3 col4\" >2</td>\n",
" <td id=\"T_435ce_row3_col5\" class=\"data row3 col5\" >40.45 s</td>\n",
" <td id=\"T_435ce_row3_col6\" class=\"data row3 col6\" >26%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row4_col0\" class=\"data row4 col0\" >train</td>\n",
" <td id=\"T_435ce_row4_col1\" class=\"data row4 col1\" >batch</td>\n",
" <td id=\"T_435ce_row4_col2\" class=\"data row4 col2\" >375.5ms</td>\n",
" <td id=\"T_435ce_row4_col3\" class=\"data row4 col3\" >114.0ms</td>\n",
" <td id=\"T_435ce_row4_col4\" class=\"data row4 col4\" >294</td>\n",
" <td id=\"T_435ce_row4_col5\" class=\"data row4 col5\" >110.4 s</td>\n",
" <td id=\"T_435ce_row4_col6\" class=\"data row4 col6\" >72%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row5_col0\" class=\"data row5 col0\" ></td>\n",
" <td id=\"T_435ce_row5_col1\" class=\"data row5 col1\" >step</td>\n",
" <td id=\"T_435ce_row5_col2\" class=\"data row5 col2\" >177.2ms</td>\n",
" <td id=\"T_435ce_row5_col3\" class=\"data row5 col3\" >24.03ms</td>\n",
" <td id=\"T_435ce_row5_col4\" class=\"data row5 col4\" >294</td>\n",
" <td id=\"T_435ce_row5_col5\" class=\"data row5 col5\" >52.11 s</td>\n",
" <td id=\"T_435ce_row5_col6\" class=\"data row5 col6\" >34%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row6_col0\" class=\"data row6 col0\" ></td>\n",
" <td id=\"T_435ce_row6_col1\" class=\"data row6 col1\" >draw</td>\n",
" <td id=\"T_435ce_row6_col2\" class=\"data row6 col2\" >83.93ms</td>\n",
" <td id=\"T_435ce_row6_col3\" class=\"data row6 col3\" >113.6ms</td>\n",
" <td id=\"T_435ce_row6_col4\" class=\"data row6 col4\" >294</td>\n",
" <td id=\"T_435ce_row6_col5\" class=\"data row6 col5\" >24.68 s</td>\n",
" <td id=\"T_435ce_row6_col6\" class=\"data row6 col6\" >16%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row7_col0\" class=\"data row7 col0\" ></td>\n",
" <td id=\"T_435ce_row7_col1\" class=\"data row7 col1\" >backward</td>\n",
" <td id=\"T_435ce_row7_col2\" class=\"data row7 col2\" >57.99ms</td>\n",
" <td id=\"T_435ce_row7_col3\" class=\"data row7 col3\" >14.16ms</td>\n",
" <td id=\"T_435ce_row7_col4\" class=\"data row7 col4\" >294</td>\n",
" <td id=\"T_435ce_row7_col5\" class=\"data row7 col5\" >17.05 s</td>\n",
" <td id=\"T_435ce_row7_col6\" class=\"data row7 col6\" >11%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row8_col0\" class=\"data row8 col0\" ></td>\n",
" <td id=\"T_435ce_row8_col1\" class=\"data row8 col1\" >pred</td>\n",
" <td id=\"T_435ce_row8_col2\" class=\"data row8 col2\" >49.67ms</td>\n",
" <td id=\"T_435ce_row8_col3\" class=\"data row8 col3\" >13.58ms</td>\n",
" <td id=\"T_435ce_row8_col4\" class=\"data row8 col4\" >294</td>\n",
" <td id=\"T_435ce_row8_col5\" class=\"data row8 col5\" >14.60 s</td>\n",
" <td id=\"T_435ce_row8_col6\" class=\"data row8 col6\" >10%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row9_col0\" class=\"data row9 col0\" ></td>\n",
" <td id=\"T_435ce_row9_col1\" class=\"data row9 col1\" >zero_grad</td>\n",
" <td id=\"T_435ce_row9_col2\" class=\"data row9 col2\" >4.312ms</td>\n",
" <td id=\"T_435ce_row9_col3\" class=\"data row9 col3\" >2.997ms</td>\n",
" <td id=\"T_435ce_row9_col4\" class=\"data row9 col4\" >294</td>\n",
" <td id=\"T_435ce_row9_col5\" class=\"data row9 col5\" >1.268 s</td>\n",
" <td id=\"T_435ce_row9_col6\" class=\"data row9 col6\" >1%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row10_col0\" class=\"data row10 col0\" ></td>\n",
" <td id=\"T_435ce_row10_col1\" class=\"data row10 col1\" >loss</td>\n",
" <td id=\"T_435ce_row10_col2\" class=\"data row10 col2\" >2.058ms</td>\n",
" <td id=\"T_435ce_row10_col3\" class=\"data row10 col3\" >2.189ms</td>\n",
" <td id=\"T_435ce_row10_col4\" class=\"data row10 col4\" >294</td>\n",
" <td id=\"T_435ce_row10_col5\" class=\"data row10 col5\" >605.1ms</td>\n",
" <td id=\"T_435ce_row10_col6\" class=\"data row10 col6\" >0%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row11_col0\" class=\"data row11 col0\" >valid</td>\n",
" <td id=\"T_435ce_row11_col1\" class=\"data row11 col1\" >batch</td>\n",
" <td id=\"T_435ce_row11_col2\" class=\"data row11 col2\" >279.1ms</td>\n",
" <td id=\"T_435ce_row11_col3\" class=\"data row11 col3\" >192.9ms</td>\n",
" <td id=\"T_435ce_row11_col4\" class=\"data row11 col4\" >124</td>\n",
" <td id=\"T_435ce_row11_col5\" class=\"data row11 col5\" >34.61 s</td>\n",
" <td id=\"T_435ce_row11_col6\" class=\"data row11 col6\" >23%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row12_col0\" class=\"data row12 col0\" ></td>\n",
" <td id=\"T_435ce_row12_col1\" class=\"data row12 col1\" >draw</td>\n",
" <td id=\"T_435ce_row12_col2\" class=\"data row12 col2\" >234.2ms</td>\n",
" <td id=\"T_435ce_row12_col3\" class=\"data row12 col3\" >191.4ms</td>\n",
" <td id=\"T_435ce_row12_col4\" class=\"data row12 col4\" >124</td>\n",
" <td id=\"T_435ce_row12_col5\" class=\"data row12 col5\" >29.04 s</td>\n",
" <td id=\"T_435ce_row12_col6\" class=\"data row12 col6\" >19%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row13_col0\" class=\"data row13 col0\" ></td>\n",
" <td id=\"T_435ce_row13_col1\" class=\"data row13 col1\" >pred</td>\n",
" <td id=\"T_435ce_row13_col2\" class=\"data row13 col2\" >42.37ms</td>\n",
" <td id=\"T_435ce_row13_col3\" class=\"data row13 col3\" >11.61ms</td>\n",
" <td id=\"T_435ce_row13_col4\" class=\"data row13 col4\" >124</td>\n",
" <td id=\"T_435ce_row13_col5\" class=\"data row13 col5\" >5.254 s</td>\n",
" <td id=\"T_435ce_row13_col6\" class=\"data row13 col6\" >3%</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_435ce_row14_col0\" class=\"data row14 col0\" ></td>\n",
" <td id=\"T_435ce_row14_col1\" class=\"data row14 col1\" >loss</td>\n",
" <td id=\"T_435ce_row14_col2\" class=\"data row14 col2\" >2.227ms</td>\n",
" <td id=\"T_435ce_row14_col3\" class=\"data row14 col3\" >2.896ms</td>\n",
" <td id=\"T_435ce_row14_col4\" class=\"data row14 col4\" >124</td>\n",
" <td id=\"T_435ce_row14_col5\" class=\"data row14 col5\" >276.1ms</td>\n",
" <td id=\"T_435ce_row14_col6\" class=\"data row14 col6\" >0%</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"# Cast to TensorBase instead of torch.tensor or TensorImage"
],
"metadata": {
"id": "nhSzdwgANYP2"
}
},
{
"cell_type": "code",
"source": [
"class ChannelsLastTensor(LoseTypeTransform):\n",
" \"Sets image-like inputs to `channels_last` format. For use in ChannelsLastCallback\"\n",
" order = 110 # run after all other transforms if added to batch_tfms\n",
" def encodes(self, x:TensorImageBase|TensorMask):\n",
" return TensorBase(x).to(memory_format=torch.channels_last)"
],
"metadata": {
"id": "l92GVy9zNWh0"
},
"execution_count": 18,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class ChannelsLastCallback(Callback):\n",
" \"Channels last training using PyTorch's Channels Last Memory Format (beta)\"\n",
" order = MixedPrecision.order+1\n",
" def __init__(self, astensor=False):\n",
" if astensor: self._channels_last = Pipeline([ChannelsLastTensor()])\n",
" else: self._channels_last = Pipeline([ChannelsLastTfm()])\n",
"\n",
" def before_fit(self):\n",
" self.learn.model.to(memory_format=torch.channels_last)\n",
"\n",
" def before_batch(self):\n",
" self.learn.xb = self._channels_last(self.xb)"
],
"metadata": {
"id": "ve-xXsTBNWh9"
},
"execution_count": 19,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Warmup"
],
"metadata": {
"id": "mNhlK0t5Ni3d"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" opt_func=Adam,\n",
" metrics=[Accuracy()]).to_channelslast()\n",
"\n",
"learn.fit_one_cycle(1, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 81
},
"outputId": "fd7925eb-6023-4561-8c84-0588c66742d0",
"id": "1K0ZxQcGNi3e"
},
"execution_count": 20,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.829382</td>\n",
" <td>1.654897</td>\n",
" <td>0.438726</td>\n",
" <td>01:35</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"## Test"
],
"metadata": {
"id": "B-mI8ifUNi3g"
}
},
{
"cell_type": "code",
"source": [
"dls = get_dls(256, False, 64)\n",
"learn = Learner(dls, resnet50(num_classes=dls.c), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" opt_func=Adam,\n",
" metrics=[Accuracy()]).to_channelslast().profile()\n",
"\n",
"learn.fit_one_cycle(2, 3e-3)\n",
"\n",
"free_gpu_memory(learn, dls)"
],
"metadata": {
"colab": {