Skip to content

Instantly share code, notes, and snippets.

@clementpoiret
Last active March 25, 2020 14:28
Show Gist options
  • Save clementpoiret/0dc4021a3e81f0c60fa49f9d5dc7c602 to your computer and use it in GitHub Desktop.
Save clementpoiret/0dc4021a3e81f0c60fa49f9d5dc7c602 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import pingouin as pg\n",
"import seaborn as sns\n",
"import tensorflow as tf\n",
"from sklearn.model_selection import KFold\n",
"from sklearn.preprocessing import MinMaxScaler, StandardScaler\n",
"from tensorflow.keras.activations import relu, sigmoid\n",
"from tensorflow.keras.callbacks import CSVLogger\n",
"from tensorflow.keras.layers import BatchNormalization, Dense, Input\n",
"from tensorflow.keras.models import Model"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def get_dataset(path, pheno, shape, kind, standardize=False):\n",
" connectivities = np.load(path, allow_pickle=True)\n",
"\n",
" connectivity = connectivities[kind] \n",
"\n",
" pheno = pd.read_csv(pheno)\n",
" pheno = pheno[pheno.iloc[:, 0].isin(connectivity.keys())]\n",
"\n",
" #subjects = pheno.iloc[:, 0]\n",
" subjects = np.array(list(connectivity.keys()))\n",
" subjects = subjects[np.isin(subjects, pheno.iloc[:, 0])]\n",
" \n",
" X = np.array([\n",
" connectivity[subject][np.tril_indices(shape, k=-1)] for subject in subjects\n",
" ])\n",
" X = np.hstack((X, np.array([pheno[pheno.iloc[:, 0]==ID].Sex.values[0] for ID in subjects]).reshape(-1,1)))\n",
" \n",
" sc_X = StandardScaler()\n",
" \n",
" if standardize:\n",
" X = sc_X.fit_transform(X)\n",
"\n",
" y = np.array([pheno[pheno.iloc[:, 0]==ID].Age.values[0] for ID in subjects]).reshape(-1,1)\n",
" sc_y = MinMaxScaler()\n",
" y = sc_y.fit_transform(y)\n",
" \n",
" return (X, y, sc_X, sc_y)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def get_model():\n",
" input = Input(shape=(X.shape[1]))\n",
" a = Dense(128,\n",
" activation='relu',\n",
" kernel_initializer='he_uniform')(input)\n",
" a = tf.keras.layers.Dropout(0.5)(a)\n",
" a = Dense(128,\n",
" activation='relu',\n",
" kernel_initializer='he_uniform')(a)\n",
" a = tf.keras.layers.Dropout(0.5)(a)\n",
" a = Dense(128,\n",
" activation='relu',\n",
" kernel_initializer='he_uniform')(a)\n",
" a = tf.keras.layers.Dropout(0.5)(a)\n",
" a = Dense(1, activation='sigmoid')(a)\n",
"\n",
" model = Model(inputs=[input], outputs=[a])\n",
"\n",
" model.compile(optimizer='adam',\n",
" loss=\"mean_squared_error\",\n",
" metrics=[\"mean_absolute_error\"])\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'correlation': {'old': 'data/cbic_r7_msdl_correlation_matrices.pkl',\n",
" 'new': 'data/1584804874451_64subjects_correlation.pkl'},\n",
" 'partial_correlation': {'old': 'data/cbic_r7_msdl_partial correlation_matrices.pkl',\n",
" 'new': 'data/1584806176292_64subjects_partial-correlation.pkl'}}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"datasets = {'correlation': {'old': 'data/cbic_r7_msdl_correlation_matrices.pkl',\n",
" 'new': 'data/1584804874451_64subjects_correlation.pkl'},\n",
" 'partial_correlation': {'old': 'data/cbic_r7_msdl_partial correlation_matrices.pkl',\n",
" 'new': 'data/1584806176292_64subjects_partial-correlation.pkl'}}\n",
"datasets"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'data/cbic_r7_msdl_correlation_matrices.pkl'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"datasets['correlation']['old']"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"tags": [
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend",
"outputPrepend"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trial 0\n",
"Training on correlation\n",
"7/7 [==============================] - 0s 802us/sample - loss: 0.0740 - mean_absolute_error: 0.2326\n",
"7/7 [==============================] - 0s 805us/sample - loss: 0.0702 - mean_absolute_error: 0.1735\n",
"7/7 [==============================] - 0s 802us/sample - loss: 0.1075 - mean_absolute_error: 0.3009\n",
"6/6 [==============================] - 0s 933us/sample - loss: 0.0457 - mean_absolute_error: 0.1847\n",
"6/6 [==============================] - 0s 907us/sample - loss: 0.0119 - mean_absolute_error: 0.0988\n",
"6/6 [==============================] - 0s 943us/sample - loss: 0.0806 - mean_absolute_error: 0.2346\n",
"6/6 [==============================] - 0s 883us/sample - loss: 0.0602 - mean_absolute_error: 0.2078\n",
"6/6 [==============================] - 0s 789us/sample - loss: 0.0282 - mean_absolute_error: 0.1338\n",
"6/6 [==============================] - 0s 506us/sample - loss: 0.0584 - mean_absolute_error: 0.1904\n",
"6/6 [==============================] - 0s 916us/sample - loss: 0.0532 - mean_absolute_error: 0.1885\n",
"Training on partial correlation\n",
"7/7 [==============================] - 0s 785us/sample - loss: 0.0883 - mean_absolute_error: 0.2447\n",
"7/7 [==============================] - 0s 785us/sample - loss: 0.0461 - mean_absolute_error: 0.1828\n",
"7/7 [==============================] - 0s 842us/sample - loss: 0.0419 - mean_absolute_error: 0.1640\n",
"6/6 [==============================] - 0s 966us/sample - loss: 0.0486 - mean_absolute_error: 0.1937\n",
"6/6 [==============================] - 0s 938us/sample - loss: 0.0674 - mean_absolute_error: 0.2075\n",
"6/6 [==============================] - 0s 904us/sample - loss: 0.0595 - mean_absolute_error: 0.1881\n",
"6/6 [==============================] - 0s 893us/sample - loss: 0.0333 - mean_absolute_error: 0.1623\n",
"6/6 [==============================] - 0s 933us/sample - loss: 0.0336 - mean_absolute_error: 0.1463\n",
"6/6 [==============================] - 0s 938us/sample - loss: 0.0708 - mean_absolute_error: 0.2342\n",
"6/6 [==============================] - 0s 912us/sample - loss: 0.1105 - mean_absolute_error: 0.2601\n",
"Training on tangent\n",
"7/7 [==============================] - 0s 782us/sample - loss: 0.0452 - mean_absolute_error: 0.1810\n",
"7/7 [==============================] - 0s 715us/sample - loss: 0.0952 - mean_absolute_error: 0.2774\n",
"7/7 [==============================] - 0s 762us/sample - loss: 0.0333 - mean_absolute_error: 0.1549\n",
"6/6 [==============================] - 0s 344us/sample - loss: 0.0703 - mean_absolute_error: 0.2336\n",
"6/6 [==============================] - 0s 951us/sample - loss: 0.0772 - mean_absolute_error: 0.1817\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0630 - mean_absolute_error: 0.2378\n",
"6/6 [==============================] - 0s 305us/sample - loss: 0.0098 - mean_absolute_error: 0.0855\n",
"6/6 [==============================] - 0s 931us/sample - loss: 0.0500 - mean_absolute_error: 0.1930\n",
"6/6 [==============================] - 0s 905us/sample - loss: 0.1243 - mean_absolute_error: 0.2955\n",
"6/6 [==============================] - 0s 889us/sample - loss: 0.0777 - mean_absolute_error: 0.2629\n",
"Trial 1\n",
"Training on correlation\n",
"7/7 [==============================] - 0s 762us/sample - loss: 0.1238 - mean_absolute_error: 0.2961\n",
"7/7 [==============================] - 0s 802us/sample - loss: 0.0361 - mean_absolute_error: 0.1706\n",
"7/7 [==============================] - 0s 825us/sample - loss: 0.1038 - mean_absolute_error: 0.2672\n",
"6/6 [==============================] - 0s 912us/sample - loss: 0.0821 - mean_absolute_error: 0.2716\n",
"6/6 [==============================] - 0s 938us/sample - loss: 0.0294 - mean_absolute_error: 0.1569\n",
"6/6 [==============================] - 0s 926us/sample - loss: 0.1056 - mean_absolute_error: 0.2988\n",
"6/6 [==============================] - 0s 911us/sample - loss: 0.0360 - mean_absolute_error: 0.1593\n",
"6/6 [==============================] - 0s 934us/sample - loss: 0.0575 - mean_absolute_error: 0.1981\n",
"6/6 [==============================] - 0s 309us/sample - loss: 0.1032 - mean_absolute_error: 0.2538\n",
"6/6 [==============================] - 0s 930us/sample - loss: 0.0441 - mean_absolute_error: 0.1750\n",
"Training on partial correlation\n",
"7/7 [==============================] - 0s 730us/sample - loss: 0.0221 - mean_absolute_error: 0.1278\n",
"7/7 [==============================] - 0s 733us/sample - loss: 0.0759 - mean_absolute_error: 0.2464\n",
"7/7 [==============================] - 0s 752us/sample - loss: 0.0741 - mean_absolute_error: 0.1657\n",
"6/6 [==============================] - 0s 923us/sample - loss: 0.0677 - mean_absolute_error: 0.2115\n",
"6/6 [==============================] - 0s 922us/sample - loss: 0.0376 - mean_absolute_error: 0.1511\n",
"6/6 [==============================] - 0s 891us/sample - loss: 0.0641 - mean_absolute_error: 0.2233\n",
"6/6 [==============================] - 0s 322us/sample - loss: 0.0643 - mean_absolute_error: 0.1987\n",
"6/6 [==============================] - 0s 928us/sample - loss: 0.0734 - mean_absolute_error: 0.2381\n",
"6/6 [==============================] - 0s 921us/sample - loss: 0.0311 - mean_absolute_error: 0.1577\n",
"6/6 [==============================] - 0s 893us/sample - loss: 0.0793 - mean_absolute_error: 0.2568\n",
"Training on tangent\n",
"7/7 [==============================] - 0s 754us/sample - loss: 0.0481 - mean_absolute_error: 0.1973\n",
"7/7 [==============================] - 0s 798us/sample - loss: 0.0259 - mean_absolute_error: 0.1466\n",
"7/7 [==============================] - 0s 351us/sample - loss: 0.0630 - mean_absolute_error: 0.2013\n",
"6/6 [==============================] - 0s 920us/sample - loss: 0.0704 - mean_absolute_error: 0.2326\n",
"6/6 [==============================] - 0s 899us/sample - loss: 0.0474 - mean_absolute_error: 0.1910\n",
"6/6 [==============================] - 0s 905us/sample - loss: 0.1100 - mean_absolute_error: 0.2748\n",
"6/6 [==============================] - 0s 917us/sample - loss: 0.1202 - mean_absolute_error: 0.2881\n",
"6/6 [==============================] - 0s 907us/sample - loss: 0.0515 - mean_absolute_error: 0.1951\n",
"6/6 [==============================] - 0s 921us/sample - loss: 0.0786 - mean_absolute_error: 0.2596\n",
"6/6 [==============================] - 0s 898us/sample - loss: 0.0766 - mean_absolute_error: 0.2505\n",
"Trial 2\n",
"Training on correlation\n",
"7/7 [==============================] - 0s 804us/sample - loss: 0.0785 - mean_absolute_error: 0.2542\n",
"7/7 [==============================] - 0s 765us/sample - loss: 0.0352 - mean_absolute_error: 0.1529\n",
"7/7 [==============================] - 0s 783us/sample - loss: 0.0515 - mean_absolute_error: 0.1866\n",
"6/6 [==============================] - 0s 252us/sample - loss: 0.0649 - mean_absolute_error: 0.2156\n",
"6/6 [==============================] - 0s 906us/sample - loss: 0.0866 - mean_absolute_error: 0.2406\n",
"6/6 [==============================] - 0s 900us/sample - loss: 0.0538 - mean_absolute_error: 0.1873\n",
"6/6 [==============================] - 0s 897us/sample - loss: 0.0658 - mean_absolute_error: 0.1828\n",
"6/6 [==============================] - 0s 901us/sample - loss: 0.0787 - mean_absolute_error: 0.2125\n",
"6/6 [==============================] - 0s 939us/sample - loss: 0.0793 - mean_absolute_error: 0.2321\n",
"6/6 [==============================] - 0s 942us/sample - loss: 0.0350 - mean_absolute_error: 0.1530\n",
"Training on partial correlation\n",
"7/7 [==============================] - 0s 839us/sample - loss: 0.0771 - mean_absolute_error: 0.2391\n",
"7/7 [==============================] - 0s 768us/sample - loss: 0.0798 - mean_absolute_error: 0.2191\n",
"7/7 [==============================] - 0s 801us/sample - loss: 0.0309 - mean_absolute_error: 0.1515\n",
"6/6 [==============================] - 0s 908us/sample - loss: 0.0226 - mean_absolute_error: 0.1262\n",
"6/6 [==============================] - 0s 912us/sample - loss: 0.0439 - mean_absolute_error: 0.1897\n",
"6/6 [==============================] - 0s 925us/sample - loss: 0.0489 - mean_absolute_error: 0.1738\n",
"6/6 [==============================] - 0s 662us/sample - loss: 0.0052 - mean_absolute_error: 0.0625\n",
"6/6 [==============================] - 0s 900us/sample - loss: 0.0965 - mean_absolute_error: 0.2774\n",
"6/6 [==============================] - 0s 901us/sample - loss: 0.0950 - mean_absolute_error: 0.2928\n",
"6/6 [==============================] - 0s 936us/sample - loss: 0.0595 - mean_absolute_error: 0.2185\n",
"Training on tangent\n",
"7/7 [==============================] - 0s 803us/sample - loss: 0.0730 - mean_absolute_error: 0.1900\n",
"7/7 [==============================] - 0s 798us/sample - loss: 0.0362 - mean_absolute_error: 0.1450\n",
"7/7 [==============================] - 0s 793us/sample - loss: 0.1019 - mean_absolute_error: 0.2417\n",
"6/6 [==============================] - 0s 917us/sample - loss: 0.0257 - mean_absolute_error: 0.1355\n",
"6/6 [==============================] - 0s 273us/sample - loss: 0.1171 - mean_absolute_error: 0.2872\n",
"6/6 [==============================] - 0s 907us/sample - loss: 0.0692 - mean_absolute_error: 0.2468\n",
"6/6 [==============================] - 0s 931us/sample - loss: 0.0412 - mean_absolute_error: 0.1778\n",
"6/6 [==============================] - 0s 912us/sample - loss: 0.0288 - mean_absolute_error: 0.1488\n",
"6/6 [==============================] - 0s 309us/sample - loss: 0.0788 - mean_absolute_error: 0.2624\n",
"6/6 [==============================] - 0s 921us/sample - loss: 0.0733 - mean_absolute_error: 0.2248\n",
"Trial 3\n",
"Training on correlation\n",
"7/7 [==============================] - 0s 808us/sample - loss: 0.0776 - mean_absolute_error: 0.2312\n",
"7/7 [==============================] - 0s 844us/sample - loss: 0.0333 - mean_absolute_error: 0.1527\n",
"7/7 [==============================] - 0s 798us/sample - loss: 0.0661 - mean_absolute_error: 0.2083\n",
"6/6 [==============================] - 0s 946us/sample - loss: 0.0506 - mean_absolute_error: 0.1817\n",
"6/6 [==============================] - 0s 923us/sample - loss: 0.1007 - mean_absolute_error: 0.2543\n",
"6/6 [==============================] - 0s 934us/sample - loss: 0.0421 - mean_absolute_error: 0.1752\n",
"6/6 [==============================] - 0s 900us/sample - loss: 0.0408 - mean_absolute_error: 0.1799\n",
"6/6 [==============================] - 0s 891us/sample - loss: 0.0941 - mean_absolute_error: 0.2434\n",
"6/6 [==============================] - 0s 940us/sample - loss: 0.1099 - mean_absolute_error: 0.3062\n",
"6/6 [==============================] - 0s 924us/sample - loss: 0.0742 - mean_absolute_error: 0.2227\n",
"Training on partial correlation\n",
"7/7 [==============================] - 0s 861us/sample - loss: 0.0772 - mean_absolute_error: 0.2595\n",
"7/7 [==============================] - 0s 800us/sample - loss: 0.0260 - mean_absolute_error: 0.1158\n",
"7/7 [==============================] - 0s 798us/sample - loss: 0.1019 - mean_absolute_error: 0.2297\n",
"6/6 [==============================] - 0s 276us/sample - loss: 0.0542 - mean_absolute_error: 0.1761\n",
"6/6 [==============================] - 0s 933us/sample - loss: 0.0415 - mean_absolute_error: 0.1789\n",
"6/6 [==============================] - 0s 892us/sample - loss: 0.0743 - mean_absolute_error: 0.2625\n",
"6/6 [==============================] - 0s 894us/sample - loss: 0.0607 - mean_absolute_error: 0.2246\n",
"6/6 [==============================] - 0s 906us/sample - loss: 0.0321 - mean_absolute_error: 0.1591\n",
"6/6 [==============================] - 0s 912us/sample - loss: 0.0234 - mean_absolute_error: 0.1234\n",
"6/6 [==============================] - 0s 862us/sample - loss: 0.0686 - mean_absolute_error: 0.2183\n",
"Training on tangent\n",
"7/7 [==============================] - 0s 784us/sample - loss: 0.0683 - mean_absolute_error: 0.2378\n",
"7/7 [==============================] - 0s 870us/sample - loss: 0.0739 - mean_absolute_error: 0.2548\n",
"7/7 [==============================] - 0s 796us/sample - loss: 0.0588 - mean_absolute_error: 0.1967\n",
"6/6 [==============================] - 0s 897us/sample - loss: 0.1128 - mean_absolute_error: 0.2793\n",
"6/6 [==============================] - 0s 987us/sample - loss: 0.0604 - mean_absolute_error: 0.2233\n",
"6/6 [==============================] - 0s 275us/sample - loss: 0.1029 - mean_absolute_error: 0.2946\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0799 - mean_absolute_error: 0.2140\n",
"6/6 [==============================] - 0s 244us/sample - loss: 0.0361 - mean_absolute_error: 0.1595\n",
"6/6 [==============================] - 0s 404us/sample - loss: 0.0291 - mean_absolute_error: 0.1317\n",
"6/6 [==============================] - 0s 293us/sample - loss: 0.0160 - mean_absolute_error: 0.1152\n",
"Trial 4\n",
"Training on correlation\n",
"7/7 [==============================] - 0s 225us/sample - loss: 0.1204 - mean_absolute_error: 0.2923\n",
"7/7 [==============================] - 0s 750us/sample - loss: 0.0650 - mean_absolute_error: 0.2194\n",
"7/7 [==============================] - 0s 804us/sample - loss: 0.0174 - mean_absolute_error: 0.1125\n",
"6/6 [==============================] - 0s 927us/sample - loss: 0.0272 - mean_absolute_error: 0.1321\n",
"6/6 [==============================] - 0s 288us/sample - loss: 0.0299 - mean_absolute_error: 0.1379\n",
"6/6 [==============================] - 0s 923us/sample - loss: 0.0836 - mean_absolute_error: 0.1900\n",
"6/6 [==============================] - 0s 955us/sample - loss: 0.0769 - mean_absolute_error: 0.2227\n",
"6/6 [==============================] - 0s 262us/sample - loss: 0.0485 - mean_absolute_error: 0.2060\n",
"6/6 [==============================] - 0s 869us/sample - loss: 0.0422 - mean_absolute_error: 0.1765\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0628 - mean_absolute_error: 0.1743\n",
"Training on partial correlation\n",
"7/7 [==============================] - 0s 741us/sample - loss: 0.0609 - mean_absolute_error: 0.2071\n",
"7/7 [==============================] - 0s 263us/sample - loss: 0.0746 - mean_absolute_error: 0.2587\n",
"7/7 [==============================] - 0s 735us/sample - loss: 0.0500 - mean_absolute_error: 0.1986\n",
"6/6 [==============================] - 0s 902us/sample - loss: 0.0572 - mean_absolute_error: 0.2031\n",
"6/6 [==============================] - 0s 834us/sample - loss: 0.0385 - mean_absolute_error: 0.1600\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.1265 - mean_absolute_error: 0.2646\n",
"6/6 [==============================] - 0s 306us/sample - loss: 0.0753 - mean_absolute_error: 0.2264\n",
"6/6 [==============================] - 0s 988us/sample - loss: 0.0160 - mean_absolute_error: 0.1005\n",
"6/6 [==============================] - 0s 563us/sample - loss: 0.0601 - mean_absolute_error: 0.2056\n",
"6/6 [==============================] - 0s 862us/sample - loss: 0.0522 - mean_absolute_error: 0.1890\n",
"Training on tangent\n",
"7/7 [==============================] - 0s 276us/sample - loss: 0.0480 - mean_absolute_error: 0.1842\n",
"7/7 [==============================] - 0s 362us/sample - loss: 0.0456 - mean_absolute_error: 0.1633\n",
"7/7 [==============================] - 0s 677us/sample - loss: 0.0541 - mean_absolute_error: 0.2021\n",
"6/6 [==============================] - 0s 318us/sample - loss: 0.0171 - mean_absolute_error: 0.1197\n",
"6/6 [==============================] - 0s 679us/sample - loss: 0.1223 - mean_absolute_error: 0.2916\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.1706 - mean_absolute_error: 0.3696\n",
"6/6 [==============================] - 0s 988us/sample - loss: 0.0442 - mean_absolute_error: 0.1880\n",
"6/6 [==============================] - 0s 997us/sample - loss: 0.0425 - mean_absolute_error: 0.1914\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0802 - mean_absolute_error: 0.2344\n",
"6/6 [==============================] - 0s 232us/sample - loss: 0.0398 - mean_absolute_error: 0.1740\n",
"Trial 5\n",
"Training on correlation\n",
"7/7 [==============================] - 0s 981us/sample - loss: 0.1415 - mean_absolute_error: 0.2885\n",
"7/7 [==============================] - 0s 864us/sample - loss: 0.0434 - mean_absolute_error: 0.1790\n",
"7/7 [==============================] - 0s 327us/sample - loss: 0.0764 - mean_absolute_error: 0.2321\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0347 - mean_absolute_error: 0.1629\n",
"6/6 [==============================] - 0s 284us/sample - loss: 0.0768 - mean_absolute_error: 0.2075\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0673 - mean_absolute_error: 0.2276\n",
"6/6 [==============================] - 0s 279us/sample - loss: 0.0300 - mean_absolute_error: 0.1449\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0605 - mean_absolute_error: 0.1756\n",
"6/6 [==============================] - 0s 255us/sample - loss: 0.0710 - mean_absolute_error: 0.2401\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0468 - mean_absolute_error: 0.1870\n",
"Training on partial correlation\n",
"7/7 [==============================] - 0s 251us/sample - loss: 0.0821 - mean_absolute_error: 0.2667\n",
"7/7 [==============================] - 0s 893us/sample - loss: 0.0741 - mean_absolute_error: 0.2365\n",
"7/7 [==============================] - 0s 416us/sample - loss: 0.0875 - mean_absolute_error: 0.1939\n",
"6/6 [==============================] - 0s 224us/sample - loss: 0.0956 - mean_absolute_error: 0.2819\n",
"6/6 [==============================] - 0s 674us/sample - loss: 0.0955 - mean_absolute_error: 0.2713\n",
"6/6 [==============================] - 0s 290us/sample - loss: 0.0228 - mean_absolute_error: 0.1305\n",
"6/6 [==============================] - 0s 2ms/sample - loss: 0.0119 - mean_absolute_error: 0.0880\n",
"6/6 [==============================] - 0s 906us/sample - loss: 0.0565 - mean_absolute_error: 0.1853\n",
"6/6 [==============================] - 0s 260us/sample - loss: 0.0283 - mean_absolute_error: 0.1247\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0577 - mean_absolute_error: 0.2079\n",
"Training on tangent\n",
"7/7 [==============================] - 0s 267us/sample - loss: 0.0533 - mean_absolute_error: 0.1992\n",
"7/7 [==============================] - 0s 533us/sample - loss: 0.0654 - mean_absolute_error: 0.1836\n",
"7/7 [==============================] - 0s 936us/sample - loss: 0.1033 - mean_absolute_error: 0.2484\n",
"6/6 [==============================] - 0s 572us/sample - loss: 0.0505 - mean_absolute_error: 0.1915\n",
"6/6 [==============================] - 0s 322us/sample - loss: 0.0440 - mean_absolute_error: 0.1897\n",
"6/6 [==============================] - 0s 477us/sample - loss: 0.0793 - mean_absolute_error: 0.2582\n",
"6/6 [==============================] - 0s 917us/sample - loss: 0.0435 - mean_absolute_error: 0.1628\n",
"6/6 [==============================] - 0s 767us/sample - loss: 0.0535 - mean_absolute_error: 0.2175\n",
"6/6 [==============================] - 0s 280us/sample - loss: 0.0275 - mean_absolute_error: 0.1438\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0693 - mean_absolute_error: 0.2387\n",
"Trial 6\n",
"Training on correlation\n",
"7/7 [==============================] - 0s 905us/sample - loss: 0.0713 - mean_absolute_error: 0.2121\n",
"7/7 [==============================] - 0s 716us/sample - loss: 0.0634 - mean_absolute_error: 0.1746\n",
"7/7 [==============================] - 0s 1ms/sample - loss: 0.0748 - mean_absolute_error: 0.2047\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0794 - mean_absolute_error: 0.2150\n",
"6/6 [==============================] - 0s 547us/sample - loss: 0.1270 - mean_absolute_error: 0.3197\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0259 - mean_absolute_error: 0.1341\n",
"6/6 [==============================] - 0s 499us/sample - loss: 0.0428 - mean_absolute_error: 0.1861\n",
"6/6 [==============================] - 0s 517us/sample - loss: 0.0234 - mean_absolute_error: 0.1297\n",
"6/6 [==============================] - 0s 686us/sample - loss: 0.0382 - mean_absolute_error: 0.1699\n",
"6/6 [==============================] - 0s 978us/sample - loss: 0.0785 - mean_absolute_error: 0.2470\n",
"Training on partial correlation\n",
"7/7 [==============================] - 0s 959us/sample - loss: 0.0727 - mean_absolute_error: 0.2159\n",
"7/7 [==============================] - 0s 383us/sample - loss: 0.0526 - mean_absolute_error: 0.2008\n",
"7/7 [==============================] - 0s 847us/sample - loss: 0.0500 - mean_absolute_error: 0.1999\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0530 - mean_absolute_error: 0.1416\n",
"6/6 [==============================] - 0s 2ms/sample - loss: 0.0444 - mean_absolute_error: 0.1871\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.1124 - mean_absolute_error: 0.2961\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.1466 - mean_absolute_error: 0.2916\n",
"6/6 [==============================] - 0s 791us/sample - loss: 0.0546 - mean_absolute_error: 0.2056\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0428 - mean_absolute_error: 0.1833\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0532 - mean_absolute_error: 0.2207\n",
"Training on tangent\n",
"7/7 [==============================] - 0s 945us/sample - loss: 0.0673 - mean_absolute_error: 0.2166\n",
"7/7 [==============================] - 0s 872us/sample - loss: 0.0995 - mean_absolute_error: 0.2476\n",
"7/7 [==============================] - 0s 1ms/sample - loss: 0.0450 - mean_absolute_error: 0.1961\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0191 - mean_absolute_error: 0.1074\n",
"6/6 [==============================] - 0s 884us/sample - loss: 0.0590 - mean_absolute_error: 0.1855\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0501 - mean_absolute_error: 0.1932\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0360 - mean_absolute_error: 0.1685\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0646 - mean_absolute_error: 0.2059\n",
"6/6 [==============================] - 0s 409us/sample - loss: 0.1237 - mean_absolute_error: 0.3014\n",
"6/6 [==============================] - 0s 919us/sample - loss: 0.0691 - mean_absolute_error: 0.2478\n",
"Trial 7\n",
"Training on correlation\n",
"7/7 [==============================] - 0s 810us/sample - loss: 0.0586 - mean_absolute_error: 0.1809\n",
"7/7 [==============================] - 0s 811us/sample - loss: 0.1151 - mean_absolute_error: 0.2772\n",
"7/7 [==============================] - 0s 800us/sample - loss: 0.0574 - mean_absolute_error: 0.2294\n",
"6/6 [==============================] - 0s 926us/sample - loss: 0.0357 - mean_absolute_error: 0.1550\n",
"6/6 [==============================] - 0s 267us/sample - loss: 0.0334 - mean_absolute_error: 0.1803\n",
"6/6 [==============================] - 0s 897us/sample - loss: 0.1120 - mean_absolute_error: 0.2855\n",
"6/6 [==============================] - 0s 870us/sample - loss: 0.0152 - mean_absolute_error: 0.0787\n",
"6/6 [==============================] - 0s 939us/sample - loss: 0.0717 - mean_absolute_error: 0.2230\n",
"6/6 [==============================] - 0s 954us/sample - loss: 0.0525 - mean_absolute_error: 0.2049\n",
"6/6 [==============================] - 0s 939us/sample - loss: 0.0886 - mean_absolute_error: 0.2342\n",
"Training on partial correlation\n",
"7/7 [==============================] - 0s 808us/sample - loss: 0.1724 - mean_absolute_error: 0.3516\n",
"7/7 [==============================] - 0s 837us/sample - loss: 0.0582 - mean_absolute_error: 0.2040\n",
"7/7 [==============================] - 0s 772us/sample - loss: 0.0280 - mean_absolute_error: 0.1036\n",
"6/6 [==============================] - 0s 969us/sample - loss: 0.0762 - mean_absolute_error: 0.2497\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0754 - mean_absolute_error: 0.2592\n",
"6/6 [==============================] - 0s 826us/sample - loss: 0.0360 - mean_absolute_error: 0.1747\n",
"6/6 [==============================] - 0s 925us/sample - loss: 0.0178 - mean_absolute_error: 0.1148\n",
"6/6 [==============================] - 0s 909us/sample - loss: 0.0539 - mean_absolute_error: 0.2069\n",
"6/6 [==============================] - 0s 906us/sample - loss: 0.0537 - mean_absolute_error: 0.2047\n",
"6/6 [==============================] - 0s 285us/sample - loss: 0.0418 - mean_absolute_error: 0.1557\n",
"Training on tangent\n",
"7/7 [==============================] - 0s 843us/sample - loss: 0.0562 - mean_absolute_error: 0.2163\n",
"7/7 [==============================] - 0s 837us/sample - loss: 0.0686 - mean_absolute_error: 0.2267\n",
"7/7 [==============================] - 0s 807us/sample - loss: 0.0737 - mean_absolute_error: 0.2508\n",
"6/6 [==============================] - 0s 288us/sample - loss: 0.0822 - mean_absolute_error: 0.2367\n",
"6/6 [==============================] - 0s 938us/sample - loss: 0.0538 - mean_absolute_error: 0.1928\n",
"6/6 [==============================] - 0s 943us/sample - loss: 0.0615 - mean_absolute_error: 0.1890\n",
"6/6 [==============================] - 0s 257us/sample - loss: 0.0982 - mean_absolute_error: 0.2579\n",
"6/6 [==============================] - 0s 919us/sample - loss: 0.0496 - mean_absolute_error: 0.2069\n",
"6/6 [==============================] - 0s 941us/sample - loss: 0.0385 - mean_absolute_error: 0.1550\n",
"6/6 [==============================] - 0s 933us/sample - loss: 0.0532 - mean_absolute_error: 0.2237\n",
"Trial 8\n",
"Training on correlation\n",
"7/7 [==============================] - 0s 772us/sample - loss: 0.0264 - mean_absolute_error: 0.1244\n",
"7/7 [==============================] - 0s 731us/sample - loss: 0.0594 - mean_absolute_error: 0.2268\n",
"7/7 [==============================] - 0s 855us/sample - loss: 0.0885 - mean_absolute_error: 0.2661\n",
"6/6 [==============================] - 0s 721us/sample - loss: 0.0791 - mean_absolute_error: 0.1958\n",
"6/6 [==============================] - 0s 505us/sample - loss: 0.0550 - mean_absolute_error: 0.1636\n",
"6/6 [==============================] - 0s 343us/sample - loss: 0.0356 - mean_absolute_error: 0.1718\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0685 - mean_absolute_error: 0.2064\n",
"6/6 [==============================] - 0s 880us/sample - loss: 0.0572 - mean_absolute_error: 0.1924\n",
"6/6 [==============================] - 0s 610us/sample - loss: 0.0783 - mean_absolute_error: 0.2085\n",
"6/6 [==============================] - 0s 916us/sample - loss: 0.0664 - mean_absolute_error: 0.2025\n",
"Training on partial correlation\n",
"7/7 [==============================] - 0s 882us/sample - loss: 0.0728 - mean_absolute_error: 0.2641\n",
"7/7 [==============================] - 0s 825us/sample - loss: 0.1274 - mean_absolute_error: 0.2713\n",
"7/7 [==============================] - 0s 815us/sample - loss: 0.0502 - mean_absolute_error: 0.1512\n",
"6/6 [==============================] - 0s 912us/sample - loss: 0.0229 - mean_absolute_error: 0.1356\n",
"6/6 [==============================] - 0s 834us/sample - loss: 0.0431 - mean_absolute_error: 0.1871\n",
"6/6 [==============================] - 0s 888us/sample - loss: 0.0940 - mean_absolute_error: 0.2807\n",
"6/6 [==============================] - 0s 905us/sample - loss: 0.0362 - mean_absolute_error: 0.1667\n",
"6/6 [==============================] - 0s 926us/sample - loss: 0.0121 - mean_absolute_error: 0.0806\n",
"6/6 [==============================] - 0s 895us/sample - loss: 0.0706 - mean_absolute_error: 0.2234\n",
"6/6 [==============================] - 0s 938us/sample - loss: 0.0610 - mean_absolute_error: 0.2300\n",
"Training on tangent\n",
"7/7 [==============================] - 0s 298us/sample - loss: 0.1504 - mean_absolute_error: 0.2743\n",
"7/7 [==============================] - 0s 248us/sample - loss: 0.0381 - mean_absolute_error: 0.1486\n",
"7/7 [==============================] - 0s 255us/sample - loss: 0.0737 - mean_absolute_error: 0.2052\n",
"6/6 [==============================] - 0s 926us/sample - loss: 0.0322 - mean_absolute_error: 0.1492\n",
"6/6 [==============================] - 0s 887us/sample - loss: 0.0626 - mean_absolute_error: 0.2235\n",
"6/6 [==============================] - 0s 941us/sample - loss: 0.0486 - mean_absolute_error: 0.1727\n",
"6/6 [==============================] - 0s 290us/sample - loss: 0.0663 - mean_absolute_error: 0.2332\n",
"6/6 [==============================] - 0s 901us/sample - loss: 0.0659 - mean_absolute_error: 0.2235\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0487 - mean_absolute_error: 0.2109\n",
"6/6 [==============================] - 0s 963us/sample - loss: 0.0609 - mean_absolute_error: 0.2249\n",
"Trial 9\n",
"Training on correlation\n",
"7/7 [==============================] - 0s 966us/sample - loss: 0.0525 - mean_absolute_error: 0.1725\n",
"7/7 [==============================] - 0s 823us/sample - loss: 0.0523 - mean_absolute_error: 0.1648\n",
"7/7 [==============================] - 0s 797us/sample - loss: 0.0355 - mean_absolute_error: 0.1643\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0191 - mean_absolute_error: 0.1130\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0453 - mean_absolute_error: 0.2010\n",
"6/6 [==============================] - 0s 373us/sample - loss: 0.0842 - mean_absolute_error: 0.2683\n",
"6/6 [==============================] - 0s 549us/sample - loss: 0.0745 - mean_absolute_error: 0.2271\n",
"6/6 [==============================] - 0s 940us/sample - loss: 0.0591 - mean_absolute_error: 0.2035\n",
"6/6 [==============================] - 0s 932us/sample - loss: 0.0873 - mean_absolute_error: 0.2229\n",
"6/6 [==============================] - 0s 944us/sample - loss: 0.0774 - mean_absolute_error: 0.2364\n",
"Training on partial correlation\n",
"7/7 [==============================] - 0s 899us/sample - loss: 0.0321 - mean_absolute_error: 0.1437\n",
"7/7 [==============================] - 0s 848us/sample - loss: 0.0598 - mean_absolute_error: 0.2197\n",
"7/7 [==============================] - 0s 811us/sample - loss: 0.0434 - mean_absolute_error: 0.1825\n",
"6/6 [==============================] - 0s 792us/sample - loss: 0.0642 - mean_absolute_error: 0.1745\n",
"6/6 [==============================] - 0s 932us/sample - loss: 0.1254 - mean_absolute_error: 0.2942\n",
"6/6 [==============================] - 0s 944us/sample - loss: 0.0487 - mean_absolute_error: 0.1717\n",
"6/6 [==============================] - 0s 904us/sample - loss: 0.1015 - mean_absolute_error: 0.2779\n",
"6/6 [==============================] - 0s 931us/sample - loss: 0.0362 - mean_absolute_error: 0.1620\n",
"6/6 [==============================] - 0s 936us/sample - loss: 0.0303 - mean_absolute_error: 0.1666\n",
"6/6 [==============================] - 0s 913us/sample - loss: 0.0694 - mean_absolute_error: 0.2388\n",
"Training on tangent\n",
"7/7 [==============================] - 0s 261us/sample - loss: 0.0426 - mean_absolute_error: 0.1846\n",
"7/7 [==============================] - 0s 766us/sample - loss: 0.0716 - mean_absolute_error: 0.2226\n",
"7/7 [==============================] - 0s 367us/sample - loss: 0.0968 - mean_absolute_error: 0.2530\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0288 - mean_absolute_error: 0.1373\n",
"6/6 [==============================] - 0s 937us/sample - loss: 0.0577 - mean_absolute_error: 0.2095\n",
"6/6 [==============================] - 0s 377us/sample - loss: 0.0716 - mean_absolute_error: 0.1919\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.1049 - mean_absolute_error: 0.2761\n",
"6/6 [==============================] - 0s 289us/sample - loss: 0.0429 - mean_absolute_error: 0.1990\n",
"6/6 [==============================] - 0s 1ms/sample - loss: 0.0726 - mean_absolute_error: 0.2170\n",
"6/6 [==============================] - 0s 267us/sample - loss: 0.0267 - mean_absolute_error: 0.1470\n"
]
}
],
"source": [
"model = None\n",
"\n",
"scores = pd.DataFrame()\n",
"\n",
"trials = 10\n",
"\n",
"kinds = ['correlation', 'partial correlation', 'tangent']\n",
"\n",
"for trial in range(trials):\n",
" print(f'Trial {trial}')\n",
" kfold = KFold(n_splits=10, shuffle=True)\n",
" \n",
" for kind in kinds:\n",
" print('Training on {}'.format(kind))\n",
" \n",
" X, y, sc_X, sc_y = get_dataset('data/1585070166511_63subjects_correlation_partial-correlation_tangent.pkl',\n",
" 'data/HBN_R7_Pheno.csv',\n",
" 39,\n",
" kind=kind)\n",
" for split, (train, test) in enumerate(kfold.split(X, y)):\n",
" tf.keras.backend.clear_session()\n",
"\n",
" if model:\n",
" model.reset_states()\n",
" del model\n",
"\n",
" model = get_model()\n",
"\n",
" history = model.fit(X[train], y[train],\n",
" batch_size=16,\n",
" epochs=256,\n",
" verbose=0,\n",
" validation_split=0.2)\n",
"\n",
" score = model.evaluate(X[test], y[test])\n",
" scores = scores.append([[kind,\n",
" trial,\n",
" split,\n",
" history.history['loss'][-1],\n",
" history.history['val_loss'][-1],\n",
" history.history['mean_absolute_error'][-1],\n",
" history.history['val_mean_absolute_error'][-1],\n",
" score[0],\n",
" score[1]]])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"scores.columns = ['kind',\n",
" 'trial',\n",
" 'split',\n",
" 'MSE',\n",
" 'val_MSE',\n",
" 'MAE',\n",
" 'val_MAE',\n",
" 'dev_MSE',\n",
" 'dev_MAE']\n",
"scores.to_csv('./scores.csv')"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>kind</th>\n",
" <th>trial</th>\n",
" <th>split</th>\n",
" <th>MSE</th>\n",
" <th>val_MSE</th>\n",
" <th>MAE</th>\n",
" <th>val_MAE</th>\n",
" <th>dev_MSE</th>\n",
" <th>dev_MAE</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>correlation</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.018553</td>\n",
" <td>0.074741</td>\n",
" <td>0.103415</td>\n",
" <td>0.225603</td>\n",
" <td>0.074033</td>\n",
" <td>0.232597</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>correlation</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0.017185</td>\n",
" <td>0.066889</td>\n",
" <td>0.105498</td>\n",
" <td>0.201655</td>\n",
" <td>0.070227</td>\n",
" <td>0.173453</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>correlation</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>0.015297</td>\n",
" <td>0.047304</td>\n",
" <td>0.093099</td>\n",
" <td>0.169098</td>\n",
" <td>0.107540</td>\n",
" <td>0.300915</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>correlation</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>0.025425</td>\n",
" <td>0.032396</td>\n",
" <td>0.109083</td>\n",
" <td>0.133276</td>\n",
" <td>0.045701</td>\n",
" <td>0.184709</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>correlation</td>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>0.017808</td>\n",
" <td>0.043696</td>\n",
" <td>0.104037</td>\n",
" <td>0.161209</td>\n",
" <td>0.011879</td>\n",
" <td>0.098843</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" kind trial split MSE val_MSE MAE val_MAE \\\n",
"0 correlation 0 0 0.018553 0.074741 0.103415 0.225603 \n",
"0 correlation 0 1 0.017185 0.066889 0.105498 0.201655 \n",
"0 correlation 0 2 0.015297 0.047304 0.093099 0.169098 \n",
"0 correlation 0 3 0.025425 0.032396 0.109083 0.133276 \n",
"0 correlation 0 4 0.017808 0.043696 0.104037 0.161209 \n",
"\n",
" dev_MSE dev_MAE \n",
"0 0.074033 0.232597 \n",
"0 0.070227 0.173453 \n",
"0 0.107540 0.300915 \n",
"0 0.045701 0.184709 \n",
"0 0.011879 0.098843 "
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"scores.head()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7f2b68145590>"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1152x576 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, (ax1, ax2) = plt.subplots(figsize=(16, 8), nrows=1, ncols=2)\n",
"sns.distplot(scores[(scores.kind == 'correlation')].dev_MSE, label='Correlation', ax=ax1)\n",
"sns.distplot(scores[(scores.kind == 'partial correlation')].dev_MSE, label='Partial Correlation', ax=ax1)\n",
"sns.distplot(scores[(scores.kind == 'tangent')].dev_MSE, label='Tangent', ax=ax1)\n",
"ax1.legend()\n",
"sns.distplot(scores[(scores.kind == 'correlation')].dev_MAE, label='Correlation', ax=ax2)\n",
"sns.distplot(scores[(scores.kind == 'partial correlation')].dev_MAE, label='Partial Correlation', ax=ax2)\n",
"sns.distplot(scores[(scores.kind == 'tangent')].dev_MAE, label='Tangent', ax=ax2)\n",
"ax2.legend()\n",
"\n",
"\n",
"\n",
"#plt.savefig('./comparison.pdf')"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE - CORRELATION: 0.0631906935106963\n",
"MSE - PARTIAL CORRELATION: 0.060293166567571464\n",
"MSE - TANGENT: 0.06407719370909035\n",
"MAE - CORRELATION: 0.20249531373381616\n",
"MAE - PARTIAL CORRELATION: 0.20049512945115566\n",
"MAE - TANGENT: 0.20988476172089576\n"
]
}
],
"source": [
"print('MSE - CORRELATION: {}'.format(scores[(scores.kind == 'correlation')].dev_MSE.mean()))\n",
"print('MSE - PARTIAL CORRELATION: {}'.format(scores[(scores.kind == 'partial correlation')].dev_MSE.mean()))\n",
"print('MSE - TANGENT: {}'.format(scores[(scores.kind == 'tangent')].dev_MSE.mean()))\n",
"print('MAE - CORRELATION: {}'.format(scores[(scores.kind == 'correlation')].dev_MAE.mean()))\n",
"print('MAE - PARTIAL CORRELATION: {}'.format(scores[(scores.kind == 'partial correlation')].dev_MAE.mean()))\n",
"print('MAE - TANGENT: {}'.format(scores[(scores.kind == 'tangent')].dev_MAE.mean()))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Source</th>\n",
" <th>SS</th>\n",
" <th>DF</th>\n",
" <th>MS</th>\n",
" <th>F</th>\n",
" <th>p-unc</th>\n",
" <th>np2</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>kind</td>\n",
" <td>0.001</td>\n",
" <td>2</td>\n",
" <td>0.000</td>\n",
" <td>0.471</td>\n",
" <td>0.625003</td>\n",
" <td>0.003</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Within</td>\n",
" <td>0.247</td>\n",
" <td>297</td>\n",
" <td>0.001</td>\n",
" <td>-</td>\n",
" <td>-</td>\n",
" <td>-</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Source SS DF MS F p-unc np2\n",
"0 kind 0.001 2 0.000 0.471 0.625003 0.003\n",
"1 Within 0.247 297 0.001 - - -"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"aov = pg.anova(dv='dev_MSE',\n",
" between=['kind'],\n",
" data=scores,\n",
" detailed=True,\n",
" ss_type=2)\n",
"aov.to_csv('./anova.csv')\n",
"aov"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Contrast</th>\n",
" <th>A</th>\n",
" <th>B</th>\n",
" <th>mean(A)</th>\n",
" <th>std(A)</th>\n",
" <th>mean(B)</th>\n",
" <th>std(B)</th>\n",
" <th>Paired</th>\n",
" <th>Parametric</th>\n",
" <th>T</th>\n",
" <th>dof</th>\n",
" <th>Tail</th>\n",
" <th>p-unc</th>\n",
" <th>p-corr</th>\n",
" <th>p-adjust</th>\n",
" <th>BF10</th>\n",
" <th>cohen</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>kind</td>\n",
" <td>correlation</td>\n",
" <td>partial correlation</td>\n",
" <td>0.063</td>\n",
" <td>0.027</td>\n",
" <td>0.060</td>\n",
" <td>0.030</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>0.717</td>\n",
" <td>198.0</td>\n",
" <td>two-sided</td>\n",
" <td>0.474118</td>\n",
" <td>0.711177</td>\n",
" <td>fdr_bh</td>\n",
" <td>0.196</td>\n",
" <td>0.101</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>kind</td>\n",
" <td>correlation</td>\n",
" <td>tangent</td>\n",
" <td>0.063</td>\n",
" <td>0.027</td>\n",
" <td>0.064</td>\n",
" <td>0.029</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>-0.221</td>\n",
" <td>198.0</td>\n",
" <td>two-sided</td>\n",
" <td>0.825269</td>\n",
" <td>0.825269</td>\n",
" <td>fdr_bh</td>\n",
" <td>0.157</td>\n",
" <td>-0.031</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>kind</td>\n",
" <td>partial correlation</td>\n",
" <td>tangent</td>\n",
" <td>0.060</td>\n",
" <td>0.030</td>\n",
" <td>0.064</td>\n",
" <td>0.029</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>-0.904</td>\n",
" <td>198.0</td>\n",
" <td>two-sided</td>\n",
" <td>0.367035</td>\n",
" <td>0.711177</td>\n",
" <td>fdr_bh</td>\n",
" <td>0.226</td>\n",
" <td>-0.128</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Contrast A B mean(A) std(A) \\\n",
"0 kind correlation partial correlation 0.063 0.027 \n",
"1 kind correlation tangent 0.063 0.027 \n",
"2 kind partial correlation tangent 0.060 0.030 \n",
"\n",
" mean(B) std(B) Paired Parametric T dof Tail p-unc \\\n",
"0 0.060 0.030 False True 0.717 198.0 two-sided 0.474118 \n",
"1 0.064 0.029 False True -0.221 198.0 two-sided 0.825269 \n",
"2 0.064 0.029 False True -0.904 198.0 two-sided 0.367035 \n",
"\n",
" p-corr p-adjust BF10 cohen \n",
"0 0.711177 fdr_bh 0.196 0.101 \n",
"1 0.825269 fdr_bh 0.157 -0.031 \n",
"2 0.711177 fdr_bh 0.226 -0.128 "
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"posthocs = pg.pairwise_ttests(dv='dev_MSE',\n",
" between=['kind'],\n",
" data=scores,\n",
" alpha=.5,\n",
" padjust='fdr_bh',\n",
" effsize='cohen',\n",
" return_desc=True)\n",
"posthocs.to_csv('./ttests.csv')\n",
"posthocs"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7f2ba01f5e50>"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"predictions = model.predict(X[test])\n",
"plt.figure()\n",
"plt.plot(y[test], label='truth')\n",
"plt.plot(predictions, label='predictions')\n",
"plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.7.6 64-bit ('base': conda)",
"language": "python",
"name": "python37664bitbaseconda7a3bdc022e3d40b690de10a054bae57d"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment