Skip to content

Instantly share code, notes, and snippets.

@bikcrum
Created January 31, 2024 00:20
Show Gist options
  • Save bikcrum/5f336eeaee78d848b6453d27ce9b274a to your computer and use it in GitHub Desktop.
Save bikcrum/5f336eeaee78d848b6453d27ce9b274a to your computer and use it in GitHub Desktop.
Policy iteration using numpy broadcasting
# env [MDP environment]
# nS [state_dim]
# nA [action_dim]
# env.TransitProb [state_dim x action_dim x state_dim]
# env.TransitReward [state_dim x action_dim], Note: It maybe in shape [state_dim x action_dim x state_dim]. In such
# case no need to expand dim. It can also be in the same [state_dim], then it is required to expand dim twice accordingly
beta = 0.99 # [discount factor]
theta = 0.0001 # [threshold for termination]
V = np.zeros(S) # [Value function]
delta = float('inf')
P = np.ones((nS, nA)) / nA # [Initialize uniform stochastic policy]
while True:
# Eval policy
delta = float('inf')
V = np.zeros(env.GetStateSpace())
while delta > 0.0001:
X = P * np.sum(
np.transpose(env.TransitProb, [1, 0, 2]) * (np.expand_dims(env.TransitReward, axis=-1) + beta * V),
axis=-1)
V_new = np.sum(X, axis=-1)
delta = np.max(np.abs(V - V_new))
V = V_new
# Improve policy
chosen_A = np.argmax(P, axis=-1)
best_A = np.argmax(
np.sum(np.transpose(env.TransitProb, [1, 0, 2]) * (np.expand_dims(env.TransitReward, axis=-1) + beta * V),
axis=-1), axis=-1)
if not np.any(chosen_A != best_A):
break
P.fill(0.0)
P[np.arange(nS), best_A] = 1.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment