Skip to content

Instantly share code, notes, and snippets.

@agoodm
Last active February 1, 2023 17:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save agoodm/04a3086608d25a2f285255676b1b6f01 to your computer and use it in GitHub Desktop.
Save agoodm/04a3086608d25a2f285255676b1b6f01 to your computer and use it in GitHub Desktop.
Loading kerchunk references as split records
from functools import lru_cache
from numcodecs import blosc
import os
import collections
import ujson
import pandas as pd
import numpy as np
def get_variables(keys):
"""Get list of variable names from references.
Parameters
----------
keys : list of str
kerchunk references keys
Returns
-------
fields : list of str
List of variable names.
"""
fields = []
for k in keys:
if '/' in k:
name, chunk = k.split('/')
if name not in fields:
fields.append(name)
else:
continue
else:
fields.append(k)
return fields
def write_record(fname, json_obj, **kwargs):
"""Write references into a record file.
Parameters
----------
fname : str
Output filename for record.
json_obj : str, bytes, dict, list
JSON data for record file to be written.
**kwargs : dict, optional
Keyword arguments passed to numcodecs.blosc.compress. If not provided
then no compression is done and output is standard JSON.
"""
if not isinstance(json_obj, str):
json_obj = ujson.dumps(json_obj)
if not isinstance(json_obj, bytes):
json_obj = json_obj.encode()
if kwargs:
json_obj = blosc.compress(json_obj, **kwargs)
with open(fname, 'wb') as f:
f.write(json_obj)
def make_records_store(store_name, refs, record_size=1000, **kwargs):
"""Write references as a store of multiple record files. The directory
structure should mimic a normal zarr store but instead of standard chunk
keys, references are split into record files containing a multiple
references.
Parameters
----------
store_name : str
Name of record store.
refs : dict
Kerchunk references
record_size : int, optional
Number of references to store in each reference file (default 1000)
**kwargs : dict, optional
Keyword arguments passed to numcodecs.blosc.compress. If not provided
then no compression is done and output is standard JSON.
"""
if not os.path.exists(store_name):
os.makedirs(store_name)
# Save record_size in a separate JSON file so the mapper can automatically
# interpret the correct record_size.
write_record(os.path.join(store_name, '.record_size'),
dict(record_size=record_size))
if 'refs' in refs:
refs = refs['refs']
fields = get_variables(refs)
for field in fields:
field_path = os.path.join(store_name, field)
if field.startswith('.'):
# zarr metadata keys (.zgroup, .zmetadata, etc)
write_record(field_path, refs[field])
else:
if not os.path.exists(field_path):
os.makedirs(field_path)
# Read the variable zarray metadata to determine number of chunks
zarray = ujson.loads(refs[f'{field}/.zarray'])
chunk_sizes = np.array(zarray['shape']) / np.array(zarray['chunks'])
chunk_numbers = [np.arange(n) for n in chunk_sizes]
nums = np.asarray(pd.MultiIndex.from_product(chunk_numbers).codes).T
nchunks = nums.shape[0]
nmissing = 0
for metakey in ['.zarray', '.zattrs']:
key = f'{field}/{metakey}'
write_record(os.path.join(field_path, metakey), refs[key])
# The convention for record files is
# <store_name>/<field_name>/<record_id>
out = {}
record_id = 0
out_path = os.path.join(field_path, str(record_id))
for i in range(nchunks):
if (i + 1) % record_size == 0:
write_record(out_path, out, **kwargs)
out = {}
record_id += 1
out_path = os.path.join(field_path, str(record_id))
chunk_id = '.'.join(nums[i].astype(str))
key = f'{field}/{chunk_id}'
# Make note if expected number of chunks differs from actual
# number found in references
if key not in refs:
nmissing += 1
continue
else:
out[key] = refs[key]
write_record(out_path, out, **kwargs)
if nmissing:
print(f'Warning: Chunks missing for field {field}. '
f'Expected: {nchunks}, Found: {nchunks - nmissing}')
class ReferenceRecordMapper(collections.abc.MutableMapping):
"""Interface to read record store as if it were a standard kerchunk
references dict."""
def __init__(self, root, cache_size=128):
"""
Parameters
----------
root : str
Root of record store
cache_size : int
Maximum size of LRU cache, where cache_size*record_size denotes
the total number of records that can be loaded in memory at once.
"""
self.root = root
self.chunk_sizes = {}
self._items = {}
# Function to open and decompress record data and store in LRU cache
@lru_cache(maxsize=cache_size)
def open_reference_record(record_file):
with open(record_file, 'rb') as f:
data = f.read()
decompressed_data = blosc.decompress(data)
return ujson.loads(decompressed_data)
self.open_reference_record = open_reference_record
@property
def record_size(self):
if not hasattr(self, '_record_size'):
with open(os.path.join(self.root, '.record_size')) as f:
self._record_size = ujson.load(f)['record_size']
return self._record_size
def _load_one_key(self, key):
if '/' not in key:
if key not in os.listdir(self.root):
raise KeyError
return self._get_and_cache_metadata(key)
else:
field, sub_key = key.split('/')
if sub_key.startswith('.'):
# zarr metadata keys are always cached
return self._get_and_cache_metadata(key)
# Chunk keys can be loaded from record and cached in LRU cache
record_id = self._key_to_record_id(key)
record_file = os.path.join(self.root, field, str(record_id))
record = self.open_reference_record(record_file)
return record[key]
def _get_and_cache_metadata(self, key):
with open(os.path.join(self.root, key), 'rb') as f:
data = f.read()
self._items[key] = data
return data
def _key_to_record_id(self, key):
field, chunk = key.split('/')
chunk_sizes = self._get_chunk_sizes(field)
chunk_idx = np.array([int(c) for c in chunk.split('.')])
chunk_number = np.ravel_multi_index(chunk_idx, chunk_sizes)
record_id = chunk_number // self.record_size
return record_id
def _get_chunk_sizes(self, field):
if field not in self.chunk_sizes:
zarray = ujson.loads(self.__getitem__(f'{field}/.zarray'))
size_ratio = np.array(zarray['shape']) / np.array(zarray['chunks'])
self.chunk_sizes[field] = np.ceil(size_ratio).astype(int)
return self.chunk_sizes[field]
def __getitem__(self, key):
if key in self._items:
return self._items[key]
return self._load_one_key(key)
def __setitem__(self, key, value):
self._items[key] = value
def __delitem__(self, key):
del self._items[key]
def __len__(self):
# Caveat: This counts expected references, not actual
count = 0
for field in os.listdir(self.root):
if field.startswith('.'):
count += 1
else:
chunk_sizes = self._get_chunk_sizes(field)
nchunks = np.product(chunk_sizes)
count += 2 + nchunks
return count
def __iter__(self):
# Caveat: Note that this generates all expected keys, but does not
# account for reference keys that are missing.
for field in os.listdir(self.root):
if field.startswith('.'):
yield field
else:
chunk_sizes = self._get_chunk_sizes(field)
nchunks = np.product(chunk_sizes)
yield '/'.join([field, '.zarray'])
yield '/'.join([field, '.zattrs'])
inds = np.asarray(np.unravel_index(np.arange(nchunks), chunk_sizes)).T
for ind in inds:
yield field + '/' + '.'.join(ind.astype(str))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment