Skip to content

Instantly share code, notes, and snippets.

@jmuchovej
Last active June 1, 2023 14:40
Show Gist options
  • Save jmuchovej/2e8e295cabb4b3e76626b166e1549992 to your computer and use it in GitHub Desktop.
Save jmuchovej/2e8e295cabb4b3e76626b166e1549992 to your computer and use it in GitHub Desktop.
using Random
using POMDPs
using POMDPTools
using Distances
using StaticArrays
using Base.Iterators
using LinearAlgebra
using NativeSARSOP
using SARSOP
using QMDP
using PointBasedValueIteration
const Point = SVector{2, Int}
@enum Fruit πŸ‹=1 πŸ’ 🫐
Base.convert(t::Type{T}, f::Fruit) where {T <: Integer} = t(Integer(f))
Base.convert(t::Type{Symbol}, f::Fruit) = t(f)
Base.convert(t::Type{String}, f::Fruit) = t(Symbol(f))
struct State
pos::Point
boxes::SVector{3, Fruit}
end
# Base.:(==)(s1::State, s2::State) = (s1.pos == s2.pos) && (s1.boxes == s2.boxes)
struct Box
pos::Point
item::Fruit
end
Box1(item::Fruit) = Box(Point(1, 5), item)
Box2(item::Fruit) = Box(Point(5, 5), item)
Box3(item::Fruit) = Box(Point(5, 1), item)
struct Action{t}
targetbox::Number
end
function Action(type, targetbox::Number)
@assert targetbox > 0
return Action{type}(targetbox)
end
Action(type) = Action{type}(0)
const MoveAction = Action{:move}
Move(box::Number) = Action(:move, box)
const TakeAction = Action{:take}
Take(box::Number) = Action(:take, box)
const OpenAction = Action{:open}
Open(box::Number) = Action(:open, box)
struct FruitWorld <: POMDP{State, Action, Fruit}
spawn::Point
mapshape::Point
belief::Union{Vector{<:Real}, Nothing}
target::Fruit
boxes::Vector{Box}
items::Vector{Fruit}
locations::Vector{Point}
rewards::Dict{Symbol, Real}
end
function FruitWorld(
p::FruitWorld;
spawn::Point = p.spawn, mapshape::Point = p.mapshape,
belief::Union{SparseCat, Vector{<:Real}, Nothing} = p.belief,
target::Fruit = p.target, boxes::Vector{Box} = p.boxes,
items::Vector{Fruit} = p.items, locations::Vector{Point} = p.locations,
rewards::Dict{Symbol, Real} = p.rewards,
)
belief = isa(belief, SparseCat) ? belief.probs : belief
return FruitWorld(spawn, mapshape, belief, target, boxes, items, locations, rewards)
end
POMDPs.discount(p::FruitWorld) = 0.999999
function POMDPs.isterminal(p::FruitWorld, s::State)
return s.pos == Point(-1, -1)
end
function POMDPs.actions(p::FruitWorld)
move_actions = Move.(1:length(p.boxes))
open_actions = Open.(1:length(p.boxes))
take_actions = Take.(1:length(p.boxes))
return vcat(
move_actions,
# open_actions,
take_actions,
)
end
POMDPs.actionindex(p::FruitWorld, a::Action) = findfirst(isequal(a), actions(p))
Base.CartesianIndices(p::FruitWorld) = CartesianIndices((
length(p.locations), fill(length(p.items), length(p.boxes))...,
))
Base.length(p::FruitWorld) = length(CartesianIndices(p)) + 1
function Base.getindex(p::FruitWorld, stateindex::Int)
if stateindex == length(p)
box_location = Point(-1, -1)
boxes = [b.item for b in p.boxes]
else
indices = CartesianIndices(p)
(point, states...) = Tuple(indices[stateindex])
box_location = p.locations[point]
boxes = [Fruit.(states)...]
end
state = State(box_location, boxes)
return state
end
function Base.iterate(p::FruitWorld, stateindex::Int = 1)
if stateindex > length(p)
return nothing
end
state = p[stateindex]
return (state, stateindex + 1)
end
function POMDPs.stateindex(p::FruitWorld, s::State)
if s.pos == Point(-1, -1)
return length(p)
end
location = findfirst(isequal(s.pos), p.locations)
#* Drop the last product since it's length(p) - 1
indices = cumprod(size(CartesianIndices(p)))[1:end - 1]
#* Subtract 1 from the Integer(Fruit) b/c otherwise we get a +nlocations offset
items = dot(indices, Integer.(s.boxes) .- 1)
stateindex = location + items
return stateindex
end
POMDPs.states(p::FruitWorld) = collect(p)
function POMDPs.initialstate(p::FruitWorld)
spawnstates = filter(s -> s.pos == p.spawn, states(p))
probs = fill(1 / length(spawnstates), length(spawnstates))
if !(isa(p.belief, Nothing))
probs = p.belief
end
dist = SparseCat(spawnstates, probs)
return dist
end
POMDPs.observations(p::FruitWorld) = p.items
POMDPs.obsindex(p::FruitWorld, obs::Fruit) = findfirst(isequal(obs), observations(p))
function POMDPs.observation(p::FruitWorld, a::MoveAction, sp::State)
p_item = p.boxes[a.targetbox].item
s_item = sp.boxes[a.targetbox]
p_obs_idx = obsindex(p, p_item)
s_obs_idx = obsindex(p, s_item)
obs_dist = zeros(length(observations(p)))
#* Originally this was `0.05`
# obs_dist[1:end .!= obs_idx] .= 0.
# obs_dist[end] = 0.
# obs_dist[p_obs_idx] = 1 - sum(obs_dist)
obs_dist[s_obs_idx] = 1 - sum(obs_dist)
dist = SparseCat(observations(p), obs_dist)
return dist
end
function POMDPs.observation(p::FruitWorld, a::TakeAction, sp::State)
obs_dist = zeros(length(observations(p)))
obs_dist[end] = 1.
dist = SparseCat(observations(p), obs_dist)
return dist
end
R(p::FruitWorld, key::Symbol ) = p.rewards[key]
R(p::FruitWorld, fruit::Fruit) = R(p, Symbol(fruit))
R(p::FruitWorld, box::Box ) = R(p, box.item)
function POMDPs.reward(p::FruitWorld, s::State, a::MoveAction)
box = p.boxes[a.targetbox]
mvcost = euclidean(box.pos, s.pos) * -1
return mvcost
end
function POMDPs.reward(p::FruitWorld, s::State, a::TakeAction)
box = p.boxes[a.targetbox]
if box.pos != s.pos
return 0
end
return R(p, box)
end
function POMDPs.transition(p::FruitWorld, s::State, a::MoveAction)
if isterminal(p, s)
return Deterministic(s)
end
box = p.boxes[a.targetbox]
boxes = MVector{length(s.boxes), eltype(s.boxes)}(s.boxes)
# boxes[a.targetbox] = box.item
boxes = SVector(boxes)
sp = State(box.pos, boxes)
return Deterministic(sp)
end
function POMDPs.transition(p::FruitWorld, s::State, a::TakeAction)
box = p.boxes[a.targetbox]
if isterminal(p, s) || box.pos != s.pos
return Deterministic(s)
end
sp = states(p)[end]
return Deterministic(sp)
end
# solver = NativeSARSOP.SARSOPSolver(; epsilon=1e-1, precision=1e-4, verbose=true)
solver = SARSOP.SARSOPSolver(; verbose=true)
# solver = QMDPSolver(; verbose=true)
# solver = PBVISolver(; Ο΅=20.0, verbose=true)
SPAWN = Point(1, 1)
MAPSHAPE = Point(5, 5)
BOXES = [Box1(πŸ’), Box2(πŸ‹), Box3(πŸ‹)]
TARGET = πŸ‹
pomdp = FruitWorld(
SPAWN, MAPSHAPE, nothing,
TARGET, BOXES, [instances(Fruit)...], [SPAWN, [b.pos for b=BOXES]...],
Dict(:movement => -1., :πŸ‹ => 90., :πŸ’ => 45., :🫐 => 1.),
)
function marginals_to_belief(p::FruitWorld, marginal::Matrix)
states = initialstate(p).vals
belief = map(states) do state
prod([marginal[box, Int(state.boxes[box])] for box in 1:length(p.boxes)])
end
dist = SparseCat(states, belief)
return dist
end
marginalbeliefs = [
[.98 .01 .01; .02 .49 .49; .02 .49 .49],
[.02 .49 .49; .98 .01 .01; .02 .49 .49],
[.02 .49 .49; .02 .49 .49; .98 .01 .01],
[.10 .45 .45; .10 .45 .45; .90 .05 .05],
[.90 .05 .05; .05 .40 .55; .80 .10 .10],
[.10 .45 .45; .10 .45 .45; .20 .40 .40],
]
belief = marginals_to_belief(pomdp, marginalbeliefs[1])
# belief_unif = initialstate(FruitWorld(pomdp; belief=nothing))
# belief_box1 = marginals_to_belief(pomdp, marginalbeliefs[1])
# belief_box3 = marginals_to_belief(pomdp, marginalbeliefs[3])
policy = solve(solver, FruitWorld(pomdp, belief=belief ))
# policy_unif = solve(solver, FruitWorld(pomdp, belief=nothing))
# policy_box1 = solve(solver, FruitWorld(pomdp, belief=belief_box1))
# policy_box3 = solve(solver, FruitWorld(pomdp, belief=belief_box3))
b0 = initialize_belief(updater(policy), initialstate(FruitWorld(pomdp; belief=nothing)))
b1 = initialize_belief(updater(policy), initialstate(FruitWorld(pomdp; belief=belief)))
@show actionvalues(policy, b0)
@show actionvalues(policy, b1)
begin
b = b1
s = rand(b)
reward = 0.
d = 1.
up = updater(policy)
counter = 1
while counter < 20 && !isterminal(pomdp, s)
a = action(policy, b)
sp, o, r = @gen(:sp, :o, :r)(pomdp, s, a)
reward += d * r
d *= discount(pomdp)
b = update(up, b, a, o)
c_str = lpad(counter, 3)
@show (c_str, s, a, sp, r)
counter += 1
s = sp
end
end
bp1 = update(DiscreteUpdater(pomdp), b1, Move(2), πŸ‹);
showdistribution(bp1)
belief = marginals_to_belief(pomdp, marginalbeliefs[3]);
begin
b1 = initialize_belief(
updater(policy), initialstate(FruitWorld(pomdp; belief=belief))
);
maxb1s = findall(isequal(maximum(b1.b)), b1.b);
@show b1.state_list[maxb1s]
showdistribution(b1)
end
begin
b2 = update(DiscreteUpdater(pomdp), b1, Move(2), πŸ‹);
maxb2s = findall(isequal(maximum(b2.b)), b2.b);
@show b2.state_list[maxb2s]
showdistribution(b2)
end
begin
b3 = update(DiscreteUpdater(pomdp), b2, Move(3), πŸ‹);
maxb3s = findall(isequal(maximum(b3.b)), b3.b);
@show b3.state_list[maxb3s]
showdistribution(b3)
end
begin
b4 = update(DiscreteUpdater(pomdp), b3, Move(1), πŸ’);
maxb4s = findall(isequal(maximum(b4.b)), b4.b);
@show b4.state_list[maxb3s]
showdistribution(b4)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment