Skip to content

Instantly share code, notes, and snippets.

@dressel
Created December 17, 2017 20:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dressel/b5f44535df5c4ca244e6b8f733184769 to your computer and use it in GitHub Desktop.
Save dressel/b5f44535df5c4ca244e6b8f733184769 to your computer and use it in GitHub Desktop.
New implementation of discrete belief
# Goals: minimize calls to ordered_states (allocates memory)
# needs pomdp for state_index in pdf(b, s)
# needs list of ordered_states for rand(b)
mutable struct DiscreteBelief{P<:POMDP, S}
pomdp::P
state_list::Vector{S} # vector of ordered states
b::vector{Float64}
end
function DiscreteBelief(pomdp)
state_list = ordered_states(pomdp)
ns = n_states(pomdp)
b = ones(ns) / ns
return DiscreteBelief(pomdp, os, b)
end
mutable struct DiscreteUpdater{P<:POMDP} <: Updater
pomdp::P
end
# Things that showed in the original
# Base.length
# index
# weight
# iterator
# rand
# pdf
function Base.fill!(b::DiscreteBelief, x::Float64)
fill!(b.b, x)
return b
end
function rand(rng::AbstractRNG, b::DiscreteBlief)
i = sample(rng, Weights(b.b))
return b.state_list[i]
end
pdf{S}(b::DiscreteBelief, s::S) = b.b[state_index(pomdp, s)]
# Will this cause excessive calls to DiscreteBelief ?
create_belief(bu::DiscreteUpdater) = DiscreteBelief(bu.pomdp)
function initialize_belief(bu::DiscreteUpdater, dist::Any, belief::DiscreteBelief=create_belief(bu))
belief = fill!(belief, 0.0)
for s in iterator(dist)
sidx = state_index(bu.pomdp, s)
beleif[sid] = pdf(dist, s)
end
end
function update{A,O}(bu::DiscreteUpdater, b::DiscreteBelief, a::A, o::O)
pomdp = b.pomdp
state_space = b.state_list
bp = zeros(length(state_space))
bp_sum = 0.0 # to normalize the distribution
for (spi, sp) in enumerate(state_space)
# po = O(a, sp, o)
od = observation(pomdp, a, sp)
po = pdf(od, o)
if po == 0.0
continue
end
b_sum = 0.0
for (si, s) in enumerate(state_space)
td = transition(pomdp, s, a)
pp = pdf(td, sp)
b_sum += pp * b.b[si]
end
bp[spi] = po * b_sum
bp_sum += bp[i]
end
if bp_sum = 0.0
error("INVALID BELIEF UPDATE")
else
for i = 1:length(bp); bp[i] /= bp_sum; end
end
return DiscreteBelief(pomdp, b.state_list, bp)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment