Skip to content

Instantly share code, notes, and snippets.

@jmuchovej
Last active June 1, 2023 14:40
using Random
using POMDPs
using POMDPTools
using Distances
using StaticArrays
using Base.Iterators
using LinearAlgebra
using NativeSARSOP
using SARSOP
using QMDP
using PointBasedValueIteration
const Point = SVector{2, Int}
@enum Fruit πŸ‹=1 πŸ’ 🫐
Base.convert(t::Type{T}, f::Fruit) where {T <: Integer} = t(Integer(f))
Base.convert(t::Type{Symbol}, f::Fruit) = t(f)
Base.convert(t::Type{String}, f::Fruit) = t(Symbol(f))
struct State
pos::Point
boxes::SVector{3, Fruit}
end
# Base.:(==)(s1::State, s2::State) = (s1.pos == s2.pos) && (s1.boxes == s2.boxes)
struct Box
pos::Point
item::Fruit
end
Box1(item::Fruit) = Box(Point(1, 5), item)
Box2(item::Fruit) = Box(Point(5, 5), item)
Box3(item::Fruit) = Box(Point(5, 1), item)
struct Action{t}
targetbox::Number
end
function Action(type, targetbox::Number)
@assert targetbox > 0
return Action{type}(targetbox)
end
Action(type) = Action{type}(0)
const MoveAction = Action{:move}
Move(box::Number) = Action(:move, box)
const TakeAction = Action{:take}
Take(box::Number) = Action(:take, box)
const OpenAction = Action{:open}
Open(box::Number) = Action(:open, box)
struct FruitWorld <: POMDP{State, Action, Fruit}
spawn::Point
mapshape::Point
belief::Union{Vector{<:Real}, Nothing}
target::Fruit
boxes::Vector{Box}
items::Vector{Fruit}
locations::Vector{Point}
rewards::Dict{Symbol, Real}
end
function FruitWorld(
p::FruitWorld;
spawn::Point = p.spawn, mapshape::Point = p.mapshape,
belief::Union{SparseCat, Vector{<:Real}, Nothing} = p.belief,
target::Fruit = p.target, boxes::Vector{Box} = p.boxes,
items::Vector{Fruit} = p.items, locations::Vector{Point} = p.locations,
rewards::Dict{Symbol, Real} = p.rewards,
)
belief = isa(belief, SparseCat) ? belief.probs : belief
return FruitWorld(spawn, mapshape, belief, target, boxes, items, locations, rewards)
end
POMDPs.discount(p::FruitWorld) = 0.999999
function POMDPs.isterminal(p::FruitWorld, s::State)
return s.pos == Point(-1, -1)
end
function POMDPs.actions(p::FruitWorld)
move_actions = Move.(1:length(p.boxes))
open_actions = Open.(1:length(p.boxes))
take_actions = Take.(1:length(p.boxes))
return vcat(
move_actions,
# open_actions,
take_actions,
)
end
POMDPs.actionindex(p::FruitWorld, a::Action) = findfirst(isequal(a), actions(p))
Base.CartesianIndices(p::FruitWorld) = CartesianIndices((
length(p.locations), fill(length(p.items), length(p.boxes))...,
))
Base.length(p::FruitWorld) = length(CartesianIndices(p)) + 1
function Base.getindex(p::FruitWorld, stateindex::Int)
if stateindex == length(p)
box_location = Point(-1, -1)
boxes = [b.item for b in p.boxes]
else
indices = CartesianIndices(p)
(point, states...) = Tuple(indices[stateindex])
box_location = p.locations[point]
boxes = [Fruit.(states)...]
end
state = State(box_location, boxes)
return state
end
function Base.iterate(p::FruitWorld, stateindex::Int = 1)
if stateindex > length(p)
return nothing
end
state = p[stateindex]
return (state, stateindex + 1)
end
function POMDPs.stateindex(p::FruitWorld, s::State)
if s.pos == Point(-1, -1)
return length(p)
end
location = findfirst(isequal(s.pos), p.locations)
#* Drop the last product since it's length(p) - 1
indices = cumprod(size(CartesianIndices(p)))[1:end - 1]
#* Subtract 1 from the Integer(Fruit) b/c otherwise we get a +nlocations offset
items = dot(indices, Integer.(s.boxes) .- 1)
stateindex = location + items
return stateindex
end
POMDPs.states(p::FruitWorld) = collect(p)
function POMDPs.initialstate(p::FruitWorld)
spawnstates = filter(s -> s.pos == p.spawn, states(p))
probs = fill(1 / length(spawnstates), length(spawnstates))
if !(isa(p.belief, Nothing))
probs = p.belief
end
dist = SparseCat(spawnstates, probs)
return dist
end
POMDPs.observations(p::FruitWorld) = p.items
POMDPs.obsindex(p::FruitWorld, obs::Fruit) = findfirst(isequal(obs), observations(p))
function POMDPs.observation(p::FruitWorld, a::MoveAction, sp::State)
p_item = p.boxes[a.targetbox].item
s_item = sp.boxes[a.targetbox]
p_obs_idx = obsindex(p, p_item)
s_obs_idx = obsindex(p, s_item)
obs_dist = zeros(length(observations(p)))
#* Originally this was `0.05`
# obs_dist[1:end .!= obs_idx] .= 0.
# obs_dist[end] = 0.
# obs_dist[p_obs_idx] = 1 - sum(obs_dist)
obs_dist[s_obs_idx] = 1 - sum(obs_dist)
dist = SparseCat(observations(p), obs_dist)
return dist
end
function POMDPs.observation(p::FruitWorld, a::TakeAction, sp::State)
obs_dist = zeros(length(observations(p)))
obs_dist[end] = 1.
dist = SparseCat(observations(p), obs_dist)
return dist
end
R(p::FruitWorld, key::Symbol ) = p.rewards[key]
R(p::FruitWorld, fruit::Fruit) = R(p, Symbol(fruit))
R(p::FruitWorld, box::Box ) = R(p, box.item)
function POMDPs.reward(p::FruitWorld, s::State, a::MoveAction)
box = p.boxes[a.targetbox]
mvcost = euclidean(box.pos, s.pos) * -1
return mvcost
end
function POMDPs.reward(p::FruitWorld, s::State, a::TakeAction)
box = p.boxes[a.targetbox]
if box.pos != s.pos
return 0
end
return R(p, box)
end
function POMDPs.transition(p::FruitWorld, s::State, a::MoveAction)
if isterminal(p, s)
return Deterministic(s)
end
box = p.boxes[a.targetbox]
boxes = MVector{length(s.boxes), eltype(s.boxes)}(s.boxes)
# boxes[a.targetbox] = box.item
boxes = SVector(boxes)
sp = State(box.pos, boxes)
return Deterministic(sp)
end
function POMDPs.transition(p::FruitWorld, s::State, a::TakeAction)
box = p.boxes[a.targetbox]
if isterminal(p, s) || box.pos != s.pos
return Deterministic(s)
end
sp = states(p)[end]
return Deterministic(sp)
end
# solver = NativeSARSOP.SARSOPSolver(; epsilon=1e-1, precision=1e-4, verbose=true)
solver = SARSOP.SARSOPSolver(; verbose=true)
# solver = QMDPSolver(; verbose=true)
# solver = PBVISolver(; Ο΅=20.0, verbose=true)
SPAWN = Point(1, 1)
MAPSHAPE = Point(5, 5)
BOXES = [Box1(πŸ’), Box2(πŸ‹), Box3(πŸ‹)]
TARGET = πŸ‹
pomdp = FruitWorld(
SPAWN, MAPSHAPE, nothing,
TARGET, BOXES, [instances(Fruit)...], [SPAWN, [b.pos for b=BOXES]...],
Dict(:movement => -1., :πŸ‹ => 90., :πŸ’ => 45., :🫐 => 1.),
)
function marginals_to_belief(p::FruitWorld, marginal::Matrix)
states = initialstate(p).vals
belief = map(states) do state
prod([marginal[box, Int(state.boxes[box])] for box in 1:length(p.boxes)])
end
dist = SparseCat(states, belief)
return dist
end
marginalbeliefs = [
[.98 .01 .01; .02 .49 .49; .02 .49 .49],
[.02 .49 .49; .98 .01 .01; .02 .49 .49],
[.02 .49 .49; .02 .49 .49; .98 .01 .01],
[.10 .45 .45; .10 .45 .45; .90 .05 .05],
[.90 .05 .05; .05 .40 .55; .80 .10 .10],
[.10 .45 .45; .10 .45 .45; .20 .40 .40],
]
belief = marginals_to_belief(pomdp, marginalbeliefs[1])
# belief_unif = initialstate(FruitWorld(pomdp; belief=nothing))
# belief_box1 = marginals_to_belief(pomdp, marginalbeliefs[1])
# belief_box3 = marginals_to_belief(pomdp, marginalbeliefs[3])
policy = solve(solver, FruitWorld(pomdp, belief=belief ))
# policy_unif = solve(solver, FruitWorld(pomdp, belief=nothing))
# policy_box1 = solve(solver, FruitWorld(pomdp, belief=belief_box1))
# policy_box3 = solve(solver, FruitWorld(pomdp, belief=belief_box3))
b0 = initialize_belief(updater(policy), initialstate(FruitWorld(pomdp; belief=nothing)))
b1 = initialize_belief(updater(policy), initialstate(FruitWorld(pomdp; belief=belief)))
@show actionvalues(policy, b0)
@show actionvalues(policy, b1)
begin
b = b1
s = rand(b)
reward = 0.
d = 1.
up = updater(policy)
counter = 1
while counter < 20 && !isterminal(pomdp, s)
a = action(policy, b)
sp, o, r = @gen(:sp, :o, :r)(pomdp, s, a)
reward += d * r
d *= discount(pomdp)
b = update(up, b, a, o)
c_str = lpad(counter, 3)
@show (c_str, s, a, sp, r)
counter += 1
s = sp
end
end
bp1 = update(DiscreteUpdater(pomdp), b1, Move(2), πŸ‹);
showdistribution(bp1)
belief = marginals_to_belief(pomdp, marginalbeliefs[3]);
begin
b1 = initialize_belief(
updater(policy), initialstate(FruitWorld(pomdp; belief=belief))
);
maxb1s = findall(isequal(maximum(b1.b)), b1.b);
@show b1.state_list[maxb1s]
showdistribution(b1)
end
begin
b2 = update(DiscreteUpdater(pomdp), b1, Move(2), πŸ‹);
maxb2s = findall(isequal(maximum(b2.b)), b2.b);
@show b2.state_list[maxb2s]
showdistribution(b2)
end
begin
b3 = update(DiscreteUpdater(pomdp), b2, Move(3), πŸ‹);
maxb3s = findall(isequal(maximum(b3.b)), b3.b);
@show b3.state_list[maxb3s]
showdistribution(b3)
end
begin
b4 = update(DiscreteUpdater(pomdp), b3, Move(1), πŸ’);
maxb4s = findall(isequal(maximum(b4.b)), b4.b);
@show b4.state_list[maxb3s]
showdistribution(b4)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment