Skip to content

Instantly share code, notes, and snippets.

@chris-b1
Created September 22, 2016 16:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chris-b1/d28c6b8e78bf65ef7eb97e1095bc87f2 to your computer and use it in GitHub Desktop.
Save chris-b1/d28c6b8e78bf65ef7eb97e1095bc87f2 to your computer and use it in GitHub Desktop.
from collections import OrderedDict
from itertools import chain
import numba
import xarray as xr
from xarray.core.computation import UFuncSignature
def _map_sig(sig, dim_map):
return [tuple(dim_map[x] for x in t)
for t in sig]
def xarray_gufunc(f):
sig = UFuncSignature.from_string(f.signature)
# really just need an ordered set here
dim_map = OrderedDict()
for d in chain(*sig.input_core_dims, *sig.output_core_dims):
dim_map[d] = None
def make_sig(dims):
# handle non-iterable dims?
# also accept a mapping dict?
assert len(dims) == len(dim_map)
for old_dim, new_dim in zip(dim_map, dims):
dim_map[old_dim] = new_dim
new_sig = UFuncSignature(
_map_sig(sig.input_core_dims, dim_map),
_map_sig(sig.output_core_dims, dim_map))
return new_sig
# update siganture / doctring
def inner(*args, dims=()):
sig = make_sig(dims)
return xr.apply(f, *args, signature=sig)
return inner
arr = xr.DataArray(np.random.randn(100,100,100), dims=('x','y','z'))
@xarray_gufunc
@numba.guvectorize(['void(f8[:], f8[:])'], '(n)->()')
def std_gufunc(arr, out):
out[0] = np.std(arr)
std_gufunc(arr, dims=('x',))
std_gufunc(arr, dims=('y',))
@xarray_gufunc
@numba.guvectorize(['void(f8[:,:], f8[:], f8[:])'],
'(m,n),()->()')
def reduce_scalar_gufunc(arr, scalar_arr, out):
scalar = scalar_arr[0]
out[0] = np.sum(arr) + scalar
add_scalar_gufunc(arr, 100., dims=('y','x'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment