Last active
July 28, 2020 17:42
-
-
Save amoodie/1683dbc910af9d13e29fdd61f227f301 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# test the possibility of setting and getting mwe within jitted functions | |
import numpy as np | |
import numba | |
@numba.jit | |
def set_random_seed(_seed): | |
np.random.seed(_seed) | |
@numba.jit | |
def get_random_number(): | |
return np.random.uniform(0, 1) | |
def get_random_state(): | |
ptr = numba._helperlib.rnd_get_np_state_ptr() | |
return numba._helperlib.rnd_get_state(ptr) | |
def set_random_state(_state_tuple): | |
ptr = numba._helperlib.rnd_get_np_state_ptr() | |
numba._helperlib.rnd_set_state(ptr, _state_tuple) | |
class ComplexObject(): | |
def __init__(self, seed): | |
set_random_seed(seed) | |
self.value = 0 | |
def update(self): | |
val = get_random_number() | |
self.value += val | |
def save_object_state(self): | |
np.savez_compressed('chkp_file.npz', | |
value=self.value, | |
rng_state=get_random_state()) | |
def load_object_state(self): | |
checkpoint = np.load('chkp_file.npz', allow_pickle=True) | |
rng_state = tuple(checkpoint['rng_state']) | |
self.value = checkpoint['value'] | |
set_random_state(rng_state) | |
if __name__ == '__main__': | |
# step 1, run model for 10 steps total | |
co1 = ComplexObject(42) | |
for _ in range(10): | |
co1.update() | |
print(co1.value) | |
# step 2, run model for 5 steps then save | |
co2 = ComplexObject(42) | |
for _ in range(5): | |
co2.update() | |
print(co2.value) | |
co2.save_object_state() | |
# step 3, load up a new model and run for 5 steps | |
co3 = ComplexObject(123) | |
co3.load_object_state() | |
for _ in range(5): | |
co3.update() | |
print(co3.value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# test the possibility of setting and getting mwe within jitted functions | |
import numpy as np | |
import numba | |
import multiprocessing | |
@numba.jit | |
def set_random_seed(_seed): | |
np.random.seed(_seed) | |
@numba.jit | |
def get_random_number(): | |
return np.random.uniform(0, 1) | |
def get_random_state(): | |
ptr = numba._helperlib.rnd_get_np_state_ptr() | |
return numba._helperlib.rnd_get_state(ptr) | |
def set_random_state(_state_tuple): | |
ptr = numba._helperlib.rnd_get_np_state_ptr() | |
numba._helperlib.rnd_set_state(ptr, _state_tuple) | |
class ComplexObject(): | |
def __init__(self, seed, _id): | |
set_random_seed(seed) | |
self.value = 0 | |
self.id = _id | |
self.chkp_file = 'chkp_file_' + str(self.id) + '.npz' | |
def update(self): | |
val = get_random_number() | |
self.value += val | |
def save_object_state(self): | |
np.savez_compressed(self.chkp_file, | |
value=self.value, | |
rng_state=get_random_state()) | |
def load_object_state(self, file=None): | |
if file: | |
checkpoint = np.load(file, allow_pickle=True) | |
else: | |
checkpoint = np.load(self.chkp_file, allow_pickle=True) | |
rng_state = tuple(checkpoint['rng_state']) | |
self.value = checkpoint['value'] | |
set_random_state(rng_state) | |
def parallel_func(i): | |
coA = ComplexObject(42, i) | |
for _ in range(500): | |
coA.update() | |
print(i, coA.value) | |
coA.save_object_state() | |
# step 3, load up a new model and run for 5 steps | |
coB = ComplexObject(123, i) | |
coB.load_object_state() | |
for _ in range(500): | |
coB.update() | |
print(i, coB.value) | |
if __name__ == '__main__': | |
# step 1, run model for 10 steps total | |
co1 = ComplexObject(42, 1) | |
for _ in range(1000): | |
co1.update() | |
print(co1.value) | |
# step 2, run model for 5 steps then save | |
co2 = ComplexObject(42, 2) | |
for _ in range(500): | |
co2.update() | |
print(co2.value) | |
co2.save_object_state() | |
# step 3, load up a new model and run for 5 steps | |
co3 = ComplexObject(123, 3) | |
co3.load_object_state('chkp_file_2.npz') | |
for _ in range(500): | |
co3.update() | |
print(co3.value) | |
print("\nParallel version:") | |
_pool = multiprocessing.Pool(processes=3) | |
_pool.map(parallel_func, range(3)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Output from
rng_mwe.py
:Output from
rng_mwe_parallel.py
: