Last active
August 30, 2018 22:19
-
-
Save zsunberg/50518f90dff071ca830cfb1e46380164 to your computer and use it in GitHub Desktop.
A simpler verion of a grid world problem for POMDPs.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const Vec2 = SVector{2,Int} | |
const StateTypes = Union{Vec2, TerminalState} | |
@with_kw struct SimpleGridWorld <: MDP{StateTypes, Symbol} | |
size::Tuple{Int, Int} = (10,10) | |
rewards::Dict{Vec2, Float64} = Dict(Vec2(4,3)=>-10.0, Vec2(4,6)=>-5.0, Vec2(9,3)=>10.0, Vec2(8,8)=>3.0) | |
terminate_in::Set{Vec2} = Set((Vec2(4,3), Vec2(4,6), Vec2(9,3), Vec2(8,8))) | |
tprob::Float64 = 0.7 | |
discount::Float64 = 0.95 | |
end | |
function POMDPs.states(mdp::SimpleGridWorld) | |
ss = vec(StateTypes[Vec2(x, y) for x in 1:mdp.size[1], y in mdp.size[2]]) | |
push!(ss, terminalstate) | |
return ss | |
end | |
const dir = Dict(:up=>Vec2(0,1), :down=>Vec2(0,-1), :left=>Vec2(-1,0), :right=>Vec2(1,0)) | |
const aind = Dict(:up=>1, :down=>2, :left=>3, :right=>4) | |
POMDPs.actions(mdp::SimpleGridWorld) = SVector(:up, :down, :left, :right) | |
POMDPs.n_states(mdp::SimpleGridWorld) = prod(mdp.size) + 1 | |
POMDPs.n_actions(mdp::SimpleGridWorld) = 4 | |
POMDPs.discount(mdp::SimpleGridWorld) = mdp.discount | |
POMDPs.stateindex(mdp::SimpleGridWorld, s::Vec2) = LinearIndices(mdp.size)[s...] | |
POMDPs.stateindex(mdp::SimpleGridWorld, s::TerminalState) = prod(mdp.size) + 1 | |
POMDPs.actionindex(mdp::SimpleGridWorld, a::Symbol) = aind[a] | |
POMDPs.reward(mdp::SimpleGridWorld, s::Vec2, a::Symbol) = get(mdp.rewards, s, 0.0) | |
POMDPs.initialstate_distribution(mdp::SimpleGridWorld) = uniform_belief(mdp) | |
POMDPs.initialstate(mdp::SimpleGridWorld, rng::AbstractRNG) = Vec2(rand(rng, 1:mdp.size[1]), rand(rng, 1:mdp.size[2])) | |
function POMDPs.transition(mdp::SimpleGridWorld, s::Vec2, a::Symbol) | |
if s in mdp.terminate_in | |
return Deterministic(terminalstate) | |
end | |
neighbors = map(actions(mdp)) do act | |
clamp(s+dir[act], Vec2(0,0), Vec2(mdp.size)) # clamp out of bounds to inbounds | |
end | |
probs = map(actions(mdp)) do act | |
if act == a | |
return mdp.tprob # probability of transitioning to the desired cell | |
else | |
return (1.0 - mdp.tprob)/3 # probability of transitioning to another cell | |
end | |
end | |
return SparseCat(neighbors, probs) | |
end | |
POMDPs.convert_s(::Type{A}, s::TerminalState, mdp::SimpleGridWorld) where A = convert(A, Vec2(0,0)) | |
function POMDPs.convert_s(::Type{TerminalState}, s::AbstractArray, mdp::SimpleGridWorld) where S <: StateTypes | |
@assert s == SVector(0,0) | |
return terminalstate::S | |
end | |
POMDPs.convert_a(::Type{A}, a::Symbol, mdp::SimpleGridWorld) where A<:AbstractArray = convert(A, SVector(aind[a])) | |
POMDPs.convert_a(::Type{Symbol}, a::AbstractArray, mdp::SimpleGridWorld) = actions(mdp)[first(a)] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment