Created
March 1, 2017 15:55
-
-
Save arogozhnikov/b6efe2c6c6abf66512a639e99fd253e8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Astropy.Table bug with train_test_split" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy\n", | |
"from astropy.table import Table\n", | |
"\n", | |
"n_samples = 8\n", | |
"X = Table(dict(x1=numpy.arange(1, 1+n_samples), x2=numpy.arange(2, 2 + n_samples)))\n", | |
"y = numpy.arange(n_samples)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<Table length=8>\n", | |
"<table id=\"table4555606032\" class=\"table-striped table-bordered table-condensed\">\n", | |
"<thead><tr><th>x2</th><th>x1</th></tr></thead>\n", | |
"<thead><tr><th>int64</th><th>int64</th></tr></thead>\n", | |
"<tr><td>2</td><td>1</td></tr>\n", | |
"<tr><td>3</td><td>2</td></tr>\n", | |
"<tr><td>4</td><td>3</td></tr>\n", | |
"<tr><td>5</td><td>4</td></tr>\n", | |
"<tr><td>6</td><td>5</td></tr>\n", | |
"<tr><td>7</td><td>6</td></tr>\n", | |
"<tr><td>8</td><td>7</td></tr>\n", | |
"<tr><td>9</td><td>8</td></tr>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"<Table length=8>\n", | |
" x2 x1 \n", | |
"int64 int64\n", | |
"----- -----\n", | |
" 2 1\n", | |
" 3 2\n", | |
" 4 3\n", | |
" 5 4\n", | |
" 6 5\n", | |
" 7 6\n", | |
" 8 7\n", | |
" 9 8" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# sklearn is the widespread libraru for machine learning\n", | |
"# train_test_split is frequently used (very frequently): it allows splitting several arrays at once\n", | |
"from sklearn.model_selection import train_test_split" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Problem\n", | |
"train_test_split splits correctly numpy.arrays, scipy.sparse arrays, lists, pandas.DataFrames, but provides incorrect result for astropy.Table:\n", | |
"\n", | |
"selected subset is not a table, but a list of rows. \n", | |
"\n", | |
"\n", | |
"It's ok to fail with error (not supported, ...), but not to produce incorrect result." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<Row index=1>\n", | |
" x2 x1 \n", | |
" int64 int64\n", | |
" ----- -----\n", | |
" 3 2, <Row index=5>\n", | |
" x2 x1 \n", | |
" int64 int64\n", | |
" ----- -----\n", | |
" 7 6]" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_test" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Expected result\n", | |
"a table with same rows" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"train_indices, test_indices = train_test_split(numpy.arange(len(y)), random_state=42)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<Table length=2>\n", | |
"<table id=\"table4591256272\" class=\"table-striped table-bordered table-condensed\">\n", | |
"<thead><tr><th>x2</th><th>x1</th></tr></thead>\n", | |
"<thead><tr><th>int64</th><th>int64</th></tr></thead>\n", | |
"<tr><td>3</td><td>2</td></tr>\n", | |
"<tr><td>7</td><td>6</td></tr>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"<Table length=2>\n", | |
" x2 x1 \n", | |
"int64 int64\n", | |
"----- -----\n", | |
" 3 2\n", | |
" 7 6" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X[test_indices]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Reason: safe_indexing\n", | |
"\n", | |
"I've digged into implementation: \n", | |
"- sklearn's indexing uses duck typing, relies on `safe_indexing` function\n", | |
"- since there is no sign that smart indexing is supported, it drops to lists implementation (last line). \n", | |
"\n", | |
"```python\n", | |
"def safe_indexing(X, indices):\n", | |
" \"\"\"Return items or rows from X using indices.\n", | |
"\n", | |
" Allows simple indexing of lists or arrays.\n", | |
"\n", | |
" Parameters\n", | |
" ----------\n", | |
" X : array-like, sparse-matrix, list.\n", | |
" Data from which to sample rows or items.\n", | |
"\n", | |
" indices : array-like, list\n", | |
" Indices according to which X will be subsampled.\n", | |
" \"\"\"\n", | |
" if hasattr(X, \"iloc\"):\n", | |
" # Pandas Dataframes and Series\n", | |
" try:\n", | |
" return X.iloc[indices]\n", | |
" except ValueError:\n", | |
" # Cython typed memoryviews internally used in pandas do not support\n", | |
" # readonly buffers.\n", | |
" warnings.warn(\"Copying input dataframe for slicing.\",\n", | |
" DataConversionWarning)\n", | |
" return X.copy().iloc[indices]\n", | |
" elif hasattr(X, \"shape\"):\n", | |
" if hasattr(X, 'take') and (hasattr(indices, 'dtype') and\n", | |
" indices.dtype.kind == 'i'):\n", | |
" # This is often substantially faster than X[indices]\n", | |
" return X.take(indices, axis=0)\n", | |
" else:\n", | |
" return X[indices]\n", | |
" else:\n", | |
" return [X[idx] for idx in indices]\n", | |
"```\n", | |
"\n", | |
"\n", | |
"### Question: is there a reliable duck-typing way in astropy to detect presence of smart indexing to use it in sklearn?" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.11" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment