Skip to content

Instantly share code, notes, and snippets.

@dominusmi
Last active October 9, 2018 07:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dominusmi/dff15ac1c3d211ae250ac0a5d73e0c3c to your computer and use it in GitHub Desktop.
Save dominusmi/dff15ac1c3d211ae250ac0a5d73e0c3c to your computer and use it in GitHub Desktop.
Value iteration for GridWorlds.jl
""" Value-iteration algorithm """
function value_iteration(mdp, gamma = 1.0)
""" Value-iteration algorithm """
v = zeros(n_states(mdp)) # initialize value-function
max_iterations = 1000
eps = 1e-10
for iter in 1:max_iterations
prev_v = copy(v)
for s in states(mdp)
sᵢ = stateindex(mdp, s)
Q = zeros(n_actions(mdp))
for a in actions(mdp)
aᵢ = actionindex(mdp, a)
sparse_cat = transition(mdp, s, a)
new_s = sparse_cat.vals
new_p = sparse_cat.probs
for j in 1:size(new_p,1)
sⱼ⁻ = stateindex(mdp, new_s[j])
curr_v = prev_v[sⱼ⁻]
r = reward(mdp, new_s[j], a)
Q[aᵢ] += new_p[j] * (r + discount(mdp) * curr_v)
end
end
v[sᵢ] = maximum(Q)
end
if sum(abs.(prev_v - v)) <= eps
println("Value-iteration converged at iteration #$iter")
break
end
end
return v
end
# Simple example
using GridWorlds, Plots
mdp = GridWorld()
V = value_iteration(mdp)
heatmap(reshape(V,(10,10)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment