Skip to content

Instantly share code, notes, and snippets.

@WilliamJou
Created December 5, 2017 00:12
Show Gist options
  • Save WilliamJou/98a30e1be430680f805329311a35db2a to your computer and use it in GitHub Desktop.
Save WilliamJou/98a30e1be430680f805329311a35db2a to your computer and use it in GitHub Desktop.
faucet
importall POMDPs
using POMDPToolbox
#using SARSOP
using BasicPOMCP
using D3Trees
using ParticleFilters
const DISH = 1
const HAND = 2
const POT = 3 #pot, but ignore for now
const TEMPS = 0:10:30 #temperature range in celsius
const TINDEX = Dict{Int, Int}(t=>i for (i,t) in enumerate(TEMPS))
const FLOWS = 0:1:5
const FINDEX = Dict{Int, Int}(t=>i for (i,t) in enumerate(FLOWS))
struct FState
task::Int
time::Int
prev_temp::Int
prev_flow::Int #not sure if this is necessary
end
struct FPOMDP <: POMDP{FState, Tuple{Int,Int}, Tuple{Int, Int}} # POMDP{State, Action, Observation}
p_change::Float64
max_time::Int
end
p = FPOMDP(0.5, 10)
const DTEMP = Dict{Int, Int}(DISH=>30, HAND=>20)
const DFLOW = Dict{Int, Int}(DISH=>4, HAND=>2) #desired states of flow for each of these tasks
isterminal(p::FPOMDP, s::FState) = s.time > p.max_time
states(p::FPOMDP) = vec(collect(FState(task, time, pt, pf) for task in [DISH, HAND], time in 0:p.max_time, pt in TEMPS, pf in FLOWS))
n_states(p::FPOMDP) = length(TEMPS)*(p.max_time+1)*2*length(FLOWS)
const SINDEX = Dict{FState, Int}(s=>i for (i,s) in enumerate(states(p)))
state_index(p::FPOMDP, s::FState) = SINDEX[s]
actions(p::FPOMDP) = vec(collect((t,f) for t in TEMPS, f in FLOWS))
n_actions(p::FPOMDP) = length(TEMPS)*length(FLOWS)
const AINDEX = Dict{FState, Int}(a=>i for (i,a) in enumerate(actions(p)))
action_index(p::FPOMDP, a::Int) = AINDEX[a]
observations(p::FPOMDP) = vec(collect((t,f) for t in TEMPS, f in FLOWS))
n_observations(p::FPOMDP) = length(TEMPS)*length(FLOWS)
const OINDEX = Dict{FState, Int}(o=>i for (i,o) in enumerate(observations(p)))
obs_index(p::FPOMDP, o::Int) = OINDEX[o]
function transition(p::FPOMDP, s::FState, a::Int)
SparseCat([FState(s.task, s.time+1, a)], [1.0])
end
function observation(p::FPOMDP, a::Int, sp::FState)
if sp.time > 2 && a != DTEMP[sp.task]
return SparseCat([DTEMP[sp.task], 0], [0.5, 0.5])
else
return SparseCat([0], [1.0])
end
end
function reward(p::FPOMDP, s::FState, a::Int)
if a == DTEMP[s.task]
return 10.0
else
return -10.0
end
end
initial_state_distribution(p::FPOMDP) = SparseCat([FState(t, 0, 0, 0) for t in [DISH, HAND]], [0.5, 0.5])
# policy = RandomPolicy(p)
solver = POMCPSolver(c=100)
policy = solve(solver, p)
for (b, s, a, r, o) in stepthrough(p, policy, "bsaro")
frac_hand = length(filter(s->s.task==HAND, particles(b)))/n_particles(b)
@show frac_hand
@show s
@show a
@show r
@show o
end
inchrome(D3Tree(policy))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment