Skip to content

Instantly share code, notes, and snippets.

@zsunberg
Last active August 30, 2018 22:19
Show Gist options
  • Save zsunberg/50518f90dff071ca830cfb1e46380164 to your computer and use it in GitHub Desktop.
Save zsunberg/50518f90dff071ca830cfb1e46380164 to your computer and use it in GitHub Desktop.
A simpler verion of a grid world problem for POMDPs.jl
const Vec2 = SVector{2,Int}
const StateTypes = Union{Vec2, TerminalState}
@with_kw struct SimpleGridWorld <: MDP{StateTypes, Symbol}
size::Tuple{Int, Int} = (10,10)
rewards::Dict{Vec2, Float64} = Dict(Vec2(4,3)=>-10.0, Vec2(4,6)=>-5.0, Vec2(9,3)=>10.0, Vec2(8,8)=>3.0)
terminate_in::Set{Vec2} = Set((Vec2(4,3), Vec2(4,6), Vec2(9,3), Vec2(8,8)))
tprob::Float64 = 0.7
discount::Float64 = 0.95
end
function POMDPs.states(mdp::SimpleGridWorld)
ss = vec(StateTypes[Vec2(x, y) for x in 1:mdp.size[1], y in mdp.size[2]])
push!(ss, terminalstate)
return ss
end
const dir = Dict(:up=>Vec2(0,1), :down=>Vec2(0,-1), :left=>Vec2(-1,0), :right=>Vec2(1,0))
const aind = Dict(:up=>1, :down=>2, :left=>3, :right=>4)
POMDPs.actions(mdp::SimpleGridWorld) = SVector(:up, :down, :left, :right)
POMDPs.n_states(mdp::SimpleGridWorld) = prod(mdp.size) + 1
POMDPs.n_actions(mdp::SimpleGridWorld) = 4
POMDPs.discount(mdp::SimpleGridWorld) = mdp.discount
POMDPs.stateindex(mdp::SimpleGridWorld, s::Vec2) = LinearIndices(mdp.size)[s...]
POMDPs.stateindex(mdp::SimpleGridWorld, s::TerminalState) = prod(mdp.size) + 1
POMDPs.actionindex(mdp::SimpleGridWorld, a::Symbol) = aind[a]
POMDPs.reward(mdp::SimpleGridWorld, s::Vec2, a::Symbol) = get(mdp.rewards, s, 0.0)
POMDPs.initialstate_distribution(mdp::SimpleGridWorld) = uniform_belief(mdp)
POMDPs.initialstate(mdp::SimpleGridWorld, rng::AbstractRNG) = Vec2(rand(rng, 1:mdp.size[1]), rand(rng, 1:mdp.size[2]))
function POMDPs.transition(mdp::SimpleGridWorld, s::Vec2, a::Symbol)
if s in mdp.terminate_in
return Deterministic(terminalstate)
end
neighbors = map(actions(mdp)) do act
clamp(s+dir[act], Vec2(0,0), Vec2(mdp.size)) # clamp out of bounds to inbounds
end
probs = map(actions(mdp)) do act
if act == a
return mdp.tprob # probability of transitioning to the desired cell
else
return (1.0 - mdp.tprob)/3 # probability of transitioning to another cell
end
end
return SparseCat(neighbors, probs)
end
POMDPs.convert_s(::Type{A}, s::TerminalState, mdp::SimpleGridWorld) where A = convert(A, Vec2(0,0))
function POMDPs.convert_s(::Type{TerminalState}, s::AbstractArray, mdp::SimpleGridWorld) where S <: StateTypes
@assert s == SVector(0,0)
return terminalstate::S
end
POMDPs.convert_a(::Type{A}, a::Symbol, mdp::SimpleGridWorld) where A<:AbstractArray = convert(A, SVector(aind[a]))
POMDPs.convert_a(::Type{Symbol}, a::AbstractArray, mdp::SimpleGridWorld) = actions(mdp)[first(a)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment