Created
October 17, 2019 21:52
-
-
Save jhamman/0e52d9bb29f679e26b0878c58bb813d2 to your computer and use it in GitHub Desktop.
apply_along_axis (not working)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/jhamman/miniconda3/envs/xarray-ml/lib/python3.7/site-packages/distributed/utils.py:138: RuntimeWarning: Couldn't detect a suitable IP address for reaching '8.8.8.8', defaulting to '127.0.0.1': [Errno 51] Network is unreachable\n", | |
" RuntimeWarning,\n" | |
] | |
} | |
], | |
"source": [ | |
"import numpy as np\n", | |
"import xarray as xr\n", | |
"\n", | |
"\n", | |
"def diff_mean(X, y):\n", | |
" ''' a function that only works on 1d arrays that are different lengths'''\n", | |
" assert X.ndim == 1, X.ndim\n", | |
" assert y.ndim == 1, y.ndim\n", | |
" assert len(X) != len(y), X\n", | |
" return X.mean() - y.mean()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<xarray.DataArray (t: 10, x: 4, y: 5)>\n", | |
"dask.array<xarray-<this-array>, shape=(10, 4, 5), dtype=float64, chunksize=(10, 2, 2)>\n", | |
"Dimensions without coordinates: t, x, y\n", | |
"<xarray.DataArray (t: 6, x: 4, y: 5)>\n", | |
"dask.array<xarray-<this-array>, shape=(6, 4, 5), dtype=float64, chunksize=(6, 2, 2)>\n", | |
"Dimensions without coordinates: t, x, y\n" | |
] | |
} | |
], | |
"source": [ | |
"# sample data: two arrays, different dimensions on the 0th axis (t dimension)\n", | |
"\n", | |
"X = np.random.random((10, 4, 5))\n", | |
"y = np.random.random((6, 4, 5))\n", | |
"\n", | |
"Xda = xr.DataArray(X, dims=('t', 'x', 'y')).chunk({'t': -1, 'x': 2, 'y': 2})\n", | |
"yda = xr.DataArray(y, dims=('t', 'x', 'y')).chunk({'t': -1, 'x': 2, 'y': 2})\n", | |
"\n", | |
"print(Xda)\n", | |
"print(yda)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(array([[ 0.03026252, -0.22892136, -0.24585945, -0.35186523, -0.10500887],\n", | |
" [-0.02123284, 0.03543719, -0.06187094, -0.00041126, -0.3472726 ],\n", | |
" [-0.05295272, -0.01558743, 0.02846658, 0.12608249, 0.07336431],\n", | |
" [ 0.06709901, 0.1152109 , -0.04191652, 0.00946914, -0.12678981]]),\n", | |
" (4, 5))" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# This is the manual way to accomplish what we're going for\n", | |
"out = np.empty((4, 5))\n", | |
"for index in np.ndindex(out.shape):\n", | |
" i = (Ellipsis, ) + index\n", | |
" out[index] = diff_mean(X[i], y[i])\n", | |
"out, out.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"ename": "ValueError", | |
"evalue": "arguments without labels along dimension 't' cannot be aligned because they have different dimension sizes: {10, 6}", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-4-e90cf6fba482>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mdask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"parallelized\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0moutput_dtypes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0minput_core_dims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m't'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m't'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m )\n", | |
"\u001b[0;32m~/miniconda3/envs/xarray-ml/lib/python3.7/site-packages/xarray/core/computation.py\u001b[0m in \u001b[0;36mapply_ufunc\u001b[0;34m(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, *args)\u001b[0m\n\u001b[1;32m 1042\u001b[0m \u001b[0mjoin\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1043\u001b[0m \u001b[0mexclude_dims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mexclude_dims\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1044\u001b[0;31m \u001b[0mkeep_attrs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkeep_attrs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1045\u001b[0m )\n\u001b[1;32m 1046\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mVariable\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ma\u001b[0m \u001b[0;32min\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/envs/xarray-ml/lib/python3.7/site-packages/xarray/core/computation.py\u001b[0m in \u001b[0;36mapply_dataarray_vfunc\u001b[0;34m(func, signature, join, exclude_dims, keep_attrs, *args)\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 223\u001b[0m args = deep_align(\n\u001b[0;32m--> 224\u001b[0;31m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjoin\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexclude\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mexclude_dims\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mraise_on_invalid\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 225\u001b[0m )\n\u001b[1;32m 226\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/envs/xarray-ml/lib/python3.7/site-packages/xarray/core/alignment.py\u001b[0m in \u001b[0;36mdeep_align\u001b[0;34m(objects, join, copy, indexes, exclude, raise_on_invalid, fill_value)\u001b[0m\n\u001b[1;32m 403\u001b[0m \u001b[0mindexes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mindexes\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 404\u001b[0m \u001b[0mexclude\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mexclude\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 405\u001b[0;31m \u001b[0mfill_value\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfill_value\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 406\u001b[0m )\n\u001b[1;32m 407\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/envs/xarray-ml/lib/python3.7/site-packages/xarray/core/alignment.py\u001b[0m in \u001b[0;36malign\u001b[0;34m(join, copy, indexes, exclude, fill_value, *objects)\u001b[0m\n\u001b[1;32m 321\u001b[0m \u001b[0;34m\"arguments without labels along dimension %r cannot be \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 322\u001b[0m \u001b[0;34m\"aligned because they have different dimension sizes: %r\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 323\u001b[0;31m \u001b[0;34m%\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msizes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 324\u001b[0m )\n\u001b[1;32m 325\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mValueError\u001b[0m: arguments without labels along dimension 't' cannot be aligned because they have different dimension sizes: {10, 6}" | |
] | |
} | |
], | |
"source": [ | |
"# ideally we could use xarray.apply_ufunc\n", | |
"# xarray is enforcing that the dim sizes match along the core dimension\n", | |
"\n", | |
"out = xr.apply_ufunc(\n", | |
" diff_mean,\n", | |
" Xda,\n", | |
" yda,\n", | |
" vectorize=True,\n", | |
" dask=\"parallelized\",\n", | |
" output_dtypes=[np.float],\n", | |
" input_core_dims=[['t'], ['t']],\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"ename": "AssertionError", | |
"evalue": "3", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-5-7641527cd056>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# It would be nice if wee can use numpy's apply_along_axis but it only iterates over the first argument\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0maxis\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_along_axis\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdiff_mean\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m<__array_function__ internals>\u001b[0m in \u001b[0;36mapply_along_axis\u001b[0;34m(*args, **kwargs)\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/envs/xarray-ml/lib/python3.7/site-packages/numpy/lib/shape_base.py\u001b[0m in \u001b[0;36mapply_along_axis\u001b[0;34m(func1d, axis, arr, *args, **kwargs)\u001b[0m\n\u001b[1;32m 377\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mStopIteration\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 378\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Cannot apply_along_axis when any iteration dimensions are 0'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 379\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0masanyarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc1d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minarr_view\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mind0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 380\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 381\u001b[0m \u001b[0;31m# build a buffer for storing evaluations of func1d.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-1-9608598b783a>\u001b[0m in \u001b[0;36mdiff_mean\u001b[0;34m(X, y)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;34m''' a function that only works on 1d arrays that are different lengths'''\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mAssertionError\u001b[0m: 3" | |
] | |
} | |
], | |
"source": [ | |
"# It would be nice if wee can use numpy's apply_along_axis but it only iterates over the first argument\n", | |
"axis = 0\n", | |
"np.apply_along_axis(diff_mean, axis, X, y)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment