Last active
May 24, 2019 17:46
-
-
Save 0x0L/ef78c80a42892c0f832c91357914a5a4 to your computer and use it in GitHub Desktop.
xarray string accessors
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import xarray as xr | |
@xr.register_dataarray_accessor('str') | |
class StringAccessor: | |
def __init__(self, xarray_obj): | |
self._obj = xarray_obj | |
self._data = xarray_obj.data | |
def _str_func(self, name, *args, **kwargs): | |
g = getattr(pd.core.strings, f'str_{name}') | |
result = g(self._data.flat, *args, **kwargs).reshape(self._data.shape) | |
return xr.DataArray(result, dims=self._obj.dims, coords=self._obj.coords) | |
def _na_map(fn, dtype=object): | |
def f(self): | |
result = pd.core.strings._na_map(fn, self._data.flat, dtype=dtype) | |
result = np.reshape(result, self._data.shape).reshape(self._data.shape) | |
return xr.DataArray(result, dims=self._obj.dims, coords=self._obj.coords) | |
return f | |
def __getitem__(self, key): | |
if isinstance(key, slice): | |
return self.slice(start=key.start, stop=key.stop, step=key.step) | |
else: | |
return self.get(key) | |
len = _na_map(len, dtype=int) | |
lower = _na_map(lambda x: x.lower()) | |
upper = _na_map(lambda x: x.upper()) | |
title = _na_map(lambda x: x.title()) | |
capitalize = _na_map(lambda x: x.capitalize()) | |
swapcase = _na_map(lambda x: x.swapcase()) | |
isalnum = _na_map(lambda x: x.isalnum()) | |
isalpha = _na_map(lambda x: x.isalpha()) | |
isdigit = _na_map(lambda x: x.isdigit()) | |
isspace = _na_map(lambda x: x.isspace()) | |
islower = _na_map(lambda x: x.islower()) | |
isupper = _na_map(lambda x: x.isupper()) | |
istitle = _na_map(lambda x: x.istitle()) | |
isnumeric = _na_map(lambda x: x.isnumeric()) | |
isdecimal = _na_map(lambda x: x.isdecimal()) | |
def pad(self, width, side='left', fillchar=' '): | |
return self._str_func('pad', width, side=side, fillchar=fillchar) | |
def center(self, width, fillchar=' '): | |
return self.pad(width, side='both', fillchar=fillchar) | |
def ljust(self, width, fillchar=' '): | |
return self.pad(width, side='right', fillchar=fillchar) | |
def rjust(self, width, fillchar=' '): | |
return self.pad(width, side='left', fillchar=fillchar) | |
def count(self, pat, flags=0): | |
return self._str_func('count', pat, flags=flags) | |
def startswith(self, pat, na=np.nan): | |
return self._str_func('startswith', pat, na=na) | |
def endswith(self, pat, na=np.nan): | |
return self._str_func('endswith', pat, na=na) | |
def slice(self, start=None, stop=None, step=None): | |
return self._str_func('slice', start, stop, step) | |
def slice_replace(self, start=None, stop=None, repl=None): | |
return self._str_func('slice_replace', start, stop, repl) | |
def decode(self, encoding, errors="strict"): | |
return self._str_func('decode', encoding, errors) | |
def encode(self, encoding, errors="strict"): | |
return self._str_func('encode', encoding, errors) | |
def find(self, sub, start=0, end=None): | |
return self._str_func('find', sub, start=start, end=end, side='left') | |
def rfind(self, sub, start=0, end=None): | |
return self._str_func('find', sub, start=start, end=end, side='right') | |
def index(self, sub, start=0, end=None): | |
return self._str_func('index', sub, start=start, end=end, side='left') | |
def rindex(self, sub, start=0, end=None): | |
return self._str_func('index', sub, start=start, end=end, side='right') | |
def repeat(self, repeats): | |
return self._str_func('repeat', repeats) | |
def get(self, i): | |
return self._str_func('get', i) | |
def contains(self, pat, case=True, flags=0, na=np.nan, regex=True): | |
return self._str_func('contains', pat, case=case, flags=flags, na=na, regex=regex) | |
def match(self, pat, case=True, flags=0, na=np.nan): | |
return self._str_func('match', pat, case=case, flags=flags, na=na) | |
def strip(self, to_strip=None): | |
return self._str_func('strip', to_strip, side='both') | |
def lstrip(self, to_strip=None): | |
return self._str_func('strip', to_strip, side='left') | |
def rstrip(self, to_strip=None): | |
return self._str_func('strip', to_strip, side='right') | |
def wrap(self, width, **kwargs): | |
return self._str_func('wrap', width, **kwargs) | |
def translate(self, table, deletechars=None): | |
return self._str_func('translate', table, deletechars) | |
def replace(self, pat, repl, n=-1, case=None, flags=0, regex=True): | |
return self._str_func('replace', pat, repl, n=n, case=case, flags=flags, regex=regex) | |
def zfill(self, width): | |
return self._str_func('pad', width, side='left', fillchar='0') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment