Skip to content

Instantly share code, notes, and snippets.

@amoodie
Created July 28, 2020 16:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save amoodie/26cbf4871e8951b6f221852b44065068 to your computer and use it in GitHub Desktop.
Save amoodie/26cbf4871e8951b6f221852b44065068 to your computer and use it in GitHub Desktop.
# test the possibility of setting and getting mwe within jitted functions
import numpy as np
import numba
@numba.jit
def set_random_seed(_seed):
np.random.seed(_seed)
@numba.jit
def get_random_number():
return np.random.uniform(0, 1)
def get_random_state():
ptr = numba._helperlib.rnd_get_np_state_ptr()
return numba._helperlib.rnd_get_state(ptr)
def set_random_state(_state_tuple):
ptr = numba._helperlib.rnd_get_np_state_ptr()
numba._helperlib.rnd_set_state(ptr, _state_tuple)
class ComplexObject():
def __init__(self, seed):
set_random_seed(seed)
self.value = 0
def update(self):
val = get_random_number()
self.value += val
def save_object_state(self):
np.savez_compressed('chkp_file.npz',
value=self.value,
rng_state=get_random_state())
def load_object_state(self):
checkpoint = np.load('chkp_file.npz', allow_pickle=True)
rng_state = tuple(checkpoint['rng_state'])
self.value = checkpoint['value']
set_random_state(rng_state)
if __name__ == '__main__':
# step 1, run model for 10 steps total
co1 = ComplexObject(42)
for _ in range(10):
co1.update()
print(co1.value)
# step 2, run model for 5 steps then save
co2 = ComplexObject(42)
for _ in range(5):
co2.update()
print(co2.value)
co2.save_object_state()
# step 3, load up a new model and run for 5 steps
co3 = ComplexObject(123)
co3.load_object_state()
for _ in range(5):
co3.update()
print(co3.value)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment