Created
July 16, 2018 07:03
-
-
Save liushapku/ca47de750197f42290e58ad92c46c7e2 to your computer and use it in GitHub Desktop.
manually share numpy memmapped arrays to different processes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from numpy.lib.stride_tricks import as_strided | |
from distutils.version import LooseVersion | |
import mmap | |
import tempfile | |
import shutil | |
import os | |
import uuid | |
import contextlib | |
def get_backing_memmap(a): | |
"""Recursively look up the original np.memmap instance base if any.""" | |
base = getattr(a, 'base', None) | |
if base is None: | |
return None | |
elif isinstance(base, mmap.mmap): | |
return a | |
else: | |
return get_backing_memmap(base) | |
def is_shareable(a): | |
return isinstance(a, np.ndarray) and get_backing_memmap(a) is not None | |
class SharedArray: | |
def __init__(self, a): | |
""" | |
copy from joblib._reduce_memmap_backed | |
""" | |
m = get_backing_memmap(a) | |
assert m is not None, 'invalid shared array' | |
a_start, a_end = np.byte_bounds(a) | |
m_start = np.byte_bounds(m)[0] | |
offset = a_start - m_start | |
# offset from the backing memmap | |
offset += m.offset | |
if m.flags['F_CONTIGUOUS']: | |
order = 'F' | |
else: | |
# The backing memmap buffer is necessarily contiguous hence C if not | |
# Fortran | |
order = 'C' | |
if a.flags['F_CONTIGUOUS'] or a.flags['C_CONTIGUOUS']: | |
# If the array is a contiguous view, no need to pass the strides | |
strides = None | |
total_buffer_len = None | |
else: | |
# Compute the total number of items to map from which the strided | |
# view will be extracted. | |
strides = a.strides | |
total_buffer_len = (a_end - a_start) // a.itemsize | |
self.filename = m.filename | |
self.dtype = a.dtype | |
self.mode = m.mode | |
self.offset = offset | |
self.order = order | |
self.shape = a.shape | |
self.strides = strides | |
self.total_buffer_len = total_buffer_len | |
def asarray(self): | |
"""Reconstruct an array view on a memory mapped file.""" | |
if self.mode == 'w+': | |
# Do not zero the original data when unpickling | |
mode = 'r+' | |
else: | |
mode = self.mode | |
if self.strides is None: | |
# Simple, contiguous memmap | |
rv = np.memmap(self.filename, dtype=self.dtype, shape=self.shape, mode=mode, | |
offset=self.offset, order=self.order) | |
else: | |
# For non-contiguous data, memmap the total enclosing buffer and then | |
# extract the non-contiguous view with the stride-tricks API | |
base = np.memmap(self.filename, dtype=self.dtype, shape=self.total_buffer_len, | |
mode=mode, offset=self.offset, order=self.order) | |
rv = as_strided(base, shape=self.shape, strides=self.strides) | |
return rv | |
def __getitem__(self, args): | |
array = self.asarray()[args] | |
return SharedArray(array) | |
def try_share(value): | |
""" | |
make value SharedArray if it is shareable | |
""" | |
if is_shareable(value): | |
return SharedArray(value) | |
return value | |
def share(value): | |
if value is None: | |
return None | |
elif is_shareable(value): | |
return SharedArray(value) | |
elif isinstance(value, tuple): | |
return tuple(share(val) for val in value) | |
elif isinstance(value, list): | |
return [share(val) for val in value] | |
elif isinstance(value, dict): | |
return {key: share(val) for key, val in value.items()} | |
else: | |
raise ValueError('can only share shareable array, tuple/list/dict of shareable array or None') | |
def try_asarray(value): | |
""" | |
If value is not SharedArray, returns value itself | |
otherwise, extract the ndarray from SharedArray. | |
""" | |
if isinstance(value, SharedArray): | |
return value.asarray() | |
return value | |
def _wrap(func, args, kwargs): | |
newargs = tuple(try_asarray(arg) for arg in args) | |
newkwargs = {key: try_asarray(arg) for key, arg in kwargs.items()} | |
rvs = func(*newargs, **kwargs) | |
if isinstance(rvs, tuple): | |
rv = tuple(try_share(rv) for rv in rvs) | |
elif isinstance(rvs, dict): | |
rv = {key: try_share(rv) for key, rv in rvs.items()} | |
else: | |
rv = try_share(rvs) | |
# print(rv) | |
return rv | |
class Wrapper: | |
""" | |
A helper class used for sumitting functions to pool. Usage: | |
instead of doing pool.submit(func, *args, **kwargs), do | |
pool.submit(*Wrapper(func)(*args, **kwargs)) | |
""" | |
def __init__(self, func): | |
self.func = func | |
def __call__(self, *args, **kwargs): | |
newargs = tuple(try_share(arg) for arg in args) | |
newkwargs = {key: try_share(arg) for key, arg in kwargs.items()} | |
# print(newargs, newkwargs) | |
return _wrap, self.func, newargs, kwargs | |
def wrap(func, *args, **kwargs): | |
""" | |
A helper function used for sumitting functions to pool. Usage: | |
instead of doing pool.submit(func, *args, **kwargs), do | |
pool.submit(*wrap(func, *args, **kwargs)) | |
""" | |
newargs = tuple(try_share(arg) for arg in args) | |
newkwargs = {key: try_share(arg) for key, arg in kwargs.items()} | |
return _wrap, func, newargs, kwargs | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment