Created
January 31, 2024 00:18
-
-
Save bikcrum/505132f00aedda13974dfd53ab802ef5 to your computer and use it in GitHub Desktop.
Value iteration using numpy broadcasting
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Value iteration compact | |
# env [MDP environment] | |
# nS [state_dim] | |
# env.TransitProb [state_dim x action_dim x state_dim] | |
# env.TransitReward [state_dim x action_dim], Note: It maybe in shape [state_dim x action_dim x state_dim]. In such | |
# case no need to expand dim. It can also be in the same [state_dim], then it is required to expand dim twice accordingly | |
beta = 0.99 # [discount factor] | |
theta = 0.0001 # [threshold for termination] | |
V = np.zeros(S) # [Value function] | |
delta = float('inf') | |
while delta > theta: | |
X = np.sum(env.TransitProb * (np.expand_dims(env.TransitReward, axis=-1) + beta * V), axis=-1) | |
V_new = np.max(X, axis=-1) | |
delta = np.max(np.abs(V - V_new)) | |
V = V_new | |
P = np.argmax(X, axis=-1) # [Policy] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment