Created
May 2, 2019 19:17
-
-
Save golobor/9273e6617bffe54f3088fe4d4b3554f7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# from https://stackoverflow.com/questions/22712292/how-to-use-numpy-savez-in-a-loop-for-save-more-than-one-array/22716159#22716159 | |
import numpy as np | |
import tempfile | |
class my_savez(object): | |
def __init__(self, file, compression=None, compresslevel=None): | |
# Import is postponed to here since zipfile depends on gzip, an optional | |
# component of the so-called standard library. | |
import zipfile | |
# Import deferred for startup time improvement | |
import tempfile | |
import os | |
if isinstance(file, str): | |
if not file.endswith('.npz'): | |
file = file + '.npz' | |
compression = { | |
None: zipfile.ZIP_STORED, | |
'deflate': zipfile.ZIP_DEFLATED, | |
'bz2': zipfile.ZIP_BZIP2, | |
'lzma': zipfile.ZIP_BZIP2, | |
}[compression] | |
zip = self.zipfile_factory(file, mode="w", compression=compression, compresslevel=compresslevel) | |
# Stage arrays in a temporary file on disk, before writing to zip. | |
fd, tmpfile = tempfile.mkstemp(suffix='-numpy.npy') | |
os.close(fd) | |
self.tmpfile = tmpfile | |
self.zip = zip | |
self.i = 0 | |
def zipfile_factory(self, *args, **kwargs): | |
import zipfile | |
import sys | |
if sys.version_info >= (2, 5): | |
kwargs['allowZip64'] = True | |
return zipfile.ZipFile(*args, **kwargs) | |
def savez(self, *args, **kwds): | |
import os | |
import numpy.lib.format as format | |
namedict = kwds | |
for val in args: | |
key = 'arr_%d' % self.i | |
if key in namedict.keys(): | |
raise ValueError("Cannot use un-named variables and keyword %s" % key) | |
namedict[key] = val | |
self.i += 1 | |
try: | |
for key, val in namedict.items(): | |
fname = key + '.npy' | |
fid = open(self.tmpfile, 'wb') | |
try: | |
format.write_array(fid, np.asanyarray(val)) | |
fid.close() | |
fid = None | |
self.zip.write(self.tmpfile, arcname=fname) | |
finally: | |
if fid: | |
fid.close() | |
finally: | |
os.remove(self.tmpfile) | |
def close(self): | |
self.zip.close() | |
tmp = tempfile.NamedTemporaryFile() | |
f = my_savez(tmp, compression='deflate', compresslevel=5) | |
for i in range(10): | |
array = np.random.randint(0,int(1e9),int(1e6)) | |
f.savez(array) | |
f.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment