Skip to content

Instantly share code, notes, and snippets.

@shoyer
Created November 26, 2022 23:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shoyer/5b0c485979cc9c36a9685d8cf8e94565 to your computer and use it in GitHub Desktop.
Save shoyer/5b0c485979cc9c36a9685d8cf8e94565 to your computer and use it in GitHub Desktop.
xarray zarr via tensor store
# Copyright 2022 Google LLC.
# SPDX-License-Identifier: Apache-2.0
import tensorstore
import json
import os.path
import fsspec
import xarray
import xarray.backends
def zarr_spec_from_path(path):
return {
'driver': 'zarr',
'kvstore': {
'driver': 'file',
'path': path,
}
}
def load_zarr_consolidated_metadata(path):
metadata_path = os.path.join(path, '.zmetadata')
with open(metadata_path, 'r') as f:
contents = json.load(f)
if contents.get('zarr_consolidated_format') != 1:
raise ValueError('invalid .zmetadata')
metadata = contents['metadata']
return metadata
def load_zattrs(group_path, array_names):
paths = {k: os.path.join(group_path, k, '.zattrs') for k in array_names}
fs = fsspec.filesystem('file')
data = fs.cat(paths.values())
expanded_paths = {k: fs.expand_path(v)[0] for k, v in paths.items()}
return {k: json.loads(data[expanded_paths[k]].decode('utf8')) for k in array_names}
class ZarrTensorStoreDataStore(xarray.backends.AbstractDataStore):
def __init__(self, variables, attrs):
self.variables = variables
self.attrs = attrs
def load(self):
return self.variables, self.attrs
class TensorStoreWrapper(xarray.backends.BackendArray):
def __init__(self, ts_array):
self.ts_array = ts_array
self.shape = ts_array.shape
self.dtype = ts_array.dtype.numpy_dtype
def __getitem__(self, key):
if isinstance(key, xarray.core.indexing.OuterIndexer):
indexed = self.ts_array.oindex[key.tuple]
elif isinstance(key, xarray.core.indexing.VectorizedIndexer):
indexed = self.ts_array.vindex[key.tuple]
else:
assert isinstance(key, xarray.core.indexing.BasicIndexer)
indexed = self.ts_array[key.tuple]
return indexed.read().result()
def __repr__(self):
return f'{type(self).__name__}({self.ts_array!r})'
def open_zarr_via_tensorstore(path):
"""Open a Zarr store via TensorStore.
Current limitations:
1. The Zarr store must be stored with consolidated metadata.
2. Only supports the "file" TensoreStore driver.
"""
metadata = load_zarr_consolidated_metadata(path) # blocking
array_names = [
k[:-len('/.zarray')] for k in metadata if k[-len('/.zarray'):] == '/.zarray'
]
specs = {k: zarr_spec_from_path(os.path.join(path, k)) for k in array_names}
array_futures = {k: tensorstore.open(spec) for k, spec in specs.items()}
array_zattrs = load_zattrs(path, array_names) # blocking
variables = {}
for name in array_names:
dims = array_zattrs[name]['_ARRAY_DIMENSIONS']
data = TensorStoreWrapper(array_futures[name].result())
attrs = {k: v for k, v in array_zattrs[name].items() if k != '_ARRAY_DIMENSIONS'}
variables[name] = xarray.Variable(dims, data, attrs)
store = ZarrTensorStoreDataStore(variables, metadata['.zattrs'])
return xarray.open_dataset(store, engine='store')
def run_unit_test():
ds = xarray.tutorial.load_dataset('eraint_uvz')
ds.to_zarr('eraint_uvz.zarr')
roundtripped = open_zarr_via_tensorstore('eraint_uvz.zarr')
assert 'TensorStore' in repr(roundtripped.variables['u']._data)
xarray.testing.assert_identical(roundtripped, ds)
run_unit_test()
@shoyer
Copy link
Author

shoyer commented May 18, 2023

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment