Last active
January 21, 2024 05:48
-
-
Save andres-fr/00a73aa2cd6ef5cf609a0446ec0c5d91 to your computer and use it in GitHub Desktop.
Static class to manage distributed (virtual) HDF5 databases
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin python | |
# -*- coding:utf-8 -*- | |
""" | |
""" | |
import os | |
import h5py | |
class DistributedHDF5: | |
""" | |
In general, multiple processes are not allowed to concurrently open and | |
write to HDF5 files. In order to allow for distributed, but coherent | |
writing, this class allows to create many individual HDF5 files, and then | |
'virtually' merge them into a single coherent dataset. | |
Unfortunately, most OS don't allow a single process to open many files at | |
once, and each sub-file here counts as one. As a result, large virtual | |
datasets will reach a breaking point, after which contents are apparently | |
empty. For those cases, this class also provides a ``merge_all`` method, | |
to efficiently convert the virtual dataset into a non-virtual one, merging | |
all composing files into one. | |
In the virtual mode, all files are created and must remain in the same | |
directory. Since this class is naturally intended for multiple concurrent | |
processes/devices, it is designed in the form of a static class. | |
Usage example (see distributed decompositions for more):: | |
# create the empty separate HDF5 files on disk, and the "virtual" merger | |
out_path = "/tmp/my_dataset_{}.h5" | |
each_shape, num_files = (1000,), 5 | |
h5_path, h5_subpaths = DistributedHDF5.create( | |
out_path, num_files, each_shape, torch_dtype_as_str(op_dtype) | |
) | |
# in a (possibly distributed & parallelized) loop, load and write parts | |
for i in range(num_files): | |
vals, flag, h5 = DistributedHDF5.load(h5_subpaths[i]) | |
vals[:] += 1 | |
flag[0] = 'done' | |
h5.close() # remember to close file handles when done! | |
# merged data can be used as a (1000, 50) matrix by a separate machine | |
all_data, all_flags, all_h5 = DistributedHDF5.load_virtual(h5_path) | |
print(scipy.diag(all_data)) | |
print(all_flags[::2]) | |
all_h5.close() | |
# convert virtual database into single monolithic one | |
DistributedHDF5.merge_all( | |
h5_path, | |
delete_subfiles_while_merging=True, | |
) | |
all_data, all_flags, all_h5 = DistributedHDF5.load_virtual(h5_path) | |
... | |
""" | |
MAIN_PATH = "ALL" | |
SUBPATHS_FORMAT = "{0:010d}" | |
DATA_NAME = "data" | |
FLAG_NAME = "flags" | |
FLAG_DTYPE = h5py.string_dtype() | |
INITIAL_FLAG = "initialized" | |
@classmethod | |
def create( | |
cls, | |
base_path, | |
num_files, | |
shape, | |
dtype, | |
compression="lzf", | |
filedim_last=True, | |
): | |
""" """ | |
all_path = base_path.format(cls.MAIN_PATH) | |
subpaths = [ | |
base_path.format(cls.SUBPATHS_FORMAT.format(i)) | |
for i in range(num_files) | |
] | |
# create virtual dataset to hold everything together via softlinks | |
# use relative paths to just assume everything is in the same dir | |
if filedim_last: | |
data_shape = shape + (num_files,) | |
else: | |
data_shape = (num_files,) + shape | |
data_lyt = h5py.VirtualLayout(shape=data_shape, dtype=dtype) | |
flag_lyt = h5py.VirtualLayout(shape=(num_files,), dtype=cls.FLAG_DTYPE) | |
for i, p in enumerate(subpaths): | |
p = os.path.basename(p) # relative path | |
vs = h5py.VirtualSource(p, cls.DATA_NAME, shape=shape) | |
if filedim_last: | |
data_lyt[..., i] = vs | |
else: | |
data_lyt[i] = vs | |
flag_lyt[i] = h5py.VirtualSource(p, cls.FLAG_NAME, shape=(1,)) | |
all_h5 = h5py.File(all_path, "w") # , libver="latest") | |
all_h5.create_virtual_dataset(cls.DATA_NAME, data_lyt) | |
all_h5.create_virtual_dataset(cls.FLAG_NAME, flag_lyt) | |
all_h5.close() | |
# create separate HDF5 files | |
for p in subpaths: | |
h5f = h5py.File(p, "w") | |
h5f.create_dataset( | |
cls.DATA_NAME, | |
shape=shape, | |
maxshape=shape, | |
dtype=dtype, | |
compression=compression, | |
chunks=shape, | |
) | |
h5f.create_dataset( | |
cls.FLAG_NAME, | |
shape=(1,), | |
maxshape=(1,), | |
compression=compression, | |
dtype=cls.FLAG_DTYPE, | |
chunks=(1,), | |
) | |
h5f[cls.FLAG_NAME][:] = cls.INITIAL_FLAG | |
h5f.close() | |
# | |
return all_path, subpaths | |
@classmethod | |
def load(cls, path, filemode="r+"): | |
""" | |
:filemode: Default is 'r+', read/write, file must preexist. See | |
documentation of ``h5py.File`` for more details. | |
:returns: ``(data, flag, h5f)``, where ``data`` is the dataset | |
for the numerical measurements, ``flag`` is the dataset for state | |
tracking, and ``h5f`` is the (open) HDF5 file handle. | |
Remember to ``h5f.close()`` once done with this file. | |
""" | |
h5f = h5py.File(path, filemode) | |
data = h5f[cls.DATA_NAME] | |
flag = h5f[cls.FLAG_NAME] | |
return data, flag, h5f | |
@classmethod | |
def load_virtual(cls, all_path, filemode="r+"): | |
""" | |
Remember to ``h5f.close()`` once done with this file. | |
""" | |
all_h5f = h5py.File(all_path, filemode) | |
data = all_h5f[cls.DATA_NAME] | |
flags = all_h5f[cls.FLAG_NAME] | |
return data, flags, all_h5f | |
@classmethod | |
def merge_all( | |
cls, | |
all_path, | |
out_path=None, | |
compression="lzf", | |
check_success_flag=None, | |
delete_subfiles_while_merging=False, | |
): | |
""" | |
:param out_path: If None, merged dataset will be written over the given | |
``all_path``, and all sub-components deleted. | |
""" | |
# load virtual HDF5 and grab shapes and dtypes | |
data, flags, h5f = DistributedHDF5.load_virtual(all_path, filemode="r") | |
data_shape, flags_shape = data.shape, flags.shape | |
data_dtype, flags_dtype = data.dtype, flags.dtype | |
if check_success_flag is not None: | |
assert all( | |
f.decode() == check_success_flag for f in flags | |
), f"Some flags don't equal {check_success_flag}" | |
# figure out involved paths and their respective indices in virtual | |
abspath = h5f.filename | |
rootdir = os.path.dirname(abspath) | |
data_map, flag_map = {}, {} | |
data_subshapes, flag_subshapes = [], [] | |
for vs in h5f[cls.DATA_NAME].virtual_sources(): | |
subpath = os.path.join(rootdir, vs.file_name) | |
begs_ends = tuple( | |
(b, e + 1) for b, e in zip(*vs.vspace.get_select_bounds()) | |
) | |
shape = tuple(e - b for b, e in begs_ends) | |
data_map[begs_ends] = subpath | |
data_subshapes.append(shape) | |
for vs in h5f[cls.FLAG_NAME].virtual_sources(): | |
subpath = os.path.join(rootdir, vs.file_name) | |
begs_ends = tuple( | |
(b, e + 1) for b, e in zip(*vs.vspace.get_select_bounds()) | |
) | |
shape = tuple(e - b for b, e in begs_ends) | |
assert shape == (1,), "Flags expected to have shape (1,)!" | |
flag_map[begs_ends] = subpath | |
flag_subshapes.append(shape) | |
subpaths = set(data_map.values()) | |
# figure out position of filedim was first or last | |
is_filedim = [ | |
(a - b) == 0 for a, b in zip(data_shape, data_subshapes[0]) | |
] | |
assert sum(is_filedim) == 1, "Only one running dimension supported!" | |
filedim_idx = is_filedim.index(False) | |
# sanity check and close virtual | |
data_beginnings = {k[filedim_idx][0] for k in data_map.keys()} | |
assert len(data_beginnings) == len( | |
h5f[cls.DATA_NAME].virtual_sources() | |
), "Repeated file_idx beginnings in data?" | |
assert len(flag_map) == len( | |
h5f[cls.FLAG_NAME].virtual_sources() | |
), "Repeated indices in flags?" | |
for sp in subpaths: | |
assert os.path.isfile(sp), f"Subpath doesn't exist! {sp}" | |
for sp2 in flag_map.values(): | |
assert sp2 in subpaths, "Flag subpaths different to data subpath!" | |
assert ( | |
len(set(data_subshapes)) == 1 | |
), "Heterogeneous data shapes in virtual dataset not supported!" | |
assert ( | |
len(set(flag_subshapes)) == 1 | |
), "Heterogeneous flag shapes in virtual dataset not supported!" | |
h5f.close() | |
# create empty HDF5 that we will iteratively expand | |
if out_path is None: | |
out_path = all_path | |
h5f = h5py.File(out_path, "w") | |
# | |
init_shape = list(data_shape) | |
init_shape[filedim_idx] = 0 | |
h5f.create_dataset( | |
cls.DATA_NAME, | |
shape=init_shape, | |
maxshape=data_shape, | |
dtype=data_dtype, | |
compression=compression, | |
chunks=data_subshapes[0], | |
) | |
h5f.create_dataset( | |
cls.FLAG_NAME, | |
shape=0, | |
maxshape=flags_shape, | |
dtype=flags_dtype, | |
compression=compression, | |
chunks=flag_subshapes[0], | |
) | |
# iterate over contents in sorted order and extend h5f with them | |
sorted_data = sorted(data_map, key=lambda x: x[filedim_idx][0]) | |
for begs_ends in sorted_data: | |
subpath = data_map[begs_ends] | |
subdata, subflag, h5 = cls.load(subpath, filemode="r") | |
if check_success_flag is not None: | |
assert ( | |
subflag[0].decode() == check_success_flag | |
), f"Subfile flag not equal {check_success_flag}!" | |
# increment size of h5f by 1 entry | |
data_shape = list(h5f[cls.DATA_NAME].shape) | |
data_shape[filedim_idx] += 1 | |
h5f[cls.DATA_NAME].resize(data_shape) | |
h5f[cls.FLAG_NAME].resize((len(h5f[cls.FLAG_NAME]) + 1,)) | |
# write subdata and subflags to h5f, flush and close subfile | |
target_slices = tuple(slice(*be) for be in begs_ends) | |
h5f[cls.DATA_NAME][target_slices] = subdata[:].reshape( | |
data_subshapes[0] | |
) | |
h5f[cls.FLAG_NAME][-1:] = subflag[:] | |
h5f.flush() | |
h5.close() | |
# optionally, delete subfile | |
if delete_subfiles_while_merging: | |
os.remove(subpath) | |
# | |
h5f.close() | |
return out_path |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment