Created
July 28, 2020 16:45
-
-
Save amoodie/26cbf4871e8951b6f221852b44065068 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# test the possibility of setting and getting mwe within jitted functions | |
import numpy as np | |
import numba | |
@numba.jit | |
def set_random_seed(_seed): | |
np.random.seed(_seed) | |
@numba.jit | |
def get_random_number(): | |
return np.random.uniform(0, 1) | |
def get_random_state(): | |
ptr = numba._helperlib.rnd_get_np_state_ptr() | |
return numba._helperlib.rnd_get_state(ptr) | |
def set_random_state(_state_tuple): | |
ptr = numba._helperlib.rnd_get_np_state_ptr() | |
numba._helperlib.rnd_set_state(ptr, _state_tuple) | |
class ComplexObject(): | |
def __init__(self, seed): | |
set_random_seed(seed) | |
self.value = 0 | |
def update(self): | |
val = get_random_number() | |
self.value += val | |
def save_object_state(self): | |
np.savez_compressed('chkp_file.npz', | |
value=self.value, | |
rng_state=get_random_state()) | |
def load_object_state(self): | |
checkpoint = np.load('chkp_file.npz', allow_pickle=True) | |
rng_state = tuple(checkpoint['rng_state']) | |
self.value = checkpoint['value'] | |
set_random_state(rng_state) | |
if __name__ == '__main__': | |
# step 1, run model for 10 steps total | |
co1 = ComplexObject(42) | |
for _ in range(10): | |
co1.update() | |
print(co1.value) | |
# step 2, run model for 5 steps then save | |
co2 = ComplexObject(42) | |
for _ in range(5): | |
co2.update() | |
print(co2.value) | |
co2.save_object_state() | |
# step 3, load up a new model and run for 5 steps | |
co3 = ComplexObject(123) | |
co3.load_object_state() | |
for _ in range(5): | |
co3.update() | |
print(co3.value) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment