Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
A Temporal-Difference Learning Snapshot
# ===== A Temporal-Difference Learning Snapshot =====
# Patrick M. Pilarski, pilarski@ualberta.ca, Feb. 11, 2017
# ----
# 'xt' and 'xpt1' are the state information for the current (time t) and next (time t+1) time steps in the form of binary vectors
# e.g., if you had a house with five rooms, 'xtp1' could be [0,0,1,0,0], with the single '1' indicating that you are in Room 3;
# in this case, 'xt' might be [0,1,0,0,0] indicating you were just in Room 2;
# for a robot servo, instead of "rooms" you could use binned joint angle ranges.
# 'r' is the signal to be predicted; this is a scalar and could represent reward, or any other signal of interest
# e.g., in the examples above, it might be the reward you get for entering a given room, or the current draw of the servo.
# 'gamma' is the desired time scale of the prediction (with 0 <= gamma <= 1);
# e.g., gamma = 0.9 will predict 'r' values as summed over roughly 10=1/(1-gamma) steps into the future
# 'alpha' is the learning rate; to start, try 0.1 divided by the norm of 'xt'
# 'lamda' is the eligibility trace decay; try 0.0 to start, and speed up learning by changing to 0.9 later.
# Initialize 'xt', 'xtp1', 'w', and 'e' as numpy vectors of the same length (e.g., numpy.zeros(5))
# The learner's prediction (the "value" of a state or state-action pair) is a linear combination of state 'xt' and the learned weights 'w'
# For more detail on reinforcement learning and algorithms, please refer to the following references:
# - White, 2015. Developing a predictive approach to knowledge. PhD Thesis, University of Alberta. (c.f., Chapter 2.)
# - Sutton and Barto, 1988. Reinforcement learning: An introduction. MIT Press.
# ----- The following code goes in your main loop. It does the learning. -----
# xtp1 = <every step, sample some signals from the world and use them to construct 'xtp1'>
# r = <every step, also sample or provide the signal or interest 'r'>
pred = numpy.dot(xt,w) # Compute the prediction for the current tilmestep (i.e., for state 'xt')
delta = r + gamma*numpy.dot(xtp1,w) - pred # Compute the temporal-difference error 'delta'
e = gamma*lamda*e + xt # Update eligibility trace
w = w + alpha*delta*e # Update the weight vector
xt = xtp1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment