-
-
Save jmuchovej/2e8e295cabb4b3e76626b166e1549992 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Random | |
using POMDPs | |
using POMDPTools | |
using Distances | |
using StaticArrays | |
using Base.Iterators | |
using LinearAlgebra | |
using NativeSARSOP | |
using SARSOP | |
using QMDP | |
using PointBasedValueIteration | |
const Point = SVector{2, Int} | |
@enum Fruit π=1 π π« | |
Base.convert(t::Type{T}, f::Fruit) where {T <: Integer} = t(Integer(f)) | |
Base.convert(t::Type{Symbol}, f::Fruit) = t(f) | |
Base.convert(t::Type{String}, f::Fruit) = t(Symbol(f)) | |
struct State | |
pos::Point | |
boxes::SVector{3, Fruit} | |
end | |
# Base.:(==)(s1::State, s2::State) = (s1.pos == s2.pos) && (s1.boxes == s2.boxes) | |
struct Box | |
pos::Point | |
item::Fruit | |
end | |
Box1(item::Fruit) = Box(Point(1, 5), item) | |
Box2(item::Fruit) = Box(Point(5, 5), item) | |
Box3(item::Fruit) = Box(Point(5, 1), item) | |
struct Action{t} | |
targetbox::Number | |
end | |
function Action(type, targetbox::Number) | |
@assert targetbox > 0 | |
return Action{type}(targetbox) | |
end | |
Action(type) = Action{type}(0) | |
const MoveAction = Action{:move} | |
Move(box::Number) = Action(:move, box) | |
const TakeAction = Action{:take} | |
Take(box::Number) = Action(:take, box) | |
const OpenAction = Action{:open} | |
Open(box::Number) = Action(:open, box) | |
struct FruitWorld <: POMDP{State, Action, Fruit} | |
spawn::Point | |
mapshape::Point | |
belief::Union{Vector{<:Real}, Nothing} | |
target::Fruit | |
boxes::Vector{Box} | |
items::Vector{Fruit} | |
locations::Vector{Point} | |
rewards::Dict{Symbol, Real} | |
end | |
function FruitWorld( | |
p::FruitWorld; | |
spawn::Point = p.spawn, mapshape::Point = p.mapshape, | |
belief::Union{SparseCat, Vector{<:Real}, Nothing} = p.belief, | |
target::Fruit = p.target, boxes::Vector{Box} = p.boxes, | |
items::Vector{Fruit} = p.items, locations::Vector{Point} = p.locations, | |
rewards::Dict{Symbol, Real} = p.rewards, | |
) | |
belief = isa(belief, SparseCat) ? belief.probs : belief | |
return FruitWorld(spawn, mapshape, belief, target, boxes, items, locations, rewards) | |
end | |
POMDPs.discount(p::FruitWorld) = 0.999999 | |
function POMDPs.isterminal(p::FruitWorld, s::State) | |
return s.pos == Point(-1, -1) | |
end | |
function POMDPs.actions(p::FruitWorld) | |
move_actions = Move.(1:length(p.boxes)) | |
open_actions = Open.(1:length(p.boxes)) | |
take_actions = Take.(1:length(p.boxes)) | |
return vcat( | |
move_actions, | |
# open_actions, | |
take_actions, | |
) | |
end | |
POMDPs.actionindex(p::FruitWorld, a::Action) = findfirst(isequal(a), actions(p)) | |
Base.CartesianIndices(p::FruitWorld) = CartesianIndices(( | |
length(p.locations), fill(length(p.items), length(p.boxes))..., | |
)) | |
Base.length(p::FruitWorld) = length(CartesianIndices(p)) + 1 | |
function Base.getindex(p::FruitWorld, stateindex::Int) | |
if stateindex == length(p) | |
box_location = Point(-1, -1) | |
boxes = [b.item for b in p.boxes] | |
else | |
indices = CartesianIndices(p) | |
(point, states...) = Tuple(indices[stateindex]) | |
box_location = p.locations[point] | |
boxes = [Fruit.(states)...] | |
end | |
state = State(box_location, boxes) | |
return state | |
end | |
function Base.iterate(p::FruitWorld, stateindex::Int = 1) | |
if stateindex > length(p) | |
return nothing | |
end | |
state = p[stateindex] | |
return (state, stateindex + 1) | |
end | |
function POMDPs.stateindex(p::FruitWorld, s::State) | |
if s.pos == Point(-1, -1) | |
return length(p) | |
end | |
location = findfirst(isequal(s.pos), p.locations) | |
#* Drop the last product since it's length(p) - 1 | |
indices = cumprod(size(CartesianIndices(p)))[1:end - 1] | |
#* Subtract 1 from the Integer(Fruit) b/c otherwise we get a +nlocations offset | |
items = dot(indices, Integer.(s.boxes) .- 1) | |
stateindex = location + items | |
return stateindex | |
end | |
POMDPs.states(p::FruitWorld) = collect(p) | |
function POMDPs.initialstate(p::FruitWorld) | |
spawnstates = filter(s -> s.pos == p.spawn, states(p)) | |
probs = fill(1 / length(spawnstates), length(spawnstates)) | |
if !(isa(p.belief, Nothing)) | |
probs = p.belief | |
end | |
dist = SparseCat(spawnstates, probs) | |
return dist | |
end | |
POMDPs.observations(p::FruitWorld) = p.items | |
POMDPs.obsindex(p::FruitWorld, obs::Fruit) = findfirst(isequal(obs), observations(p)) | |
function POMDPs.observation(p::FruitWorld, a::MoveAction, sp::State) | |
p_item = p.boxes[a.targetbox].item | |
s_item = sp.boxes[a.targetbox] | |
p_obs_idx = obsindex(p, p_item) | |
s_obs_idx = obsindex(p, s_item) | |
obs_dist = zeros(length(observations(p))) | |
#* Originally this was `0.05` | |
# obs_dist[1:end .!= obs_idx] .= 0. | |
# obs_dist[end] = 0. | |
# obs_dist[p_obs_idx] = 1 - sum(obs_dist) | |
obs_dist[s_obs_idx] = 1 - sum(obs_dist) | |
dist = SparseCat(observations(p), obs_dist) | |
return dist | |
end | |
function POMDPs.observation(p::FruitWorld, a::TakeAction, sp::State) | |
obs_dist = zeros(length(observations(p))) | |
obs_dist[end] = 1. | |
dist = SparseCat(observations(p), obs_dist) | |
return dist | |
end | |
R(p::FruitWorld, key::Symbol ) = p.rewards[key] | |
R(p::FruitWorld, fruit::Fruit) = R(p, Symbol(fruit)) | |
R(p::FruitWorld, box::Box ) = R(p, box.item) | |
function POMDPs.reward(p::FruitWorld, s::State, a::MoveAction) | |
box = p.boxes[a.targetbox] | |
mvcost = euclidean(box.pos, s.pos) * -1 | |
return mvcost | |
end | |
function POMDPs.reward(p::FruitWorld, s::State, a::TakeAction) | |
box = p.boxes[a.targetbox] | |
if box.pos != s.pos | |
return 0 | |
end | |
return R(p, box) | |
end | |
function POMDPs.transition(p::FruitWorld, s::State, a::MoveAction) | |
if isterminal(p, s) | |
return Deterministic(s) | |
end | |
box = p.boxes[a.targetbox] | |
boxes = MVector{length(s.boxes), eltype(s.boxes)}(s.boxes) | |
# boxes[a.targetbox] = box.item | |
boxes = SVector(boxes) | |
sp = State(box.pos, boxes) | |
return Deterministic(sp) | |
end | |
function POMDPs.transition(p::FruitWorld, s::State, a::TakeAction) | |
box = p.boxes[a.targetbox] | |
if isterminal(p, s) || box.pos != s.pos | |
return Deterministic(s) | |
end | |
sp = states(p)[end] | |
return Deterministic(sp) | |
end | |
# solver = NativeSARSOP.SARSOPSolver(; epsilon=1e-1, precision=1e-4, verbose=true) | |
solver = SARSOP.SARSOPSolver(; verbose=true) | |
# solver = QMDPSolver(; verbose=true) | |
# solver = PBVISolver(; Ο΅=20.0, verbose=true) | |
SPAWN = Point(1, 1) | |
MAPSHAPE = Point(5, 5) | |
BOXES = [Box1(π), Box2(π), Box3(π)] | |
TARGET = π | |
pomdp = FruitWorld( | |
SPAWN, MAPSHAPE, nothing, | |
TARGET, BOXES, [instances(Fruit)...], [SPAWN, [b.pos for b=BOXES]...], | |
Dict(:movement => -1., :π => 90., :π => 45., :π« => 1.), | |
) | |
function marginals_to_belief(p::FruitWorld, marginal::Matrix) | |
states = initialstate(p).vals | |
belief = map(states) do state | |
prod([marginal[box, Int(state.boxes[box])] for box in 1:length(p.boxes)]) | |
end | |
dist = SparseCat(states, belief) | |
return dist | |
end | |
marginalbeliefs = [ | |
[.98 .01 .01; .02 .49 .49; .02 .49 .49], | |
[.02 .49 .49; .98 .01 .01; .02 .49 .49], | |
[.02 .49 .49; .02 .49 .49; .98 .01 .01], | |
[.10 .45 .45; .10 .45 .45; .90 .05 .05], | |
[.90 .05 .05; .05 .40 .55; .80 .10 .10], | |
[.10 .45 .45; .10 .45 .45; .20 .40 .40], | |
] | |
belief = marginals_to_belief(pomdp, marginalbeliefs[1]) | |
# belief_unif = initialstate(FruitWorld(pomdp; belief=nothing)) | |
# belief_box1 = marginals_to_belief(pomdp, marginalbeliefs[1]) | |
# belief_box3 = marginals_to_belief(pomdp, marginalbeliefs[3]) | |
policy = solve(solver, FruitWorld(pomdp, belief=belief )) | |
# policy_unif = solve(solver, FruitWorld(pomdp, belief=nothing)) | |
# policy_box1 = solve(solver, FruitWorld(pomdp, belief=belief_box1)) | |
# policy_box3 = solve(solver, FruitWorld(pomdp, belief=belief_box3)) | |
b0 = initialize_belief(updater(policy), initialstate(FruitWorld(pomdp; belief=nothing))) | |
b1 = initialize_belief(updater(policy), initialstate(FruitWorld(pomdp; belief=belief))) | |
@show actionvalues(policy, b0) | |
@show actionvalues(policy, b1) | |
begin | |
b = b1 | |
s = rand(b) | |
reward = 0. | |
d = 1. | |
up = updater(policy) | |
counter = 1 | |
while counter < 20 && !isterminal(pomdp, s) | |
a = action(policy, b) | |
sp, o, r = @gen(:sp, :o, :r)(pomdp, s, a) | |
reward += d * r | |
d *= discount(pomdp) | |
b = update(up, b, a, o) | |
c_str = lpad(counter, 3) | |
@show (c_str, s, a, sp, r) | |
counter += 1 | |
s = sp | |
end | |
end | |
bp1 = update(DiscreteUpdater(pomdp), b1, Move(2), π); | |
showdistribution(bp1) | |
belief = marginals_to_belief(pomdp, marginalbeliefs[3]); | |
begin | |
b1 = initialize_belief( | |
updater(policy), initialstate(FruitWorld(pomdp; belief=belief)) | |
); | |
maxb1s = findall(isequal(maximum(b1.b)), b1.b); | |
@show b1.state_list[maxb1s] | |
showdistribution(b1) | |
end | |
begin | |
b2 = update(DiscreteUpdater(pomdp), b1, Move(2), π); | |
maxb2s = findall(isequal(maximum(b2.b)), b2.b); | |
@show b2.state_list[maxb2s] | |
showdistribution(b2) | |
end | |
begin | |
b3 = update(DiscreteUpdater(pomdp), b2, Move(3), π); | |
maxb3s = findall(isequal(maximum(b3.b)), b3.b); | |
@show b3.state_list[maxb3s] | |
showdistribution(b3) | |
end | |
begin | |
b4 = update(DiscreteUpdater(pomdp), b3, Move(1), π); | |
maxb4s = findall(isequal(maximum(b4.b)), b4.b); | |
@show b4.state_list[maxb3s] | |
showdistribution(b4) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment