Skip to content

Instantly share code, notes, and snippets.

@tsvikas
Created January 20, 2020 14:59
Show Gist options
  • Save tsvikas/d2925d215ee3f9b3d30e975cba726cc3 to your computer and use it in GitHub Desktop.
Save tsvikas/d2925d215ee3f9b3d30e975cba726cc3 to your computer and use it in GitHub Desktop.
register a module (i.e. ndimage) as dataarray accessor
import functools
import xarray as xr
def register_dataarray_module(module, accessor_name=None):
if accessor_name is None:
accessor_name = module.__name__.split(".")[-1]
callable_funcs = [name for name in dir(module) if callable(getattr(module, name))]
@xr.register_dataarray_accessor(accessor_name)
class ModuleAccessor:
def __init__(self, ar):
self._ar = ar
def __getattr__(self, name):
if name not in callable_funcs:
raise AttributeError(
f"module {accessor_name!r} has no callable attribute {name!r}"
)
orig_func = getattr(module, name)
@functools.wraps(orig_func)
def func(*args, **kwargs):
res = orig_func(self._ar.values, *args, **kwargs)
if not isinstance(res, xr.DataArray):
res = xr.DataArray(
res,
dims=self._ar.dims,
coords=self._ar.coords,
attrs=self._ar.attrs,
)
return res
return func
def __dir__(self):
return callable_funcs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment