Skip to content

Instantly share code, notes, and snippets.

@bmerry
Created July 26, 2017 07:56
Show Gist options
  • Save bmerry/a254a69597dce9b6957fe4470c1bbf84 to your computer and use it in GitHub Desktop.
Save bmerry/a254a69597dce9b6957fe4470c1bbf84 to your computer and use it in GitHub Desktop.
Benchmarks of different implementations of dask.array.where
No broadcasting or scalars
where_orig : 0.000482s ± 0.000005 (construct), 0.165047s ± 0.001631s (compute)
where_where: 0.000273s ± 0.000002 (construct), 0.070438s ± 0.000439s (compute)
where_new : 0.000289s ± 0.000004 (construct), 0.069269s ± 0.000712s (compute)
Scalar condition
where_orig : 0.000013s ± 0.000000 (construct), 0.024283s ± 0.000316s (compute)
where_where: 0.000012s ± 0.000000 (construct), 0.024131s ± 0.000329s (compute)
where_new : 0.000017s ± 0.000000 (construct), 0.023551s ± 0.000195s (compute)
Broadcasting
where_orig : 0.000810s ± 0.000008 (construct), 0.168275s ± 0.001371s (compute)
where_where: 0.000593s ± 0.000003 (construct), 0.061382s ± 0.000654s (compute)
where_new : 0.000392s ± 0.000002 (construct), 0.058013s ± 0.000645s (compute)
Broadcasting, many chunks
where_orig : 0.022362s ± 0.000423 (construct), 1.387578s ± 0.004326s (compute)
where_where: 0.022167s ± 0.000406 (construct), 1.315636s ± 0.004534s (compute)
where_new : 0.016981s ± 0.000361 (construct), 0.875134s ± 0.002577s (compute)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function, division
import dask.array as da
import numpy as np
import scipy.stats
import time
# Original version function from master (0ef9424)
def where_orig(condition, x=None, y=None):
if x is None or y is None:
raise TypeError(da.core.where_error_message)
x = da.asarray(x)
y = da.asarray(y)
shape = da.core.broadcast_shapes(x.shape, y.shape)
dtype = np.promote_types(x.dtype, y.dtype)
x = da.broadcast_to(x, shape).astype(dtype)
y = da.broadcast_to(y, shape).astype(dtype)
if np.isscalar(condition):
return x if condition else y
else:
condition = da.asarray(condition).astype('bool')
return da.choose(condition, [y, x])
# Variant of where_orig that uses np.where instead of np.choose underneath
def where_where(condition, x=None, y=None):
if x is None or y is None:
raise TypeError(da.core.where_error_message)
x = da.asarray(x)
y = da.asarray(y)
shape = da.core.broadcast_shapes(x.shape, y.shape)
dtype = np.promote_types(x.dtype, y.dtype)
x = da.broadcast_to(x, shape).astype(dtype)
y = da.broadcast_to(y, shape).astype(dtype)
if np.isscalar(condition):
return x if condition else y
else:
return da.core.elemwise(np.where, condition, x, y)
# New version from my branch (20d7682)
def result_type(*args):
args = [a if da.core.is_scalar_for_elemwise(a) else a.dtype for a in args]
return np.result_type(*args)
def where_new(condition, x=None, y=None):
if x is None or y is None:
raise TypeError(where_error_message)
if np.isscalar(condition):
dtype = result_type(x, y)
x = da.asarray(x)
y = da.asarray(y)
shape = da.core.broadcast_shapes(x.shape, y.shape)
out = x if condition else y
return da.broadcast_to(out, shape).astype(dtype)
else:
return da.core.elemwise(np.where, condition, x, y)
def time_function(func, passes):
func() # Warmup
times = []
for i in range(passes):
start = time.time()
func()
end = time.time()
times.append(end - start)
return np.mean(times), scipy.stats.sem(times)
def timer(name, func, passes=20):
construct_time, construct_std = time_function(func, passes * 10)
array = func()
compute_time, compute_std = time_function(array.compute, passes)
print('{}: {:.6f}s ± {:.6f} (construct), {:.6f}s ± {:.6f}s (compute)'.format(
name, construct_time, construct_std, compute_time, compute_std))
def time_all(*args, **kwargs):
timer('where_orig ', lambda: where_orig(*args, **kwargs))
timer('where_where', lambda: where_where(*args, **kwargs))
timer('where_new ', lambda: where_new(*args, **kwargs))
print('No broadcasting or scalars')
a = da.zeros(10000000, chunks=1000000)
b = da.ones(10000000, chunks=1000000)
c = da.random.randint(0, 2, 10000000, chunks=1000000)
time_all(c, a, b)
print('Scalar condition')
time_all(True, a, b)
print('Broadcasting')
b2 = da.ones(1, chunks=1)
time_all(c, a, b2)
print('Broadcasting, many chunks')
a2 = da.zeros(10000000, chunks=10000)
time_all(c, a2, b2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment