Skip to content

Instantly share code, notes, and snippets.

@liushapku
Created July 16, 2018 07:03
Show Gist options
  • Save liushapku/ca47de750197f42290e58ad92c46c7e2 to your computer and use it in GitHub Desktop.
Save liushapku/ca47de750197f42290e58ad92c46c7e2 to your computer and use it in GitHub Desktop.
manually share numpy memmapped arrays to different processes
import numpy as np
from numpy.lib.stride_tricks import as_strided
from distutils.version import LooseVersion
import mmap
import tempfile
import shutil
import os
import uuid
import contextlib
def get_backing_memmap(a):
"""Recursively look up the original np.memmap instance base if any."""
base = getattr(a, 'base', None)
if base is None:
return None
elif isinstance(base, mmap.mmap):
return a
else:
return get_backing_memmap(base)
def is_shareable(a):
return isinstance(a, np.ndarray) and get_backing_memmap(a) is not None
class SharedArray:
def __init__(self, a):
"""
copy from joblib._reduce_memmap_backed
"""
m = get_backing_memmap(a)
assert m is not None, 'invalid shared array'
a_start, a_end = np.byte_bounds(a)
m_start = np.byte_bounds(m)[0]
offset = a_start - m_start
# offset from the backing memmap
offset += m.offset
if m.flags['F_CONTIGUOUS']:
order = 'F'
else:
# The backing memmap buffer is necessarily contiguous hence C if not
# Fortran
order = 'C'
if a.flags['F_CONTIGUOUS'] or a.flags['C_CONTIGUOUS']:
# If the array is a contiguous view, no need to pass the strides
strides = None
total_buffer_len = None
else:
# Compute the total number of items to map from which the strided
# view will be extracted.
strides = a.strides
total_buffer_len = (a_end - a_start) // a.itemsize
self.filename = m.filename
self.dtype = a.dtype
self.mode = m.mode
self.offset = offset
self.order = order
self.shape = a.shape
self.strides = strides
self.total_buffer_len = total_buffer_len
def asarray(self):
"""Reconstruct an array view on a memory mapped file."""
if self.mode == 'w+':
# Do not zero the original data when unpickling
mode = 'r+'
else:
mode = self.mode
if self.strides is None:
# Simple, contiguous memmap
rv = np.memmap(self.filename, dtype=self.dtype, shape=self.shape, mode=mode,
offset=self.offset, order=self.order)
else:
# For non-contiguous data, memmap the total enclosing buffer and then
# extract the non-contiguous view with the stride-tricks API
base = np.memmap(self.filename, dtype=self.dtype, shape=self.total_buffer_len,
mode=mode, offset=self.offset, order=self.order)
rv = as_strided(base, shape=self.shape, strides=self.strides)
return rv
def __getitem__(self, args):
array = self.asarray()[args]
return SharedArray(array)
def try_share(value):
"""
make value SharedArray if it is shareable
"""
if is_shareable(value):
return SharedArray(value)
return value
def share(value):
if value is None:
return None
elif is_shareable(value):
return SharedArray(value)
elif isinstance(value, tuple):
return tuple(share(val) for val in value)
elif isinstance(value, list):
return [share(val) for val in value]
elif isinstance(value, dict):
return {key: share(val) for key, val in value.items()}
else:
raise ValueError('can only share shareable array, tuple/list/dict of shareable array or None')
def try_asarray(value):
"""
If value is not SharedArray, returns value itself
otherwise, extract the ndarray from SharedArray.
"""
if isinstance(value, SharedArray):
return value.asarray()
return value
def _wrap(func, args, kwargs):
newargs = tuple(try_asarray(arg) for arg in args)
newkwargs = {key: try_asarray(arg) for key, arg in kwargs.items()}
rvs = func(*newargs, **kwargs)
if isinstance(rvs, tuple):
rv = tuple(try_share(rv) for rv in rvs)
elif isinstance(rvs, dict):
rv = {key: try_share(rv) for key, rv in rvs.items()}
else:
rv = try_share(rvs)
# print(rv)
return rv
class Wrapper:
"""
A helper class used for sumitting functions to pool. Usage:
instead of doing pool.submit(func, *args, **kwargs), do
pool.submit(*Wrapper(func)(*args, **kwargs))
"""
def __init__(self, func):
self.func = func
def __call__(self, *args, **kwargs):
newargs = tuple(try_share(arg) for arg in args)
newkwargs = {key: try_share(arg) for key, arg in kwargs.items()}
# print(newargs, newkwargs)
return _wrap, self.func, newargs, kwargs
def wrap(func, *args, **kwargs):
"""
A helper function used for sumitting functions to pool. Usage:
instead of doing pool.submit(func, *args, **kwargs), do
pool.submit(*wrap(func, *args, **kwargs))
"""
newargs = tuple(try_share(arg) for arg in args)
newkwargs = {key: try_share(arg) for key, arg in kwargs.items()}
return _wrap, func, newargs, kwargs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment