Created
May 24, 2012 06:48
-
-
Save basarat/2779865 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def agent(env, alpha, epsilon, initialQ, gamma, numberEp, alg): #the agent function | |
n_s = env.numStates | |
n_a = env.numActions | |
Q = initialQ * np.ones([n_s, n_a]) #initialize Q-table | |
epLength = [] #stores the episode length for plotting purposes | |
eps = 0 | |
while eps < numberEp: #run numberEp episodes | |
s = env.initState | |
if alg==0: #SARSA: | |
#chose the first action of the episode | |
if eps / 10 == eps / 10.0: | |
a = selectAction(Q, s, 0, n_a) #if we don't use epsilon exploration in this episode | |
else: | |
a = selectAction(Q, s, epsilon, n_a) #if we use epsilon exploration in this episode | |
eplen = 0 | |
total_R = 0 | |
while 1: #loop until break | |
eplen += 1 #increase counter for episode length | |
# | |
#Add code here! | |
#You need to get new state and reward based on the current action, | |
#choose a new action | |
#and update the Q values | |
# | |
#get new state and reward | |
(sc, r) = env.generate(a) | |
#choose the action | |
if eps / 10 == eps / 10.0: | |
ac = selectAction(Q, sc, 0, n_a) #if we don't use epsilon exploration in this episode | |
else: | |
ac = selectAction(Q, sc, epsilon, n_a) #if we use epsilon exploration in this episode | |
#update the Q values | |
if eplen > 100000: | |
print 'previous state:', s, 'we took action:', a, 'orig Q', Q[s, :], | |
Q[s, a] = Q[s, a] + alpha * (r + gamma * Q[sc, ac] - Q[s, a]) | |
if eplen > 100000: | |
print 'new Q:', Q[s, :] | |
a = ac | |
s = sc | |
if (env.currentCoord == env.goalCoord): #End of episode | |
if eps / 10 == eps / 10.0: #logging info for plotting | |
epLength.append(eplen) | |
#print eplen | |
eps += 1 | |
env.reset() | |
break # move on to next episode | |
if alg==1: #Q-learning: | |
eplen = 0 | |
total_R = 0 | |
while 1: #loop until break | |
eplen += 1 #increase counter for episode length | |
#chose the action of the episode | |
if eps / 10 == eps / 10.0: | |
a = selectAction(Q, s, 0, n_a) #if we don't use epsilon exploration in this episode | |
else: | |
a = selectAction(Q, s, epsilon, n_a) #if we use epsilon exploration in this episode | |
#get new state and reward | |
(sc, r) = env.generate(a) | |
#update the Q values | |
if eplen > 100000: | |
print 'previous state:', s, 'we took action:', a, 'orig Q', Q[s, :], | |
Q[s, a] = Q[s, a] + alpha * (r + gamma * max(Q[sc, :]) - Q[s, a]) | |
if eplen > 100000: | |
print 'new Q:', Q[s, :] | |
s = sc | |
if (env.currentCoord == env.goalCoord): #End of episode | |
if eps / 10 == eps / 10.0: #logging info for plotting | |
epLength.append(eplen) | |
#print eplen | |
eps += 1 | |
env.reset() | |
break # move on to next episode | |
return epLength |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment