Skip to content

Instantly share code, notes, and snippets.

@kkew3
Created April 10, 2019 02:56
Show Gist options
  • Save kkew3/6731356ffa9aaff82e5fc24e32526c5f to your computer and use it in GitHub Desktop.
Save kkew3/6731356ffa9aaff82e5fc24e32526c5f to your computer and use it in GitHub Desktop.
Write incrementally to `npz` file to save memory
import zipfile
import io
import typing
import numpy as np
class IncrementalNpzWriter:
"""
Write data to npz file incrementally rather than compute all and write
once, as in ``np.save``. This class can be used with ``contextlib.closing``
to ensure closed after usage.
"""
def __init__(self, tofile: str, mode: str = 'x'):
"""
:param tofile: the ``npz`` file to write
:param mode: must be one of {'x', 'w', 'a'}. See
https://docs.python.org/3/library/zipfile.html for detail
"""
assert mode in 'xwa', str(mode)
self.tofile = zipfile.ZipFile(tofile, mode=mode,
compression=zipfile.ZIP_DEFLATED)
def write(self, key: str, data: typing.Union[np.ndarray, bytes],
is_npy_data: bool = True) -> None:
"""
:param key: the name of data to write
:param data: the data
:param is_npy_data: if ``True``, ".npz" will be appended to ``key``,
and ``data`` will be serialized by ``np.save``;
otherwise, ``key`` will be treated as is, and ``data`` will be
treated as binary data
:raise KeyError: if the transformed ``key`` (as per ``is_npy_data``)
already exists in ``self.tofile``
"""
if key in self.tofile.namelist():
raise KeyError('Duplicate key "{}" already exists in "{}"'
.format(key, self.tofile.filename))
self.update(key, data, is_npy_data=is_npy_data)
def update(self, key: str, data: typing.Union[np.ndarray, bytes],
is_npy_data: bool = True) -> None:
"""
Same as ``self.write`` but overwrite existing data of name ``key``.
:param key: the name of data to write
:param data: the data
:param is_npy_data: if ``True``, ".npz" will be appended to ``key``,
and ``data`` will be serialized by ``np.save``;
otherwise, ``key`` will be treated as is, and ``data`` will be
treated as binary data
"""
kwargs = {
'mode': 'w',
'force_zip64': True,
}
if is_npy_data:
key += '.npy'
with io.BytesIO() as cbuf:
np.save(cbuf, data)
cbuf.seek(0)
with self.tofile.open(key, **kwargs) as outfile:
shutil.copyfileobj(cbuf, outfile)
else:
with self.tofile.open(key, **kwargs) as outfile:
outfile.write(data)
def close(self):
if self.tofile is not None:
self.tofile.close()
self.tofile = None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment