Skip to content

Instantly share code, notes, and snippets.

@amoodie
Last active July 28, 2020 17:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save amoodie/1683dbc910af9d13e29fdd61f227f301 to your computer and use it in GitHub Desktop.
Save amoodie/1683dbc910af9d13e29fdd61f227f301 to your computer and use it in GitHub Desktop.
# test the possibility of setting and getting mwe within jitted functions
import numpy as np
import numba
@numba.jit
def set_random_seed(_seed):
np.random.seed(_seed)
@numba.jit
def get_random_number():
return np.random.uniform(0, 1)
def get_random_state():
ptr = numba._helperlib.rnd_get_np_state_ptr()
return numba._helperlib.rnd_get_state(ptr)
def set_random_state(_state_tuple):
ptr = numba._helperlib.rnd_get_np_state_ptr()
numba._helperlib.rnd_set_state(ptr, _state_tuple)
class ComplexObject():
def __init__(self, seed):
set_random_seed(seed)
self.value = 0
def update(self):
val = get_random_number()
self.value += val
def save_object_state(self):
np.savez_compressed('chkp_file.npz',
value=self.value,
rng_state=get_random_state())
def load_object_state(self):
checkpoint = np.load('chkp_file.npz', allow_pickle=True)
rng_state = tuple(checkpoint['rng_state'])
self.value = checkpoint['value']
set_random_state(rng_state)
if __name__ == '__main__':
# step 1, run model for 10 steps total
co1 = ComplexObject(42)
for _ in range(10):
co1.update()
print(co1.value)
# step 2, run model for 5 steps then save
co2 = ComplexObject(42)
for _ in range(5):
co2.update()
print(co2.value)
co2.save_object_state()
# step 3, load up a new model and run for 5 steps
co3 = ComplexObject(123)
co3.load_object_state()
for _ in range(5):
co3.update()
print(co3.value)
# test the possibility of setting and getting mwe within jitted functions
import numpy as np
import numba
import multiprocessing
@numba.jit
def set_random_seed(_seed):
np.random.seed(_seed)
@numba.jit
def get_random_number():
return np.random.uniform(0, 1)
def get_random_state():
ptr = numba._helperlib.rnd_get_np_state_ptr()
return numba._helperlib.rnd_get_state(ptr)
def set_random_state(_state_tuple):
ptr = numba._helperlib.rnd_get_np_state_ptr()
numba._helperlib.rnd_set_state(ptr, _state_tuple)
class ComplexObject():
def __init__(self, seed, _id):
set_random_seed(seed)
self.value = 0
self.id = _id
self.chkp_file = 'chkp_file_' + str(self.id) + '.npz'
def update(self):
val = get_random_number()
self.value += val
def save_object_state(self):
np.savez_compressed(self.chkp_file,
value=self.value,
rng_state=get_random_state())
def load_object_state(self, file=None):
if file:
checkpoint = np.load(file, allow_pickle=True)
else:
checkpoint = np.load(self.chkp_file, allow_pickle=True)
rng_state = tuple(checkpoint['rng_state'])
self.value = checkpoint['value']
set_random_state(rng_state)
def parallel_func(i):
coA = ComplexObject(42, i)
for _ in range(500):
coA.update()
print(i, coA.value)
coA.save_object_state()
# step 3, load up a new model and run for 5 steps
coB = ComplexObject(123, i)
coB.load_object_state()
for _ in range(500):
coB.update()
print(i, coB.value)
if __name__ == '__main__':
# step 1, run model for 10 steps total
co1 = ComplexObject(42, 1)
for _ in range(1000):
co1.update()
print(co1.value)
# step 2, run model for 5 steps then save
co2 = ComplexObject(42, 2)
for _ in range(500):
co2.update()
print(co2.value)
co2.save_object_state()
# step 3, load up a new model and run for 5 steps
co3 = ComplexObject(123, 3)
co3.load_object_state('chkp_file_2.npz')
for _ in range(500):
co3.update()
print(co3.value)
print("\nParallel version:")
_pool = multiprocessing.Pool(processes=3)
_pool.map(parallel_func, range(3))
@amoodie
Copy link
Author

amoodie commented Jul 28, 2020

Output from rng_mwe.py:

5.201367359526748
2.811925491708157
5.201367359526748

Output from rng_mwe_parallel.py:

490.25655332013355
249.28085611700678
490.25655332013355

Parallel version:
0 249.28085611700678
1 249.28085611700678
0 490.25655332013355
1 490.25655332013355
2 249.28085611700678
2 490.25655332013355

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment