Created
May 22, 2020 12:59
-
-
Save iwatobipen/c3674d951891a2f8b5c1b0e94cd652e6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%matplotlib inline\n", | |
"import os\n", | |
"import pandas as pd\n", | |
"import numpy as np\n", | |
"from rdkit import Chem\n", | |
"from rdkit.Chem import RDConfig\n", | |
"from rdkit.Chem import DataStructs\n", | |
"from rdkit.Chem import AllChem\n", | |
"from rdkit.Chem.Draw import IPythonConsole\n", | |
"from rdkit.Chem import Draw\n", | |
"from sklearn.ensemble import RandomForestClassifier\n", | |
"from nonconformist.nc import ClassifierNc\n", | |
"from nonconformist.nc import ClassifierAdapter\n", | |
"from nonconformist.icp import IcpClassifier\n", | |
"from nonconformist.evaluation import ClassIcpCvHelper\n", | |
"from nonconformist.evaluation import cross_val_score" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train = os.path.join(RDConfig.RDDocsDir, 'Book/data/solubility.train.sdf')\n", | |
"test = os.path.join(RDConfig.RDDocsDir, 'Book/data/solubility.test.sdf')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"trainmol = [m for m in Chem.SDMolSupplier(train)]\n", | |
"testmol = [m for m in Chem.SDMolSupplier(test)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'(A) low', '(C) high', '(B) medium'}\n" | |
] | |
} | |
], | |
"source": [ | |
"labels = set([m.GetProp('SOL_classification') for m in trainmol])\n", | |
"print(labels)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"label2cls = {'(A) low':0, '(B) medium':1, '(C) high':2}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def fp2arr(fp):\n", | |
" arr = np.zeros((1,))\n", | |
" DataStructs.ConvertToNumpyArray(fp, arr)\n", | |
" return arr" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"trainfps = [AllChem.GetMorganFingerprintAsBitVect(m, 2, 1024) for m in trainmol]\n", | |
"trainfps = np.array([fp2arr(fp) for fp in trainfps])\n", | |
"\n", | |
"testfps = [AllChem.GetMorganFingerprintAsBitVect(m, 2, 1024) for m in testmol]\n", | |
"testfps = np.array([fp2arr(fp) for fp in testfps])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(1025, 1024) (1025,) (257, 1024) (257,)\n" | |
] | |
} | |
], | |
"source": [ | |
"train_cls = [label2cls[m.GetProp('SOL_classification')] for m in trainmol]\n", | |
"train_cls = np.array(train_cls)\n", | |
"test_cls = [label2cls[m.GetProp('SOL_classification')] for m in testmol]\n", | |
"test_cls = np.array(test_cls)\n", | |
"print(trainfps.shape, train_cls.shape, testfps.shape, test_cls.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#train data is devided to train and calibration data\n", | |
"ids = np.random.permutation(train_cls.size)\n", | |
"# Use first 700 data for train and second set is used for calibration\n", | |
"trainX, calibX = trainfps[ids[:700],:],trainfps[ids[700:],:] \n", | |
"trainY, calibY = train_cls[ids[:700]],train_cls[ids[700:]] " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"testX = testfps\n", | |
"testY = test_cls" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"rf = RandomForestClassifier(n_estimators=500, random_state=794)\n", | |
"nc = ClassifierNc(ClassifierAdapter(rf))\n", | |
"icp = IcpClassifier(nc)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"icp.fit(trainX, trainY)\n", | |
"icp.calibrate(calibX, calibY)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"pred = icp.predict(testX)\n", | |
"pred95 = icp.predict(testX, significance=0.05).astype(np.int32)\n", | |
"pred80 = icp.predict(testX, significance=0.2).astype(np.int32)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from nonconformist.evaluation import class_avg_c, class_n_correct" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"244" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"class_n_correct(pred, testY, significance=0.05)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"232" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"class_n_correct(pred, testY, significance=0.1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0 2 False [0 1 1] : [1 1 1]\n", | |
"0 1 False [0 1 1] : [1 1 1]\n", | |
"1 1 True [0 1 0] : [1 1 0]\n", | |
"1 1 True [0 1 0] : [1 1 0]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"1 1 True [0 1 0] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"1 0 False [1 1 0] : [1 1 1]\n", | |
"0 1 False [0 1 0] : [1 1 0]\n", | |
"0 1 False [1 1 0] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 1 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 1 0] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 1 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 1 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"1 0 False [1 1 0] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"1 1 True [0 1 0] : [1 1 1]\n", | |
"1 1 True [0 1 0] : [0 1 1]\n", | |
"1 0 False [1 1 0] : [1 1 0]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"1 1 True [0 1 0] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 1 False [0 1 0] : [0 1 0]\n", | |
"1 0 False [1 0 0] : [1 1 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"1 1 True [0 1 0] : [1 1 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"2 2 True [0 0 1] : [0 1 1]\n", | |
"2 2 True [0 0 1] : [0 1 1]\n", | |
"2 2 True [0 0 1] : [0 0 1]\n", | |
"1 2 False [0 0 1] : [0 1 1]\n", | |
"2 2 True [0 0 1] : [0 0 1]\n", | |
"1 1 True [1 1 1] : [1 1 1]\n", | |
"1 2 False [0 0 1] : [0 1 1]\n", | |
"1 2 False [0 0 1] : [0 0 1]\n", | |
"2 2 True [0 0 1] : [0 0 1]\n", | |
"1 1 True [0 1 0] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 1 0]\n", | |
"1 2 False [0 1 1] : [1 1 1]\n", | |
"2 1 False [0 1 0] : [0 1 1]\n", | |
"2 1 False [0 1 0] : [1 1 1]\n", | |
"1 1 True [1 1 0] : [1 1 1]\n", | |
"1 0 False [1 0 0] : [1 1 0]\n", | |
"2 2 True [0 1 1] : [1 1 1]\n", | |
"0 1 False [0 1 0] : [0 1 0]\n", | |
"2 1 False [0 1 1] : [1 1 1]\n", | |
"2 0 False [1 1 1] : [1 1 1]\n", | |
"1 0 False [1 1 0] : [1 1 1]\n", | |
"2 1 False [0 1 0] : [1 1 1]\n", | |
"0 1 False [0 1 0] : [0 1 0]\n", | |
"2 2 True [0 0 1] : [0 0 1]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"1 1 True [0 1 0] : [1 1 1]\n", | |
"1 1 True [1 1 0] : [1 1 1]\n", | |
"0 1 False [0 1 0] : [0 1 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"2 2 True [0 0 1] : [0 1 1]\n", | |
"2 2 True [0 0 1] : [0 0 1]\n", | |
"1 2 False [0 0 1] : [0 0 1]\n", | |
"1 2 False [0 0 1] : [0 1 1]\n", | |
"1 0 False [1 0 0] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"2 2 True [0 0 1] : [0 0 1]\n", | |
"2 2 True [0 0 1] : [1 1 1]\n", | |
"1 2 False [0 1 1] : [1 1 1]\n", | |
"1 0 False [1 1 0] : [1 1 1]\n", | |
"1 0 False [1 1 0] : [1 1 1]\n", | |
"2 2 True [0 0 1] : [1 1 1]\n", | |
"1 0 False [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"1 1 True [0 1 0] : [0 1 1]\n", | |
"1 1 True [1 1 0] : [1 1 1]\n", | |
"1 0 False [1 0 0] : [1 1 0]\n", | |
"0 1 False [1 1 0] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"1 1 True [0 1 0] : [1 1 0]\n", | |
"1 1 True [0 1 0] : [1 1 0]\n", | |
"1 1 True [0 1 0] : [1 1 1]\n", | |
"1 0 False [1 1 0] : [1 1 1]\n", | |
"1 0 False [1 0 0] : [1 1 1]\n", | |
"2 1 False [0 1 0] : [1 1 1]\n", | |
"1 1 True [1 1 1] : [1 1 1]\n", | |
"2 1 False [0 1 1] : [1 1 1]\n", | |
"1 1 True [1 1 0] : [1 1 1]\n", | |
"2 1 False [0 1 0] : [1 1 1]\n", | |
"0 1 False [0 1 0] : [1 1 0]\n", | |
"1 1 True [0 1 0] : [1 1 1]\n", | |
"1 1 True [0 1 0] : [0 1 1]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"0 0 True [1 0 0] : [1 1 0]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"1 1 True [0 1 0] : [1 1 1]\n", | |
"2 1 False [0 1 0] : [1 1 1]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"0 1 False [0 1 0] : [0 1 0]\n", | |
"0 2 False [1 1 1] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"1 0 False [1 1 0] : [1 1 1]\n", | |
"1 1 True [0 1 0] : [1 1 0]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"1 1 True [0 1 0] : [1 1 0]\n", | |
"0 1 False [1 1 0] : [1 1 1]\n", | |
"1 1 True [1 1 1] : [1 1 1]\n", | |
"2 1 False [0 1 0] : [1 1 1]\n", | |
"1 2 False [1 1 1] : [1 1 1]\n", | |
"2 2 True [0 0 1] : [1 1 1]\n", | |
"2 2 True [0 0 1] : [0 0 1]\n", | |
"2 1 False [0 1 1] : [1 1 1]\n", | |
"2 2 True [0 1 1] : [1 1 1]\n", | |
"2 2 True [0 0 1] : [0 0 1]\n", | |
"0 1 False [1 1 0] : [1 1 0]\n", | |
"2 2 True [0 0 1] : [0 0 1]\n", | |
"2 2 True [0 0 1] : [0 1 1]\n", | |
"1 1 True [0 1 0] : [1 1 1]\n", | |
"2 1 False [0 1 0] : [0 1 1]\n", | |
"0 1 False [1 1 0] : [1 1 0]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"1 1 True [0 1 0] : [0 1 1]\n", | |
"1 0 False [1 1 0] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 1 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 1 False [1 1 0] : [1 1 1]\n", | |
"2 1 False [0 1 1] : [1 1 1]\n", | |
"2 2 True [0 0 1] : [0 1 1]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"2 2 True [0 1 1] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 1 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"2 2 True [0 0 1] : [1 1 1]\n", | |
"1 0 False [1 1 0] : [1 1 1]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 1 0]\n", | |
"0 0 True [1 1 0] : [1 1 1]\n", | |
"1 1 True [1 1 0] : [1 1 1]\n", | |
"1 1 True [0 1 1] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 1 1]\n", | |
"1 0 False [1 0 0] : [1 0 0]\n", | |
"1 1 True [0 1 0] : [1 1 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"1 0 False [1 1 0] : [1 1 1]\n", | |
"1 0 False [1 0 1] : [1 1 1]\n", | |
"1 1 True [1 1 0] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 1 0]\n", | |
"1 2 False [0 0 1] : [1 1 1]\n", | |
"0 1 False [0 1 0] : [1 1 1]\n", | |
"1 1 True [0 1 0] : [1 1 1]\n", | |
"2 2 True [0 0 1] : [0 0 1]\n", | |
"2 2 True [0 1 1] : [1 1 1]\n", | |
"2 2 True [0 0 1] : [0 1 1]\n", | |
"1 1 True [0 1 0] : [1 1 1]\n", | |
"2 2 True [0 0 1] : [1 1 1]\n", | |
"1 2 False [0 1 1] : [1 1 1]\n", | |
"1 1 True [0 1 0] : [1 1 1]\n", | |
"0 2 False [0 0 1] : [1 1 1]\n", | |
"2 2 True [1 0 1] : [1 1 1]\n", | |
"0 1 False [0 1 0] : [1 1 0]\n", | |
"1 1 True [0 1 1] : [1 1 1]\n", | |
"1 2 False [0 0 1] : [0 1 1]\n", | |
"1 1 True [1 1 0] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 1 0]\n", | |
"1 1 True [0 1 0] : [1 1 0]\n", | |
"1 1 True [0 1 0] : [1 1 0]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"1 0 False [1 1 0] : [1 1 1]\n", | |
"1 2 False [0 0 1] : [0 1 1]\n", | |
"1 2 False [0 1 1] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"1 0 False [1 1 0] : [1 1 1]\n", | |
"1 0 False [1 1 0] : [1 1 1]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"1 1 True [0 1 0] : [1 1 1]\n", | |
"1 0 False [1 0 0] : [1 0 0]\n", | |
"1 0 False [1 0 0] : [1 1 0]\n", | |
"0 1 False [0 1 0] : [0 1 0]\n", | |
"1 1 True [0 1 0] : [1 1 1]\n", | |
"1 0 False [1 1 1] : [1 1 1]\n", | |
"1 1 True [1 1 0] : [1 1 1]\n", | |
"1 0 False [1 0 0] : [1 1 1]\n", | |
"1 0 False [1 1 1] : [1 1 1]\n", | |
"1 0 False [1 0 0] : [1 1 1]\n", | |
"0 1 False [1 1 0] : [1 1 1]\n", | |
"1 1 True [0 1 1] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 1 0]\n", | |
"0 1 False [1 1 0] : [1 1 1]\n", | |
"0 0 True [1 1 0] : [1 1 0]\n", | |
"1 1 True [1 1 0] : [1 1 1]\n", | |
"1 1 True [0 1 0] : [1 1 1]\n", | |
"1 0 False [1 1 0] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 1 0]\n", | |
"0 0 True [1 1 0] : [1 1 0]\n", | |
"0 0 True [1 1 0] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"1 1 True [0 1 0] : [0 1 0]\n", | |
"1 1 True [0 1 0] : [1 1 0]\n", | |
"0 1 False [0 1 0] : [1 1 0]\n", | |
"1 1 True [0 1 0] : [1 1 1]\n", | |
"0 0 True [1 1 0] : [1 1 1]\n", | |
"0 0 True [1 1 0] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 1 0]\n", | |
"1 1 True [1 1 0] : [1 1 1]\n", | |
"1 1 True [0 1 0] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 1 1]\n", | |
"1 1 True [1 1 0] : [1 1 1]\n", | |
"0 2 False [0 0 1] : [0 1 1]\n", | |
"0 2 False [0 1 1] : [1 1 1]\n", | |
"1 0 False [1 0 0] : [1 1 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n", | |
"0 0 True [1 0 0] : [1 1 0]\n", | |
"2 2 True [0 0 1] : [1 1 1]\n", | |
"0 0 True [1 0 0] : [1 1 0]\n", | |
"0 0 True [1 0 0] : [1 0 0]\n" | |
] | |
} | |
], | |
"source": [ | |
"tp = 0\n", | |
"for idx, j in enumerate(testY):\n", | |
" print(j, np.argmax(pred[idx]), j == np.argmax(pred[idx]) , pred80[idx], \":\", pred95[idx])\n", | |
" if j == np.argmax(pred[idx]):\n", | |
" tp += 1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0.6848249027237354\n" | |
] | |
} | |
], | |
"source": [ | |
"print(tp/testY.size)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2.0155642023346303\n", | |
"1.2996108949416343\n" | |
] | |
} | |
], | |
"source": [ | |
"print(class_avg_c(pred, testY, significance=0.05))\n", | |
"print(class_avg_c(pred, testY, significance=0.2))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"icpmodel = ClassIcpCvHelper(icp)\n", | |
"res = cross_val_score(icpmodel,\n", | |
" trainfps,\n", | |
" train_cls,\n", | |
" iterations=10,\n", | |
" scoring_funcs=[class_avg_c],\n", | |
" significance_levels=[0.05, 0.1, 0.2],\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>iter</th>\n", | |
" <th>fold</th>\n", | |
" <th>significance</th>\n", | |
" <th>class_avg_c</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0.05</td>\n", | |
" <td>2.135922</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0.10</td>\n", | |
" <td>1.708738</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0.20</td>\n", | |
" <td>1.330097</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>0.05</td>\n", | |
" <td>2.495146</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>0.10</td>\n", | |
" <td>1.922330</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>0.20</td>\n", | |
" <td>1.281553</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>6</th>\n", | |
" <td>0</td>\n", | |
" <td>2</td>\n", | |
" <td>0.05</td>\n", | |
" <td>2.194175</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>0</td>\n", | |
" <td>2</td>\n", | |
" <td>0.10</td>\n", | |
" <td>1.582524</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>8</th>\n", | |
" <td>0</td>\n", | |
" <td>2</td>\n", | |
" <td>0.20</td>\n", | |
" <td>1.271845</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>9</th>\n", | |
" <td>0</td>\n", | |
" <td>3</td>\n", | |
" <td>0.05</td>\n", | |
" <td>2.000000</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" iter fold significance class_avg_c\n", | |
"0 0 0 0.05 2.135922\n", | |
"1 0 0 0.10 1.708738\n", | |
"2 0 0 0.20 1.330097\n", | |
"3 0 1 0.05 2.495146\n", | |
"4 0 1 0.10 1.922330\n", | |
"5 0 1 0.20 1.281553\n", | |
"6 0 2 0.05 2.194175\n", | |
"7 0 2 0.10 1.582524\n", | |
"8 0 2 0.20 1.271845\n", | |
"9 0 3 0.05 2.000000" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"res.head(10)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment