Skip to content

Instantly share code, notes, and snippets.

@mivade
Last active August 21, 2018 15:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mivade/db0f280ff845d51979a86f87d8ec6d03 to your computer and use it in GitHub Desktop.
Save mivade/db0f280ff845d51979a86f87d8ec6d03 to your computer and use it in GitHub Desktop.
import codecs
import json
from typing import Union
import h5py
import numpy as np
import pandas as pd
vlen = np.vectorize(len)
vencode = np.vectorize(codecs.encode)
vdecode = np.vectorize(codecs.decode)
def save(hfile: h5py.File, where: str, data: Union[pd.DataFrame, np.array]):
"""Save record array-like data to HDF5.
Parameters
----------
hfile
Opened HDF5 file object.
where
Dataset name.
data
The data to write.
"""
original_type = str(type(data))
if isinstance(data, pd.DataFrame):
data = data.to_records()
if not isinstance(data, np.recarray):
data = np.rec.array(data)
dtype = []
encoded = set()
for name in data.dtype.names:
this_dtype = data[name].dtype
if this_dtype == object or this_dtype.char == "U":
maxlen = np.amax(vlen(data[name]))
dtype.append((name, f"|S{maxlen}"))
encoded.add(name)
else:
dtype.append((name, this_dtype))
sanitized = np.recarray(data.shape, dtype=dtype)
for name, _ in dtype:
if name in encoded:
sanitized[name] = vencode(data[name])
else:
sanitized[name] = data[name]
hfile[where] = sanitized
hfile[where].attrs["utf8_encoded_fields"] = json.dumps(list(encoded))
hfile[where].attrs["original_type"] = original_type
def load(hfile: h5py.File, where: str) -> np.array:
"""Load data stored with :func:`save`.
Parameters
----------
hfile
Open HDF5 file object.
where
Key to load data from.
"""
data = pd.DataFrame(hfile[where][:])
encoded = json.loads(hfile[where].attrs["utf8_encoded_fields"])
columns = {key: value for key, value in data.items()}
for name in encoded:
columns[name] = vdecode(columns[name])
df = pd.DataFrame(columns)
if "DataFrame" not in hfile[where].attrs["original_type"]:
return df.to_records()
return df
if __name__ == "__main__":
df = pd.DataFrame({
"string": ["a", "string"],
"integer": [1, 2],
"float": [1., 2.],
})
ra = np.rec.array(
[("a", 1), ("longer string", 2)],
dtype=[("description", "<U32"), ("number", int)]
)
with h5py.File("test.h5", "w") as hfile:
save(hfile, "dataframe", df)
save(hfile, "recarray", ra)
with h5py.File("test.h5", "r") as hfile:
print(load(hfile, "dataframe"))
print(load(hfile, "recarray"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment