Created
January 31, 2024 00:20
-
-
Save bikcrum/5f336eeaee78d848b6453d27ce9b274a to your computer and use it in GitHub Desktop.
Policy iteration using numpy broadcasting
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# env [MDP environment] | |
# nS [state_dim] | |
# nA [action_dim] | |
# env.TransitProb [state_dim x action_dim x state_dim] | |
# env.TransitReward [state_dim x action_dim], Note: It maybe in shape [state_dim x action_dim x state_dim]. In such | |
# case no need to expand dim. It can also be in the same [state_dim], then it is required to expand dim twice accordingly | |
beta = 0.99 # [discount factor] | |
theta = 0.0001 # [threshold for termination] | |
V = np.zeros(S) # [Value function] | |
delta = float('inf') | |
P = np.ones((nS, nA)) / nA # [Initialize uniform stochastic policy] | |
while True: | |
# Eval policy | |
delta = float('inf') | |
V = np.zeros(env.GetStateSpace()) | |
while delta > 0.0001: | |
X = P * np.sum( | |
np.transpose(env.TransitProb, [1, 0, 2]) * (np.expand_dims(env.TransitReward, axis=-1) + beta * V), | |
axis=-1) | |
V_new = np.sum(X, axis=-1) | |
delta = np.max(np.abs(V - V_new)) | |
V = V_new | |
# Improve policy | |
chosen_A = np.argmax(P, axis=-1) | |
best_A = np.argmax( | |
np.sum(np.transpose(env.TransitProb, [1, 0, 2]) * (np.expand_dims(env.TransitReward, axis=-1) + beta * V), | |
axis=-1), axis=-1) | |
if not np.any(chosen_A != best_A): | |
break | |
P.fill(0.0) | |
P[np.arange(nS), best_A] = 1.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment