Skip to content

Instantly share code, notes, and snippets.

@agoodm
Last active February 22, 2023 00:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save agoodm/25d41ce0c47cd714271be66d0db0459d to your computer and use it in GitHub Desktop.
Save agoodm/25d41ce0c47cd714271be66d0db0459d to your computer and use it in GitHub Desktop.
from functools import lru_cache
import os
import collections
import ujson
import pandas as pd
import numpy as np
import fsspec
def get_variables(keys):
"""Get list of variable names from references.
Parameters
----------
keys : list of str
kerchunk references keys
Returns
-------
fields : list of str
List of variable names.
"""
fields = []
for k in keys:
if '/' in k:
name, chunk = k.split('/')
if name not in fields:
fields.append(name)
else:
continue
else:
fields.append(k)
return fields
def normalize_json(json_obj):
"""Normalize json representation as bytes
Parameters
----------
json_obj : str, bytes, dict, list
JSON data for parquet file to be written.
"""
if not isinstance(json_obj, str):
json_obj = ujson.dumps(json_obj)
if not isinstance(json_obj, bytes):
json_obj = json_obj.encode()
return json_obj
def write_json(fname, json_obj):
"""Write references into a parquet file.
Parameters
----------
fname : str
Output filename.
json_obj : str, bytes, dict, list
JSON data for parquet file to be written.
"""
json_obj = normalize_json(json_obj)
with open(fname, 'wb') as f:
f.write(json_obj)
def make_parquet_store(store_name, refs, row_group_size=1000, compression='zstd',
engine='fastparquet', **kwargs):
"""Write references as a store of parquet files with multiple row groups.
The directory structure should mimic a normal zarr store but instead of standard chunk
keys, references are saved as parquet dataframes with multiple row groups.
Parameters
----------
store_name : str
Name of parquet store.
refs : dict
Kerchunk references
row_group_size : int, optional
Number of references to store in each reference file (default 1000)
compression : str, optional
Compression information to pass to parquet engine, default is zstd.
engine : {'fastparquet', 'pyarrow'}
Library to use for writing parquet files.
**kwargs : dict, optional
Additional keyword arguments passed to parquet engine of choice.
"""
if not os.path.exists(store_name):
os.makedirs(store_name)
if 'refs' in refs:
refs = refs['refs']
write_json(os.path.join(store_name, '.row_group_size'),
dict(row_group_size=row_group_size))
fields = get_variables(refs)
for field in fields:
# Initialize dataframe columns
paths, offsets, sizes, raws = [], [], [], []
field_path = os.path.join(store_name, field)
if field.startswith('.'):
# zarr metadata keys (.zgroup, .zmetadata, etc)
write_json(field_path, refs[field])
else:
if not os.path.exists(field_path):
os.makedirs(field_path)
# Read the variable zarray metadata to determine number of chunks
zarray = ujson.loads(refs[f'{field}/.zarray'])
chunk_sizes = np.array(zarray['shape']) / np.array(zarray['chunks'])
chunk_numbers = [np.arange(n) for n in chunk_sizes]
if chunk_sizes.size != 0:
nums = np.asarray(pd.MultiIndex.from_product(chunk_numbers).codes).T
else:
nums = np.array([0])
nchunks = nums.shape[0]
nmissing = 0
for metakey in ['.zarray', '.zattrs']:
key = f'{field}/{metakey}'
write_json(os.path.join(field_path, metakey), refs[key])
for i in range(nchunks):
chunk_id = '.'.join(nums[i].astype(str))
key = f'{field}/{chunk_id}'
# Make note if expected number of chunks differs from actual
# number found in references
if key not in refs:
nmissing += 1
paths.append(None)
offsets.append(0)
sizes.append(0)
raws.append(None)
else:
data = refs[key]
if isinstance(data, list):
paths.append(data[0])
offsets.append(data[1])
sizes.append(data[2])
raws.append(None)
else:
paths.append(None)
offsets.append(0)
sizes.append(0)
raws.append(data)
if nmissing:
print(f'Warning: Chunks missing for field {field}. '
f'Expected: {nchunks}, Found: {nchunks - nmissing}')
# Need to pad extra rows so total number divides row_group_size evenly
extra_rows = row_group_size - nchunks % row_group_size
for i in range(extra_rows):
paths.append(None)
offsets.append(0)
sizes.append(0)
raws.append(None)
# The convention for parquet files is
# <store_name>/<field_name>/refs.parq
out_path = os.path.join(field_path, 'refs.parq')
df = pd.DataFrame(
dict(path=paths,
offset=offsets,
size=sizes,
raw=raws)
)
# Engine specific kwarg conventions. Set stats to false since
# those are currently unneeded.
if engine == 'pyarrow':
kwargs.update(write_statistics=False)
else:
kwargs.update(row_group_offsets=row_group_size,
stats=False)
df.to_parquet(out_path, engine=engine,
compression=compression, **kwargs)
class LazyReferenceMapper(collections.abc.MutableMapping):
"""Interface to read parquet store as if it were a standard kerchunk
references dict."""
# import is class level to prevent numpy dep requirement for fsspec
import numpy as np
def __init__(self, root, fs=None, engine='fastparquet', cache_size=128,
categorical_urls=True):
"""
Parameters
----------
root : str
Root of parquet store
fs : fsspec.AbstractFileSystem
fsspec filesystem object, default is local filesystem.
cache_size : int
Maximum size of LRU cache, where cache_size*row_group_size denotes
the total number of references that can be loaded in memory at once.
engine : {'fastparquet', 'pyarrow'}
Library to use for writing parquet files.
categorical_urls : bool
Whether to use pandas.Categorical to encode urls. This can greatly
reduce memory usage for reference sets with URLs that are used many
times by multiple keys (eg, when a single variable has many chunks)
in exchange for a bit of additional overhead when loading references.
"""
self.root = root
self.chunk_sizes = {}
self._items = {}
self.pfs = {}
self.engine = engine
self.categorical_urls = categorical_urls
self.fs = fsspec.filesystem('file') if fs is None else fs
# Define function to open and decompress row group data and store
# in LRU cache
if self.engine == 'pyarrow':
import pyarrow.parquet as pq
self.pf_cls = pq.ParquetFile
@lru_cache(maxsize=cache_size)
def open_row_group(pf, row_group):
rg = pf.read_row_group(row_group)
refs = {c: rg[c].to_numpy() for c in rg.column_names}
if self.categorical_urls:
refs['path'] = pd.Categorical(refs['path'])
return refs
else:
import fastparquet
self.pf_cls = fastparquet.ParquetFile
@lru_cache(maxsize=cache_size)
def open_row_group(pf, row_group):
rg = pf.row_groups[row_group]
refs = {
c: self.np.empty(rg.num_rows, dtype=dtype)
for c, dtype in pf.dtypes.items()
}
fastparquet.core.read_row_group_arrays(pf.open(), rg, pf.columns,
pf.categories, pf.schema,
pf.cats, assign=refs)
if self.categorical_urls:
refs['path'] = pd.Categorical(refs['path'])
return refs
self.open_row_group = open_row_group
def listdir(self, basename=True):
listing = self.fs.ls(self.root)
if basename:
listing = [os.path.basename(path) for path in listing]
return listing
def join(self, *args):
return self.fs.sep.join(args)
@property
def row_group_size(self):
if not hasattr(self, '_row_group_size'):
with self.fs.open(self.join(self.root, '.row_group_size')) as f:
self._row_group_size = ujson.load(f)['row_group_size']
return self._row_group_size
def _load_one_key(self, key):
if '/' not in key:
if key not in self.listdir():
raise KeyError
return self._get_and_cache_metadata(key)
else:
field, sub_key = key.split('/')
if sub_key.startswith('.'):
# zarr metadata keys are always cached
return self._get_and_cache_metadata(key)
# Chunk keys can be loaded from row group and cached in LRU cache
row_group, row_number = self._key_to_row_group(key)
if field not in self.pfs:
pf_path = self.join(self.root, field, 'refs.parq')
self.pfs[field] = self.pf_cls(self.fs.open(pf_path))
pf = self.pfs[field]
refs = self.open_row_group(pf, row_group)
columns = ['path', 'offset', 'size', 'raw']
selection = [refs[c][row_number] for c in columns]
raw = selection[-1]
if raw is not None:
if isinstance(raw, bytes):
raw = raw.decode()
return raw
return selection[:-1]
def _get_and_cache_metadata(self, key):
with self.fs.open(self.join(self.root, key), 'rb') as f:
data = f.read()
self._items[key] = data
return data
def _key_to_row_group(self, key):
field, chunk = key.split('/')
chunk_sizes = self._get_chunk_sizes(field)
if chunk_sizes.size == 0:
return 0, 0
chunk_idx = self.np.array([int(c) for c in chunk.split('.')])
chunk_number = self.np.ravel_multi_index(chunk_idx, chunk_sizes)
row_group = chunk_number // self.row_group_size
row_number = chunk_number % self.row_group_size
return row_group, row_number
def _get_chunk_sizes(self, field):
if field not in self.chunk_sizes:
zarray = ujson.loads(self.__getitem__(f'{field}/.zarray'))
size_ratio = self.np.array(zarray['shape']) / self.np.array(zarray['chunks'])
self.chunk_sizes[field] = self.np.ceil(size_ratio).astype(int)
return self.chunk_sizes[field]
def __getitem__(self, key):
if key in self._items:
return self._items[key]
return self._load_one_key(key)
def __setitem__(self, key, value):
self._items[key] = value
def __delitem__(self, key):
del self._items[key]
def __len__(self):
# Caveat: This counts expected references, not actual
count = 0
for field in fs.ls(self.root):
if field.startswith('.'):
count += 1
else:
chunk_sizes = self._get_chunk_sizes(field)
nchunks = self.np.product(chunk_sizes)
count += 2 + nchunks
return count
def __iter__(self):
# Caveat: Note that this generates all expected keys, but does not
# account for reference keys that are missing.
for field in self.listdir():
if field.startswith('.'):
yield field
else:
chunk_sizes = self._get_chunk_sizes(field)
nchunks = self.np.product(chunk_sizes)
yield '/'.join([field, '.zarray'])
yield '/'.join([field, '.zattrs'])
inds = self.np.asarray(self.np.unravel_index(self.np.arange(nchunks), chunk_sizes)).T
for ind in inds:
yield field + '/' + '.'.join(ind.astype(str))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment