Created
November 22, 2018 06:10
-
-
Save m1m0r1/d9004ec382223b4adf1ad60b9581f67c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Examples: | |
>>> a = lock_array(np.zeros((5, 5))) # locked array | |
>>> b = np.zeros((5, 5)) # default array | |
# >>> a[:, 3] = 1. # You can't update locked arrays | |
>>> b[:, 3] = 1. | |
>>> with unlocked(a, b): | |
>>> a[:, 2] = 1 # You can update unlocked arrays here | |
>>> b[:, 2] = 1 | |
""" | |
def lock_array(values): | |
values.setflags(write=False) | |
return values | |
def unlock_array(values): | |
values.setflags(write=True) | |
return values | |
class unlocked: | |
def __init__(self, *values): | |
for v in values: | |
assert isinstance(v, np.ndarray) | |
self._values = values | |
def __enter__(self, *values): | |
self._org_flags = [v.flags.writeable for v in self._values] | |
for v in self._values: | |
v.setflags(write=True) | |
def __exit__(self, *args): | |
for v, flag in zip(self._values, self._org_flags): | |
v.setflags(write=flag) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment