Skip to content

Instantly share code, notes, and snippets.

@arogozhnikov
Created March 1, 2017 15:55
Show Gist options
  • Save arogozhnikov/b6efe2c6c6abf66512a639e99fd253e8 to your computer and use it in GitHub Desktop.
Save arogozhnikov/b6efe2c6c6abf66512a639e99fd253e8 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Astropy.Table bug with train_test_split"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy\n",
"from astropy.table import Table\n",
"\n",
"n_samples = 8\n",
"X = Table(dict(x1=numpy.arange(1, 1+n_samples), x2=numpy.arange(2, 2 + n_samples)))\n",
"y = numpy.arange(n_samples)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/html": [
"<Table length=8>\n",
"<table id=\"table4555606032\" class=\"table-striped table-bordered table-condensed\">\n",
"<thead><tr><th>x2</th><th>x1</th></tr></thead>\n",
"<thead><tr><th>int64</th><th>int64</th></tr></thead>\n",
"<tr><td>2</td><td>1</td></tr>\n",
"<tr><td>3</td><td>2</td></tr>\n",
"<tr><td>4</td><td>3</td></tr>\n",
"<tr><td>5</td><td>4</td></tr>\n",
"<tr><td>6</td><td>5</td></tr>\n",
"<tr><td>7</td><td>6</td></tr>\n",
"<tr><td>8</td><td>7</td></tr>\n",
"<tr><td>9</td><td>8</td></tr>\n",
"</table>"
],
"text/plain": [
"<Table length=8>\n",
" x2 x1 \n",
"int64 int64\n",
"----- -----\n",
" 2 1\n",
" 3 2\n",
" 4 3\n",
" 5 4\n",
" 6 5\n",
" 7 6\n",
" 8 7\n",
" 9 8"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# sklearn is the widespread libraru for machine learning\n",
"# train_test_split is frequently used (very frequently): it allows splitting several arrays at once\n",
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Problem\n",
"train_test_split splits correctly numpy.arrays, scipy.sparse arrays, lists, pandas.DataFrames, but provides incorrect result for astropy.Table:\n",
"\n",
"selected subset is not a table, but a list of rows. \n",
"\n",
"\n",
"It's ok to fail with error (not supported, ...), but not to produce incorrect result."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"[<Row index=1>\n",
" x2 x1 \n",
" int64 int64\n",
" ----- -----\n",
" 3 2, <Row index=5>\n",
" x2 x1 \n",
" int64 int64\n",
" ----- -----\n",
" 7 6]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_test"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Expected result\n",
"a table with same rows"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"train_indices, test_indices = train_test_split(numpy.arange(len(y)), random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/html": [
"&lt;Table length=2&gt;\n",
"<table id=\"table4591256272\" class=\"table-striped table-bordered table-condensed\">\n",
"<thead><tr><th>x2</th><th>x1</th></tr></thead>\n",
"<thead><tr><th>int64</th><th>int64</th></tr></thead>\n",
"<tr><td>3</td><td>2</td></tr>\n",
"<tr><td>7</td><td>6</td></tr>\n",
"</table>"
],
"text/plain": [
"<Table length=2>\n",
" x2 x1 \n",
"int64 int64\n",
"----- -----\n",
" 3 2\n",
" 7 6"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X[test_indices]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Reason: safe_indexing\n",
"\n",
"I've digged into implementation: \n",
"- sklearn's indexing uses duck typing, relies on `safe_indexing` function\n",
"- since there is no sign that smart indexing is supported, it drops to lists implementation (last line). \n",
"\n",
"```python\n",
"def safe_indexing(X, indices):\n",
" \"\"\"Return items or rows from X using indices.\n",
"\n",
" Allows simple indexing of lists or arrays.\n",
"\n",
" Parameters\n",
" ----------\n",
" X : array-like, sparse-matrix, list.\n",
" Data from which to sample rows or items.\n",
"\n",
" indices : array-like, list\n",
" Indices according to which X will be subsampled.\n",
" \"\"\"\n",
" if hasattr(X, \"iloc\"):\n",
" # Pandas Dataframes and Series\n",
" try:\n",
" return X.iloc[indices]\n",
" except ValueError:\n",
" # Cython typed memoryviews internally used in pandas do not support\n",
" # readonly buffers.\n",
" warnings.warn(\"Copying input dataframe for slicing.\",\n",
" DataConversionWarning)\n",
" return X.copy().iloc[indices]\n",
" elif hasattr(X, \"shape\"):\n",
" if hasattr(X, 'take') and (hasattr(indices, 'dtype') and\n",
" indices.dtype.kind == 'i'):\n",
" # This is often substantially faster than X[indices]\n",
" return X.take(indices, axis=0)\n",
" else:\n",
" return X[indices]\n",
" else:\n",
" return [X[idx] for idx in indices]\n",
"```\n",
"\n",
"\n",
"### Question: is there a reliable duck-typing way in astropy to detect presence of smart indexing to use it in sklearn?"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment