Skip to content

Instantly share code, notes, and snippets.

Created October 10, 2016 20:04
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save harahu/5ef5c044aac6c1cac31fb86792e1e053 to your computer and use it in GitHub Desktop.
Save harahu/5ef5c044aac6c1cac31fb86792e1e053 to your computer and use it in GitHub Desktop.
Q-learning on the Frozen Lake domain
import gym, random
class qTable:
Implements a table tracking the estimated values
for state action pairs in an MDP.
def __init__(self, nS, nA):
self.nS = nS
self.nA = nA
self.table = [[0 for i in range(nA)] for j in range(nS)]
def getQ(self, s, a):
def setQ(self, s, a, value):
self.table[s][a] = value
def getMaxQ(self, s):
Returns the highest Q-value
for a given state.
hVal = self.table[s][0]
for a in range(self.nA):
aVal = self.table[s][a]
if aVal > hVal:
hVal = aVal
return hVal
def getMaxQAction(self, s):
Returns the action that has the highest Q-value
for a given state.
h = 0
hVal = self.table[s][0]
for a in range(self.nA):
aVal = self.table[s][a]
if aVal > hVal:
h = a
hVal = aVal
return h
def epsilonGreedy(epsilon, env, obs, qtab):
if random.random() < epsilon:
action = env.action_space.sample()
action = qtab.getMaxQAction(obs)
return action
def main():
env = gym.make('FrozenLake-v0')
rewardWindow = [0 for _ in range(100)]
qtab = qTable(env.observation_space.n, env.action_space.n)
epsilon = 1
for i_episode in range(8000):
observation = env.reset()
accumulatedReward = 0
for t in range(10000):
#Render enviorment
#Select action
action = epsilonGreedy(epsilon, env, observation, qtab)
#Perform action
prevObs = observation
observation, reward, done, info = env.step(action)
accumulatedReward += reward
#Update Q
oldQ = qtab.getQ(prevObs, action)
maxCurrQ = qtab.getMaxQ(observation)
newQ = oldQ + LEARNING_RATE*(reward + DISCOUNT*maxCurrQ - oldQ)
qtab.setQ(prevObs, action, newQ)
#Check if episode is done
if done:
rewardWindow[i_episode % 99] = accumulatedReward
#Decrease exploration rate
epsilon *= 0.998
windowAvg = 0
for i in rewardWindow:
windowAvg += i
print(i_episode, " ", windowAvg)
#if windowAvg >= 78:
if __name__ == '__main__':
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment