Created
April 10, 2019 02:56
-
-
Save kkew3/6731356ffa9aaff82e5fc24e32526c5f to your computer and use it in GitHub Desktop.
Write incrementally to `npz` file to save memory
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import zipfile | |
import io | |
import typing | |
import numpy as np | |
class IncrementalNpzWriter: | |
""" | |
Write data to npz file incrementally rather than compute all and write | |
once, as in ``np.save``. This class can be used with ``contextlib.closing`` | |
to ensure closed after usage. | |
""" | |
def __init__(self, tofile: str, mode: str = 'x'): | |
""" | |
:param tofile: the ``npz`` file to write | |
:param mode: must be one of {'x', 'w', 'a'}. See | |
https://docs.python.org/3/library/zipfile.html for detail | |
""" | |
assert mode in 'xwa', str(mode) | |
self.tofile = zipfile.ZipFile(tofile, mode=mode, | |
compression=zipfile.ZIP_DEFLATED) | |
def write(self, key: str, data: typing.Union[np.ndarray, bytes], | |
is_npy_data: bool = True) -> None: | |
""" | |
:param key: the name of data to write | |
:param data: the data | |
:param is_npy_data: if ``True``, ".npz" will be appended to ``key``, | |
and ``data`` will be serialized by ``np.save``; | |
otherwise, ``key`` will be treated as is, and ``data`` will be | |
treated as binary data | |
:raise KeyError: if the transformed ``key`` (as per ``is_npy_data``) | |
already exists in ``self.tofile`` | |
""" | |
if key in self.tofile.namelist(): | |
raise KeyError('Duplicate key "{}" already exists in "{}"' | |
.format(key, self.tofile.filename)) | |
self.update(key, data, is_npy_data=is_npy_data) | |
def update(self, key: str, data: typing.Union[np.ndarray, bytes], | |
is_npy_data: bool = True) -> None: | |
""" | |
Same as ``self.write`` but overwrite existing data of name ``key``. | |
:param key: the name of data to write | |
:param data: the data | |
:param is_npy_data: if ``True``, ".npz" will be appended to ``key``, | |
and ``data`` will be serialized by ``np.save``; | |
otherwise, ``key`` will be treated as is, and ``data`` will be | |
treated as binary data | |
""" | |
kwargs = { | |
'mode': 'w', | |
'force_zip64': True, | |
} | |
if is_npy_data: | |
key += '.npy' | |
with io.BytesIO() as cbuf: | |
np.save(cbuf, data) | |
cbuf.seek(0) | |
with self.tofile.open(key, **kwargs) as outfile: | |
shutil.copyfileobj(cbuf, outfile) | |
else: | |
with self.tofile.open(key, **kwargs) as outfile: | |
outfile.write(data) | |
def close(self): | |
if self.tofile is not None: | |
self.tofile.close() | |
self.tofile = None |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment