Created
March 9, 2021 18:23
-
-
Save ottonemo/f8771c3a7f0f6abf6afb8ae157b673ba to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[[package]] | |
name = "atomicwrites" | |
version = "1.4.0" | |
description = "Atomic file writes." | |
category = "dev" | |
optional = false | |
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" | |
[[package]] | |
name = "attrs" | |
version = "20.3.0" | |
description = "Classes Without Boilerplate" | |
category = "dev" | |
optional = false | |
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" | |
[package.extras] | |
dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "furo", "sphinx", "pre-commit"] | |
docs = ["furo", "sphinx", "zope.interface"] | |
tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface"] | |
tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six"] | |
[[package]] | |
name = "colorama" | |
version = "0.4.4" | |
description = "Cross-platform colored terminal text." | |
category = "dev" | |
optional = false | |
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" | |
[[package]] | |
name = "more-itertools" | |
version = "8.7.0" | |
description = "More routines for operating on iterables, beyond itertools" | |
category = "dev" | |
optional = false | |
python-versions = ">=3.5" | |
[[package]] | |
name = "packaging" | |
version = "20.9" | |
description = "Core utilities for Python packages" | |
category = "dev" | |
optional = false | |
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" | |
[package.dependencies] | |
pyparsing = ">=2.0.2" | |
[[package]] | |
name = "pluggy" | |
version = "0.13.1" | |
description = "plugin and hook calling mechanisms for python" | |
category = "dev" | |
optional = false | |
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" | |
[package.extras] | |
dev = ["pre-commit", "tox"] | |
[[package]] | |
name = "py" | |
version = "1.10.0" | |
description = "library with cross-python path, ini-parsing, io, code, log facilities" | |
category = "dev" | |
optional = false | |
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" | |
[[package]] | |
name = "pyparsing" | |
version = "2.4.7" | |
description = "Python parsing module" | |
category = "dev" | |
optional = false | |
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" | |
[[package]] | |
name = "pytest" | |
version = "5.4.3" | |
description = "pytest: simple powerful testing with Python" | |
category = "dev" | |
optional = false | |
python-versions = ">=3.5" | |
[package.dependencies] | |
atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} | |
attrs = ">=17.4.0" | |
colorama = {version = "*", markers = "sys_platform == \"win32\""} | |
more-itertools = ">=4.0.0" | |
packaging = "*" | |
pluggy = ">=0.12,<1.0" | |
py = ">=1.5.0" | |
wcwidth = "*" | |
[package.extras] | |
checkqa-mypy = ["mypy (==v0.761)"] | |
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] | |
[[package]] | |
name = "wcwidth" | |
version = "0.2.5" | |
description = "Measures the displayed width of unicode strings in a terminal" | |
category = "dev" | |
optional = false | |
python-versions = "*" | |
[metadata] | |
lock-version = "1.1" | |
python-versions = "^3.8" | |
content-hash = "c27944f25b55067b06883f1cea204be7d97841a4b8228fab69b91895347494ad" | |
[metadata.files] | |
atomicwrites = [ | |
{file = "atomicwrites-1.4.0-py2.py3-none-any.whl", hash = "sha256:6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197"}, | |
{file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"}, | |
] | |
attrs = [ | |
{file = "attrs-20.3.0-py2.py3-none-any.whl", hash = "sha256:31b2eced602aa8423c2aea9c76a724617ed67cf9513173fd3a4f03e3a929c7e6"}, | |
{file = "attrs-20.3.0.tar.gz", hash = "sha256:832aa3cde19744e49938b91fea06d69ecb9e649c93ba974535d08ad92164f700"}, | |
] | |
colorama = [ | |
{file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, | |
{file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, | |
] | |
more-itertools = [ | |
{file = "more-itertools-8.7.0.tar.gz", hash = "sha256:c5d6da9ca3ff65220c3bfd2a8db06d698f05d4d2b9be57e1deb2be5a45019713"}, | |
{file = "more_itertools-8.7.0-py3-none-any.whl", hash = "sha256:5652a9ac72209ed7df8d9c15daf4e1aa0e3d2ccd3c87f8265a0673cd9cbc9ced"}, | |
] | |
packaging = [ | |
{file = "packaging-20.9-py2.py3-none-any.whl", hash = "sha256:67714da7f7bc052e064859c05c595155bd1ee9f69f76557e21f051443c20947a"}, | |
{file = "packaging-20.9.tar.gz", hash = "sha256:5b327ac1320dc863dca72f4514ecc086f31186744b84a230374cc1fd776feae5"}, | |
] | |
pluggy = [ | |
{file = "pluggy-0.13.1-py2.py3-none-any.whl", hash = "sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d"}, | |
{file = "pluggy-0.13.1.tar.gz", hash = "sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0"}, | |
] | |
py = [ | |
{file = "py-1.10.0-py2.py3-none-any.whl", hash = "sha256:3b80836aa6d1feeaa108e046da6423ab8f6ceda6468545ae8d02d9d58d18818a"}, | |
{file = "py-1.10.0.tar.gz", hash = "sha256:21b81bda15b66ef5e1a777a21c4dcd9c20ad3efd0b3f817e7a809035269e1bd3"}, | |
] | |
pyparsing = [ | |
{file = "pyparsing-2.4.7-py2.py3-none-any.whl", hash = "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b"}, | |
{file = "pyparsing-2.4.7.tar.gz", hash = "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1"}, | |
] | |
pytest = [ | |
{file = "pytest-5.4.3-py3-none-any.whl", hash = "sha256:5c0db86b698e8f170ba4582a492248919255fcd4c79b1ee64ace34301fb589a1"}, | |
{file = "pytest-5.4.3.tar.gz", hash = "sha256:7979331bfcba207414f5e1263b5a0f8f521d0f457318836a7355531ed1a4c7d8"}, | |
] | |
wcwidth = [ | |
{file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"}, | |
{file = "wcwidth-0.2.5.tar.gz", hash = "sha256:c4d647b99872929fdb7bdcaa4fbe7f01413ed3d98077df798530e5b04f116c83"}, | |
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Example from https://github.com/dreamquark-ai/tabnet/blob/develop/census_example.ipynb" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# tabnet example code" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.preprocessing import LabelEncoder\n", | |
"from sklearn.metrics import roc_auc_score\n", | |
"\n", | |
"from pytorch_tabnet.tab_model import TabNetClassifier\n", | |
"\n", | |
"import torch\n", | |
"import pandas as pd\n", | |
"import numpy as np\n", | |
"np.random.seed(0)\n", | |
"\n", | |
"import os\n", | |
"import wget\n", | |
"from pathlib import Path\n", | |
"\n", | |
"from matplotlib import pyplot as plt\n", | |
"%matplotlib inline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\"\n", | |
"dataset_name = 'census-income'\n", | |
"out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"File already exists.\n" | |
] | |
} | |
], | |
"source": [ | |
"out.parent.mkdir(parents=True, exist_ok=True)\n", | |
"if out.exists():\n", | |
" print(\"File already exists.\")\n", | |
"else:\n", | |
" print(\"Downloading file...\")\n", | |
" wget.download(url, out.as_posix())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train = pd.read_csv(out)\n", | |
"target = ' <=50K'\n", | |
"if \"Set\" not in train.columns:\n", | |
" train[\"Set\"] = np.random.choice([\"train\", \"valid\", \"test\"], p =[.8, .1, .1], size=(train.shape[0],))\n", | |
"\n", | |
"train_indices = train[train.Set==\"train\"].index\n", | |
"valid_indices = train[train.Set==\"valid\"].index\n", | |
"test_indices = train[train.Set==\"test\"].index" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"39 73\n", | |
" State-gov 9\n", | |
" Bachelors 16\n", | |
" 13 16\n", | |
" Never-married 7\n", | |
" Adm-clerical 15\n", | |
" Not-in-family 6\n", | |
" White 5\n", | |
" Male 2\n", | |
" 2174 119\n", | |
" 0 92\n", | |
" 40 94\n", | |
" United-States 42\n", | |
" <=50K 2\n", | |
"Set 3\n" | |
] | |
} | |
], | |
"source": [ | |
"nunique = train.nunique()\n", | |
"types = train.dtypes\n", | |
"\n", | |
"categorical_columns = []\n", | |
"categorical_dims = {}\n", | |
"for col in train.columns:\n", | |
" if types[col] == 'object' or nunique[col] < 200:\n", | |
" print(col, train[col].nunique())\n", | |
" l_enc = LabelEncoder()\n", | |
" train[col] = train[col].fillna(\"VV_likely\")\n", | |
" train[col] = l_enc.fit_transform(train[col].values)\n", | |
" categorical_columns.append(col)\n", | |
" categorical_dims[col] = len(l_enc.classes_)\n", | |
" else:\n", | |
" train.fillna(train.loc[train_indices, col].mean(), inplace=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# check that pipeline accepts strings\n", | |
"train.loc[train[target]==0, target] = \"wealthy\"\n", | |
"train.loc[train[target]==1, target] = \"not_wealthy\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"unused_feat = ['Set']\n", | |
"\n", | |
"features = [ col for col in train.columns if col not in unused_feat+[target]] \n", | |
"\n", | |
"cat_idxs = [ i for i, f in enumerate(features) if f in categorical_columns]\n", | |
"\n", | |
"cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Device used : cuda\n" | |
] | |
} | |
], | |
"source": [ | |
"clf = TabNetClassifier(cat_idxs=cat_idxs,\n", | |
" cat_dims=cat_dims,\n", | |
" cat_emb_dim=1,\n", | |
" optimizer_fn=torch.optim.Adam,\n", | |
" optimizer_params=dict(lr=2e-2),\n", | |
" scheduler_params={\"step_size\":50, # how to use learning rate scheduler\n", | |
" \"gamma\":0.9},\n", | |
" scheduler_fn=torch.optim.lr_scheduler.StepLR,\n", | |
" mask_type='entmax' # \"sparsemax\"\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"X_train = train[features].values[train_indices]\n", | |
"y_train = train[target].values[train_indices]\n", | |
"\n", | |
"X_valid = train[features].values[valid_indices]\n", | |
"y_valid = train[target].values[valid_indices]\n", | |
"\n", | |
"X_test = train[features].values[test_indices]\n", | |
"y_test = train[target].values[test_indices]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"max_epochs = 1000 if not os.getenv(\"CI\", False) else 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch 0 | loss: 0.668 | train_auc: 0.75705 | valid_auc: 0.7551 | 0:00:02s\n", | |
"epoch 1 | loss: 0.52031 | train_auc: 0.81912 | valid_auc: 0.82696 | 0:00:04s\n", | |
"epoch 2 | loss: 0.47527 | train_auc: 0.84816 | valid_auc: 0.85195 | 0:00:06s\n", | |
"epoch 3 | loss: 0.45715 | train_auc: 0.86756 | valid_auc: 0.86571 | 0:00:08s\n", | |
"epoch 4 | loss: 0.43029 | train_auc: 0.88064 | valid_auc: 0.87487 | 0:00:10s\n", | |
"epoch 5 | loss: 0.41997 | train_auc: 0.89128 | valid_auc: 0.8849 | 0:00:12s\n", | |
"epoch 6 | loss: 0.40586 | train_auc: 0.898 | valid_auc: 0.88995 | 0:00:14s\n", | |
"epoch 7 | loss: 0.40141 | train_auc: 0.90266 | valid_auc: 0.89769 | 0:00:16s\n", | |
"epoch 8 | loss: 0.39187 | train_auc: 0.90459 | valid_auc: 0.8956 | 0:00:18s\n", | |
"epoch 9 | loss: 0.37791 | train_auc: 0.91019 | valid_auc: 0.90593 | 0:00:20s\n", | |
"epoch 10 | loss: 0.37631 | train_auc: 0.91394 | valid_auc: 0.90945 | 0:00:22s\n", | |
"epoch 11 | loss: 0.36412 | train_auc: 0.91093 | valid_auc: 0.90707 | 0:00:24s\n", | |
"epoch 12 | loss: 0.3587 | train_auc: 0.91243 | valid_auc: 0.90965 | 0:00:26s\n", | |
"epoch 13 | loss: 0.35557 | train_auc: 0.915 | valid_auc: 0.90905 | 0:00:29s\n", | |
"epoch 14 | loss: 0.34672 | train_auc: 0.9182 | valid_auc: 0.91487 | 0:00:31s\n", | |
"epoch 15 | loss: 0.35145 | train_auc: 0.92211 | valid_auc: 0.91805 | 0:00:33s\n", | |
"epoch 16 | loss: 0.34199 | train_auc: 0.92471 | valid_auc: 0.92013 | 0:00:35s\n", | |
"epoch 17 | loss: 0.3372 | train_auc: 0.9272 | valid_auc: 0.92226 | 0:00:37s\n", | |
"epoch 18 | loss: 0.34344 | train_auc: 0.92886 | valid_auc: 0.92452 | 0:00:39s\n", | |
"epoch 19 | loss: 0.34549 | train_auc: 0.92919 | valid_auc: 0.92233 | 0:00:41s\n", | |
"epoch 20 | loss: 0.33269 | train_auc: 0.93105 | valid_auc: 0.92654 | 0:00:43s\n", | |
"epoch 21 | loss: 0.32923 | train_auc: 0.93199 | valid_auc: 0.92505 | 0:00:45s\n", | |
"epoch 22 | loss: 0.33069 | train_auc: 0.93208 | valid_auc: 0.92693 | 0:00:47s\n", | |
"epoch 23 | loss: 0.3301 | train_auc: 0.93287 | valid_auc: 0.92766 | 0:00:49s\n", | |
"epoch 24 | loss: 0.33326 | train_auc: 0.93347 | valid_auc: 0.92745 | 0:00:51s\n", | |
"epoch 25 | loss: 0.32665 | train_auc: 0.93452 | valid_auc: 0.92802 | 0:00:53s\n", | |
"epoch 26 | loss: 0.32089 | train_auc: 0.93444 | valid_auc: 0.92747 | 0:00:55s\n", | |
"epoch 27 | loss: 0.32657 | train_auc: 0.93284 | valid_auc: 0.92749 | 0:00:57s\n", | |
"epoch 28 | loss: 0.32863 | train_auc: 0.93331 | valid_auc: 0.92529 | 0:00:59s\n", | |
"epoch 29 | loss: 0.32456 | train_auc: 0.93459 | valid_auc: 0.92775 | 0:01:01s\n", | |
"epoch 30 | loss: 0.3245 | train_auc: 0.93506 | valid_auc: 0.92776 | 0:01:03s\n", | |
"epoch 31 | loss: 0.31973 | train_auc: 0.93558 | valid_auc: 0.92732 | 0:01:05s\n", | |
"epoch 32 | loss: 0.32807 | train_auc: 0.9334 | valid_auc: 0.92574 | 0:01:07s\n", | |
"epoch 33 | loss: 0.32806 | train_auc: 0.93508 | valid_auc: 0.92774 | 0:01:09s\n", | |
"epoch 34 | loss: 0.31981 | train_auc: 0.93656 | valid_auc: 0.93014 | 0:01:11s\n", | |
"epoch 35 | loss: 0.31738 | train_auc: 0.93678 | valid_auc: 0.92766 | 0:01:13s\n", | |
"epoch 36 | loss: 0.3209 | train_auc: 0.93637 | valid_auc: 0.92766 | 0:01:15s\n", | |
"epoch 37 | loss: 0.31531 | train_auc: 0.93336 | valid_auc: 0.92297 | 0:01:17s\n", | |
"epoch 38 | loss: 0.3231 | train_auc: 0.93368 | valid_auc: 0.92438 | 0:01:19s\n", | |
"epoch 39 | loss: 0.31914 | train_auc: 0.93741 | valid_auc: 0.92685 | 0:01:22s\n", | |
"epoch 40 | loss: 0.31784 | train_auc: 0.93709 | valid_auc: 0.92647 | 0:01:24s\n", | |
"epoch 41 | loss: 0.32154 | train_auc: 0.93775 | valid_auc: 0.92521 | 0:01:26s\n", | |
"epoch 42 | loss: 0.31726 | train_auc: 0.93814 | valid_auc: 0.92743 | 0:01:28s\n", | |
"epoch 43 | loss: 0.31768 | train_auc: 0.93822 | valid_auc: 0.9265 | 0:01:30s\n", | |
"epoch 44 | loss: 0.31297 | train_auc: 0.93664 | valid_auc: 0.92333 | 0:01:32s\n", | |
"epoch 45 | loss: 0.31219 | train_auc: 0.93833 | valid_auc: 0.92682 | 0:01:34s\n", | |
"epoch 46 | loss: 0.31816 | train_auc: 0.93877 | valid_auc: 0.92526 | 0:01:36s\n", | |
"epoch 47 | loss: 0.3168 | train_auc: 0.93903 | valid_auc: 0.92521 | 0:01:38s\n", | |
"epoch 48 | loss: 0.31014 | train_auc: 0.93864 | valid_auc: 0.92364 | 0:01:40s\n", | |
"epoch 49 | loss: 0.31637 | train_auc: 0.93793 | valid_auc: 0.92628 | 0:01:42s\n", | |
"epoch 50 | loss: 0.31441 | train_auc: 0.9398 | valid_auc: 0.92782 | 0:01:44s\n", | |
"epoch 51 | loss: 0.30673 | train_auc: 0.94062 | valid_auc: 0.92624 | 0:01:46s\n", | |
"epoch 52 | loss: 0.30835 | train_auc: 0.94006 | valid_auc: 0.92509 | 0:01:48s\n", | |
"epoch 53 | loss: 0.30838 | train_auc: 0.94081 | valid_auc: 0.92882 | 0:01:50s\n", | |
"epoch 54 | loss: 0.31133 | train_auc: 0.94049 | valid_auc: 0.92622 | 0:01:52s\n", | |
"\n", | |
"Early stopping occurred at epoch 54 with best_epoch = 34 and best_valid_auc = 0.93014\n", | |
"Best weights from best epoch are automatically used!\n" | |
] | |
} | |
], | |
"source": [ | |
"clf.fit(\n", | |
" X_train=X_train, y_train=y_train,\n", | |
" eval_set=[(X_train, y_train), (X_valid, y_valid)],\n", | |
" eval_name=['train', 'valid'],\n", | |
" eval_metric=['auc'],\n", | |
" max_epochs=max_epochs , patience=20,\n", | |
" batch_size=1024, virtual_batch_size=128,\n", | |
" num_workers=0,\n", | |
" weights=1,\n", | |
" drop_last=False\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "raw", | |
"metadata": {}, | |
"source": [ | |
"y_pred = clf.predict(X_train)\n", | |
"y_pred_enc = label_encoder.transform(y_pred)\n", | |
"roc_auc_score(y_train_enc, y_pred_enc)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"---" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# skorch tabnet port" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 61, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import sklearn\n", | |
"\n", | |
"import skorch\n", | |
"from skorch.helper import predefined_split\n", | |
"\n", | |
"import pytorch_tabnet\n", | |
"from pytorch_tabnet.multiclass_utils import infer_output_dim\n", | |
"from pytorch_tabnet.tab_network import TabNet\n", | |
"from pytorch_tabnet.utils import (\n", | |
" PredictDataset,\n", | |
" create_explain_matrix,\n", | |
" validate_eval_set,\n", | |
" create_dataloaders,\n", | |
" define_device,\n", | |
" ComplexEncoder,\n", | |
")\n", | |
"\n", | |
"from torch.nn import CrossEntropyLoss\n", | |
"\n", | |
"from scipy.sparse import csc_matrix" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 62, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class SkorchTabModel(skorch.NeuralNet):\n", | |
" def __init__(\n", | |
" self,\n", | |
" criterion,\n", | |
" module=TabNet, \n", | |
" module__input_dim=100, \n", | |
" module__output_dim=5,\n", | |
" **kwargs,\n", | |
" ):\n", | |
" super().__init__(\n", | |
" module,\n", | |
" criterion,\n", | |
" module__input_dim=module__input_dim,\n", | |
" module__output_dim=module__output_dim,\n", | |
" **kwargs,\n", | |
" )\n", | |
" \n", | |
" def initialize_module(self):\n", | |
" \"\"\"Setup the network and explain matrix.\"\"\"\n", | |
" kwargs = self.get_params_for('module')\n", | |
"\n", | |
" self.module_ = TabNet(**kwargs).to(self.device)\n", | |
"\n", | |
" self.reducing_matrix_ = create_explain_matrix(\n", | |
" self.module_.input_dim,\n", | |
" self.module_.cat_emb_dim,\n", | |
" self.module_.cat_idxs,\n", | |
" self.module_.post_embed_dim,\n", | |
" )\n", | |
" \n", | |
" def compute_feature_importances(self, X):\n", | |
" \"\"\"Compute global feature importance.\"\"\" \n", | |
" feature_importances_ = np.zeros((self.module_.post_embed_dim))\n", | |
" \n", | |
" for (M_explain, masks) in self.forward_masks_iter(X):\n", | |
" feature_importances_ += M_explain.sum(dim=0).cpu().detach().numpy()\n", | |
"\n", | |
" feature_importances_ = csc_matrix.dot(\n", | |
" feature_importances_, self.reducing_matrix_,\n", | |
" )\n", | |
" return feature_importances_ / np.sum(feature_importances_)\n", | |
" \n", | |
" def on_train_end(self, net, X, **kwargs):\n", | |
" self.feature_importances_ = self.compute_feature_importances(X)\n", | |
" super().on_train_end(net, X=X, **kwargs)\n", | |
" \n", | |
" def forward_masks_iter(self, X, training=False, device='cpu'):\n", | |
" dataset = self.get_dataset(X)\n", | |
" iterator = self.get_iterator(dataset, training=training)\n", | |
" for data in iterator:\n", | |
" Xi = skorch.dataset.unpack_data(data)[0]\n", | |
" Xi = skorch.utils.to_device(Xi, self.device)\n", | |
" with torch.set_grad_enabled(False):\n", | |
" yp = self.module_.forward_masks(Xi)\n", | |
" yield skorch.utils.to_device(yp, device=device)\n", | |
" \n", | |
" def explain(self, X):\n", | |
" \"\"\"\n", | |
" Return local explanation\n", | |
"\n", | |
" Parameters\n", | |
" ----------\n", | |
" X : tensor: `torch.Tensor`\n", | |
" Input data\n", | |
"\n", | |
" Returns\n", | |
" -------\n", | |
" M_explain : matrix\n", | |
" Importance per sample, per columns.\n", | |
" masks : matrix\n", | |
" Sparse matrix showing attention masks used by network.\n", | |
" \"\"\"\n", | |
" res_explain = []\n", | |
" \n", | |
" for i, (M_explain, masks) in enumerate(self.forward_masks_iter(X)):\n", | |
" for key, value in masks.items():\n", | |
" masks[key] = csc_matrix.dot(\n", | |
" value.cpu().detach().numpy(), self.reducing_matrix_\n", | |
" )\n", | |
"\n", | |
" res_explain.append(\n", | |
" csc_matrix.dot(M_explain.cpu().detach().numpy(), self.reducing_matrix_)\n", | |
" )\n", | |
"\n", | |
" if i == 0:\n", | |
" res_masks = masks\n", | |
" else:\n", | |
" for key, value in masks.items():\n", | |
" res_masks[key] = np.vstack([res_masks[key], value])\n", | |
" \n", | |
" res_explain = np.vstack(res_explain)\n", | |
" return res_explain, res_masks\n", | |
" \n", | |
" def predict(self, X):\n", | |
" y_proba = self.predict_proba(X)\n", | |
" return y_proba.argmax(-1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 63, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"LabelEncoder()" | |
] | |
}, | |
"execution_count": 63, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"label_encoder = LabelEncoder()\n", | |
"label_encoder.fit(y_train)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 64, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"y_train_enc = label_encoder.transform(y_train)\n", | |
"y_valid_enc = label_encoder.transform(y_valid)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 65, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class CrossEntropySparsityLoss(torch.nn.CrossEntropyLoss):\n", | |
" def __init__(self, lambda_sparse=1e-3):\n", | |
" super().__init__()\n", | |
" self.lambda_sparse = lambda_sparse\n", | |
" \n", | |
" def forward(self, y_pred, y_true):\n", | |
" output, M_loss = y_pred\n", | |
"\n", | |
" loss = super().forward(output, y_true)\n", | |
" \n", | |
" # Add the overall sparsity loss\n", | |
" loss -= self.lambda_sparse * M_loss\n", | |
" \n", | |
" return loss" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 66, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"skorch_clf = SkorchTabModel(\n", | |
" criterion=CrossEntropySparsityLoss,\n", | |
" \n", | |
" module__input_dim=X_train.shape[-1],\n", | |
" module__output_dim=infer_output_dim(y_train)[0],\n", | |
" module__cat_idxs=cat_idxs,\n", | |
" module__cat_dims=cat_dims,\n", | |
" module__cat_emb_dim=1,\n", | |
" module__mask_type='entmax', # \"sparsemax\"\n", | |
" module__virtual_batch_size=128,\n", | |
" \n", | |
" optimizer=torch.optim.Adam,\n", | |
" optimizer__lr=2e-2,\n", | |
" \n", | |
" batch_size=1024,\n", | |
" iterator_train__num_workers=0,\n", | |
" iterator_train__drop_last=False,\n", | |
" iterator_valid__num_workers=0,\n", | |
" iterator_valid__drop_last=False,\n", | |
" \n", | |
" callbacks=[\n", | |
" skorch.callbacks.LRScheduler(\n", | |
" policy=torch.optim.lr_scheduler.StepLR,\n", | |
" step_size=50,\n", | |
" gamma=0.9,\n", | |
" ),\n", | |
" skorch.callbacks.EarlyStopping(patience=20),\n", | |
" skorch.callbacks.GradientNormClipping(gradient_clip_value=1.),\n", | |
" skorch.callbacks.EpochScoring('roc_auc'),\n", | |
" ],\n", | |
" train_split=predefined_split(skorch.dataset.Dataset(X_valid, y_valid_enc)),\n", | |
" max_epochs=max_epochs,\n", | |
" \n", | |
" device='cuda',\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 67, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"skorch_clf.initialize()\n", | |
"skorch_clf.load_params('ble.pt')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 68, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Automatic pdb calling has been turned ON\n", | |
"Re-initializing optimizer because the following parameters were re-set: lr.\n", | |
" epoch roc_auc train_loss valid_loss lr dur\n", | |
"------- --------- ------------ ------------ ------ ------\n", | |
" 1 \u001b[36m0.7798\u001b[0m \u001b[32m0.5105\u001b[0m \u001b[35m0.6061\u001b[0m 0.0200 1.6014\n", | |
" 2 0.8398 \u001b[32m0.4094\u001b[0m \u001b[35m0.5071\u001b[0m 0.0200 1.6054\n", | |
" 3 0.8557 \u001b[32m0.3883\u001b[0m \u001b[35m0.4586\u001b[0m 0.0200 1.5067\n", | |
" 4 0.8740 \u001b[32m0.3762\u001b[0m \u001b[35m0.4217\u001b[0m 0.0200 1.5827\n", | |
" 5 0.8743 \u001b[32m0.3664\u001b[0m \u001b[35m0.3860\u001b[0m 0.0200 1.5149\n", | |
" 6 0.8841 \u001b[32m0.3616\u001b[0m \u001b[35m0.3518\u001b[0m 0.0200 1.5379\n", | |
" 7 0.8913 \u001b[32m0.3562\u001b[0m \u001b[35m0.3398\u001b[0m 0.0200 1.5068\n", | |
" 8 0.8944 \u001b[32m0.3425\u001b[0m 0.3400 0.0200 1.5603\n", | |
" 9 0.8916 \u001b[32m0.3336\u001b[0m 0.4340 0.0200 1.5260\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<class '__main__.SkorchTabModel'>[initialized](\n", | |
" module_=TabNet(\n", | |
" (embedder): EmbeddingGenerator(\n", | |
" (embeddings): ModuleList(\n", | |
" (0): Embedding(73, 1)\n", | |
" (1): Embedding(9, 1)\n", | |
" (2): Embedding(16, 1)\n", | |
" (3): Embedding(16, 1)\n", | |
" (4): Embedding(7, 1)\n", | |
" (5): Embedding(15, 1)\n", | |
" (6): Embedding(6, 1)\n", | |
" (7): Embedding(5, 1)\n", | |
" (8): Embedding(2, 1)\n", | |
" (9): Embedding(119, 1)\n", | |
" (10): Embedding(92, 1)\n", | |
" (11): Embedding(94, 1)\n", | |
" (12): Embedding(42, 1)\n", | |
" )\n", | |
" )\n", | |
" (tabnet): TabNetNoEmbeddings(\n", | |
" (initial_bn): BatchNorm1d(14, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)\n", | |
" (encoder): TabNetEncoder(\n", | |
" (initial_bn): BatchNorm1d(14, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)\n", | |
" (initial_splitter): FeatTransformer(\n", | |
" (shared): GLU_Block(\n", | |
" (shared_layers): ModuleList(\n", | |
" (0): Linear(in_features=14, out_features=32, bias=False)\n", | |
" (1): Linear(in_features=16, out_features=32, bias=False)\n", | |
" )\n", | |
" (glu_layers): ModuleList(\n", | |
" (0): GLU_Layer(\n", | |
" (fc): Linear(in_features=14, out_features=32, bias=False)\n", | |
" (bn): GBN(\n", | |
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): GLU_Layer(\n", | |
" (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
" (bn): GBN(\n", | |
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" (specifics): GLU_Block(\n", | |
" (glu_layers): ModuleList(\n", | |
" (0): GLU_Layer(\n", | |
" (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
" (bn): GBN(\n", | |
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): GLU_Layer(\n", | |
" (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
" (bn): GBN(\n", | |
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" (feat_transformers): ModuleList(\n", | |
" (0): FeatTransformer(\n", | |
" (shared): GLU_Block(\n", | |
" (shared_layers): ModuleList(\n", | |
" (0): Linear(in_features=14, out_features=32, bias=False)\n", | |
" (1): Linear(in_features=16, out_features=32, bias=False)\n", | |
" )\n", | |
" (glu_layers): ModuleList(\n", | |
" (0): GLU_Layer(\n", | |
" (fc): Linear(in_features=14, out_features=32, bias=False)\n", | |
" (bn): GBN(\n", | |
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): GLU_Layer(\n", | |
" (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
" (bn): GBN(\n", | |
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" (specifics): GLU_Block(\n", | |
" (glu_layers): ModuleList(\n", | |
" (0): GLU_Layer(\n", | |
" (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
" (bn): GBN(\n", | |
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): GLU_Layer(\n", | |
" (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
" (bn): GBN(\n", | |
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" (1): FeatTransformer(\n", | |
" (shared): GLU_Block(\n", | |
" (shared_layers): ModuleList(\n", | |
" (0): Linear(in_features=14, out_features=32, bias=False)\n", | |
" (1): Linear(in_features=16, out_features=32, bias=False)\n", | |
" )\n", | |
" (glu_layers): ModuleList(\n", | |
" (0): GLU_Layer(\n", | |
" (fc): Linear(in_features=14, out_features=32, bias=False)\n", | |
" (bn): GBN(\n", | |
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): GLU_Layer(\n", | |
" (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
" (bn): GBN(\n", | |
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" (specifics): GLU_Block(\n", | |
" (glu_layers): ModuleList(\n", | |
" (0): GLU_Layer(\n", | |
" (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
" (bn): GBN(\n", | |
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): GLU_Layer(\n", | |
" (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
" (bn): GBN(\n", | |
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" (2): FeatTransformer(\n", | |
" (shared): GLU_Block(\n", | |
" (shared_layers): ModuleList(\n", | |
" (0): Linear(in_features=14, out_features=32, bias=False)\n", | |
" (1): Linear(in_features=16, out_features=32, bias=False)\n", | |
" )\n", | |
" (glu_layers): ModuleList(\n", | |
" (0): GLU_Layer(\n", | |
" (fc): Linear(in_features=14, out_features=32, bias=False)\n", | |
" (bn): GBN(\n", | |
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): GLU_Layer(\n", | |
" (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
" (bn): GBN(\n", | |
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" (specifics): GLU_Block(\n", | |
" (glu_layers): ModuleList(\n", | |
" (0): GLU_Layer(\n", | |
" (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
" (bn): GBN(\n", | |
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): GLU_Layer(\n", | |
" (fc): Linear(in_features=16, out_features=32, bias=False)\n", | |
" (bn): GBN(\n", | |
" (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" (att_transformers): ModuleList(\n", | |
" (0): AttentiveTransformer(\n", | |
" (fc): Linear(in_features=8, out_features=14, bias=False)\n", | |
" (bn): GBN(\n", | |
" (bn): BatchNorm1d(14, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" (selector): Entmax15()\n", | |
" )\n", | |
" (1): AttentiveTransformer(\n", | |
" (fc): Linear(in_features=8, out_features=14, bias=False)\n", | |
" (bn): GBN(\n", | |
" (bn): BatchNorm1d(14, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" (selector): Entmax15()\n", | |
" )\n", | |
" (2): AttentiveTransformer(\n", | |
" (fc): Linear(in_features=8, out_features=14, bias=False)\n", | |
" (bn): GBN(\n", | |
" (bn): BatchNorm1d(14, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" (selector): Entmax15()\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" (final_mapping): Linear(in_features=8, out_features=2, bias=False)\n", | |
" )\n", | |
" ),\n", | |
")" | |
] | |
}, | |
"execution_count": 68, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%pdb on\n", | |
"skorch_clf.fit(\n", | |
" X_train, \n", | |
" y_train_enc,\n", | |
" #weights=1,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 69, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.8094909975251289" | |
] | |
}, | |
"execution_count": 69, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sklearn.metrics.roc_auc_score(y_train_enc, skorch_clf.predict(X_train))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 70, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.809482523487759" | |
] | |
}, | |
"execution_count": 70, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sklearn.metrics.roc_auc_score(y_valid_enc, skorch_clf.predict(X_valid))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 71, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(array([[0.00000000e+00, 3.16206515e-01, 0.00000000e+00, ...,\n", | |
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", | |
" [0.00000000e+00, 4.00576532e-01, 0.00000000e+00, ...,\n", | |
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", | |
" [0.00000000e+00, 4.18307304e-01, 0.00000000e+00, ...,\n", | |
" 0.00000000e+00, 1.17498322e-03, 0.00000000e+00],\n", | |
" ...,\n", | |
" [0.00000000e+00, 3.26685965e-01, 0.00000000e+00, ...,\n", | |
" 1.23882025e-01, 5.39707905e-03, 0.00000000e+00],\n", | |
" [1.30514791e-02, 2.11704329e-01, 0.00000000e+00, ...,\n", | |
" 1.28650689e+00, 3.36281955e-03, 1.37906140e-02],\n", | |
" [1.27431333e-01, 1.56031281e-01, 1.24305105e-02, ...,\n", | |
" 2.07836628e-01, 1.33876745e-02, 9.61675271e-02]]),\n", | |
" {0: array([[0.00000000e+00, 9.32369530e-02, 0.00000000e+00, ...,\n", | |
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", | |
" [0.00000000e+00, 4.57307428e-01, 0.00000000e+00, ...,\n", | |
" 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n", | |
" [0.00000000e+00, 6.30757153e-01, 0.00000000e+00, ...,\n", | |
" 0.00000000e+00, 1.77173363e-03, 0.00000000e+00],\n", | |
" ...,\n", | |
" [0.00000000e+00, 1.68548390e-01, 0.00000000e+00, ...,\n", | |
" 6.39149472e-02, 2.78453645e-03, 0.00000000e+00],\n", | |
" [2.71768179e-02, 1.00498842e-02, 0.00000000e+00, ...,\n", | |
" 1.84161370e-04, 7.00232806e-03, 2.87159029e-02],\n", | |
" [1.20953389e-01, 1.07856750e-01, 1.24875447e-02, ...,\n", | |
" 1.23726524e-01, 1.34490998e-02, 9.66087654e-02]]),\n", | |
" 1: array([[0. , 0. , 0. , ..., 0. , 0. ,\n", | |
" 0. ],\n", | |
" [0. , 0. , 0. , ..., 0. , 0. ,\n", | |
" 0. ],\n", | |
" [0. , 0. , 0. , ..., 0. , 0. ,\n", | |
" 0. ],\n", | |
" ...,\n", | |
" [0. , 0. , 0. , ..., 0. , 0. ,\n", | |
" 0. ],\n", | |
" [0. , 0.02491164, 0. , ..., 0.15490678, 0. ,\n", | |
" 0. ],\n", | |
" [0.01362148, 0.09429348, 0. , ..., 0.1640597 , 0. ,\n", | |
" 0. ]]),\n", | |
" 2: array([[0., 0., 0., ..., 0., 0., 0.],\n", | |
" [0., 0., 0., ..., 0., 0., 0.],\n", | |
" [0., 0., 0., ..., 0., 0., 0.],\n", | |
" ...,\n", | |
" [0., 0., 0., ..., 0., 0., 0.],\n", | |
" [0., 0., 0., ..., 0., 0., 0.],\n", | |
" [0., 0., 0., ..., 0., 0., 0.]])})" | |
] | |
}, | |
"execution_count": 71, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"skorch_clf.explain(X_valid)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "raw", | |
"metadata": {}, | |
"source": [ | |
"skorch_clf.save_params(f_params='ble.pt')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment