Created
August 8, 2019 07:42
-
-
Save dzil123/383fca3832aedd02f293d509845279a0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import multiprocessing as mp | |
import numpy as np | |
# Create a np.ndarray that can be shared across processes | |
# Array.np and Array.arr point to the same memory | |
# This memory is shared between the processes | |
# And mp.RawArray can be sent to a new spawned process | |
class Array: | |
# arr is mp.RawArray | |
# if arr is mp.Array, use arr.get_obj() | |
# shape is tuple of ints | |
# type 'B' is uint8 0-255 | |
def __init__(self, shape, arr=None, type="B"): | |
self.shape = shape | |
self.type = type | |
if arr is None: | |
arr = self._make_arr() | |
self.arr = arr | |
self.np = self.make_np() | |
def _make_arr(self): | |
num = np.prod(self.shape) # arr is 1D; need to squash shape | |
return mp.RawArray(self.type, num) | |
def _make_np(self): | |
return np.frombuffer(self.arr, self.type).reshape(self.shape) | |
def __getstate__(self): | |
return (self.shape, self.arr, self.type) | |
def __setstate__(self, data): | |
self.__init__(*data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment