Skip to content

Instantly share code, notes, and snippets.

@martindurant
Last active February 12, 2017 00:45
Show Gist options
  • Save martindurant/06a1e98c91f0033c4649a48a2f943390 to your computer and use it in GitHub Desktop.
Save martindurant/06a1e98c91f0033c4649a48a2f943390 to your computer and use it in GitHub Desktop.
import collections
import dask.array as da
from dask.utils import infer_storage_options
import numpy as np
import pickle
import xarray as xr
import zarr
def _get_chunks(darr):
for c in darr.chunks:
if len(set(c)) != 1:
# I believe arbitrary chunking is not possible
raise ValueError("Must use regular chunking for zarr; %s"
% darr.chunks)
return [c[0] for c in darr.chunks]
def dask_to_zarr(darr, url, compressor='default', ret=False,
storage_options=None):
"""
Save dask array to a zarr
Parameters
----------
darr: dask array
url: location
May include protocol, e.g., ``s3://mybucket/mykey.zarr``
compressor: string ['default']
Compression to use, see [zarr compressors](http://zarr.readthedocs.io/en/latest/api/codecs.html)
"""
chunks = _get_chunks(darr)
out = zarr.open_array(url, mode='w', shape=darr.shape,
chunks=chunks, dtype=darr.dtype,
compressor=compressor)
da.store(darr, out)
if ret:
return out
def dask_from_zarr(url, ret=False, storage_options=None):
"""
Load zarr data into a dask array
Parameters
----------
url: location
May include protocol, e.g., ``s3://mybucket/mykey.zarr``
ret: bool (False)
To also return the raw zarr.
"""
d = zarr.open_array(url)
out = da.from_array(d, chunks=d.chunks)
if ret:
return out, d
return out
def xarray_to_zarr(arr, url, storage_options=None, **kwargs):
"""
Save xarray.DataArray to a zarr
This is a simplified method, where all metadata, including coordinates,
is stored into a special key within the zarr.
Parameters
----------
arr: data to store
url: location to store into
kwargs: passed on to zarr
"""
coorddict = [(name, arr.coords[name].values) for name in arr.dims]
z = dask_to_zarr(arr.data, url,
compressor=kwargs.get('compressor', 'default'), ret=True)
z.store['.xarray'] = pickle.dumps({'coords': coorddict, 'attrs': arr.attrs,
'name': arr.name, 'dims': arr.dims}, -1)
def xarray_from_zarr(url, storage_options=None):
"""
Load xarray.DataArray from a zarr, stored using ``xarray_to_zarr``
Parameters
----------
url: location of zarr
Returns
-------
xarray.DataArray instance
"""
z, d = dask_from_zarr(url, True)
meta = pickle.loads(d.store['.xarray'])
out = xr.DataArray(z, **meta)
return out
def _safe_attrs(attrs, order=True):
"""
Rationalize numpy contents of attrbutes for serialization
Since number in xarray attributes are often numpy numbers or arrays,
which will not JSON serialize, replace them with simple values or lists.
Parameters
----------
attrs: attributes set, dict-like
order: bool, whether to produce an ordered or simple dict.
Returns
-------
Altered attribute set
"""
out = collections.OrderedDict() if order else {}
for k, v in attrs.items():
if isinstance(v, (np.number, np.ndarray)):
out[k] = v.tolist()
else:
out[k] = v
return out
def dataset_to_zarr(ds, url, path_in_dataset='/', storage_options=None,
**kwargs):
"""
Save xarray.Dataset in to zarr
All coordinates, variables and their attributes will be saved. If the
variables are based on dask.Arrays, chunking will be preserved; otherwise,
zarr will guess a suitable chunking scheme. The user may wish to define
the chunking manually by calling ``ds.chunk`` first.
Parameters
----------
ds: data
url: location to save to
path_in_dataset: string ('/')
If only writing to some sub-set of a larger dataset, in the sense
used by netCDF
kwargs: passed on to zarr
"""
url = make_mapper(url, storage_options)
comp = kwargs.get('compressor', 'default')
if path_in_dataset == '/':
root = zarr.open_group(url, 'w')
else:
root = zarr.open_group(url, 'a')
root = root.create_group(path_in_dataset, overwrite=True)
attrs = {'coords': {}, 'variables': {}, 'dims': {}}
attrs['attrs'] = _safe_attrs(ds.attrs)
coords = root.create_group('coords')
for coord in ds.coords:
coords.create_dataset(name=coord, data=np.asarray(ds.coords[coord]))
attrs['coords'][coord] = _safe_attrs(ds.coords[coord].attrs)
variables = root.create_group('variables')
for variable in set(ds.variables) - set(ds.coords):
v = ds.variables[variable]
if isinstance(v.data, da.Array):
chunks = _get_chunks(v)
out = variables.create_dataset(variable, shape=v.shape,
chunks=chunks, dtype=v.dtype,
compressor=comp)
da.store(v.data, out)
else:
variables.create_dataset(name=variable, data=v, compressor=comp)
attrs['dims'][variable] = v.dims
attrs['variables'][variable] = _safe_attrs(v.attrs)
root.attrs.update(attrs)
def dataset_from_zarr(url, path_in_dataset='/', storage_options=None):
"""
Load a zarr into a xarray.Dataset.
Variables and coordinates will be loaded, with applicable attributes.
Variables will load lazily and respect the on-disc chunking scheme;
coordinates will be loaded eagerly into memory.
Parameters
----------
url: location to load from
path_in_dataset: string ('/')
For some sub-set of a larger dataset, in the sense used by netCDF
Returns
-------
xarray.Dataset instance
"""
url = make_mapper(url, storage_options)
root = zarr.open_group(url, 'r')
if path_in_dataset != '/':
root = root[path_in_dataset]
attrs = dict(root.attrs)
coords = {}
for coord in root['coords']:
coords[coord] = (coord, np.array(root['coords'][coord]),
attrs['coords'][coord])
out = {}
for variable in root['variables']:
d = root['variables'][variable]
out[variable] = (root.attrs['dims'][variable],
da.from_array(d, chunks=d.chunks),
attrs['variables'][variable])
ds = xr.Dataset(out, coords, attrs=attrs['attrs'])
return ds
def s3mapper(url, key=None, username=None, secret=None, password=None,
path=None, host=None, s3=None, **kwargs):
import s3fs
if username is not None:
key = username
if key is not None:
kwargs['key'] = key
if password is not None:
secret = password
if secret is not None:
kwargs['secret'] = secret
options = infer_storage_options(url, kwargs)
s3 = s3fs.S3FileSystem(**kwargs)
return s3fs.S3Map(options['host'] + options['path'], s3)
mappers = {'file': lambda x, **kw: x,
's3': s3mapper}
def make_mapper(url, storage_options=None):
options = infer_storage_options(url, storage_options)
protocol = options.pop('protocol')
return mappers[protocol](url, **options)
# Test pieces
testfile = '/Users/mdurant/Downloads/smith_sandwell_topo_v8_2.nc'
def test_dask_roundtrip():
arr = xr.open_dataset(testfile, chunks={'latitude': 6336//11,
'longitude': 10800//15}).ROSE
darr = arr.data
dask_to_zarr(darr, 'out.zarr', compressor=zarr.Blosc())
assert dask_from_zarr('out.zarr').mean().compute() == darr.mean().compute()
def test_xarray_roundtrip():
arr = xr.open_dataset(testfile, chunks={'latitude': 6336//11,
'longitude': 10800//15}).ROSE
darr = arr.data
xarray_to_zarr(arr, 'out.xarr')
out = xarray_from_zarr('out.xarr')
assert out.mean().values == darr.mean().compute()
def test_dataset_roundtrip():
ds = xr.open_dataset(testfile, chunks={'latitude': 6336//11,
'longitude': 10800//15})
dataset_to_zarr(ds, 'out.xzarr')
out = dataset_from_zarr('out.xzarr')
assert isinstance(out.ROSE.data, da.Array)
for c in out.coords:
assert (ds.coords[c].data == out.coords[c].data).all()
assert ds.coords[c].attrs == out.coords[c].attrs
for c in out.ROSE.coords:
assert (ds.coords[c].data == out.coords[c].data).all()
assert ds.coords[c].attrs == out.coords[c].attrs
assert _safe_attrs(ds.ROSE.attrs, False) == _safe_attrs(
out.ROSE.attrs, False)
assert _safe_attrs(ds.attrs, False) == _safe_attrs(out.attrs, False)
assert (ds.ROSE == out.ROSE).all().values
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment