Skip to content

Instantly share code, notes, and snippets.

@zsunberg
Created August 31, 2018 23:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zsunberg/d4013bd63c71352e6e8269a196c08f1b to your computer and use it in GitHub Desktop.
Save zsunberg/d4013bd63c71352e6e8269a196c08f1b to your computer and use it in GitHub Desktop.
Grid world benchmark showing that the current julia compiler cannot handle multiple state types. Output for julia 1.0 at bottom.
using POMDPs
using POMDPModelTools
using POMDPSimulators
using POMDPPolicies
using StaticArrays
using Parameters
using Random
using BenchmarkTools
using POMDPModels
using Test
# Common
const GWPos = SVector{2,Int}
abstract type AbstractGridWorld{S} <: MDP{S, Symbol} end
const dir = Dict(:up=>GWPos(0,1), :down=>GWPos(0,-1), :left=>GWPos(-1,0), :right=>GWPos(1,0))
const aind = Dict(:up=>1, :down=>2, :left=>3, :right=>4)
POMDPs.actions(mdp::AbstractGridWorld) = (:up, :down, :left, :right)
Base.rand(rng::AbstractRNG, t::Tuple) = t[rand(rng, 1:length(t))] # don't know why this doesn't work out of the box
POMDPs.n_states(mdp::AbstractGridWorld) = prod(mdp.size) + 1
POMDPs.n_actions(mdp::AbstractGridWorld) = 4
POMDPs.discount(mdp::AbstractGridWorld) = mdp.discount
POMDPs.actionindex(mdp::AbstractGridWorld, a::Symbol) = aind[a]
POMDPs.reward(mdp::AbstractGridWorld, s::GWPos, a::Symbol) = get(mdp.rewards, s, 0.0)
POMDPs.initialstate(mdp::AbstractGridWorld, rng::AbstractRNG) = GWPos(rand(rng, 1:mdp.size[1]), rand(rng, 1:mdp.size[2]))
# attempts to eliminate extraneous allocations
# @inline clamp2(v::GWPos, l, u) = GWPos(clamp(v[1], l[1], u[1]), clamp(v[2], l[2], u[2]))
# function neighbors(mdp::AbstractGridWorld, s)
# return (GWPos(s[1], min(s[2]+1, mdp.size[2])), # up
# GWPos(s[1], max(s[2]-1, 1)), # down
# GWPos(max(s[1]-1, 1), s[1]), # left
# GWPos(min(s[1]+1, mdp.size[1]), s[2])
# )
# end
######################################
# Simple version using TerminalState #
######################################
const StateTypes = Union{GWPos, TerminalState}
@with_kw struct SimpleGridWorld <: AbstractGridWorld{StateTypes}
size::Tuple{Int, Int} = (10,10)
rewards::Dict{GWPos, Float64} = Dict(GWPos(4,3)=>-10.0, GWPos(4,6)=>-5.0, GWPos(9,3)=>10.0, GWPos(8,8)=>3.0)
terminate_in::Set{GWPos} = Set((GWPos(4,3), GWPos(4,6), GWPos(9,3), GWPos(8,8)))
tprob::Float64 = 0.7
discount::Float64 = 0.95
end
function POMDPs.states(mdp::SimpleGridWorld)
ss = vec(StateTypes[GWPos(x, y) for x in 1:mdp.size[1], y in mdp.size[2]])
push!(ss, terminalstate)
return ss
end
POMDPs.stateindex(mdp::SimpleGridWorld, s::GWPos) = LinearIndices(mdp.size)[s...]
POMDPs.stateindex(mdp::SimpleGridWorld, s::TerminalState) = prod(mdp.size) + 1
function POMDPs.transition(mdp::SimpleGridWorld, s::GWPos, a::Symbol)
if s in mdp.terminate_in
return Deterministic(terminalstate)
end
## This causes allocations
# neighbors = map(actions(mdp)) do act
# clamp2(s+dir[act], (1,1), mdp.size) # clamp out of bounds to inbounds
# end
neighbors = map(actions(mdp)) do act
s + dir[act]
end
probs = map(actions(mdp)) do act
if act == a
return mdp.tprob # probability of transitioning to the desired cell
else
return (1.0 - mdp.tprob)/3 # probability of transitioning to another cell
end
end
return SparseCat(neighbors, probs)
end
#######################
# Type-stable version #
#######################
const tv = GWPos(-1,-1)
@with_kw struct SimpleTypeStableGridWorld <: AbstractGridWorld{GWPos}
size::Tuple{Int, Int} = (10,10)
rewards::Dict{GWPos, Float64} = Dict(GWPos(4,3)=>-10.0, GWPos(4,6)=>-5.0, GWPos(9,3)=>10.0, GWPos(8,8)=>3.0)
terminate_in::Set{GWPos} = Set((GWPos(4,3), GWPos(4,6), GWPos(9,3), GWPos(8,8)))
tprob::Float64 = 0.7
discount::Float64 = 0.95
end
function POMDPs.states(mdp::SimpleTypeStableGridWorld)
ss = vec(GWPos[GWPos(x, y) for x in 1:mdp.size[1], y in mdp.size[2]])
push!(ss, GWPos(-1,-1))
return ss
end
function POMDPs.stateindex(mdp::SimpleTypeStableGridWorld, s::GWPos)
if all(s.>0)
return LinearIndices(mdp.size)[s...]
else
return prod(mdp.size + 1)
end
end
function POMDPs.transition(mdp::SimpleTypeStableGridWorld, s::GWPos, a::Symbol)
if s in mdp.terminate_in
return SparseCat((tv, tv, tv, tv), (1.0, 0.0, 0.0, 0.0))
end
## This causes allocations
# neighbors = map(actions(mdp)) do act
# clamp2(s+dir[act], (1,1), mdp.size) # clamp out of bounds to inbounds
# end
neighbors = map(actions(mdp)) do act
s + dir[act]
end
probs = map(actions(mdp)) do act
if act == a
return mdp.tprob # probability of transitioning to the desired cell
else
return (1.0 - mdp.tprob)/3 # probability of transitioning to another cell
end
end
return SparseCat(neighbors, probs)
end
POMDPs.isterminal(::SimpleTypeStableGridWorld, s::GWPos) = any(s.<0)
####################
# Benchmark Script #
####################
mdps = [GridWorld(terminals = Set()),
SimpleGridWorld(terminate_in = Set()),
SimpleTypeStableGridWorld(terminate_in = Set())
]
@inferred transition(mdps[3], GWPos(1,1), :up)
for m in mdps
@show typeof(m)
policy = RandomPolicy(m, rng=MersenneTwister(7))
rosim = RolloutSimulator(max_steps=10_000, rng=MersenneTwister(2))
@btime simulate($rosim, $m, $policy)
end
############
## OUTPUT ##
############
#=
julia> include("gw_bench.jl")
WARNING: redefining constant dir
WARNING: redefining constant aind
typeof(m) = GridWorld
3.265 ms (40000 allocations: 3.05 MiB)
typeof(m) = SimpleGridWorld
100.529 ms (436593 allocations: 12.55 MiB)
typeof(m) = SimpleTypeStableGridWorld
12.241 μs (73 allocations: 2.09 KiB)
=#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment