Last active
February 12, 2017 00:45
-
-
Save martindurant/06a1e98c91f0033c4649a48a2f943390 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import collections | |
import dask.array as da | |
from dask.utils import infer_storage_options | |
import numpy as np | |
import pickle | |
import xarray as xr | |
import zarr | |
def _get_chunks(darr): | |
for c in darr.chunks: | |
if len(set(c)) != 1: | |
# I believe arbitrary chunking is not possible | |
raise ValueError("Must use regular chunking for zarr; %s" | |
% darr.chunks) | |
return [c[0] for c in darr.chunks] | |
def dask_to_zarr(darr, url, compressor='default', ret=False, | |
storage_options=None): | |
""" | |
Save dask array to a zarr | |
Parameters | |
---------- | |
darr: dask array | |
url: location | |
May include protocol, e.g., ``s3://mybucket/mykey.zarr`` | |
compressor: string ['default'] | |
Compression to use, see [zarr compressors](http://zarr.readthedocs.io/en/latest/api/codecs.html) | |
""" | |
chunks = _get_chunks(darr) | |
out = zarr.open_array(url, mode='w', shape=darr.shape, | |
chunks=chunks, dtype=darr.dtype, | |
compressor=compressor) | |
da.store(darr, out) | |
if ret: | |
return out | |
def dask_from_zarr(url, ret=False, storage_options=None): | |
""" | |
Load zarr data into a dask array | |
Parameters | |
---------- | |
url: location | |
May include protocol, e.g., ``s3://mybucket/mykey.zarr`` | |
ret: bool (False) | |
To also return the raw zarr. | |
""" | |
d = zarr.open_array(url) | |
out = da.from_array(d, chunks=d.chunks) | |
if ret: | |
return out, d | |
return out | |
def xarray_to_zarr(arr, url, storage_options=None, **kwargs): | |
""" | |
Save xarray.DataArray to a zarr | |
This is a simplified method, where all metadata, including coordinates, | |
is stored into a special key within the zarr. | |
Parameters | |
---------- | |
arr: data to store | |
url: location to store into | |
kwargs: passed on to zarr | |
""" | |
coorddict = [(name, arr.coords[name].values) for name in arr.dims] | |
z = dask_to_zarr(arr.data, url, | |
compressor=kwargs.get('compressor', 'default'), ret=True) | |
z.store['.xarray'] = pickle.dumps({'coords': coorddict, 'attrs': arr.attrs, | |
'name': arr.name, 'dims': arr.dims}, -1) | |
def xarray_from_zarr(url, storage_options=None): | |
""" | |
Load xarray.DataArray from a zarr, stored using ``xarray_to_zarr`` | |
Parameters | |
---------- | |
url: location of zarr | |
Returns | |
------- | |
xarray.DataArray instance | |
""" | |
z, d = dask_from_zarr(url, True) | |
meta = pickle.loads(d.store['.xarray']) | |
out = xr.DataArray(z, **meta) | |
return out | |
def _safe_attrs(attrs, order=True): | |
""" | |
Rationalize numpy contents of attrbutes for serialization | |
Since number in xarray attributes are often numpy numbers or arrays, | |
which will not JSON serialize, replace them with simple values or lists. | |
Parameters | |
---------- | |
attrs: attributes set, dict-like | |
order: bool, whether to produce an ordered or simple dict. | |
Returns | |
------- | |
Altered attribute set | |
""" | |
out = collections.OrderedDict() if order else {} | |
for k, v in attrs.items(): | |
if isinstance(v, (np.number, np.ndarray)): | |
out[k] = v.tolist() | |
else: | |
out[k] = v | |
return out | |
def dataset_to_zarr(ds, url, path_in_dataset='/', storage_options=None, | |
**kwargs): | |
""" | |
Save xarray.Dataset in to zarr | |
All coordinates, variables and their attributes will be saved. If the | |
variables are based on dask.Arrays, chunking will be preserved; otherwise, | |
zarr will guess a suitable chunking scheme. The user may wish to define | |
the chunking manually by calling ``ds.chunk`` first. | |
Parameters | |
---------- | |
ds: data | |
url: location to save to | |
path_in_dataset: string ('/') | |
If only writing to some sub-set of a larger dataset, in the sense | |
used by netCDF | |
kwargs: passed on to zarr | |
""" | |
url = make_mapper(url, storage_options) | |
comp = kwargs.get('compressor', 'default') | |
if path_in_dataset == '/': | |
root = zarr.open_group(url, 'w') | |
else: | |
root = zarr.open_group(url, 'a') | |
root = root.create_group(path_in_dataset, overwrite=True) | |
attrs = {'coords': {}, 'variables': {}, 'dims': {}} | |
attrs['attrs'] = _safe_attrs(ds.attrs) | |
coords = root.create_group('coords') | |
for coord in ds.coords: | |
coords.create_dataset(name=coord, data=np.asarray(ds.coords[coord])) | |
attrs['coords'][coord] = _safe_attrs(ds.coords[coord].attrs) | |
variables = root.create_group('variables') | |
for variable in set(ds.variables) - set(ds.coords): | |
v = ds.variables[variable] | |
if isinstance(v.data, da.Array): | |
chunks = _get_chunks(v) | |
out = variables.create_dataset(variable, shape=v.shape, | |
chunks=chunks, dtype=v.dtype, | |
compressor=comp) | |
da.store(v.data, out) | |
else: | |
variables.create_dataset(name=variable, data=v, compressor=comp) | |
attrs['dims'][variable] = v.dims | |
attrs['variables'][variable] = _safe_attrs(v.attrs) | |
root.attrs.update(attrs) | |
def dataset_from_zarr(url, path_in_dataset='/', storage_options=None): | |
""" | |
Load a zarr into a xarray.Dataset. | |
Variables and coordinates will be loaded, with applicable attributes. | |
Variables will load lazily and respect the on-disc chunking scheme; | |
coordinates will be loaded eagerly into memory. | |
Parameters | |
---------- | |
url: location to load from | |
path_in_dataset: string ('/') | |
For some sub-set of a larger dataset, in the sense used by netCDF | |
Returns | |
------- | |
xarray.Dataset instance | |
""" | |
url = make_mapper(url, storage_options) | |
root = zarr.open_group(url, 'r') | |
if path_in_dataset != '/': | |
root = root[path_in_dataset] | |
attrs = dict(root.attrs) | |
coords = {} | |
for coord in root['coords']: | |
coords[coord] = (coord, np.array(root['coords'][coord]), | |
attrs['coords'][coord]) | |
out = {} | |
for variable in root['variables']: | |
d = root['variables'][variable] | |
out[variable] = (root.attrs['dims'][variable], | |
da.from_array(d, chunks=d.chunks), | |
attrs['variables'][variable]) | |
ds = xr.Dataset(out, coords, attrs=attrs['attrs']) | |
return ds | |
def s3mapper(url, key=None, username=None, secret=None, password=None, | |
path=None, host=None, s3=None, **kwargs): | |
import s3fs | |
if username is not None: | |
key = username | |
if key is not None: | |
kwargs['key'] = key | |
if password is not None: | |
secret = password | |
if secret is not None: | |
kwargs['secret'] = secret | |
options = infer_storage_options(url, kwargs) | |
s3 = s3fs.S3FileSystem(**kwargs) | |
return s3fs.S3Map(options['host'] + options['path'], s3) | |
mappers = {'file': lambda x, **kw: x, | |
's3': s3mapper} | |
def make_mapper(url, storage_options=None): | |
options = infer_storage_options(url, storage_options) | |
protocol = options.pop('protocol') | |
return mappers[protocol](url, **options) | |
# Test pieces | |
testfile = '/Users/mdurant/Downloads/smith_sandwell_topo_v8_2.nc' | |
def test_dask_roundtrip(): | |
arr = xr.open_dataset(testfile, chunks={'latitude': 6336//11, | |
'longitude': 10800//15}).ROSE | |
darr = arr.data | |
dask_to_zarr(darr, 'out.zarr', compressor=zarr.Blosc()) | |
assert dask_from_zarr('out.zarr').mean().compute() == darr.mean().compute() | |
def test_xarray_roundtrip(): | |
arr = xr.open_dataset(testfile, chunks={'latitude': 6336//11, | |
'longitude': 10800//15}).ROSE | |
darr = arr.data | |
xarray_to_zarr(arr, 'out.xarr') | |
out = xarray_from_zarr('out.xarr') | |
assert out.mean().values == darr.mean().compute() | |
def test_dataset_roundtrip(): | |
ds = xr.open_dataset(testfile, chunks={'latitude': 6336//11, | |
'longitude': 10800//15}) | |
dataset_to_zarr(ds, 'out.xzarr') | |
out = dataset_from_zarr('out.xzarr') | |
assert isinstance(out.ROSE.data, da.Array) | |
for c in out.coords: | |
assert (ds.coords[c].data == out.coords[c].data).all() | |
assert ds.coords[c].attrs == out.coords[c].attrs | |
for c in out.ROSE.coords: | |
assert (ds.coords[c].data == out.coords[c].data).all() | |
assert ds.coords[c].attrs == out.coords[c].attrs | |
assert _safe_attrs(ds.ROSE.attrs, False) == _safe_attrs( | |
out.ROSE.attrs, False) | |
assert _safe_attrs(ds.attrs, False) == _safe_attrs(out.attrs, False) | |
assert (ds.ROSE == out.ROSE).all().values |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment