Skip to content

Instantly share code, notes, and snippets.

@EXJUSTICE
Created April 20, 2020 14:16
Show Gist options
  • Save EXJUSTICE/886cb62d5e981cc185cab4a3aa3aef27 to your computer and use it in GitHub Desktop.
Save EXJUSTICE/886cb62d5e981cc185cab4a3aa3aef27 to your computer and use it in GitHub Desktop.
def mc_prediction_q(env, num_episodes, generate_episode, gamma=1.0):
# Dictionary for returns
returns_sum = defaultdict(lambda: np.zeros(env.action_space.n))
#Dictionary for Number of visits
N = defaultdict(lambda: np.zeros(env.action_space.n))
#Action Values for State-Action Pair
Q = defaultdict(lambda: np.zeros(env.action_space.n))
# loop over episodes
for i_episode in range(1, num_episodes+1):
# monitor progress
if i_episode % 1000 == 0:
print("\rEpisode {}/{}.".format(i_episode, num_episodes), end="")
sys.stdout.flush()
# generate an episode - generate list of state-action tuples
episode = generate_episode(env)
# obtain the states, actions, and rewards. zip returns an iterator, see the explanation in previous cell
states, actions, rewards = zip(*episode)
# Discounting multiplier Prepare an array that discounts with gamma taken to ith power
discounts = np.array([gamma**i for i in range(len(rewards)+1)])
# update the sum of the returns, number of visits, and action-value
# function estimates for each state-action pair in the episode
#Recall enumerate returns the index + state., here i is the TIMESTEP.
#We loop over timesteps i,
"""
Recall that zips return
Example: Single episode, with three timesteps
[((14, 9, False), 1, 0), ((16, 9, False), 1, 0), ((17, 9, False), 1, -1)]
states: ((14, 9, False), (16, 9, False), (17, 9, False))
actions (1, 1, 1)
rewards (0, 0, -1)
"""
#Access states by timestep in an episode, and use index to access actions taken by timestep
for i in range(len(states)):
#Sum up the rewards of the state-action pairs for each timestep.
# Recall the rewards are simply all of the discounted returns FROM THAT MOMENT ONWARDS,
# THIS MAY NEED CLOSER ANALYSIS
returns_sum[states[i]][actions[i]] += sum(rewards[i:]*discounts[:-(1+i)])
#Increment Number of visits for the state-aciton pair visited at a particular timestep.
N[states[i]][actions[i]] += 1.0
#Q state action valus is simply the sum of total returns for that pair/ number of visits.
#So we update that
Q[states[i]][actions[i]] = returns_sum[states[i]][actions[i]] / N[states[i]][actions[i]]
#My own version that isnt confusing
return Q
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment