Skip to content

Instantly share code, notes, and snippets.

@jhamman
Created October 17, 2019 21:52
Show Gist options
  • Save jhamman/0e52d9bb29f679e26b0878c58bb813d2 to your computer and use it in GitHub Desktop.
Save jhamman/0e52d9bb29f679e26b0878c58bb813d2 to your computer and use it in GitHub Desktop.
apply_along_axis (not working)
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jhamman/miniconda3/envs/xarray-ml/lib/python3.7/site-packages/distributed/utils.py:138: RuntimeWarning: Couldn't detect a suitable IP address for reaching '8.8.8.8', defaulting to '127.0.0.1': [Errno 51] Network is unreachable\n",
" RuntimeWarning,\n"
]
}
],
"source": [
"import numpy as np\n",
"import xarray as xr\n",
"\n",
"\n",
"def diff_mean(X, y):\n",
" ''' a function that only works on 1d arrays that are different lengths'''\n",
" assert X.ndim == 1, X.ndim\n",
" assert y.ndim == 1, y.ndim\n",
" assert len(X) != len(y), X\n",
" return X.mean() - y.mean()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<xarray.DataArray (t: 10, x: 4, y: 5)>\n",
"dask.array<xarray-<this-array>, shape=(10, 4, 5), dtype=float64, chunksize=(10, 2, 2)>\n",
"Dimensions without coordinates: t, x, y\n",
"<xarray.DataArray (t: 6, x: 4, y: 5)>\n",
"dask.array<xarray-<this-array>, shape=(6, 4, 5), dtype=float64, chunksize=(6, 2, 2)>\n",
"Dimensions without coordinates: t, x, y\n"
]
}
],
"source": [
"# sample data: two arrays, different dimensions on the 0th axis (t dimension)\n",
"\n",
"X = np.random.random((10, 4, 5))\n",
"y = np.random.random((6, 4, 5))\n",
"\n",
"Xda = xr.DataArray(X, dims=('t', 'x', 'y')).chunk({'t': -1, 'x': 2, 'y': 2})\n",
"yda = xr.DataArray(y, dims=('t', 'x', 'y')).chunk({'t': -1, 'x': 2, 'y': 2})\n",
"\n",
"print(Xda)\n",
"print(yda)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([[ 0.03026252, -0.22892136, -0.24585945, -0.35186523, -0.10500887],\n",
" [-0.02123284, 0.03543719, -0.06187094, -0.00041126, -0.3472726 ],\n",
" [-0.05295272, -0.01558743, 0.02846658, 0.12608249, 0.07336431],\n",
" [ 0.06709901, 0.1152109 , -0.04191652, 0.00946914, -0.12678981]]),\n",
" (4, 5))"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# This is the manual way to accomplish what we're going for\n",
"out = np.empty((4, 5))\n",
"for index in np.ndindex(out.shape):\n",
" i = (Ellipsis, ) + index\n",
" out[index] = diff_mean(X[i], y[i])\n",
"out, out.shape"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"ename": "ValueError",
"evalue": "arguments without labels along dimension 't' cannot be aligned because they have different dimension sizes: {10, 6}",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-4-e90cf6fba482>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mdask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"parallelized\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0moutput_dtypes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0minput_core_dims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m't'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m't'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m )\n",
"\u001b[0;32m~/miniconda3/envs/xarray-ml/lib/python3.7/site-packages/xarray/core/computation.py\u001b[0m in \u001b[0;36mapply_ufunc\u001b[0;34m(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, *args)\u001b[0m\n\u001b[1;32m 1042\u001b[0m \u001b[0mjoin\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1043\u001b[0m \u001b[0mexclude_dims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mexclude_dims\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1044\u001b[0;31m \u001b[0mkeep_attrs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkeep_attrs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1045\u001b[0m )\n\u001b[1;32m 1046\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mVariable\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ma\u001b[0m \u001b[0;32min\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/xarray-ml/lib/python3.7/site-packages/xarray/core/computation.py\u001b[0m in \u001b[0;36mapply_dataarray_vfunc\u001b[0;34m(func, signature, join, exclude_dims, keep_attrs, *args)\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 223\u001b[0m args = deep_align(\n\u001b[0;32m--> 224\u001b[0;31m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjoin\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexclude\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mexclude_dims\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mraise_on_invalid\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 225\u001b[0m )\n\u001b[1;32m 226\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/xarray-ml/lib/python3.7/site-packages/xarray/core/alignment.py\u001b[0m in \u001b[0;36mdeep_align\u001b[0;34m(objects, join, copy, indexes, exclude, raise_on_invalid, fill_value)\u001b[0m\n\u001b[1;32m 403\u001b[0m \u001b[0mindexes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mindexes\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 404\u001b[0m \u001b[0mexclude\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mexclude\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 405\u001b[0;31m \u001b[0mfill_value\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfill_value\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 406\u001b[0m )\n\u001b[1;32m 407\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/xarray-ml/lib/python3.7/site-packages/xarray/core/alignment.py\u001b[0m in \u001b[0;36malign\u001b[0;34m(join, copy, indexes, exclude, fill_value, *objects)\u001b[0m\n\u001b[1;32m 321\u001b[0m \u001b[0;34m\"arguments without labels along dimension %r cannot be \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 322\u001b[0m \u001b[0;34m\"aligned because they have different dimension sizes: %r\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 323\u001b[0;31m \u001b[0;34m%\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msizes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 324\u001b[0m )\n\u001b[1;32m 325\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: arguments without labels along dimension 't' cannot be aligned because they have different dimension sizes: {10, 6}"
]
}
],
"source": [
"# ideally we could use xarray.apply_ufunc\n",
"# xarray is enforcing that the dim sizes match along the core dimension\n",
"\n",
"out = xr.apply_ufunc(\n",
" diff_mean,\n",
" Xda,\n",
" yda,\n",
" vectorize=True,\n",
" dask=\"parallelized\",\n",
" output_dtypes=[np.float],\n",
" input_core_dims=[['t'], ['t']],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"ename": "AssertionError",
"evalue": "3",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-5-7641527cd056>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# It would be nice if wee can use numpy's apply_along_axis but it only iterates over the first argument\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0maxis\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_along_axis\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdiff_mean\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<__array_function__ internals>\u001b[0m in \u001b[0;36mapply_along_axis\u001b[0;34m(*args, **kwargs)\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/xarray-ml/lib/python3.7/site-packages/numpy/lib/shape_base.py\u001b[0m in \u001b[0;36mapply_along_axis\u001b[0;34m(func1d, axis, arr, *args, **kwargs)\u001b[0m\n\u001b[1;32m 377\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mStopIteration\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 378\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Cannot apply_along_axis when any iteration dimensions are 0'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 379\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0masanyarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc1d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minarr_view\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mind0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 380\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 381\u001b[0m \u001b[0;31m# build a buffer for storing evaluations of func1d.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-1-9608598b783a>\u001b[0m in \u001b[0;36mdiff_mean\u001b[0;34m(X, y)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;34m''' a function that only works on 1d arrays that are different lengths'''\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAssertionError\u001b[0m: 3"
]
}
],
"source": [
"# It would be nice if wee can use numpy's apply_along_axis but it only iterates over the first argument\n",
"axis = 0\n",
"np.apply_along_axis(diff_mean, axis, X, y)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment