Skip to content

Instantly share code, notes, and snippets.

@Ben-Karr
Created November 10, 2022 10:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Ben-Karr/c663d4ad8f12ca9ceba9769eeba34028 to your computer and use it in GitHub Desktop.
Save Ben-Karr/c663d4ad8f12ca9ceba9769eeba34028 to your computer and use it in GitHub Desktop.
coding/tmp_store/[fastai]-Captum-Callback.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"trusted": true
},
"id": "5f376a0d",
"cell_type": "code",
"source": "from fastai.vision.all import *\nfrom fastai.callback.captum import *\nfrom random import randint",
"execution_count": 1,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "fa890213",
"cell_type": "code",
"source": "path = untar_data(URLs.PASCAL_2007)\ndf = pd.read_csv(path/'train.csv')\ndf.columns = ['image_id', 'labels', 'valid']\ndf['image_id'] = df.image_id.apply(lambda x: path/'train'/x)",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "ef31b840",
"cell_type": "code",
"source": "def splitter(df):\n train = df.index[~df['valid']].to_list()\n valid = df.index[df['valid']].to_list()\n return train, valid\n\ndef get_x(r): return r['image_id']\ndef get_y(r): return r['labels'].split(' ')",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "05143823",
"cell_type": "code",
"source": "def get_dblock(size):\n return DataBlock(\n blocks=(ImageBlock,MultiCategoryBlock()),\n get_x=get_x,\n get_y=get_y,\n splitter=splitter,\n item_tfms=Resize(size, method=ResizeMethod.Squish),\n batch_tfms=[*aug_transforms(size=size, min_scale=1), Normalize.from_stats(*imagenet_stats)]\n )\n\ndblock = get_dblock((90))\ndls = dblock.dataloaders(df, n_workers=8)",
"execution_count": 4,
"outputs": []
},
{
"metadata": {
"trusted": true,
"scrolled": true
},
"id": "cbaf4c6a",
"cell_type": "code",
"source": "learn = vision_learner(dls, 'convnext_small', metrics=partial(accuracy_multi, thresh=0.5))\nlearn.fit_one_cycle(1)",
"execution_count": 5,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<IPython.core.display.HTML object>",
"text/html": "\n<style>\n /* Turns off some styling */\n progress {\n /* gets rid of default border in Firefox and Opera. */\n border: none;\n /* Needs to be in here for Safari polyfill so background images work as expected. */\n background-size: auto;\n }\n progress:not([value]), progress:not([value])::-webkit-progress-bar {\n background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n }\n .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n background: #F44336;\n }\n</style>\n"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": "<IPython.core.display.HTML object>",
"text/html": "<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>epoch</th>\n <th>train_loss</th>\n <th>valid_loss</th>\n <th>accuracy_multi</th>\n <th>time</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <td>0</td>\n <td>0.904057</td>\n <td>0.656957</td>\n <td>0.652889</td>\n <td>00:09</td>\n </tr>\n </tbody>\n</table>"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Reproduce error"
},
{
"metadata": {
"trusted": true,
"scrolled": false
},
"cell_type": "code",
"source": "captum=CaptumInterpretation(learn)\nidx=randint(0,len(dls.valid.items))\nf = dls.valid.items.iloc[[idx]]\ncaptum.visualize(f)",
"execution_count": 6,
"outputs": [
{
"output_type": "error",
"ename": "RuntimeError",
"evalue": "grid_sampler(): expected grid to have size 1 in last dimension, but got grid with sizes [3, 90, 90, 2]",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [6], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m idx\u001b[38;5;241m=\u001b[39mrandint(\u001b[38;5;241m0\u001b[39m,\u001b[38;5;28mlen\u001b[39m(dls\u001b[38;5;241m.\u001b[39mvalid\u001b[38;5;241m.\u001b[39mitems))\n\u001b[1;32m 3\u001b[0m f \u001b[38;5;241m=\u001b[39m dls\u001b[38;5;241m.\u001b[39mvalid\u001b[38;5;241m.\u001b[39mitems\u001b[38;5;241m.\u001b[39miloc[[idx]]\n\u001b[0;32m----> 4\u001b[0m captum\u001b[38;5;241m.\u001b[39mvisualize(f)\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/fastai/callback/captum.py:55\u001b[0m, in \u001b[0;36mCaptumInterpretation.visualize\u001b[0;34m(self, inp, metric, n_steps, baseline_type, nt_type, strides, sliding_window_shapes)\u001b[0m\n\u001b[1;32m 53\u001b[0m tls \u001b[38;5;241m=\u001b[39m L([TfmdLists(inp, t) \u001b[38;5;28;01mfor\u001b[39;00m t \u001b[38;5;129;01min\u001b[39;00m L(ifnone(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdls\u001b[38;5;241m.\u001b[39mtfms,[\u001b[38;5;28;01mNone\u001b[39;00m]))])\n\u001b[1;32m 54\u001b[0m inp_data\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mzip\u001b[39m(\u001b[38;5;241m*\u001b[39m(tls[\u001b[38;5;241m0\u001b[39m],tls[\u001b[38;5;241m1\u001b[39m])))[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m---> 55\u001b[0m enc_data,dec_data\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_enc_dec_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43minp_data\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 56\u001b[0m attributions\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_attributions(enc_data,metric,n_steps,nt_type,baseline_type,strides,sliding_window_shapes)\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_viz(attributions,dec_data,metric)\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/fastai/callback/captum.py:73\u001b[0m, in \u001b[0;36mCaptumInterpretation._get_enc_dec_data\u001b[0;34m(self, inp_data)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_get_enc_dec_data\u001b[39m(\u001b[38;5;28mself\u001b[39m,inp_data):\n\u001b[1;32m 72\u001b[0m dec_data\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdls\u001b[38;5;241m.\u001b[39mafter_item(inp_data)\n\u001b[0;32m---> 73\u001b[0m enc_data\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdls\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mafter_batch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mto_device\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdls\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbefore_batch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdec_data\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdls\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m(enc_data,dec_data)\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/fastcore/transform.py:208\u001b[0m, in \u001b[0;36mPipeline.__call__\u001b[0;34m(self, o)\u001b[0m\n\u001b[0;32m--> 208\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, o): \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcompose_tfms\u001b[49m\u001b[43m(\u001b[49m\u001b[43mo\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtfms\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msplit_idx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msplit_idx\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/fastcore/transform.py:158\u001b[0m, in \u001b[0;36mcompose_tfms\u001b[0;34m(x, tfms, is_enc, reverse, **kwargs)\u001b[0m\n\u001b[1;32m 156\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m f \u001b[38;5;129;01min\u001b[39;00m tfms:\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_enc: f \u001b[38;5;241m=\u001b[39m f\u001b[38;5;241m.\u001b[39mdecode\n\u001b[0;32m--> 158\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 159\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/fastai/vision/augment.py:49\u001b[0m, in \u001b[0;36mRandTransform.__call__\u001b[0;34m(self, b, split_idx, **kwargs)\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \n\u001b[1;32m 44\u001b[0m b, \n\u001b[1;32m 45\u001b[0m split_idx:\u001b[38;5;28mint\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;66;03m# Index of the train/valid dataset\u001b[39;00m\n\u001b[1;32m 46\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m 47\u001b[0m ):\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbefore_call(b, split_idx\u001b[38;5;241m=\u001b[39msplit_idx)\n\u001b[0;32m---> 49\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__call__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msplit_idx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msplit_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdo \u001b[38;5;28;01melse\u001b[39;00m b\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/fastcore/transform.py:81\u001b[0m, in \u001b[0;36mTransform.__call__\u001b[0;34m(self, x, **kwargs)\u001b[0m\n\u001b[0;32m---> 81\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs): \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mencodes\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/fastcore/transform.py:91\u001b[0m, in \u001b[0;36mTransform._call\u001b[0;34m(self, fn, x, split_idx, **kwargs)\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_call\u001b[39m(\u001b[38;5;28mself\u001b[39m, fn, x, split_idx\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m split_idx\u001b[38;5;241m!=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msplit_idx \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msplit_idx \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m: \u001b[38;5;28;01mreturn\u001b[39;00m x\n\u001b[0;32m---> 91\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_do_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfn\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/fastcore/transform.py:98\u001b[0m, in \u001b[0;36mTransform._do_call\u001b[0;34m(self, f, x, **kwargs)\u001b[0m\n\u001b[1;32m 96\u001b[0m ret \u001b[38;5;241m=\u001b[39m f\u001b[38;5;241m.\u001b[39mreturns(x) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(f,\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mreturns\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m retain_type(f(x, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs), x, ret)\n\u001b[0;32m---> 98\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mtuple\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_do_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mx_\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 99\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m retain_type(res, x)\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/fastcore/transform.py:98\u001b[0m, in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 96\u001b[0m ret \u001b[38;5;241m=\u001b[39m f\u001b[38;5;241m.\u001b[39mreturns(x) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(f,\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mreturns\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m retain_type(f(x, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs), x, ret)\n\u001b[0;32m---> 98\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtuple\u001b[39m(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_do_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m x_ \u001b[38;5;129;01min\u001b[39;00m x)\n\u001b[1;32m 99\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m retain_type(res, x)\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/fastcore/transform.py:97\u001b[0m, in \u001b[0;36mTransform._do_call\u001b[0;34m(self, f, x, **kwargs)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m f \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m: \u001b[38;5;28;01mreturn\u001b[39;00m x\n\u001b[1;32m 96\u001b[0m ret \u001b[38;5;241m=\u001b[39m f\u001b[38;5;241m.\u001b[39mreturns(x) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(f,\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mreturns\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m---> 97\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m retain_type(\u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m, x, ret)\n\u001b[1;32m 98\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtuple\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_do_call(f, x_, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;28;01mfor\u001b[39;00m x_ \u001b[38;5;129;01min\u001b[39;00m x)\n\u001b[1;32m 99\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m retain_type(res, x)\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/fastcore/dispatch.py:120\u001b[0m, in \u001b[0;36mTypeDispatch.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minst \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m: f \u001b[38;5;241m=\u001b[39m MethodType(f, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minst)\n\u001b[1;32m 119\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mowner \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m: f \u001b[38;5;241m=\u001b[39m MethodType(f, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mowner)\n\u001b[0;32m--> 120\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/fastai/vision/augment.py:501\u001b[0m, in \u001b[0;36mAffineCoordTfm.encodes\u001b[0;34m(self, x)\u001b[0m\n\u001b[0;32m--> 501\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mencodes\u001b[39m(\u001b[38;5;28mself\u001b[39m, x:TensorImage): \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_encode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/fastai/vision/augment.py:499\u001b[0m, in \u001b[0;36mAffineCoordTfm._encode\u001b[0;34m(self, x, mode, reverse)\u001b[0m\n\u001b[1;32m 497\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_encode\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, mode, reverse\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[1;32m 498\u001b[0m coord_func \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcoord_fs)\u001b[38;5;241m==\u001b[39m\u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msplit_idx \u001b[38;5;28;01melse\u001b[39;00m partial(compose_tfms, tfms\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcoord_fs, reverse\u001b[38;5;241m=\u001b[39mreverse)\n\u001b[0;32m--> 499\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maffine_coord\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmat\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcoord_func\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msz\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpad_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpad_mode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malign_corners\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43malign_corners\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/fastai/vision/augment.py:391\u001b[0m, in \u001b[0;36maffine_coord\u001b[0;34m(x, mat, coord_tfm, sz, mode, pad_mode, align_corners)\u001b[0m\n\u001b[1;32m 389\u001b[0m coords \u001b[38;5;241m=\u001b[39m affine_grid(mat, x\u001b[38;5;241m.\u001b[39mshape[:\u001b[38;5;241m2\u001b[39m] \u001b[38;5;241m+\u001b[39m size, align_corners\u001b[38;5;241m=\u001b[39malign_corners)\n\u001b[1;32m 390\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m coord_tfm \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m: coords \u001b[38;5;241m=\u001b[39m coord_tfm(coords)\n\u001b[0;32m--> 391\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m TensorImage(\u001b[43m_grid_sample\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcoords\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpadding_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpad_mode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malign_corners\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43malign_corners\u001b[49m\u001b[43m)\u001b[49m)\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/fastai/vision/augment.py:364\u001b[0m, in \u001b[0;36m_grid_sample\u001b[0;34m(x, coords, mode, padding_mode, align_corners)\u001b[0m\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m d\u001b[38;5;241m>\u001b[39m\u001b[38;5;241m1\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m d\u001b[38;5;241m>\u001b[39mz:\n\u001b[1;32m 363\u001b[0m x \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39minterpolate(x, scale_factor\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;241m/\u001b[39md, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124marea\u001b[39m\u001b[38;5;124m'\u001b[39m, recompute_scale_factor\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m--> 364\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgrid_sample\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcoords\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpadding_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpadding_mode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malign_corners\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43malign_corners\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/torch/nn/functional.py:4185\u001b[0m, in \u001b[0;36mgrid_sample\u001b[0;34m(input, grid, mode, padding_mode, align_corners)\u001b[0m\n\u001b[1;32m 4085\u001b[0m \u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Given an :attr:`input` and a flow-field :attr:`grid`, computes the\u001b[39;00m\n\u001b[1;32m 4086\u001b[0m \u001b[38;5;124;03m``output`` using :attr:`input` values and pixel locations from :attr:`grid`.\u001b[39;00m\n\u001b[1;32m 4087\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 4182\u001b[0m \u001b[38;5;124;03m.. _`OpenCV`: https://github.com/opencv/opencv/blob/f345ed564a06178670750bad59526cfa4033be55/modules/imgproc/src/resize.cpp#L908\u001b[39;00m\n\u001b[1;32m 4183\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 4184\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_variadic(\u001b[38;5;28minput\u001b[39m, grid):\n\u001b[0;32m-> 4185\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mhandle_torch_function\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 4186\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrid_sample\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrid\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrid\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpadding_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpadding_mode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malign_corners\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43malign_corners\u001b[49m\n\u001b[1;32m 4187\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbilinear\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnearest\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbicubic\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 4189\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 4190\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnn.functional.grid_sample(): expected mode to be \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 4191\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbilinear\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mnearest\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m or \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbicubic\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, but got: \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(mode)\n\u001b[1;32m 4192\u001b[0m )\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/torch/overrides.py:1498\u001b[0m, in \u001b[0;36mhandle_torch_function\u001b[0;34m(public_api, relevant_args, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1492\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDefining your `__torch_function__ as a plain method is deprecated and \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1493\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwill be an error in future, please define it as a classmethod.\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1494\u001b[0m \u001b[38;5;167;01mDeprecationWarning\u001b[39;00m)\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# Use `public_api` instead of `implementation` so __torch_function__\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# implementations can do equality/identity comparisons.\u001b[39;00m\n\u001b[0;32m-> 1498\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mtorch_func_method\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpublic_api\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mNotImplemented\u001b[39m:\n\u001b[1;32m 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m result\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/fastai/torch_core.py:376\u001b[0m, in \u001b[0;36mTensorBase.__torch_function__\u001b[0;34m(cls, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m 374\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mdebug \u001b[38;5;129;01mand\u001b[39;00m func\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m__str__\u001b[39m\u001b[38;5;124m'\u001b[39m,\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m__repr__\u001b[39m\u001b[38;5;124m'\u001b[39m): \u001b[38;5;28mprint\u001b[39m(func, types, args, kwargs)\n\u001b[1;32m 375\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _torch_handled(args, \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_opt, func): types \u001b[38;5;241m=\u001b[39m (torch\u001b[38;5;241m.\u001b[39mTensor,)\n\u001b[0;32m--> 376\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__torch_function__\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mifnone\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m{\u001b[49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 377\u001b[0m dict_objs \u001b[38;5;241m=\u001b[39m _find_args(args) \u001b[38;5;28;01mif\u001b[39;00m args \u001b[38;5;28;01melse\u001b[39;00m _find_args(\u001b[38;5;28mlist\u001b[39m(kwargs\u001b[38;5;241m.\u001b[39mvalues()))\n\u001b[1;32m 378\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28missubclass\u001b[39m(\u001b[38;5;28mtype\u001b[39m(res),TensorBase) \u001b[38;5;129;01mand\u001b[39;00m dict_objs: res\u001b[38;5;241m.\u001b[39mset_meta(dict_objs[\u001b[38;5;241m0\u001b[39m],as_copy\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/torch/_tensor.py:1121\u001b[0m, in \u001b[0;36mTensor.__torch_function__\u001b[0;34m(cls, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m 1118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mNotImplemented\u001b[39m\n\u001b[1;32m 1120\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _C\u001b[38;5;241m.\u001b[39mDisableTorchFunction():\n\u001b[0;32m-> 1121\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1122\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m func \u001b[38;5;129;01min\u001b[39;00m get_default_nowrap_functions():\n\u001b[1;32m 1123\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ret\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/torch/nn/functional.py:4223\u001b[0m, in \u001b[0;36mgrid_sample\u001b[0;34m(input, grid, mode, padding_mode, align_corners)\u001b[0m\n\u001b[1;32m 4215\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m 4216\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDefault grid_sample and affine_grid behavior has changed \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 4217\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mto align_corners=False since 1.3.0. Please specify \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 4218\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124malign_corners=True if the old behavior is desired. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 4219\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSee the documentation of grid_sample for details.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 4220\u001b[0m )\n\u001b[1;32m 4221\u001b[0m align_corners \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m-> 4223\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgrid_sampler\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrid\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode_enum\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpadding_mode_enum\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malign_corners\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mRuntimeError\u001b[0m: grid_sampler(): expected grid to have size 1 in last dimension, but got grid with sizes [3, 90, 90, 2]"
]
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Remove `Flip` from transforms"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn.dls.after_batch[1]",
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 7,
"data": {
"text/plain": "Flip -- {'size': 90, 'mode': 'bilinear', 'pad_mode': 'reflection', 'mode_mask': 'nearest', 'align_corners': True, 'p': 0.5}:\nencodes: (TensorImage,object) -> encodes\n(TensorMask,object) -> encodes\n(TensorBBox,object) -> encodes\n(TensorPoint,object) -> encodes\ndecodes: "
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn.dls.after_batch = Pipeline(funcs=[t for i,t in enumerate(learn.dls.after_batch) if i!=1])",
"execution_count": 8,
"outputs": []
},
{
"metadata": {
"trusted": true,
"scrolled": true
},
"cell_type": "code",
"source": "captum=CaptumInterpretation(learn)\nidx=randint(0,len(dls.valid.items))\nf = dls.valid.items.iloc[[idx]]\ncaptum.visualize(f)",
"execution_count": 9,
"outputs": [
{
"output_type": "error",
"ename": "AssertionError",
"evalue": "Tensor target dimension torch.Size([4000]) is not valid. torch.Size([200, 20])",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [9], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m idx\u001b[38;5;241m=\u001b[39mrandint(\u001b[38;5;241m0\u001b[39m,\u001b[38;5;28mlen\u001b[39m(dls\u001b[38;5;241m.\u001b[39mvalid\u001b[38;5;241m.\u001b[39mitems))\n\u001b[1;32m 3\u001b[0m f \u001b[38;5;241m=\u001b[39m dls\u001b[38;5;241m.\u001b[39mvalid\u001b[38;5;241m.\u001b[39mitems\u001b[38;5;241m.\u001b[39miloc[[idx]]\n\u001b[0;32m----> 4\u001b[0m captum\u001b[38;5;241m.\u001b[39mvisualize(f)\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/fastai/callback/captum.py:56\u001b[0m, in \u001b[0;36mCaptumInterpretation.visualize\u001b[0;34m(self, inp, metric, n_steps, baseline_type, nt_type, strides, sliding_window_shapes)\u001b[0m\n\u001b[1;32m 54\u001b[0m inp_data\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mzip\u001b[39m(\u001b[38;5;241m*\u001b[39m(tls[\u001b[38;5;241m0\u001b[39m],tls[\u001b[38;5;241m1\u001b[39m])))[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 55\u001b[0m enc_data,dec_data\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_enc_dec_data(inp_data)\n\u001b[0;32m---> 56\u001b[0m attributions\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_attributions\u001b[49m\u001b[43m(\u001b[49m\u001b[43menc_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43mmetric\u001b[49m\u001b[43m,\u001b[49m\u001b[43mn_steps\u001b[49m\u001b[43m,\u001b[49m\u001b[43mnt_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43mbaseline_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43mstrides\u001b[49m\u001b[43m,\u001b[49m\u001b[43msliding_window_shapes\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_viz(attributions,dec_data,metric)\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/fastai/callback/captum.py:82\u001b[0m, in \u001b[0;36mCaptumInterpretation._get_attributions\u001b[0;34m(self, enc_data, metric, n_steps, nt_type, baseline_type, strides, sliding_window_shapes)\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m metric \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mIG\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 81\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_int_grads \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_int_grads \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(\u001b[38;5;28mself\u001b[39m,\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m_int_grads\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m IntegratedGradients(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel)\n\u001b[0;32m---> 82\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_int_grads\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattribute\u001b[49m\u001b[43m(\u001b[49m\u001b[43menc_data\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43mbaseline\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43menc_data\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m200\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 83\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m metric \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mNT\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 84\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_int_grads \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_int_grads \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(\u001b[38;5;28mself\u001b[39m,\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m_int_grads\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m IntegratedGradients(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel)\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/captum/log/__init__.py:35\u001b[0m, in \u001b[0;36mlog_usage.<locals>._log_usage.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[1;32m 34\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 35\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/captum/attr/_core/integrated_gradients.py:285\u001b[0m, in \u001b[0;36mIntegratedGradients.attribute\u001b[0;34m(self, inputs, baselines, target, additional_forward_args, n_steps, method, internal_batch_size, return_convergence_delta)\u001b[0m\n\u001b[1;32m 273\u001b[0m attributions \u001b[38;5;241m=\u001b[39m _batch_attribution(\n\u001b[1;32m 274\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 275\u001b[0m num_examples,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 282\u001b[0m method\u001b[38;5;241m=\u001b[39mmethod,\n\u001b[1;32m 283\u001b[0m )\n\u001b[1;32m 284\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 285\u001b[0m attributions \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_attribute\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 286\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 287\u001b[0m \u001b[43m \u001b[49m\u001b[43mbaselines\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbaselines\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 288\u001b[0m \u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 289\u001b[0m \u001b[43m \u001b[49m\u001b[43madditional_forward_args\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43madditional_forward_args\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 290\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 291\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 292\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 294\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m return_convergence_delta:\n\u001b[1;32m 295\u001b[0m start_point, end_point \u001b[38;5;241m=\u001b[39m baselines, inputs\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/captum/attr/_core/integrated_gradients.py:350\u001b[0m, in \u001b[0;36mIntegratedGradients._attribute\u001b[0;34m(self, inputs, baselines, target, additional_forward_args, n_steps, method, step_sizes_and_alphas)\u001b[0m\n\u001b[1;32m 347\u001b[0m expanded_target \u001b[38;5;241m=\u001b[39m _expand_target(target, n_steps)\n\u001b[1;32m 349\u001b[0m \u001b[38;5;66;03m# grads: dim -> (bsz * #steps x inputs[0].shape[1:], ...)\u001b[39;00m\n\u001b[0;32m--> 350\u001b[0m grads \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgradient_func\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 351\u001b[0m \u001b[43m \u001b[49m\u001b[43mforward_fn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward_func\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 352\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mscaled_features_tpl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 353\u001b[0m \u001b[43m \u001b[49m\u001b[43mtarget_ind\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mexpanded_target\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 354\u001b[0m \u001b[43m \u001b[49m\u001b[43madditional_forward_args\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_additional_args\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 355\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 357\u001b[0m \u001b[38;5;66;03m# flattening grads so that we can multilpy it with step-size\u001b[39;00m\n\u001b[1;32m 358\u001b[0m \u001b[38;5;66;03m# calling contiguous to avoid `memory whole` problems\u001b[39;00m\n\u001b[1;32m 359\u001b[0m scaled_grads \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 360\u001b[0m grad\u001b[38;5;241m.\u001b[39mcontiguous()\u001b[38;5;241m.\u001b[39mview(n_steps, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 361\u001b[0m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor(step_sizes)\u001b[38;5;241m.\u001b[39mview(n_steps, \u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mto(grad\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m grad \u001b[38;5;129;01min\u001b[39;00m grads\n\u001b[1;32m 363\u001b[0m ]\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/captum/_utils/gradient.py:112\u001b[0m, in \u001b[0;36mcompute_gradients\u001b[0;34m(forward_fn, inputs, target_ind, additional_forward_args)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;124;03mComputes gradients of the output with respect to inputs for an\u001b[39;00m\n\u001b[1;32m 96\u001b[0m \u001b[38;5;124;03marbitrary forward function.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[38;5;124;03m arguments) if no additional arguments are required\u001b[39;00m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28;01mTrue\u001b[39;00m):\n\u001b[1;32m 111\u001b[0m \u001b[38;5;66;03m# runs forward pass\u001b[39;00m\n\u001b[0;32m--> 112\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43m_run_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mforward_fn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_ind\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43madditional_forward_args\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m outputs[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mnumel() \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m, (\n\u001b[1;32m 114\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTarget not provided when necessary, cannot\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 115\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m take gradient with respect to multiple outputs.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 116\u001b[0m )\n\u001b[1;32m 117\u001b[0m \u001b[38;5;66;03m# torch.unbind(forward_out) is a list of scalar tensor tuples and\u001b[39;00m\n\u001b[1;32m 118\u001b[0m \u001b[38;5;66;03m# contains batch_size * #steps elements\u001b[39;00m\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/captum/_utils/common.py:461\u001b[0m, in \u001b[0;36m_run_forward\u001b[0;34m(forward_func, inputs, target, additional_forward_args)\u001b[0m\n\u001b[1;32m 454\u001b[0m additional_forward_args \u001b[38;5;241m=\u001b[39m _format_additional_forward_args(additional_forward_args)\n\u001b[1;32m 456\u001b[0m output \u001b[38;5;241m=\u001b[39m forward_func(\n\u001b[1;32m 457\u001b[0m \u001b[38;5;241m*\u001b[39m(\u001b[38;5;241m*\u001b[39minputs, \u001b[38;5;241m*\u001b[39madditional_forward_args)\n\u001b[1;32m 458\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m additional_forward_args \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 459\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m inputs\n\u001b[1;32m 460\u001b[0m )\n\u001b[0;32m--> 461\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_select_targets\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutput\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/mambaforge/lib/python3.9/site-packages/captum/_utils/common.py:480\u001b[0m, in \u001b[0;36m_select_targets\u001b[0;34m(output, target)\u001b[0m\n\u001b[1;32m 478\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mgather(output, \u001b[38;5;241m1\u001b[39m, target\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;28mlen\u001b[39m(output), \u001b[38;5;241m1\u001b[39m))\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 480\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAssertionError\u001b[39;00m(\n\u001b[1;32m 481\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTensor target dimension \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m is not valid. \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 482\u001b[0m \u001b[38;5;241m%\u001b[39m (target\u001b[38;5;241m.\u001b[39mshape, output\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 483\u001b[0m )\n\u001b[1;32m 484\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(target, \u001b[38;5;28mlist\u001b[39m):\n\u001b[1;32m 485\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(target) \u001b[38;5;241m==\u001b[39m num_examples, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTarget list length does not match output!\u001b[39m\u001b[38;5;124m\"\u001b[39m\n",
"\u001b[0;31mAssertionError\u001b[0m: Tensor target dimension torch.Size([4000]) is not valid. torch.Size([200, 20])"
]
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Adjust for MultiCategories"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "from captum.attr import IntegratedGradients,NoiseTunnel,GradientShap,Occlusion\nfrom matplotlib.colors import LinearSegmentedColormap\nfrom captum.attr import visualization as viz",
"execution_count": 10,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Hardcode particular target category"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "target_idx = 12\n\n@patch\ndef _get_attributions(self:CaptumInterpretation,enc_data,metric,n_steps,nt_type,baseline_type,strides,sliding_window_shapes):\n # Hardcode target category idx\n enc_data = (enc_data[0], target_idx)\n # Get Baseline\n baseline=self.get_baseline_img(enc_data[0],baseline_type)\n supported_metrics ={}\n if metric == 'IG':\n self._int_grads = self._int_grads if hasattr(self,'_int_grads') else IntegratedGradients(self.model)\n return self._int_grads.attribute(enc_data[0],baseline, target=enc_data[1], n_steps=200)\n elif metric == 'NT':\n self._int_grads = self._int_grads if hasattr(self,'_int_grads') else IntegratedGradients(self.model)\n self._noise_tunnel= self._noise_tunnel if hasattr(self,'_noise_tunnel') else NoiseTunnel(self._int_grads)\n return self._noise_tunnel.attribute(enc_data[0].to(self.dls.device), nt_samples=1, nt_type=nt_type, target=enc_data[1])\n elif metric == 'Occl':\n self._occlusion = self._occlusion if hasattr(self,'_occlusion') else Occlusion(self.model)\n return self._occlusion.attribute(enc_data[0].to(self.dls.device),\n strides = strides,\n target=enc_data[1],\n sliding_window_shapes=sliding_window_shapes,\n baselines=baseline)",
"execution_count": 11,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "captum=CaptumInterpretation(learn)",
"execution_count": 12,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "idx=randint(0,len(dls.valid.items))\nf = dls.valid.items.iloc[[idx]]\ncaptum.visualize(f)",
"execution_count": 13,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 800x600 with 4 Axes>",
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Average multiple categories"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "@patch\ndef _get_attributions(self:CaptumInterpretation,enc_data,metric,n_steps,nt_type,baseline_type,strides,sliding_window_shapes):\n # Get Baseline\n baseline=self.get_baseline_img(enc_data[0],baseline_type)\n supported_metrics ={}\n if metric == 'IG':\n attribute = torch.zeros(enc_data[0].shape, device=self.dls.device)\n self._int_grads = self._int_grads if hasattr(self,'_int_grads') else IntegratedGradients(self.model)\n c = 0\n #for k in range(len(enc_data[1])): ## Average all classes\n for k in torch.argwhere(enc_data[1]==1).tolist(): ## Average predicted classes\n c+=1\n attribute += self._int_grads.attribute(enc_data[0],baseline, target=k, n_steps=200)\n return attribute/c\n elif metric == 'NT':\n self._int_grads = self._int_grads if hasattr(self,'_int_grads') else IntegratedGradients(self.model)\n self._noise_tunnel= self._noise_tunnel if hasattr(self,'_noise_tunnel') else NoiseTunnel(self._int_grads)\n return self._noise_tunnel.attribute(enc_data[0].to(self.dls.device), nt_samples=1, nt_type=nt_type, target=enc_data[1])\n elif metric == 'Occl':\n self._occlusion = self._occlusion if hasattr(self,'_occlusion') else Occlusion(self.model)\n return self._occlusion.attribute(enc_data[0].to(self.dls.device),\n strides = strides,\n target=enc_data[1],\n sliding_window_shapes=sliding_window_shapes,\n baselines=baseline)",
"execution_count": 14,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "captum=CaptumInterpretation(learn)",
"execution_count": 15,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "idx=randint(0,len(dls.valid.items))\nf = dls.valid.items.iloc[[idx]]\ncaptum.visualize(f)",
"execution_count": 16,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 800x600 with 4 Axes>",
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3 (ipykernel)",
"language": "python"
},
"toc": {
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"base_numbering": 1,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": true
},
"language_info": {
"name": "python",
"version": "3.9.13",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"gist": {
"id": "",
"data": {
"description": "coding/tmp_store/[fastai]-Captum-Callback.ipynb",
"public": true
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment