Skip to content

Instantly share code, notes, and snippets.

@ottonemo
Created March 9, 2021 18:23
Show Gist options
  • Save ottonemo/f8771c3a7f0f6abf6afb8ae157b673ba to your computer and use it in GitHub Desktop.
Save ottonemo/f8771c3a7f0f6abf6afb8ae157b673ba to your computer and use it in GitHub Desktop.
[[package]]
name = "atomicwrites"
version = "1.4.0"
description = "Atomic file writes."
category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
[[package]]
name = "attrs"
version = "20.3.0"
description = "Classes Without Boilerplate"
category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
[package.extras]
dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "furo", "sphinx", "pre-commit"]
docs = ["furo", "sphinx", "zope.interface"]
tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface"]
tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six"]
[[package]]
name = "colorama"
version = "0.4.4"
description = "Cross-platform colored terminal text."
category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
[[package]]
name = "more-itertools"
version = "8.7.0"
description = "More routines for operating on iterables, beyond itertools"
category = "dev"
optional = false
python-versions = ">=3.5"
[[package]]
name = "packaging"
version = "20.9"
description = "Core utilities for Python packages"
category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
[package.dependencies]
pyparsing = ">=2.0.2"
[[package]]
name = "pluggy"
version = "0.13.1"
description = "plugin and hook calling mechanisms for python"
category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
[package.extras]
dev = ["pre-commit", "tox"]
[[package]]
name = "py"
version = "1.10.0"
description = "library with cross-python path, ini-parsing, io, code, log facilities"
category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
[[package]]
name = "pyparsing"
version = "2.4.7"
description = "Python parsing module"
category = "dev"
optional = false
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
[[package]]
name = "pytest"
version = "5.4.3"
description = "pytest: simple powerful testing with Python"
category = "dev"
optional = false
python-versions = ">=3.5"
[package.dependencies]
atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""}
attrs = ">=17.4.0"
colorama = {version = "*", markers = "sys_platform == \"win32\""}
more-itertools = ">=4.0.0"
packaging = "*"
pluggy = ">=0.12,<1.0"
py = ">=1.5.0"
wcwidth = "*"
[package.extras]
checkqa-mypy = ["mypy (==v0.761)"]
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"]
[[package]]
name = "wcwidth"
version = "0.2.5"
description = "Measures the displayed width of unicode strings in a terminal"
category = "dev"
optional = false
python-versions = "*"
[metadata]
lock-version = "1.1"
python-versions = "^3.8"
content-hash = "c27944f25b55067b06883f1cea204be7d97841a4b8228fab69b91895347494ad"
[metadata.files]
atomicwrites = [
{file = "atomicwrites-1.4.0-py2.py3-none-any.whl", hash = "sha256:6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197"},
{file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"},
]
attrs = [
{file = "attrs-20.3.0-py2.py3-none-any.whl", hash = "sha256:31b2eced602aa8423c2aea9c76a724617ed67cf9513173fd3a4f03e3a929c7e6"},
{file = "attrs-20.3.0.tar.gz", hash = "sha256:832aa3cde19744e49938b91fea06d69ecb9e649c93ba974535d08ad92164f700"},
]
colorama = [
{file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"},
{file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"},
]
more-itertools = [
{file = "more-itertools-8.7.0.tar.gz", hash = "sha256:c5d6da9ca3ff65220c3bfd2a8db06d698f05d4d2b9be57e1deb2be5a45019713"},
{file = "more_itertools-8.7.0-py3-none-any.whl", hash = "sha256:5652a9ac72209ed7df8d9c15daf4e1aa0e3d2ccd3c87f8265a0673cd9cbc9ced"},
]
packaging = [
{file = "packaging-20.9-py2.py3-none-any.whl", hash = "sha256:67714da7f7bc052e064859c05c595155bd1ee9f69f76557e21f051443c20947a"},
{file = "packaging-20.9.tar.gz", hash = "sha256:5b327ac1320dc863dca72f4514ecc086f31186744b84a230374cc1fd776feae5"},
]
pluggy = [
{file = "pluggy-0.13.1-py2.py3-none-any.whl", hash = "sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d"},
{file = "pluggy-0.13.1.tar.gz", hash = "sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0"},
]
py = [
{file = "py-1.10.0-py2.py3-none-any.whl", hash = "sha256:3b80836aa6d1feeaa108e046da6423ab8f6ceda6468545ae8d02d9d58d18818a"},
{file = "py-1.10.0.tar.gz", hash = "sha256:21b81bda15b66ef5e1a777a21c4dcd9c20ad3efd0b3f817e7a809035269e1bd3"},
]
pyparsing = [
{file = "pyparsing-2.4.7-py2.py3-none-any.whl", hash = "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b"},
{file = "pyparsing-2.4.7.tar.gz", hash = "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1"},
]
pytest = [
{file = "pytest-5.4.3-py3-none-any.whl", hash = "sha256:5c0db86b698e8f170ba4582a492248919255fcd4c79b1ee64ace34301fb589a1"},
{file = "pytest-5.4.3.tar.gz", hash = "sha256:7979331bfcba207414f5e1263b5a0f8f521d0f457318836a7355531ed1a4c7d8"},
]
wcwidth = [
{file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"},
{file = "wcwidth-0.2.5.tar.gz", hash = "sha256:c4d647b99872929fdb7bdcaa4fbe7f01413ed3d98077df798530e5b04f116c83"},
]
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Example from https://github.com/dreamquark-ai/tabnet/blob/develop/census_example.ipynb"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# tabnet example code"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.preprocessing import LabelEncoder\n",
"from sklearn.metrics import roc_auc_score\n",
"\n",
"from pytorch_tabnet.tab_model import TabNetClassifier\n",
"\n",
"import torch\n",
"import pandas as pd\n",
"import numpy as np\n",
"np.random.seed(0)\n",
"\n",
"import os\n",
"import wget\n",
"from pathlib import Path\n",
"\n",
"from matplotlib import pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\"\n",
"dataset_name = 'census-income'\n",
"out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"File already exists.\n"
]
}
],
"source": [
"out.parent.mkdir(parents=True, exist_ok=True)\n",
"if out.exists():\n",
" print(\"File already exists.\")\n",
"else:\n",
" print(\"Downloading file...\")\n",
" wget.download(url, out.as_posix())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"train = pd.read_csv(out)\n",
"target = ' <=50K'\n",
"if \"Set\" not in train.columns:\n",
" train[\"Set\"] = np.random.choice([\"train\", \"valid\", \"test\"], p =[.8, .1, .1], size=(train.shape[0],))\n",
"\n",
"train_indices = train[train.Set==\"train\"].index\n",
"valid_indices = train[train.Set==\"valid\"].index\n",
"test_indices = train[train.Set==\"test\"].index"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"39 73\n",
" State-gov 9\n",
" Bachelors 16\n",
" 13 16\n",
" Never-married 7\n",
" Adm-clerical 15\n",
" Not-in-family 6\n",
" White 5\n",
" Male 2\n",
" 2174 119\n",
" 0 92\n",
" 40 94\n",
" United-States 42\n",
" <=50K 2\n",
"Set 3\n"
]
}
],
"source": [
"nunique = train.nunique()\n",
"types = train.dtypes\n",
"\n",
"categorical_columns = []\n",
"categorical_dims = {}\n",
"for col in train.columns:\n",
" if types[col] == 'object' or nunique[col] < 200:\n",
" print(col, train[col].nunique())\n",
" l_enc = LabelEncoder()\n",
" train[col] = train[col].fillna(\"VV_likely\")\n",
" train[col] = l_enc.fit_transform(train[col].values)\n",
" categorical_columns.append(col)\n",
" categorical_dims[col] = len(l_enc.classes_)\n",
" else:\n",
" train.fillna(train.loc[train_indices, col].mean(), inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# check that pipeline accepts strings\n",
"train.loc[train[target]==0, target] = \"wealthy\"\n",
"train.loc[train[target]==1, target] = \"not_wealthy\""
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"unused_feat = ['Set']\n",
"\n",
"features = [ col for col in train.columns if col not in unused_feat+[target]] \n",
"\n",
"cat_idxs = [ i for i, f in enumerate(features) if f in categorical_columns]\n",
"\n",
"cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Device used : cuda\n"
]
}
],
"source": [
"clf = TabNetClassifier(cat_idxs=cat_idxs,\n",
" cat_dims=cat_dims,\n",
" cat_emb_dim=1,\n",
" optimizer_fn=torch.optim.Adam,\n",
" optimizer_params=dict(lr=2e-2),\n",
" scheduler_params={\"step_size\":50, # how to use learning rate scheduler\n",
" \"gamma\":0.9},\n",
" scheduler_fn=torch.optim.lr_scheduler.StepLR,\n",
" mask_type='entmax' # \"sparsemax\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"X_train = train[features].values[train_indices]\n",
"y_train = train[target].values[train_indices]\n",
"\n",
"X_valid = train[features].values[valid_indices]\n",
"y_valid = train[target].values[valid_indices]\n",
"\n",
"X_test = train[features].values[test_indices]\n",
"y_test = train[target].values[test_indices]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"max_epochs = 1000 if not os.getenv(\"CI\", False) else 2"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 0 | loss: 0.668 | train_auc: 0.75705 | valid_auc: 0.7551 | 0:00:02s\n",
"epoch 1 | loss: 0.52031 | train_auc: 0.81912 | valid_auc: 0.82696 | 0:00:04s\n",
"epoch 2 | loss: 0.47527 | train_auc: 0.84816 | valid_auc: 0.85195 | 0:00:06s\n",
"epoch 3 | loss: 0.45715 | train_auc: 0.86756 | valid_auc: 0.86571 | 0:00:08s\n",
"epoch 4 | loss: 0.43029 | train_auc: 0.88064 | valid_auc: 0.87487 | 0:00:10s\n",
"epoch 5 | loss: 0.41997 | train_auc: 0.89128 | valid_auc: 0.8849 | 0:00:12s\n",
"epoch 6 | loss: 0.40586 | train_auc: 0.898 | valid_auc: 0.88995 | 0:00:14s\n",
"epoch 7 | loss: 0.40141 | train_auc: 0.90266 | valid_auc: 0.89769 | 0:00:16s\n",
"epoch 8 | loss: 0.39187 | train_auc: 0.90459 | valid_auc: 0.8956 | 0:00:18s\n",
"epoch 9 | loss: 0.37791 | train_auc: 0.91019 | valid_auc: 0.90593 | 0:00:20s\n",
"epoch 10 | loss: 0.37631 | train_auc: 0.91394 | valid_auc: 0.90945 | 0:00:22s\n",
"epoch 11 | loss: 0.36412 | train_auc: 0.91093 | valid_auc: 0.90707 | 0:00:24s\n",
"epoch 12 | loss: 0.3587 | train_auc: 0.91243 | valid_auc: 0.90965 | 0:00:26s\n",
"epoch 13 | loss: 0.35557 | train_auc: 0.915 | valid_auc: 0.90905 | 0:00:29s\n",
"epoch 14 | loss: 0.34672 | train_auc: 0.9182 | valid_auc: 0.91487 | 0:00:31s\n",
"epoch 15 | loss: 0.35145 | train_auc: 0.92211 | valid_auc: 0.91805 | 0:00:33s\n",
"epoch 16 | loss: 0.34199 | train_auc: 0.92471 | valid_auc: 0.92013 | 0:00:35s\n",
"epoch 17 | loss: 0.3372 | train_auc: 0.9272 | valid_auc: 0.92226 | 0:00:37s\n",
"epoch 18 | loss: 0.34344 | train_auc: 0.92886 | valid_auc: 0.92452 | 0:00:39s\n",
"epoch 19 | loss: 0.34549 | train_auc: 0.92919 | valid_auc: 0.92233 | 0:00:41s\n",
"epoch 20 | loss: 0.33269 | train_auc: 0.93105 | valid_auc: 0.92654 | 0:00:43s\n",
"epoch 21 | loss: 0.32923 | train_auc: 0.93199 | valid_auc: 0.92505 | 0:00:45s\n",
"epoch 22 | loss: 0.33069 | train_auc: 0.93208 | valid_auc: 0.92693 | 0:00:47s\n",
"epoch 23 | loss: 0.3301 | train_auc: 0.93287 | valid_auc: 0.92766 | 0:00:49s\n",
"epoch 24 | loss: 0.33326 | train_auc: 0.93347 | valid_auc: 0.92745 | 0:00:51s\n",
"epoch 25 | loss: 0.32665 | train_auc: 0.93452 | valid_auc: 0.92802 | 0:00:53s\n",
"epoch 26 | loss: 0.32089 | train_auc: 0.93444 | valid_auc: 0.92747 | 0:00:55s\n",
"epoch 27 | loss: 0.32657 | train_auc: 0.93284 | valid_auc: 0.92749 | 0:00:57s\n",
"epoch 28 | loss: 0.32863 | train_auc: 0.93331 | valid_auc: 0.92529 | 0:00:59s\n",
"epoch 29 | loss: 0.32456 | train_auc: 0.93459 | valid_auc: 0.92775 | 0:01:01s\n",
"epoch 30 | loss: 0.3245 | train_auc: 0.93506 | valid_auc: 0.92776 | 0:01:03s\n",
"epoch 31 | loss: 0.31973 | train_auc: 0.93558 | valid_auc: 0.92732 | 0:01:05s\n",
"epoch 32 | loss: 0.32807 | train_auc: 0.9334 | valid_auc: 0.92574 | 0:01:07s\n",
"epoch 33 | loss: 0.32806 | train_auc: 0.93508 | valid_auc: 0.92774 | 0:01:09s\n",
"epoch 34 | loss: 0.31981 | train_auc: 0.93656 | valid_auc: 0.93014 | 0:01:11s\n",
"epoch 35 | loss: 0.31738 | train_auc: 0.93678 | valid_auc: 0.92766 | 0:01:13s\n",
"epoch 36 | loss: 0.3209 | train_auc: 0.93637 | valid_auc: 0.92766 | 0:01:15s\n",
"epoch 37 | loss: 0.31531 | train_auc: 0.93336 | valid_auc: 0.92297 | 0:01:17s\n",
"epoch 38 | loss: 0.3231 | train_auc: 0.93368 | valid_auc: 0.92438 | 0:01:19s\n",
"epoch 39 | loss: 0.31914 | train_auc: 0.93741 | valid_auc: 0.92685 | 0:01:22s\n",
"epoch 40 | loss: 0.31784 | train_auc: 0.93709 | valid_auc: 0.92647 | 0:01:24s\n",
"epoch 41 | loss: 0.32154 | train_auc: 0.93775 | valid_auc: 0.92521 | 0:01:26s\n",
"epoch 42 | loss: 0.31726 | train_auc: 0.93814 | valid_auc: 0.92743 | 0:01:28s\n",
"epoch 43 | loss: 0.31768 | train_auc: 0.93822 | valid_auc: 0.9265 | 0:01:30s\n",
"epoch 44 | loss: 0.31297 | train_auc: 0.93664 | valid_auc: 0.92333 | 0:01:32s\n",
"epoch 45 | loss: 0.31219 | train_auc: 0.93833 | valid_auc: 0.92682 | 0:01:34s\n",
"epoch 46 | loss: 0.31816 | train_auc: 0.93877 | valid_auc: 0.92526 | 0:01:36s\n",
"epoch 47 | loss: 0.3168 | train_auc: 0.93903 | valid_auc: 0.92521 | 0:01:38s\n",
"epoch 48 | loss: 0.31014 | train_auc: 0.93864 | valid_auc: 0.92364 | 0:01:40s\n",
"epoch 49 | loss: 0.31637 | train_auc: 0.93793 | valid_auc: 0.92628 | 0:01:42s\n",
"epoch 50 | loss: 0.31441 | train_auc: 0.9398 | valid_auc: 0.92782 | 0:01:44s\n",
"epoch 51 | loss: 0.30673 | train_auc: 0.94062 | valid_auc: 0.92624 | 0:01:46s\n",
"epoch 52 | loss: 0.30835 | train_auc: 0.94006 | valid_auc: 0.92509 | 0:01:48s\n",
"epoch 53 | loss: 0.30838 | train_auc: 0.94081 | valid_auc: 0.92882 | 0:01:50s\n",
"epoch 54 | loss: 0.31133 | train_auc: 0.94049 | valid_auc: 0.92622 | 0:01:52s\n",
"\n",
"Early stopping occurred at epoch 54 with best_epoch = 34 and best_valid_auc = 0.93014\n",
"Best weights from best epoch are automatically used!\n"
]
}
],
"source": [
"clf.fit(\n",
" X_train=X_train, y_train=y_train,\n",
" eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
" eval_name=['train', 'valid'],\n",
" eval_metric=['auc'],\n",
" max_epochs=max_epochs , patience=20,\n",
" batch_size=1024, virtual_batch_size=128,\n",
" num_workers=0,\n",
" weights=1,\n",
" drop_last=False\n",
")"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"y_pred = clf.predict(X_train)\n",
"y_pred_enc = label_encoder.transform(y_pred)\n",
"roc_auc_score(y_train_enc, y_pred_enc)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# skorch tabnet port"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"import sklearn\n",
"\n",
"import skorch\n",
"from skorch.helper import predefined_split\n",
"\n",
"import pytorch_tabnet\n",
"from pytorch_tabnet.multiclass_utils import infer_output_dim\n",
"from pytorch_tabnet.tab_network import TabNet\n",
"from pytorch_tabnet.utils import (\n",
" PredictDataset,\n",
" create_explain_matrix,\n",
" validate_eval_set,\n",
" create_dataloaders,\n",
" define_device,\n",
" ComplexEncoder,\n",
")\n",
"\n",
"from torch.nn import CrossEntropyLoss\n",
"\n",
"from scipy.sparse import csc_matrix"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"class SkorchTabModel(skorch.NeuralNet):\n",
" def __init__(\n",
" self,\n",
" criterion,\n",
" module=TabNet, \n",
" module__input_dim=100, \n",
" module__output_dim=5,\n",
" **kwargs,\n",
" ):\n",
" super().__init__(\n",
" module,\n",
" criterion,\n",
" module__input_dim=module__input_dim,\n",
" module__output_dim=module__output_dim,\n",
" **kwargs,\n",
" )\n",
" \n",
" def initialize_module(self):\n",
" \"\"\"Setup the network and explain matrix.\"\"\"\n",
" kwargs = self.get_params_for('module')\n",
"\n",
" self.module_ = TabNet(**kwargs).to(self.device)\n",
"\n",
" self.reducing_matrix_ = create_explain_matrix(\n",
" self.module_.input_dim,\n",
" self.module_.cat_emb_dim,\n",
" self.module_.cat_idxs,\n",
" self.module_.post_embed_dim,\n",
" )\n",
" \n",
" def compute_feature_importances(self, X):\n",
" \"\"\"Compute global feature importance.\"\"\" \n",
" feature_importances_ = np.zeros((self.module_.post_embed_dim))\n",
" \n",
" for (M_explain, masks) in self.forward_masks_iter(X):\n",
" feature_importances_ += M_explain.sum(dim=0).cpu().detach().numpy()\n",
"\n",
" feature_importances_ = csc_matrix.dot(\n",
" feature_importances_, self.reducing_matrix_,\n",
" )\n",
" return feature_importances_ / np.sum(feature_importances_)\n",
" \n",
" def on_train_end(self, net, X, **kwargs):\n",
" self.feature_importances_ = self.compute_feature_importances(X)\n",
" super().on_train_end(net, X=X, **kwargs)\n",
" \n",
" def forward_masks_iter(self, X, training=False, device='cpu'):\n",
" dataset = self.get_dataset(X)\n",
" iterator = self.get_iterator(dataset, training=training)\n",
" for data in iterator:\n",
" Xi = skorch.dataset.unpack_data(data)[0]\n",
" Xi = skorch.utils.to_device(Xi, self.device)\n",
" with torch.set_grad_enabled(False):\n",
" yp = self.module_.forward_masks(Xi)\n",
" yield skorch.utils.to_device(yp, device=device)\n",
" \n",
" def explain(self, X):\n",
" \"\"\"\n",
" Return local explanation\n",
"\n",
" Parameters\n",
" ----------\n",
" X : tensor: `torch.Tensor`\n",
" Input data\n",
"\n",
" Returns\n",
" -------\n",
" M_explain : matrix\n",
" Importance per sample, per columns.\n",
" masks : matrix\n",
" Sparse matrix showing attention masks used by network.\n",
" \"\"\"\n",
" res_explain = []\n",
" \n",
" for i, (M_explain, masks) in enumerate(self.forward_masks_iter(X)):\n",
" for key, value in masks.items():\n",
" masks[key] = csc_matrix.dot(\n",
" value.cpu().detach().numpy(), self.reducing_matrix_\n",
" )\n",
"\n",
" res_explain.append(\n",
" csc_matrix.dot(M_explain.cpu().detach().numpy(), self.reducing_matrix_)\n",
" )\n",
"\n",
" if i == 0:\n",
" res_masks = masks\n",
" else:\n",
" for key, value in masks.items():\n",
" res_masks[key] = np.vstack([res_masks[key], value])\n",
" \n",
" res_explain = np.vstack(res_explain)\n",
" return res_explain, res_masks\n",
" \n",
" def predict(self, X):\n",
" y_proba = self.predict_proba(X)\n",
" return y_proba.argmax(-1)"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"LabelEncoder()"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"label_encoder = LabelEncoder()\n",
"label_encoder.fit(y_train)"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"y_train_enc = label_encoder.transform(y_train)\n",
"y_valid_enc = label_encoder.transform(y_valid)"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
"class CrossEntropySparsityLoss(torch.nn.CrossEntropyLoss):\n",
" def __init__(self, lambda_sparse=1e-3):\n",
" super().__init__()\n",
" self.lambda_sparse = lambda_sparse\n",
" \n",
" def forward(self, y_pred, y_true):\n",
" output, M_loss = y_pred\n",
"\n",
" loss = super().forward(output, y_true)\n",
" \n",
" # Add the overall sparsity loss\n",
" loss -= self.lambda_sparse * M_loss\n",
" \n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
"skorch_clf = SkorchTabModel(\n",
" criterion=CrossEntropySparsityLoss,\n",
" \n",
" module__input_dim=X_train.shape[-1],\n",
" module__output_dim=infer_output_dim(y_train)[0],\n",
" module__cat_idxs=cat_idxs,\n",
" module__cat_dims=cat_dims,\n",
" module__cat_emb_dim=1,\n",
" module__mask_type='entmax', # \"sparsemax\"\n",
" module__virtual_batch_size=128,\n",
" \n",
" optimizer=torch.optim.Adam,\n",
" optimizer__lr=2e-2,\n",
" \n",
" batch_size=1024,\n",
" iterator_train__num_workers=0,\n",
" iterator_train__drop_last=False,\n",
" iterator_valid__num_workers=0,\n",
" iterator_valid__drop_last=False,\n",
" \n",
" callbacks=[\n",
" skorch.callbacks.LRScheduler(\n",
" policy=torch.optim.lr_scheduler.StepLR,\n",
" step_size=50,\n",
" gamma=0.9,\n",
" ),\n",
" skorch.callbacks.EarlyStopping(patience=20),\n",
" skorch.callbacks.GradientNormClipping(gradient_clip_value=1.),\n",
" skorch.callbacks.EpochScoring('roc_auc'),\n",
" ],\n",
" train_split=predefined_split(skorch.dataset.Dataset(X_valid, y_valid_enc)),\n",
" max_epochs=max_epochs,\n",
" \n",
" device='cuda',\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
"skorch_clf.initialize()\n",
"skorch_clf.load_params('ble.pt')"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Automatic pdb calling has been turned ON\n",
"Re-initializing optimizer because the following parameters were re-set: lr.\n",
" epoch roc_auc train_loss valid_loss lr dur\n",
"------- --------- ------------ ------------ ------ ------\n",
" 1 \u001b[36m0.7798\u001b[0m \u001b[32m0.5105\u001b[0m \u001b[35m0.6061\u001b[0m 0.0200 1.6014\n",
" 2 0.8398 \u001b[32m0.4094\u001b[0m \u001b[35m0.5071\u001b[0m 0.0200 1.6054\n",
" 3 0.8557 \u001b[32m0.3883\u001b[0m \u001b[35m0.4586\u001b[0m 0.0200 1.5067\n",
" 4 0.8740 \u001b[32m0.3762\u001b[0m \u001b[35m0.4217\u001b[0m 0.0200 1.5827\n",
" 5 0.8743 \u001b[32m0.3664\u001b[0m \u001b[35m0.3860\u001b[0m 0.0200 1.5149\n",
" 6 0.8841 \u001b[32m0.3616\u001b[0m \u001b[35m0.3518\u001b[0m 0.0200 1.5379\n",
" 7 0.8913 \u001b[32m0.3562\u001b[0m \u001b[35m0.3398\u001b[0m 0.0200 1.5068\n",
" 8 0.8944 \u001b[32m0.3425\u001b[0m 0.3400 0.0200 1.5603\n",
" 9 0.8916 \u001b[32m0.3336\u001b[0m 0.4340 0.0200 1.5260\n"
]
},
{
"data": {
"text/plain": [
"<class '__main__.SkorchTabModel'>[initialized](\n",
" module_=TabNet(\n",
" (embedder): EmbeddingGenerator(\n",
" (embeddings): ModuleList(\n",
" (0): Embedding(73, 1)\n",
" (1): Embedding(9, 1)\n",
" (2): Embedding(16, 1)\n",
" (3): Embedding(16, 1)\n",
" (4): Embedding(7, 1)\n",
" (5): Embedding(15, 1)\n",
" (6): Embedding(6, 1)\n",
" (7): Embedding(5, 1)\n",
" (8): Embedding(2, 1)\n",
" (9): Embedding(119, 1)\n",
" (10): Embedding(92, 1)\n",
" (11): Embedding(94, 1)\n",
" (12): Embedding(42, 1)\n",
" )\n",
" )\n",
" (tabnet): TabNetNoEmbeddings(\n",
" (initial_bn): BatchNorm1d(14, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)\n",
" (encoder): TabNetEncoder(\n",
" (initial_bn): BatchNorm1d(14, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)\n",
" (initial_splitter): FeatTransformer(\n",
" (shared): GLU_Block(\n",
" (shared_layers): ModuleList(\n",
" (0): Linear(in_features=14, out_features=32, bias=False)\n",
" (1): Linear(in_features=16, out_features=32, bias=False)\n",
" )\n",
" (glu_layers): ModuleList(\n",
" (0): GLU_Layer(\n",
" (fc): Linear(in_features=14, out_features=32, bias=False)\n",
" (bn): GBN(\n",
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): GLU_Layer(\n",
" (fc): Linear(in_features=16, out_features=32, bias=False)\n",
" (bn): GBN(\n",
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (specifics): GLU_Block(\n",
" (glu_layers): ModuleList(\n",
" (0): GLU_Layer(\n",
" (fc): Linear(in_features=16, out_features=32, bias=False)\n",
" (bn): GBN(\n",
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): GLU_Layer(\n",
" (fc): Linear(in_features=16, out_features=32, bias=False)\n",
" (bn): GBN(\n",
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (feat_transformers): ModuleList(\n",
" (0): FeatTransformer(\n",
" (shared): GLU_Block(\n",
" (shared_layers): ModuleList(\n",
" (0): Linear(in_features=14, out_features=32, bias=False)\n",
" (1): Linear(in_features=16, out_features=32, bias=False)\n",
" )\n",
" (glu_layers): ModuleList(\n",
" (0): GLU_Layer(\n",
" (fc): Linear(in_features=14, out_features=32, bias=False)\n",
" (bn): GBN(\n",
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): GLU_Layer(\n",
" (fc): Linear(in_features=16, out_features=32, bias=False)\n",
" (bn): GBN(\n",
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (specifics): GLU_Block(\n",
" (glu_layers): ModuleList(\n",
" (0): GLU_Layer(\n",
" (fc): Linear(in_features=16, out_features=32, bias=False)\n",
" (bn): GBN(\n",
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): GLU_Layer(\n",
" (fc): Linear(in_features=16, out_features=32, bias=False)\n",
" (bn): GBN(\n",
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (1): FeatTransformer(\n",
" (shared): GLU_Block(\n",
" (shared_layers): ModuleList(\n",
" (0): Linear(in_features=14, out_features=32, bias=False)\n",
" (1): Linear(in_features=16, out_features=32, bias=False)\n",
" )\n",
" (glu_layers): ModuleList(\n",
" (0): GLU_Layer(\n",
" (fc): Linear(in_features=14, out_features=32, bias=False)\n",
" (bn): GBN(\n",
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): GLU_Layer(\n",
" (fc): Linear(in_features=16, out_features=32, bias=False)\n",
" (bn): GBN(\n",
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (specifics): GLU_Block(\n",
" (glu_layers): ModuleList(\n",
" (0): GLU_Layer(\n",
" (fc): Linear(in_features=16, out_features=32, bias=False)\n",
" (bn): GBN(\n",
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): GLU_Layer(\n",
" (fc): Linear(in_features=16, out_features=32, bias=False)\n",
" (bn): GBN(\n",
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (2): FeatTransformer(\n",
" (shared): GLU_Block(\n",
" (shared_layers): ModuleList(\n",
" (0): Linear(in_features=14, out_features=32, bias=False)\n",
" (1): Linear(in_features=16, out_features=32, bias=False)\n",
" )\n",
" (glu_layers): ModuleList(\n",
" (0): GLU_Layer(\n",
" (fc): Linear(in_features=14, out_features=32, bias=False)\n",
" (bn): GBN(\n",
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): GLU_Layer(\n",
" (fc): Linear(in_features=16, out_features=32, bias=False)\n",
" (bn): GBN(\n",
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (specifics): GLU_Block(\n",
" (glu_layers): ModuleList(\n",
" (0): GLU_Layer(\n",
" (fc): Linear(in_features=16, out_features=32, bias=False)\n",
" (bn): GBN(\n",
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): GLU_Layer(\n",
" (fc): Linear(in_features=16, out_features=32, bias=False)\n",
" (bn): GBN(\n",
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (att_transformers): ModuleList(\n",
" (0): AttentiveTransformer(\n",
" (fc): Linear(in_features=8, out_features=14, bias=False)\n",
" (bn): GBN(\n",
" (bn): BatchNorm1d(14, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n",
" )\n",
" (selector): Entmax15()\n",
" )\n",
" (1): AttentiveTransformer(\n",
" (fc): Linear(in_features=8, out_features=14, bias=False)\n",
" (bn): GBN(\n",
" (bn): BatchNorm1d(14, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n",
" )\n",
" (selector): Entmax15()\n",
" )\n",
" (2): AttentiveTransformer(\n",
" (fc): Linear(in_features=8, out_features=14, bias=False)\n",
" (bn): GBN(\n",
" (bn): BatchNorm1d(14, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n",
" )\n",
" (selector): Entmax15()\n",
" )\n",
" )\n",
" )\n",
" (final_mapping): Linear(in_features=8, out_features=2, bias=False)\n",
" )\n",
" ),\n",
")"
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%pdb on\n",
"skorch_clf.fit(\n",
" X_train, \n",
" y_train_enc,\n",
" #weights=1,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"0.8094909975251289"
]
},
"execution_count": 69,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sklearn.metrics.roc_auc_score(y_train_enc, skorch_clf.predict(X_train))"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"0.809482523487759"
]
},
"execution_count": 70,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sklearn.metrics.roc_auc_score(y_valid_enc, skorch_clf.predict(X_valid))"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([[0.00000000e+00, 3.16206515e-01, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [0.00000000e+00, 4.00576532e-01, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [0.00000000e+00, 4.18307304e-01, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 1.17498322e-03, 0.00000000e+00],\n",
" ...,\n",
" [0.00000000e+00, 3.26685965e-01, 0.00000000e+00, ...,\n",
" 1.23882025e-01, 5.39707905e-03, 0.00000000e+00],\n",
" [1.30514791e-02, 2.11704329e-01, 0.00000000e+00, ...,\n",
" 1.28650689e+00, 3.36281955e-03, 1.37906140e-02],\n",
" [1.27431333e-01, 1.56031281e-01, 1.24305105e-02, ...,\n",
" 2.07836628e-01, 1.33876745e-02, 9.61675271e-02]]),\n",
" {0: array([[0.00000000e+00, 9.32369530e-02, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [0.00000000e+00, 4.57307428e-01, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
" [0.00000000e+00, 6.30757153e-01, 0.00000000e+00, ...,\n",
" 0.00000000e+00, 1.77173363e-03, 0.00000000e+00],\n",
" ...,\n",
" [0.00000000e+00, 1.68548390e-01, 0.00000000e+00, ...,\n",
" 6.39149472e-02, 2.78453645e-03, 0.00000000e+00],\n",
" [2.71768179e-02, 1.00498842e-02, 0.00000000e+00, ...,\n",
" 1.84161370e-04, 7.00232806e-03, 2.87159029e-02],\n",
" [1.20953389e-01, 1.07856750e-01, 1.24875447e-02, ...,\n",
" 1.23726524e-01, 1.34490998e-02, 9.66087654e-02]]),\n",
" 1: array([[0. , 0. , 0. , ..., 0. , 0. ,\n",
" 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. ,\n",
" 0. ],\n",
" [0. , 0. , 0. , ..., 0. , 0. ,\n",
" 0. ],\n",
" ...,\n",
" [0. , 0. , 0. , ..., 0. , 0. ,\n",
" 0. ],\n",
" [0. , 0.02491164, 0. , ..., 0.15490678, 0. ,\n",
" 0. ],\n",
" [0.01362148, 0.09429348, 0. , ..., 0.1640597 , 0. ,\n",
" 0. ]]),\n",
" 2: array([[0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" ...,\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.]])})"
]
},
"execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"skorch_clf.explain(X_valid)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"skorch_clf.save_params(f_params='ble.pt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment