Created
September 22, 2016 16:33
-
-
Save chris-b1/d28c6b8e78bf65ef7eb97e1095bc87f2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import OrderedDict | |
from itertools import chain | |
import numba | |
import xarray as xr | |
from xarray.core.computation import UFuncSignature | |
def _map_sig(sig, dim_map): | |
return [tuple(dim_map[x] for x in t) | |
for t in sig] | |
def xarray_gufunc(f): | |
sig = UFuncSignature.from_string(f.signature) | |
# really just need an ordered set here | |
dim_map = OrderedDict() | |
for d in chain(*sig.input_core_dims, *sig.output_core_dims): | |
dim_map[d] = None | |
def make_sig(dims): | |
# handle non-iterable dims? | |
# also accept a mapping dict? | |
assert len(dims) == len(dim_map) | |
for old_dim, new_dim in zip(dim_map, dims): | |
dim_map[old_dim] = new_dim | |
new_sig = UFuncSignature( | |
_map_sig(sig.input_core_dims, dim_map), | |
_map_sig(sig.output_core_dims, dim_map)) | |
return new_sig | |
# update siganture / doctring | |
def inner(*args, dims=()): | |
sig = make_sig(dims) | |
return xr.apply(f, *args, signature=sig) | |
return inner | |
arr = xr.DataArray(np.random.randn(100,100,100), dims=('x','y','z')) | |
@xarray_gufunc | |
@numba.guvectorize(['void(f8[:], f8[:])'], '(n)->()') | |
def std_gufunc(arr, out): | |
out[0] = np.std(arr) | |
std_gufunc(arr, dims=('x',)) | |
std_gufunc(arr, dims=('y',)) | |
@xarray_gufunc | |
@numba.guvectorize(['void(f8[:,:], f8[:], f8[:])'], | |
'(m,n),()->()') | |
def reduce_scalar_gufunc(arr, scalar_arr, out): | |
scalar = scalar_arr[0] | |
out[0] = np.sum(arr) + scalar | |
add_scalar_gufunc(arr, 100., dims=('y','x')) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment