Skip to content

Instantly share code, notes, and snippets.

@coderforlife
Created July 15, 2020 08:27
Show Gist options
  • Save coderforlife/e3a5fffa17be71c2f97779e0c41fb5a1 to your computer and use it in GitHub Desktop.
Save coderforlife/e3a5fffa17be71c2f97779e0c41fb5a1 to your computer and use it in GitHub Desktop.
Testing `cupyx.scipy.ndimage.generic_filter()`
import numpy, cupy
import scipy.ndimage as ndi
import cupyx.scipy.ndimage as cp_ndi
from scipy import LowLevelCallable
from numba import cfunc, types, carray
##### Root Mean Squared #####
# Actually these are just the mean-squared
rms_raw = cupy.RawKernel('''extern "C" __global__
void rms(const double* x, int filter_size, double* y) {
double ss = 0;
for (int i = 0; i < filter_size; ++i) { ss += x[i]*x[i]; }
y[0] = ss/filter_size;
}''', 'rms')
rms_red = cupy.ReductionKernel('X x', 'Y y', 'x*x', 'a + b', 'y = a/_in_ind.size()', '0', 'rms')
def rms_fuse_wrapper(filter_size):
def rms_fuse(x): return (x*x).sum()/filter_size
return rms_fuse
@cfunc(types.intc(types.CPointer(types.double), types.intp, types.CPointer(types.double), types.voidptr))
def rms_numba(x, filter_size, y, _):
ss = 0
for i in range(filter_size): ss += x[i]*x[i]
y[0] = ss/filter_size
return 1
rms_llc = LowLevelCallable(rms_numba.ctypes)
def rms_pyfunc(x): return (x*x).sum()/len(x)
##### Less-Than Middle #####
lt_raw = cupy.RawKernel('''extern "C" __global__
void lt(const double* x, int filter_size, double* y) {
int n = 0;
double c = x[filter_size/2];
for (int i = 0; i < filter_size; ++i) { n += c>x[i]; }
y[0] = n;
}''', 'lt')
lt_red = cupy.ReductionKernel('X x', 'Y y', '_raw_x[_in_ind.size()/2]>x', 'a + b', 'y = a', '0', 'lt', reduce_type='int')
def lt_fuse_wrapper(filter_size):
def lt_fuse(x): return (x[filter_size//2]>x).sum()
return lt_fuse
@cfunc(types.intc(types.CPointer(types.double), types.intp, types.CPointer(types.double), types.voidptr))
def lt_numba(x, filter_size, y, _):
c = x[filter_size//2]
n = 0
for i in range(filter_size): n += c>x[i]
y[0] = n
return 1
lt_llc = LowLevelCallable(lt_numba.ctypes)
def lt_pyfunc(x): return (x[len(x)//2]>x).sum()
##### All #####
all_raw = cupy.RawKernel('''extern "C" __global__
void all(const double* x, int filter_size, double* y) {
int n = 0;
for (int i = 0; i < filter_size; ++i) { n += x[i]!=0; }
y[0] = n;
}''', 'all')
all_red = cupy.ReductionKernel('X x', 'Y y', 'x!=0', 'a + b', 'y = a', '0', 'all', reduce_type='int')
all_fuse = cupy.all
@cfunc(types.intc(types.CPointer(types.double), types.intp, types.CPointer(types.double), types.voidptr))
def all_numba(x, filter_size, y, _):
n = 0
for i in range(filter_size): n += x[i]!=0
y[0] = n
return 1
all_llc = LowLevelCallable(all_numba.ctypes)
all_pyfunc = numpy.all
###### Setup for running tests ######
funcs = [
['rms', [rms_raw, rms_red, rms_fuse_wrapper], [rms_llc, rms_pyfunc]],
['lt', [lt_raw, lt_red, lt_fuse_wrapper], [lt_llc, lt_pyfunc]],
['all', [all_raw, all_red, all_fuse], [all_llc, all_pyfunc]],
]
cp_names = ['raw', 'red', 'fuse']
sp_names = ['numba', 'py']
###### Setup run timing tests ######
sp_data = numpy.random.rand(1000, 1000)
cp_data = cupy.array(sp_data)
for size in [3, 15, 25]:
for name, cp_funcs, sp_funcs in funcs:
print(name, '%dx%d' % (size, size))
for name, func in zip(cp_names, cp_funcs):
if func in (rms_fuse_wrapper, lt_fuse_wrapper): func = func(size*size)
out = cp_ndi.generic_filter(cp_data, func, size)
ref = ndi.generic_filter(sp_data, sp_funcs[0], size)
if numpy.allclose(out.get(), ref):
print(name, end=' ')
else:
print(name, '*', end=' ') # asterisks means bad result
%timeit cp_ndi.generic_filter(cp_data, func, size); cupy.cuda.Stream.null.synchronize()
for name, func in zip(sp_names, sp_funcs):
ndi.generic_filter(sp_data, func, size)
print(name, end=' ')
%timeit ndi.generic_filter(sp_data, func, size)
print('----------------------------------------')
Tested on system with a Intel Xeon Gold 5122 CPU @ 3.60GHz and a Titan V GPU.
The * for fuse with `all` indicates it is actually getting the wrong output, still need to fix that issue apparently.
rms 3x3
raw 308 µs ± 893 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
red 308 µs ± 206 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
fuse 2.06 ms ± 4.13 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
numba 14.1 ms ± 48.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
py 2.96 s ± 15.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
----------------------------------------
lt 3x3
raw 337 µs ± 940 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
red 338 µs ± 1.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
fuse 1.73 ms ± 11.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
numba 16.3 ms ± 146 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
py 3.91 s ± 391 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
----------------------------------------
all 3x3
raw 339 µs ± 487 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
red 339 µs ± 783 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
fuse * 604 µs ± 1.62 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
numba 16.2 ms ± 160 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
py 2.72 s ± 5.35 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
----------------------------------------
rms 15x15
raw 6.83 ms ± 806 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
red 6.83 ms ± 2.07 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
fuse 15.5 ms ± 54.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
numba 371 ms ± 486 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
py 3.5 s ± 6.81 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
----------------------------------------
lt 15x15
raw 6.8 ms ± 1.67 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
red 6.8 ms ± 1.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
fuse 9.36 ms ± 4.28 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
numba 156 ms ± 676 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
py 3.84 s ± 5.92 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
----------------------------------------
all 15x15
raw 6.88 ms ± 1.27 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
red 6.88 ms ± 1.81 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
fuse * 7.12 ms ± 9.96 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
numba 158 ms ± 225 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
py 3.17 s ± 4.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
----------------------------------------
rms 25x25
raw 18.7 ms ± 5.85 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
red 18.7 ms ± 5.45 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
fuse 39.1 ms ± 16 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
numba 1.02 s ± 416 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
py 4.25 s ± 25.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
----------------------------------------
lt 25x25
raw 19.4 ms ± 9.67 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
red 19.4 ms ± 8.77 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
fuse 24.8 ms ± 12.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
numba 406 ms ± 559 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
py 4.82 s ± 10.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
----------------------------------------
all 25x25
raw 19.5 ms ± 2.85 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
red 19.5 ms ± 3.39 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
fuse * 19.7 ms ± 4.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
numba 405 ms ± 118 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
py 3.89 s ± 4.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
----------------------------------------
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment