Skip to content

Instantly share code, notes, and snippets.

@andres-fr
Last active January 21, 2024 05:48
Show Gist options
  • Save andres-fr/00a73aa2cd6ef5cf609a0446ec0c5d91 to your computer and use it in GitHub Desktop.
Save andres-fr/00a73aa2cd6ef5cf609a0446ec0c5d91 to your computer and use it in GitHub Desktop.
Static class to manage distributed (virtual) HDF5 databases
#!/usr/bin python
# -*- coding:utf-8 -*-
"""
"""
import os
import h5py
class DistributedHDF5:
"""
In general, multiple processes are not allowed to concurrently open and
write to HDF5 files. In order to allow for distributed, but coherent
writing, this class allows to create many individual HDF5 files, and then
'virtually' merge them into a single coherent dataset.
Unfortunately, most OS don't allow a single process to open many files at
once, and each sub-file here counts as one. As a result, large virtual
datasets will reach a breaking point, after which contents are apparently
empty. For those cases, this class also provides a ``merge_all`` method,
to efficiently convert the virtual dataset into a non-virtual one, merging
all composing files into one.
In the virtual mode, all files are created and must remain in the same
directory. Since this class is naturally intended for multiple concurrent
processes/devices, it is designed in the form of a static class.
Usage example (see distributed decompositions for more)::
# create the empty separate HDF5 files on disk, and the "virtual" merger
out_path = "/tmp/my_dataset_{}.h5"
each_shape, num_files = (1000,), 5
h5_path, h5_subpaths = DistributedHDF5.create(
out_path, num_files, each_shape, torch_dtype_as_str(op_dtype)
)
# in a (possibly distributed & parallelized) loop, load and write parts
for i in range(num_files):
vals, flag, h5 = DistributedHDF5.load(h5_subpaths[i])
vals[:] += 1
flag[0] = 'done'
h5.close() # remember to close file handles when done!
# merged data can be used as a (1000, 50) matrix by a separate machine
all_data, all_flags, all_h5 = DistributedHDF5.load_virtual(h5_path)
print(scipy.diag(all_data))
print(all_flags[::2])
all_h5.close()
# convert virtual database into single monolithic one
DistributedHDF5.merge_all(
h5_path,
delete_subfiles_while_merging=True,
)
all_data, all_flags, all_h5 = DistributedHDF5.load_virtual(h5_path)
...
"""
MAIN_PATH = "ALL"
SUBPATHS_FORMAT = "{0:010d}"
DATA_NAME = "data"
FLAG_NAME = "flags"
FLAG_DTYPE = h5py.string_dtype()
INITIAL_FLAG = "initialized"
@classmethod
def create(
cls,
base_path,
num_files,
shape,
dtype,
compression="lzf",
filedim_last=True,
):
""" """
all_path = base_path.format(cls.MAIN_PATH)
subpaths = [
base_path.format(cls.SUBPATHS_FORMAT.format(i))
for i in range(num_files)
]
# create virtual dataset to hold everything together via softlinks
# use relative paths to just assume everything is in the same dir
if filedim_last:
data_shape = shape + (num_files,)
else:
data_shape = (num_files,) + shape
data_lyt = h5py.VirtualLayout(shape=data_shape, dtype=dtype)
flag_lyt = h5py.VirtualLayout(shape=(num_files,), dtype=cls.FLAG_DTYPE)
for i, p in enumerate(subpaths):
p = os.path.basename(p) # relative path
vs = h5py.VirtualSource(p, cls.DATA_NAME, shape=shape)
if filedim_last:
data_lyt[..., i] = vs
else:
data_lyt[i] = vs
flag_lyt[i] = h5py.VirtualSource(p, cls.FLAG_NAME, shape=(1,))
all_h5 = h5py.File(all_path, "w") # , libver="latest")
all_h5.create_virtual_dataset(cls.DATA_NAME, data_lyt)
all_h5.create_virtual_dataset(cls.FLAG_NAME, flag_lyt)
all_h5.close()
# create separate HDF5 files
for p in subpaths:
h5f = h5py.File(p, "w")
h5f.create_dataset(
cls.DATA_NAME,
shape=shape,
maxshape=shape,
dtype=dtype,
compression=compression,
chunks=shape,
)
h5f.create_dataset(
cls.FLAG_NAME,
shape=(1,),
maxshape=(1,),
compression=compression,
dtype=cls.FLAG_DTYPE,
chunks=(1,),
)
h5f[cls.FLAG_NAME][:] = cls.INITIAL_FLAG
h5f.close()
#
return all_path, subpaths
@classmethod
def load(cls, path, filemode="r+"):
"""
:filemode: Default is 'r+', read/write, file must preexist. See
documentation of ``h5py.File`` for more details.
:returns: ``(data, flag, h5f)``, where ``data`` is the dataset
for the numerical measurements, ``flag`` is the dataset for state
tracking, and ``h5f`` is the (open) HDF5 file handle.
Remember to ``h5f.close()`` once done with this file.
"""
h5f = h5py.File(path, filemode)
data = h5f[cls.DATA_NAME]
flag = h5f[cls.FLAG_NAME]
return data, flag, h5f
@classmethod
def load_virtual(cls, all_path, filemode="r+"):
"""
Remember to ``h5f.close()`` once done with this file.
"""
all_h5f = h5py.File(all_path, filemode)
data = all_h5f[cls.DATA_NAME]
flags = all_h5f[cls.FLAG_NAME]
return data, flags, all_h5f
@classmethod
def merge_all(
cls,
all_path,
out_path=None,
compression="lzf",
check_success_flag=None,
delete_subfiles_while_merging=False,
):
"""
:param out_path: If None, merged dataset will be written over the given
``all_path``, and all sub-components deleted.
"""
# load virtual HDF5 and grab shapes and dtypes
data, flags, h5f = DistributedHDF5.load_virtual(all_path, filemode="r")
data_shape, flags_shape = data.shape, flags.shape
data_dtype, flags_dtype = data.dtype, flags.dtype
if check_success_flag is not None:
assert all(
f.decode() == check_success_flag for f in flags
), f"Some flags don't equal {check_success_flag}"
# figure out involved paths and their respective indices in virtual
abspath = h5f.filename
rootdir = os.path.dirname(abspath)
data_map, flag_map = {}, {}
data_subshapes, flag_subshapes = [], []
for vs in h5f[cls.DATA_NAME].virtual_sources():
subpath = os.path.join(rootdir, vs.file_name)
begs_ends = tuple(
(b, e + 1) for b, e in zip(*vs.vspace.get_select_bounds())
)
shape = tuple(e - b for b, e in begs_ends)
data_map[begs_ends] = subpath
data_subshapes.append(shape)
for vs in h5f[cls.FLAG_NAME].virtual_sources():
subpath = os.path.join(rootdir, vs.file_name)
begs_ends = tuple(
(b, e + 1) for b, e in zip(*vs.vspace.get_select_bounds())
)
shape = tuple(e - b for b, e in begs_ends)
assert shape == (1,), "Flags expected to have shape (1,)!"
flag_map[begs_ends] = subpath
flag_subshapes.append(shape)
subpaths = set(data_map.values())
# figure out position of filedim was first or last
is_filedim = [
(a - b) == 0 for a, b in zip(data_shape, data_subshapes[0])
]
assert sum(is_filedim) == 1, "Only one running dimension supported!"
filedim_idx = is_filedim.index(False)
# sanity check and close virtual
data_beginnings = {k[filedim_idx][0] for k in data_map.keys()}
assert len(data_beginnings) == len(
h5f[cls.DATA_NAME].virtual_sources()
), "Repeated file_idx beginnings in data?"
assert len(flag_map) == len(
h5f[cls.FLAG_NAME].virtual_sources()
), "Repeated indices in flags?"
for sp in subpaths:
assert os.path.isfile(sp), f"Subpath doesn't exist! {sp}"
for sp2 in flag_map.values():
assert sp2 in subpaths, "Flag subpaths different to data subpath!"
assert (
len(set(data_subshapes)) == 1
), "Heterogeneous data shapes in virtual dataset not supported!"
assert (
len(set(flag_subshapes)) == 1
), "Heterogeneous flag shapes in virtual dataset not supported!"
h5f.close()
# create empty HDF5 that we will iteratively expand
if out_path is None:
out_path = all_path
h5f = h5py.File(out_path, "w")
#
init_shape = list(data_shape)
init_shape[filedim_idx] = 0
h5f.create_dataset(
cls.DATA_NAME,
shape=init_shape,
maxshape=data_shape,
dtype=data_dtype,
compression=compression,
chunks=data_subshapes[0],
)
h5f.create_dataset(
cls.FLAG_NAME,
shape=0,
maxshape=flags_shape,
dtype=flags_dtype,
compression=compression,
chunks=flag_subshapes[0],
)
# iterate over contents in sorted order and extend h5f with them
sorted_data = sorted(data_map, key=lambda x: x[filedim_idx][0])
for begs_ends in sorted_data:
subpath = data_map[begs_ends]
subdata, subflag, h5 = cls.load(subpath, filemode="r")
if check_success_flag is not None:
assert (
subflag[0].decode() == check_success_flag
), f"Subfile flag not equal {check_success_flag}!"
# increment size of h5f by 1 entry
data_shape = list(h5f[cls.DATA_NAME].shape)
data_shape[filedim_idx] += 1
h5f[cls.DATA_NAME].resize(data_shape)
h5f[cls.FLAG_NAME].resize((len(h5f[cls.FLAG_NAME]) + 1,))
# write subdata and subflags to h5f, flush and close subfile
target_slices = tuple(slice(*be) for be in begs_ends)
h5f[cls.DATA_NAME][target_slices] = subdata[:].reshape(
data_subshapes[0]
)
h5f[cls.FLAG_NAME][-1:] = subflag[:]
h5f.flush()
h5.close()
# optionally, delete subfile
if delete_subfiles_while_merging:
os.remove(subpath)
#
h5f.close()
return out_path
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment