Last active
February 22, 2023 00:35
-
-
Save agoodm/25d41ce0c47cd714271be66d0db0459d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import lru_cache | |
import os | |
import collections | |
import ujson | |
import pandas as pd | |
import numpy as np | |
import fsspec | |
def get_variables(keys): | |
"""Get list of variable names from references. | |
Parameters | |
---------- | |
keys : list of str | |
kerchunk references keys | |
Returns | |
------- | |
fields : list of str | |
List of variable names. | |
""" | |
fields = [] | |
for k in keys: | |
if '/' in k: | |
name, chunk = k.split('/') | |
if name not in fields: | |
fields.append(name) | |
else: | |
continue | |
else: | |
fields.append(k) | |
return fields | |
def normalize_json(json_obj): | |
"""Normalize json representation as bytes | |
Parameters | |
---------- | |
json_obj : str, bytes, dict, list | |
JSON data for parquet file to be written. | |
""" | |
if not isinstance(json_obj, str): | |
json_obj = ujson.dumps(json_obj) | |
if not isinstance(json_obj, bytes): | |
json_obj = json_obj.encode() | |
return json_obj | |
def write_json(fname, json_obj): | |
"""Write references into a parquet file. | |
Parameters | |
---------- | |
fname : str | |
Output filename. | |
json_obj : str, bytes, dict, list | |
JSON data for parquet file to be written. | |
""" | |
json_obj = normalize_json(json_obj) | |
with open(fname, 'wb') as f: | |
f.write(json_obj) | |
def make_parquet_store(store_name, refs, row_group_size=1000, compression='zstd', | |
engine='fastparquet', **kwargs): | |
"""Write references as a store of parquet files with multiple row groups. | |
The directory structure should mimic a normal zarr store but instead of standard chunk | |
keys, references are saved as parquet dataframes with multiple row groups. | |
Parameters | |
---------- | |
store_name : str | |
Name of parquet store. | |
refs : dict | |
Kerchunk references | |
row_group_size : int, optional | |
Number of references to store in each reference file (default 1000) | |
compression : str, optional | |
Compression information to pass to parquet engine, default is zstd. | |
engine : {'fastparquet', 'pyarrow'} | |
Library to use for writing parquet files. | |
**kwargs : dict, optional | |
Additional keyword arguments passed to parquet engine of choice. | |
""" | |
if not os.path.exists(store_name): | |
os.makedirs(store_name) | |
if 'refs' in refs: | |
refs = refs['refs'] | |
write_json(os.path.join(store_name, '.row_group_size'), | |
dict(row_group_size=row_group_size)) | |
fields = get_variables(refs) | |
for field in fields: | |
# Initialize dataframe columns | |
paths, offsets, sizes, raws = [], [], [], [] | |
field_path = os.path.join(store_name, field) | |
if field.startswith('.'): | |
# zarr metadata keys (.zgroup, .zmetadata, etc) | |
write_json(field_path, refs[field]) | |
else: | |
if not os.path.exists(field_path): | |
os.makedirs(field_path) | |
# Read the variable zarray metadata to determine number of chunks | |
zarray = ujson.loads(refs[f'{field}/.zarray']) | |
chunk_sizes = np.array(zarray['shape']) / np.array(zarray['chunks']) | |
chunk_numbers = [np.arange(n) for n in chunk_sizes] | |
if chunk_sizes.size != 0: | |
nums = np.asarray(pd.MultiIndex.from_product(chunk_numbers).codes).T | |
else: | |
nums = np.array([0]) | |
nchunks = nums.shape[0] | |
nmissing = 0 | |
for metakey in ['.zarray', '.zattrs']: | |
key = f'{field}/{metakey}' | |
write_json(os.path.join(field_path, metakey), refs[key]) | |
for i in range(nchunks): | |
chunk_id = '.'.join(nums[i].astype(str)) | |
key = f'{field}/{chunk_id}' | |
# Make note if expected number of chunks differs from actual | |
# number found in references | |
if key not in refs: | |
nmissing += 1 | |
paths.append(None) | |
offsets.append(0) | |
sizes.append(0) | |
raws.append(None) | |
else: | |
data = refs[key] | |
if isinstance(data, list): | |
paths.append(data[0]) | |
offsets.append(data[1]) | |
sizes.append(data[2]) | |
raws.append(None) | |
else: | |
paths.append(None) | |
offsets.append(0) | |
sizes.append(0) | |
raws.append(data) | |
if nmissing: | |
print(f'Warning: Chunks missing for field {field}. ' | |
f'Expected: {nchunks}, Found: {nchunks - nmissing}') | |
# Need to pad extra rows so total number divides row_group_size evenly | |
extra_rows = row_group_size - nchunks % row_group_size | |
for i in range(extra_rows): | |
paths.append(None) | |
offsets.append(0) | |
sizes.append(0) | |
raws.append(None) | |
# The convention for parquet files is | |
# <store_name>/<field_name>/refs.parq | |
out_path = os.path.join(field_path, 'refs.parq') | |
df = pd.DataFrame( | |
dict(path=paths, | |
offset=offsets, | |
size=sizes, | |
raw=raws) | |
) | |
# Engine specific kwarg conventions. Set stats to false since | |
# those are currently unneeded. | |
if engine == 'pyarrow': | |
kwargs.update(write_statistics=False) | |
else: | |
kwargs.update(row_group_offsets=row_group_size, | |
stats=False) | |
df.to_parquet(out_path, engine=engine, | |
compression=compression, **kwargs) | |
class LazyReferenceMapper(collections.abc.MutableMapping): | |
"""Interface to read parquet store as if it were a standard kerchunk | |
references dict.""" | |
# import is class level to prevent numpy dep requirement for fsspec | |
import numpy as np | |
def __init__(self, root, fs=None, engine='fastparquet', cache_size=128, | |
categorical_urls=True): | |
""" | |
Parameters | |
---------- | |
root : str | |
Root of parquet store | |
fs : fsspec.AbstractFileSystem | |
fsspec filesystem object, default is local filesystem. | |
cache_size : int | |
Maximum size of LRU cache, where cache_size*row_group_size denotes | |
the total number of references that can be loaded in memory at once. | |
engine : {'fastparquet', 'pyarrow'} | |
Library to use for writing parquet files. | |
categorical_urls : bool | |
Whether to use pandas.Categorical to encode urls. This can greatly | |
reduce memory usage for reference sets with URLs that are used many | |
times by multiple keys (eg, when a single variable has many chunks) | |
in exchange for a bit of additional overhead when loading references. | |
""" | |
self.root = root | |
self.chunk_sizes = {} | |
self._items = {} | |
self.pfs = {} | |
self.engine = engine | |
self.categorical_urls = categorical_urls | |
self.fs = fsspec.filesystem('file') if fs is None else fs | |
# Define function to open and decompress row group data and store | |
# in LRU cache | |
if self.engine == 'pyarrow': | |
import pyarrow.parquet as pq | |
self.pf_cls = pq.ParquetFile | |
@lru_cache(maxsize=cache_size) | |
def open_row_group(pf, row_group): | |
rg = pf.read_row_group(row_group) | |
refs = {c: rg[c].to_numpy() for c in rg.column_names} | |
if self.categorical_urls: | |
refs['path'] = pd.Categorical(refs['path']) | |
return refs | |
else: | |
import fastparquet | |
self.pf_cls = fastparquet.ParquetFile | |
@lru_cache(maxsize=cache_size) | |
def open_row_group(pf, row_group): | |
rg = pf.row_groups[row_group] | |
refs = { | |
c: self.np.empty(rg.num_rows, dtype=dtype) | |
for c, dtype in pf.dtypes.items() | |
} | |
fastparquet.core.read_row_group_arrays(pf.open(), rg, pf.columns, | |
pf.categories, pf.schema, | |
pf.cats, assign=refs) | |
if self.categorical_urls: | |
refs['path'] = pd.Categorical(refs['path']) | |
return refs | |
self.open_row_group = open_row_group | |
def listdir(self, basename=True): | |
listing = self.fs.ls(self.root) | |
if basename: | |
listing = [os.path.basename(path) for path in listing] | |
return listing | |
def join(self, *args): | |
return self.fs.sep.join(args) | |
@property | |
def row_group_size(self): | |
if not hasattr(self, '_row_group_size'): | |
with self.fs.open(self.join(self.root, '.row_group_size')) as f: | |
self._row_group_size = ujson.load(f)['row_group_size'] | |
return self._row_group_size | |
def _load_one_key(self, key): | |
if '/' not in key: | |
if key not in self.listdir(): | |
raise KeyError | |
return self._get_and_cache_metadata(key) | |
else: | |
field, sub_key = key.split('/') | |
if sub_key.startswith('.'): | |
# zarr metadata keys are always cached | |
return self._get_and_cache_metadata(key) | |
# Chunk keys can be loaded from row group and cached in LRU cache | |
row_group, row_number = self._key_to_row_group(key) | |
if field not in self.pfs: | |
pf_path = self.join(self.root, field, 'refs.parq') | |
self.pfs[field] = self.pf_cls(self.fs.open(pf_path)) | |
pf = self.pfs[field] | |
refs = self.open_row_group(pf, row_group) | |
columns = ['path', 'offset', 'size', 'raw'] | |
selection = [refs[c][row_number] for c in columns] | |
raw = selection[-1] | |
if raw is not None: | |
if isinstance(raw, bytes): | |
raw = raw.decode() | |
return raw | |
return selection[:-1] | |
def _get_and_cache_metadata(self, key): | |
with self.fs.open(self.join(self.root, key), 'rb') as f: | |
data = f.read() | |
self._items[key] = data | |
return data | |
def _key_to_row_group(self, key): | |
field, chunk = key.split('/') | |
chunk_sizes = self._get_chunk_sizes(field) | |
if chunk_sizes.size == 0: | |
return 0, 0 | |
chunk_idx = self.np.array([int(c) for c in chunk.split('.')]) | |
chunk_number = self.np.ravel_multi_index(chunk_idx, chunk_sizes) | |
row_group = chunk_number // self.row_group_size | |
row_number = chunk_number % self.row_group_size | |
return row_group, row_number | |
def _get_chunk_sizes(self, field): | |
if field not in self.chunk_sizes: | |
zarray = ujson.loads(self.__getitem__(f'{field}/.zarray')) | |
size_ratio = self.np.array(zarray['shape']) / self.np.array(zarray['chunks']) | |
self.chunk_sizes[field] = self.np.ceil(size_ratio).astype(int) | |
return self.chunk_sizes[field] | |
def __getitem__(self, key): | |
if key in self._items: | |
return self._items[key] | |
return self._load_one_key(key) | |
def __setitem__(self, key, value): | |
self._items[key] = value | |
def __delitem__(self, key): | |
del self._items[key] | |
def __len__(self): | |
# Caveat: This counts expected references, not actual | |
count = 0 | |
for field in fs.ls(self.root): | |
if field.startswith('.'): | |
count += 1 | |
else: | |
chunk_sizes = self._get_chunk_sizes(field) | |
nchunks = self.np.product(chunk_sizes) | |
count += 2 + nchunks | |
return count | |
def __iter__(self): | |
# Caveat: Note that this generates all expected keys, but does not | |
# account for reference keys that are missing. | |
for field in self.listdir(): | |
if field.startswith('.'): | |
yield field | |
else: | |
chunk_sizes = self._get_chunk_sizes(field) | |
nchunks = self.np.product(chunk_sizes) | |
yield '/'.join([field, '.zarray']) | |
yield '/'.join([field, '.zattrs']) | |
inds = self.np.asarray(self.np.unravel_index(self.np.arange(nchunks), chunk_sizes)).T | |
for ind in inds: | |
yield field + '/' + '.'.join(ind.astype(str)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment