Skip to content

Instantly share code, notes, and snippets.

@rabernat
Created October 29, 2021 18:50
Show Gist options
  • Save rabernat/0149106f271ad384af0c9ad102196621 to your computer and use it in GitHub Desktop.
Save rabernat/0149106f271ad384af0c9ad102196621 to your computer and use it in GitHub Desktop.
Expose mitgcm mds files as zarr arrays
from collections.abc import Mapping
import json
from typing import Tuple, List, Optional, Any
from pydantic import validator
from pydantic.dataclasses import dataclass
from fsspec.implementations.reference import ReferenceFileSystem
import numpy as np
import zarr
# https://github.com/numpy/numpy/issues/20236
# from numpy.typing import DTypeLike
DTypeLike = Any
class DataFileConfig:
arbitrary_types_allowed = True
@dataclass(frozen=True, config=DataFileConfig)
class DataFile(Mapping):
path: str
shape: Tuple[int, ...]
dtype: DTypeLike
@validator('dtype')
def valid_numpy_dtype(cls, v):
return np.dtype(v)
def _zarray(self):
return json.dumps({
"chunks": list(self.shape),
"compressor": None,
"dtype": self.dtype.str,
"fill_value": None,
"filters": None,
"order": "C",
"shape": list(self.shape),
"zarr_format": 2
})
def _zattrs(self):
return json.dumps({
'foo': 'bar'
})
def __getitem__(self, key):
# TODO: replace with case in python 3.10
if key==".zarray":
return self._zarray()
elif key==".zattrs":
return self._zattrs()
else:
return self._chunk_reference_for(key)
def __iter__(self):
yield '.zarray'
yield '.zattrs'
yield '0.0'
def __contains__(self, item):
if item=='.zarray':
return true
def __len__(self):
return 3
def _chunk_reference_for(self, key: str) -> bytes:
if key == '0.0':
offset = 0
chunksize = np.prod(self.shape) * self.dtype.itemsize
return [self.path, offset, chunksize]
else:
raise KeyError
df = DataFile(path='barotropic_gyre/Eta.0000000010.data', shape=(60, 60), dtype='>f4')
assert len(df) == 3
assert list(df) == ['.zarray', '.zattrs', '0.0']
fs = ReferenceFileSystem(fo=df, target_protocol='file')
mapper = fs.get_mapper('/')
assert list(mapper) == ['.zarray', '.zattrs', '0.0']
arr = zarr.open(mapper)
assert arr.attrs['foo'] == 'bar'
data = arr[:]
assert data.shape == (60, 60)
assert data.dtype == np.dtype('>f4')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment