Skip to content

Instantly share code, notes, and snippets.

# eleciawhite/td-snapshot.py Created Feb 15, 2017

A Temporal-Difference Learning Snapshot
 # ===== A Temporal-Difference Learning Snapshot ===== # Patrick M. Pilarski, pilarski@ualberta.ca, Feb. 11, 2017 # ---- # 'xt' and 'xpt1' are the state information for the current (time t) and next (time t+1) time steps in the form of binary vectors # e.g., if you had a house with five rooms, 'xtp1' could be [0,0,1,0,0], with the single '1' indicating that you are in Room 3; # in this case, 'xt' might be [0,1,0,0,0] indicating you were just in Room 2; # for a robot servo, instead of "rooms" you could use binned joint angle ranges. # 'r' is the signal to be predicted; this is a scalar and could represent reward, or any other signal of interest # e.g., in the examples above, it might be the reward you get for entering a given room, or the current draw of the servo. # 'gamma' is the desired time scale of the prediction (with 0 <= gamma <= 1); # e.g., gamma = 0.9 will predict 'r' values as summed over roughly 10=1/(1-gamma) steps into the future # 'alpha' is the learning rate; to start, try 0.1 divided by the norm of 'xt' # 'lamda' is the eligibility trace decay; try 0.0 to start, and speed up learning by changing to 0.9 later. # Initialize 'xt', 'xtp1', 'w', and 'e' as numpy vectors of the same length (e.g., numpy.zeros(5)) # The learner's prediction (the "value" of a state or state-action pair) is a linear combination of state 'xt' and the learned weights 'w' # For more detail on reinforcement learning and algorithms, please refer to the following references: # - White, 2015. Developing a predictive approach to knowledge. PhD Thesis, University of Alberta. (c.f., Chapter 2.) # - Sutton and Barto, 1988. Reinforcement learning: An introduction. MIT Press. # ----- The following code goes in your main loop. It does the learning. ----- # xtp1 = # r = pred = numpy.dot(xt,w) # Compute the prediction for the current tilmestep (i.e., for state 'xt') delta = r + gamma*numpy.dot(xtp1,w) - pred # Compute the temporal-difference error 'delta' e = gamma*lamda*e + xt # Update eligibility trace w = w + alpha*delta*e # Update the weight vector xt = xtp1
to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.