Last active
October 9, 2018 07:29
-
-
Save dominusmi/dff15ac1c3d211ae250ac0a5d73e0c3c to your computer and use it in GitHub Desktop.
Value iteration for GridWorlds.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Value-iteration algorithm """ | |
function value_iteration(mdp, gamma = 1.0) | |
""" Value-iteration algorithm """ | |
v = zeros(n_states(mdp)) # initialize value-function | |
max_iterations = 1000 | |
eps = 1e-10 | |
for iter in 1:max_iterations | |
prev_v = copy(v) | |
for s in states(mdp) | |
sᵢ = stateindex(mdp, s) | |
Q = zeros(n_actions(mdp)) | |
for a in actions(mdp) | |
aᵢ = actionindex(mdp, a) | |
sparse_cat = transition(mdp, s, a) | |
new_s = sparse_cat.vals | |
new_p = sparse_cat.probs | |
for j in 1:size(new_p,1) | |
sⱼ⁻ = stateindex(mdp, new_s[j]) | |
curr_v = prev_v[sⱼ⁻] | |
r = reward(mdp, new_s[j], a) | |
Q[aᵢ] += new_p[j] * (r + discount(mdp) * curr_v) | |
end | |
end | |
v[sᵢ] = maximum(Q) | |
end | |
if sum(abs.(prev_v - v)) <= eps | |
println("Value-iteration converged at iteration #$iter") | |
break | |
end | |
end | |
return v | |
end | |
# Simple example | |
using GridWorlds, Plots | |
mdp = GridWorld() | |
V = value_iteration(mdp) | |
heatmap(reshape(V,(10,10))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment