Created
January 20, 2020 14:59
-
-
Save tsvikas/d2925d215ee3f9b3d30e975cba726cc3 to your computer and use it in GitHub Desktop.
register a module (i.e. ndimage) as dataarray accessor
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
import xarray as xr | |
def register_dataarray_module(module, accessor_name=None): | |
if accessor_name is None: | |
accessor_name = module.__name__.split(".")[-1] | |
callable_funcs = [name for name in dir(module) if callable(getattr(module, name))] | |
@xr.register_dataarray_accessor(accessor_name) | |
class ModuleAccessor: | |
def __init__(self, ar): | |
self._ar = ar | |
def __getattr__(self, name): | |
if name not in callable_funcs: | |
raise AttributeError( | |
f"module {accessor_name!r} has no callable attribute {name!r}" | |
) | |
orig_func = getattr(module, name) | |
@functools.wraps(orig_func) | |
def func(*args, **kwargs): | |
res = orig_func(self._ar.values, *args, **kwargs) | |
if not isinstance(res, xr.DataArray): | |
res = xr.DataArray( | |
res, | |
dims=self._ar.dims, | |
coords=self._ar.coords, | |
attrs=self._ar.attrs, | |
) | |
return res | |
return func | |
def __dir__(self): | |
return callable_funcs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment