Last active
February 1, 2023 17:06
-
-
Save agoodm/04a3086608d25a2f285255676b1b6f01 to your computer and use it in GitHub Desktop.
Loading kerchunk references as split records
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import lru_cache | |
from numcodecs import blosc | |
import os | |
import collections | |
import ujson | |
import pandas as pd | |
import numpy as np | |
def get_variables(keys): | |
"""Get list of variable names from references. | |
Parameters | |
---------- | |
keys : list of str | |
kerchunk references keys | |
Returns | |
------- | |
fields : list of str | |
List of variable names. | |
""" | |
fields = [] | |
for k in keys: | |
if '/' in k: | |
name, chunk = k.split('/') | |
if name not in fields: | |
fields.append(name) | |
else: | |
continue | |
else: | |
fields.append(k) | |
return fields | |
def write_record(fname, json_obj, **kwargs): | |
"""Write references into a record file. | |
Parameters | |
---------- | |
fname : str | |
Output filename for record. | |
json_obj : str, bytes, dict, list | |
JSON data for record file to be written. | |
**kwargs : dict, optional | |
Keyword arguments passed to numcodecs.blosc.compress. If not provided | |
then no compression is done and output is standard JSON. | |
""" | |
if not isinstance(json_obj, str): | |
json_obj = ujson.dumps(json_obj) | |
if not isinstance(json_obj, bytes): | |
json_obj = json_obj.encode() | |
if kwargs: | |
json_obj = blosc.compress(json_obj, **kwargs) | |
with open(fname, 'wb') as f: | |
f.write(json_obj) | |
def make_records_store(store_name, refs, record_size=1000, **kwargs): | |
"""Write references as a store of multiple record files. The directory | |
structure should mimic a normal zarr store but instead of standard chunk | |
keys, references are split into record files containing a multiple | |
references. | |
Parameters | |
---------- | |
store_name : str | |
Name of record store. | |
refs : dict | |
Kerchunk references | |
record_size : int, optional | |
Number of references to store in each reference file (default 1000) | |
**kwargs : dict, optional | |
Keyword arguments passed to numcodecs.blosc.compress. If not provided | |
then no compression is done and output is standard JSON. | |
""" | |
if not os.path.exists(store_name): | |
os.makedirs(store_name) | |
# Save record_size in a separate JSON file so the mapper can automatically | |
# interpret the correct record_size. | |
write_record(os.path.join(store_name, '.record_size'), | |
dict(record_size=record_size)) | |
if 'refs' in refs: | |
refs = refs['refs'] | |
fields = get_variables(refs) | |
for field in fields: | |
field_path = os.path.join(store_name, field) | |
if field.startswith('.'): | |
# zarr metadata keys (.zgroup, .zmetadata, etc) | |
write_record(field_path, refs[field]) | |
else: | |
if not os.path.exists(field_path): | |
os.makedirs(field_path) | |
# Read the variable zarray metadata to determine number of chunks | |
zarray = ujson.loads(refs[f'{field}/.zarray']) | |
chunk_sizes = np.array(zarray['shape']) / np.array(zarray['chunks']) | |
chunk_numbers = [np.arange(n) for n in chunk_sizes] | |
nums = np.asarray(pd.MultiIndex.from_product(chunk_numbers).codes).T | |
nchunks = nums.shape[0] | |
nmissing = 0 | |
for metakey in ['.zarray', '.zattrs']: | |
key = f'{field}/{metakey}' | |
write_record(os.path.join(field_path, metakey), refs[key]) | |
# The convention for record files is | |
# <store_name>/<field_name>/<record_id> | |
out = {} | |
record_id = 0 | |
out_path = os.path.join(field_path, str(record_id)) | |
for i in range(nchunks): | |
if (i + 1) % record_size == 0: | |
write_record(out_path, out, **kwargs) | |
out = {} | |
record_id += 1 | |
out_path = os.path.join(field_path, str(record_id)) | |
chunk_id = '.'.join(nums[i].astype(str)) | |
key = f'{field}/{chunk_id}' | |
# Make note if expected number of chunks differs from actual | |
# number found in references | |
if key not in refs: | |
nmissing += 1 | |
continue | |
else: | |
out[key] = refs[key] | |
write_record(out_path, out, **kwargs) | |
if nmissing: | |
print(f'Warning: Chunks missing for field {field}. ' | |
f'Expected: {nchunks}, Found: {nchunks - nmissing}') | |
class ReferenceRecordMapper(collections.abc.MutableMapping): | |
"""Interface to read record store as if it were a standard kerchunk | |
references dict.""" | |
def __init__(self, root, cache_size=128): | |
""" | |
Parameters | |
---------- | |
root : str | |
Root of record store | |
cache_size : int | |
Maximum size of LRU cache, where cache_size*record_size denotes | |
the total number of records that can be loaded in memory at once. | |
""" | |
self.root = root | |
self.chunk_sizes = {} | |
self._items = {} | |
# Function to open and decompress record data and store in LRU cache | |
@lru_cache(maxsize=cache_size) | |
def open_reference_record(record_file): | |
with open(record_file, 'rb') as f: | |
data = f.read() | |
decompressed_data = blosc.decompress(data) | |
return ujson.loads(decompressed_data) | |
self.open_reference_record = open_reference_record | |
@property | |
def record_size(self): | |
if not hasattr(self, '_record_size'): | |
with open(os.path.join(self.root, '.record_size')) as f: | |
self._record_size = ujson.load(f)['record_size'] | |
return self._record_size | |
def _load_one_key(self, key): | |
if '/' not in key: | |
if key not in os.listdir(self.root): | |
raise KeyError | |
return self._get_and_cache_metadata(key) | |
else: | |
field, sub_key = key.split('/') | |
if sub_key.startswith('.'): | |
# zarr metadata keys are always cached | |
return self._get_and_cache_metadata(key) | |
# Chunk keys can be loaded from record and cached in LRU cache | |
record_id = self._key_to_record_id(key) | |
record_file = os.path.join(self.root, field, str(record_id)) | |
record = self.open_reference_record(record_file) | |
return record[key] | |
def _get_and_cache_metadata(self, key): | |
with open(os.path.join(self.root, key), 'rb') as f: | |
data = f.read() | |
self._items[key] = data | |
return data | |
def _key_to_record_id(self, key): | |
field, chunk = key.split('/') | |
chunk_sizes = self._get_chunk_sizes(field) | |
chunk_idx = np.array([int(c) for c in chunk.split('.')]) | |
chunk_number = np.ravel_multi_index(chunk_idx, chunk_sizes) | |
record_id = chunk_number // self.record_size | |
return record_id | |
def _get_chunk_sizes(self, field): | |
if field not in self.chunk_sizes: | |
zarray = ujson.loads(self.__getitem__(f'{field}/.zarray')) | |
size_ratio = np.array(zarray['shape']) / np.array(zarray['chunks']) | |
self.chunk_sizes[field] = np.ceil(size_ratio).astype(int) | |
return self.chunk_sizes[field] | |
def __getitem__(self, key): | |
if key in self._items: | |
return self._items[key] | |
return self._load_one_key(key) | |
def __setitem__(self, key, value): | |
self._items[key] = value | |
def __delitem__(self, key): | |
del self._items[key] | |
def __len__(self): | |
# Caveat: This counts expected references, not actual | |
count = 0 | |
for field in os.listdir(self.root): | |
if field.startswith('.'): | |
count += 1 | |
else: | |
chunk_sizes = self._get_chunk_sizes(field) | |
nchunks = np.product(chunk_sizes) | |
count += 2 + nchunks | |
return count | |
def __iter__(self): | |
# Caveat: Note that this generates all expected keys, but does not | |
# account for reference keys that are missing. | |
for field in os.listdir(self.root): | |
if field.startswith('.'): | |
yield field | |
else: | |
chunk_sizes = self._get_chunk_sizes(field) | |
nchunks = np.product(chunk_sizes) | |
yield '/'.join([field, '.zarray']) | |
yield '/'.join([field, '.zattrs']) | |
inds = np.asarray(np.unravel_index(np.arange(nchunks), chunk_sizes)).T | |
for ind in inds: | |
yield field + '/' + '.'.join(ind.astype(str)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment