Skip to content

Instantly share code, notes, and snippets.

@cocoademon
Last active February 13, 2018 03:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cocoademon/b2ae6b8200d9b26de073b56707e9d7de to your computer and use it in GitHub Desktop.
Save cocoademon/b2ae6b8200d9b26de073b56707e9d7de to your computer and use it in GitHub Desktop.
import timeit
import random
import itertools
import numpy as np, numpy.random as rnd
num_states = 200000
num_actions = 10
def python_dict_test():
# here's our q table, initialized with all possible state, action pairs
dict_q = {}
all_states = ( ("%s,%s" % x ) for x in itertools.product( range(num_states), range(num_actions)) )
dict_q = dict.fromkeys(all_states, 0.0 )
for i in range(num_states * num_actions):
# use a random lookup sequence - this isn't a close approximation to how
state = random.randint(0, num_states-1)
action = random.randint(0, num_actions-1)
# Q learning needs an argmax plus a write to each cell
max_action = max((dict_q["%s,%s" % (state, x)] for x in range(num_actions) ) )
dict_q["%s,%s" % (state, action)] = max_action + rnd.random()
def python_list_test():
list_q = [0,] * (num_states * num_actions)
for i in range(num_states * num_actions):
state = random.randint(0, num_states-1)
action = random.randint(0, num_actions-1)
sa = state * num_actions + action
max_action = max( (list_q[state + x] for x in range(num_actions)))
list_q[sa] = max_action + rnd.random()
def numpy_test():
# the numpy version
numpy_q = np.zeros((num_states, num_actions), dtype='f')
for i in range(num_states * num_actions):
state = rnd.randint(0, num_states)
action = rnd.randint(0, num_actions)
max_action = np.amax(numpy_q[state])
numpy_q[state,action] = max_action + rnd.random()
if __name__ == '__main__':
print("Timing...")
num = 2
print("Numpy rand: ", timeit.timeit('k = rnd.randint(0, 200000)', 'import numpy.random as rnd'))
print("Python rand: ", timeit.timeit('k = random.randint(0, 200001)', 'import random'))
print("Python dict time: ", timeit.timeit(python_dict_test, number=num))
print("Python list time: ", timeit.timeit(python_list_test, number=num))
print("Numpy time: ", timeit.timeit(numpy_test, number=num))
@cocoademon
Copy link
Author

cocoademon commented Feb 13, 2018

Results on a Core i5 laptop:

Timing...
Python dict time: 71.05109677843146
Python list time: 39.42484635248813
Numpy time: 76.45893003033947

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment