Skip to content

Instantly share code, notes, and snippets.

@bentaculum
Created July 25, 2021 11:43
Show Gist options
  • Save bentaculum/49518612cd82c6c2b0b01b0754bb0d11 to your computer and use it in GitHub Desktop.
Save bentaculum/49518612cd82c6c2b0b01b0754bb0d11 to your computer and use it in GitHub Desktop.
"""
Adapted from
https://github.com/funkey/gunpowder/blob/master/gunpowder/nodes/hdf5like_write_base.py
and https://github.com/funkey/gunpowder/blob/master/gunpowder/nodes/zarr_write.py .
"""
import logging
import os
import zarr
import numcodecs
numcodecs.blosc.use_threads = False
from gunpowder.nodes import BatchFilter # noqa
from gunpowder.batch_request import BatchRequest # noqa
from gunpowder.roi import Roi # noqa
from gunpowder.coordinate import Coordinate # noqa
from gunpowder.ext import ZarrFile # noqa
from gunpowder.compat import ensure_str # noqa
logger = logging.getLogger(__name__)
class ZarrWrite(BatchFilter):
'''Assemble arrays of passing batches in one zarr container. This is useful
to store chunks produced by :class:`Scan` on disk without keeping the
larger array in memory. The ROIs of the passing arrays will be used to
determine the position where to store the data in the dataset.
Args:
dataset_names (``dict``, :class:`ArrayKey` -> ``string``):
A dictionary from array keys to names of the datasets to store them
in.
output_dir (``string``):
The directory to save the zarr container. Will be created, if it does
not exist.
output_filename (``string``):
The output filename of the container. Will be created, if it does
not exist, otherwise data is overwritten in the existing container.
compression_type (``string`` or ``int``):
Compression strategy. Legal values are ``gzip``, ``szip``,
``lzf``. If an integer between 1 and 10, this indicates ``gzip``
compression level.
dataset_dtypes (``dict``, :class:`ArrayKey` -> data type):
A dictionary from array keys to datatype (eg. ``np.int8``). If
given, arrays are stored using this type. The original arrays
within the pipeline remain unchanged.
chunks (``tuple`` of ``int``, or ``bool``):
Chunk shape for output datasets. Set to ``True`` for auto-chunking,
set to ``False`` to obtain a chunk equal to the dataset size.
Defaults to ``True``.
'''
def __init__(
self,
dataset_names,
output_dir='.',
output_filename='output.hdf',
dataset_dtypes=None,
chunks=True):
self.dataset_names = dataset_names
self.output_dir = output_dir
self.output_filename = output_filename
self.compression_type = 'blosc'
if dataset_dtypes is None:
self.dataset_dtypes = {}
else:
self.dataset_dtypes = dataset_dtypes
self.chunks = chunks
self.dataset_offsets = {}
def setup(self):
for key in self.dataset_names.keys():
self.updates(key, self.spec[key])
self.enable_autoskip()
def prepare(self, request):
deps = BatchRequest()
for key in self.dataset_names.keys():
deps[key] = request[key]
return deps
def init_datasets(self, batch):
filename = os.path.join(self.output_dir, self.output_filename)
logger.debug("Initializing container %s", filename)
try:
os.makedirs(self.output_dir)
except BaseException:
pass
for (array_key, dataset_name) in self.dataset_names.items():
logger.debug("Initializing dataset for %s", array_key)
assert array_key in self.spec, (
"Asked to store %s, but is not provided upstream." % array_key)
assert array_key in batch.arrays, (
"Asked to store %s, but is not part of batch." % array_key)
array = batch.arrays[array_key]
dims = array.spec.roi.dims()
batch_shape = array.data.shape
with self._open_file(filename) as data_file:
# if a dataset already exists, read its meta-information (if
# present)
if dataset_name in data_file:
offset = self._get_offset(
data_file[dataset_name]) or Coordinate(
(0,) * dims)
else:
provided_roi = self.spec[array_key].roi
if provided_roi is None:
raise RuntimeError(
"Dataset %s does not exist in %s, and no ROI is "
"provided for %s. I don't know how to initialize "
"the dataset." % (dataset_name, filename, array_key))
offset = provided_roi.get_offset()
voxel_size = array.spec.voxel_size
data_shape = provided_roi.get_shape() // voxel_size
logger.debug("Shape in voxels: %s", data_shape)
# add channel dimensions (if present)
data_shape = batch_shape[:-dims] + data_shape
logger.debug(
"Shape with channel dimensions: %s", data_shape)
if array_key in self.dataset_dtypes:
dtype = self.dataset_dtypes[array_key]
else:
dtype = batch.arrays[array_key].data.dtype
logger.debug(
"create_dataset: %s, %s, %s, %s, offset=%s, resolution=%s",
dataset_name, data_shape, self.compression_type, dtype,
offset, voxel_size)
dataset = data_file.create_dataset(
name=dataset_name,
shape=data_shape,
compression=self.compression_type,
dtype=dtype,
chunks=self.chunks,
synchronizer=zarr.ProcessSynchronizer(
os.path.join(filename, 'sync')),
)
self._set_offset(dataset, offset)
self._set_voxel_size(dataset, voxel_size)
logger.debug(
"%s (%s in %s) has offset %s",
array_key,
dataset_name,
filename,
offset)
self.dataset_offsets[array_key] = offset
def process(self, batch, request):
filename = os.path.join(self.output_dir, self.output_filename)
if not self.dataset_offsets:
self.init_datasets(batch)
with self._open_file(filename) as data_file:
for (array_key, dataset_name) in self.dataset_names.items():
dataset = data_file[dataset_name]
array_roi = batch.arrays[array_key].spec.roi
voxel_size = self.spec[array_key].voxel_size
dims = array_roi.dims()
channel_slices = (slice(None),) * \
max(0, len(dataset.shape) - dims)
dataset_roi = Roi(
self.dataset_offsets[array_key],
Coordinate(dataset.shape[-dims:]) * voxel_size)
common_roi = array_roi.intersect(dataset_roi)
if common_roi.empty():
logger.warn(
"array %s with ROI %s lies outside of dataset ROI %s, "
"skipping writing" % (
array_key,
array_roi,
dataset_roi))
continue
dataset_voxel_roi = (
common_roi - self.dataset_offsets[array_key]) // voxel_size
dataset_voxel_slices = dataset_voxel_roi.to_slices()
array_voxel_roi = (
common_roi - array_roi.get_offset()) // voxel_size
array_voxel_slices = array_voxel_roi.to_slices()
logger.debug(
"writing %s to voxel coordinates %s" % (
array_key,
dataset_voxel_roi))
data = batch.arrays[array_key].data[channel_slices +
array_voxel_slices]
dataset[channel_slices + dataset_voxel_slices] = data
def _get_voxel_size(self, dataset):
if 'resolution' not in dataset.attrs:
return None
if self.output_filename.endswith('.n5'):
return Coordinate(dataset.attrs['resolution'][::-1])
else:
return Coordinate(dataset.attrs['resolution'])
def _get_offset(self, dataset):
if 'offset' not in dataset.attrs:
return None
if self.output_filename.endswith('.n5'):
return Coordinate(dataset.attrs['offset'][::-1])
else:
return Coordinate(dataset.attrs['offset'])
def _set_voxel_size(self, dataset, voxel_size):
if self.output_filename.endswith('.n5'):
dataset.attrs['resolution'] = voxel_size[::-1]
else:
dataset.attrs['resolution'] = voxel_size
def _set_offset(self, dataset, offset):
if self.output_filename.endswith('.n5'):
dataset.attrs['offset'] = offset[::-1]
else:
dataset.attrs['offset'] = offset
def _open_file(self, filename):
return ZarrFile(ensure_str(filename), mode='a')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment