Skip to content

Instantly share code, notes, and snippets.

@conradry
Created October 5, 2022 15:30
Show Gist options
  • Save conradry/7376e7504456e9c194638c78d009523e to your computer and use it in GitHub Desktop.
Save conradry/7376e7504456e9c194638c78d009523e to your computer and use it in GitHub Desktop.
Run length encode a zarr labelmap
import os
import zarr
import argparse
import numpy as np
from empanada.inference.tracker import InstanceTracker
from empanada.inference.rle import pan_seg_to_rle_seg, rle_seg_to_pan_seg
from tqdm import tqdm
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('zarr_store', type=str, metavar='zarr_store', help='Path to zarr directory')
parser.add_argument('zarr_key', type=str, metavar='zarr_key',
help='Name of dataset in the zarr store (e.g. panoptic_xy)')
parser.add_argument('label_divisor', type=int, metavar='label_divisor',
help='Label divisor to separate classes (max number objects 3D in napari)')
args = parser.parse_args()
# load the zarr and segmentation volume
data = zarr.open(args.zarr_store, mode='r')
seg = data[args.zarr_key]
# depending on how the zarr array was chunked
# it can be very slow to encode images slice by slice
# if possible, uncomment the next line to load the full seg
# into memory
#seg = np.array(seg)
# assuming mitonet
labels = [1]
thing_list = [1]
# create an instance tracker
tracker = InstanceTracker(
class_id=labels[0], label_divisor=args.label_divisor,
shape3d=seg.shape, axis='xy'
)
# run length encode segmentation slice by slice
for index, mask2d in tqdm(enumerate(np.split(seg[...], len(seg), axis=0)), total=len(seg)):
mask2d = np.squeeze(mask2d)
rle_seg = pan_seg_to_rle_seg(
mask2d, labels, args.label_divisor, thing_list, force_connected=False
)
assert np.allclose(mask2d, rle_seg_to_pan_seg(rle_seg, mask2d.shape)), \
"RLE segmentation is wrong; are you sure label_divisor is correct?"
tracker.update(rle_seg[labels[0]], index)
# end tracking
tracker.finish()
# save the run length encoded segmentation to json
tracker.write_to_json(os.path.join(args.zarr_store, f'{args.zarr_key}.json'))
print('Finished.')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment