Skip to content

Instantly share code, notes, and snippets.

@crusaderky
Last active January 23, 2017 00:28
Show Gist options
  • Save crusaderky/62832a5ffc72ccb3e0954021b0996fdf to your computer and use it in GitHub Desktop.
Save crusaderky/62832a5ffc72ccb3e0954021b0996fdf to your computer and use it in GitHub Desktop.
xarray Fast Weighted Sum
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"import dask.array\n",
"import numpy\n",
"import xarray\n",
"import random"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def fastwsum(arrays, weights, blocksize=64):\n",
" \"\"\"Weighted sum of arrays.\n",
"\n",
" :param arrays:\n",
" sequence of xarray.DataArray objects\n",
" :param weights:\n",
" sequence of scalars of the same length as arrays\n",
" :param blocksize:\n",
" number of arrays to add together at once in dask (see below)\n",
" :returns:\n",
" single xarray.DataArray\n",
"\n",
" this function is functionally equivalent to::\n",
"\n",
" sum(a * w for a, w in zip(arrays, weights))\n",
"\n",
" but it is potentially much faster because:\n",
"\n",
" - the xarray broadcast/align magic is executed once instead of being repeated for every\n",
" subtotal\n",
" - attempt to broadcast as late as possible by calculating subtotals by dims\n",
" - 1 and 0 weights are optimized away\n",
" - there is one dask operation every <blocksize> arrays, instead of two per array (one to\n",
" multiply by the weight and another to add to the subtotal).\n",
" - in case of mixed dask/numpy addends, the numpy-only subtotal is added only once to the\n",
" dask graph\n",
"\n",
" The downside is that <blocksize> inputs must be made available by dask at the same time,\n",
" with consecutive increase in RAM occupation and (in case of dask.distributed) potentially\n",
" higher-than-needed data transfers over the network.\n",
" \"\"\"\n",
" assert len(arrays) == len(weights)\n",
"\n",
" # Attempt to broadcast as late as possible by calculating subtotals by dims.\n",
" group_by_dims = defaultdict(lambda: ([], []))\n",
" for a, w in zip(arrays, weights):\n",
" group_by_dims[a.dims][0].append(a)\n",
" group_by_dims[a.dims][1].append(w)\n",
"\n",
" if len(group_by_dims) > 1 and any(len(v[0]) > 1 for v in group_by_dims.values()):\n",
" subtotals = [\n",
" fastwsum(arrays, weights, blocksize=blocksize)\n",
" for (arrays, weights) in group_by_dims.values()\n",
" ]\n",
" return fastsum(subtotals, blocksize=blocksize)\n",
"\n",
" arrays = xarray.broadcast(*arrays)\n",
"\n",
" numpy_total = 0\n",
" dask_data = []\n",
" dask_weights = []\n",
" for array, weight in zip(arrays, weights):\n",
" if weight == 0:\n",
" pass\n",
" try:\n",
" array.data.dask\n",
" # xarray with dask backend\n",
" dask_data.append(array.data)\n",
" dask_weights.append(weight)\n",
" except AttributeError:\n",
" # xarray with numpy backend or scalar\n",
" if weight == 1:\n",
" numpy_total += array.data\n",
" else:\n",
" numpy_total += array.data * weight\n",
"\n",
" if numpy_total is 0 and len(dask_data) == 0:\n",
" # All weights are 0\n",
" scalar_coords = {\n",
" k: v for k, v in arrays[0].coords.items()\n",
" if v.shape == ()\n",
" }\n",
" return xarray.DataArray(0, coords=scalar_coords)\n",
"\n",
" while len(dask_data) > 1 or (len(dask_weights) > 0 and dask_weights[0] != 1):\n",
" assert len(dask_data) == len(dask_weights)\n",
" subtotals = []\n",
" for offset in range(0, len(dask_data), blocksize):\n",
" dask_data_slice = dask_data[offset:offset + blocksize]\n",
" dask_weights_slice = dask_weights[offset:offset + blocksize]\n",
" dtype = numpy.result_type(*[d.dtype for d in dask_data_slice])\n",
" subtotals.append(dask.array.map_blocks(\n",
" _fastwsum_kernel, *dask_data_slice, weights=dask_weights_slice, dtype=dtype))\n",
" dask_data = subtotals\n",
" dask_weights = [1] * len(dask_data)\n",
"\n",
" if numpy_total is not 0:\n",
" dask_data.append(numpy_total)\n",
" assert len(dask_data) in (1, 2)\n",
"\n",
" return xarray.DataArray(sum(dask_data), dims=arrays[0].dims, coords=arrays[0].coords)\n",
"\n",
"\n",
"def _fastwsum_kernel(*arrays, weights):\n",
" total = 0\n",
" for array, weight in zip(arrays, weights):\n",
" if weight == 1:\n",
" total += array\n",
" else:\n",
" total += array * weight\n",
" return total\n",
"\n",
"\n",
"def fastsum(arrays, blocksize=64):\n",
" \"\"\"Functionally equivalent to sum(*args), but faster.\n",
" All arguments must be :class:`xarray.DataArray` objects.\n",
" All notes for :func:`fastwsum` apply.\n",
" \"\"\"\n",
" return fastwsum(arrays, [1] * len(arrays), blocksize)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Benchmarks"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def rand_addend(use_dask):\n",
" if random.randint(0, 9) > 0:\n",
" if use_dask:\n",
" data = dask.array.random.random((3, 100), chunks=10)\n",
" else:\n",
" data = numpy.random.random(300).reshape(3, 100)\n",
" return xarray.DataArray(data, dims=['time', 'scenario'], coords={'time': ['A', 'B', 'C']})\n",
" else:\n",
" return xarray.DataArray([1.1, 2.2, 3.3], dims=['time'], coords={'time': ['A', 'B', 'C']})\n",
"\n",
"def rand_weight():\n",
" w = random.random()\n",
" if w < .1:\n",
" return 0\n",
" elif w < .2:\n",
" return 1\n",
" else:\n",
" return w\n",
" \n",
"\n",
"addends = [rand_addend(True) for _ in range(4000)]\n",
"weights = [rand_weight() for _ in range(4000)]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Mixed dask+numpy - plain sum"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 20.4 s, sys: 104 ms, total: 20.5 s\n",
"Wall time: 20.5 s\n",
"CPU times: user 14.6 s, sys: 1.08 s, total: 15.7 s\n",
"Wall time: 14.6 s\n"
]
},
{
"data": {
"text/plain": [
"75712"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time x = sum(addends)\n",
"%time x.compute()\n",
"len(x.data.dask)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2.8 s, sys: 20 ms, total: 2.82 s\n",
"Wall time: 2.81 s\n",
"CPU times: user 8.87 s, sys: 760 ms, total: 9.63 s\n",
"Wall time: 8.86 s\n"
]
},
{
"data": {
"text/plain": [
"72052"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time x = fastsum(addends)\n",
"%time x.compute()\n",
"len(x.data.dask)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Mixed dask+numpy - weighted sum"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 32.5 s, sys: 56 ms, total: 32.5 s\n",
"Wall time: 32.6 s\n",
"CPU times: user 17.2 s, sys: 964 ms, total: 18.2 s\n",
"Wall time: 17.1 s\n"
]
},
{
"data": {
"text/plain": [
"112092"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time x = sum(a * w for a, w in zip(addends, weights))\n",
"%time x.compute()\n",
"len(x.data.dask)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2.8 s, sys: 8 ms, total: 2.81 s\n",
"Wall time: 2.81 s\n",
"CPU times: user 8.93 s, sys: 728 ms, total: 9.66 s\n",
"Wall time: 8.87 s\n"
]
},
{
"data": {
"text/plain": [
"72052"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time x = fastwsum(addends, weights)\n",
"%time x.compute()\n",
"len(x.data.dask)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## numpy only"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"addends = [rand_addend(False) for _ in range(4000)]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1.67 s, sys: 0 ns, total: 1.67 s\n",
"Wall time: 1.67 s\n"
]
}
],
"source": [
"%time x = sum(a * w for a, w in zip(addends, weights))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 856 ms, sys: 4 ms, total: 860 ms\n",
"Wall time: 856 ms\n"
]
}
],
"source": [
"%time x = fastwsum(addends, weights)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda root]",
"language": "python",
"name": "conda-root-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment