Skip to content

Instantly share code, notes, and snippets.

@WilliamJou
Created December 6, 2017 23:41
Show Gist options
  • Save WilliamJou/165c694a93a5d6da0797121746934525 to your computer and use it in GitHub Desktop.
Save WilliamJou/165c694a93a5d6da0797121746934525 to your computer and use it in GitHub Desktop.
working faucet
importall POMDPs
using POMDPToolbox
#using SARSOP
using BasicPOMCP
using D3Trees
using ParticleFilters
const DISH = 1
const HAND = 2
const POT = 3 #pot, but ignore for now
const TEMPS = 0:10:30 #temperature range in celsius
const TINDEX = Dict{Int, Int}(t=>i for (i,t) in enumerate(TEMPS))
const FLOWS = 0:1:5
const FINDEX = Dict{Int, Int}(t=>i for (i,t) in enumerate(FLOWS))
struct FState
task::Int
time::Int
prev_temp::Int
prev_flow::Int #not sure if this is necessary
end
struct FPOMDP <: POMDP{FState, Tuple{Int,Int}, Tuple{Int, Int}} # POMDP{State, Action, Observation}
p_change::Float64
max_time::Int
end
p = FPOMDP(0.5, 10)
const DTEMP = Dict{Int, Int}(DISH=>30, HAND=>20)
const DFLOW = Dict{Int, Int}(DISH=>4, HAND=>2) #desired states of flow for each of these tasks
isterminal(p::FPOMDP, s::FState) = s.time > p.max_time
states(p::FPOMDP) = vec(collect(FState(task, time, pt, pf) for task in [DISH, HAND], time in 0:p.max_time, pt in TEMPS, pf in FLOWS))
n_states(p::FPOMDP) = length(TEMPS)*(p.max_time+1)*2*length(FLOWS)
const SINDEX = Dict{FState, Int}(s=>i for (i,s) in enumerate(states(p)))
state_index(p::FPOMDP, s::FState) = SINDEX[s]
actions(p::FPOMDP) = vec(collect((t,f) for t in TEMPS, f in FLOWS))
n_actions(p::FPOMDP) = length(TEMPS)*length(FLOWS)
const AINDEX = Dict(a=>i for (i,a) in enumerate(actions(p)))
action_index(p::FPOMDP, a::Int) = AINDEX[a]
observations(p::FPOMDP) = vec(collect((t,f) for t in TEMPS, f in FLOWS))
n_observations(p::FPOMDP) = length(TEMPS)*length(FLOWS)
const OINDEX = Dict(o=>i for (i,o) in enumerate(observations(p)))
obs_index(p::FPOMDP, o::Tuple{Int,Int}) = OINDEX[o]
function transition(p::FPOMDP, s::FState, a::Tuple{Int,Int})
SparseCat([FState(s.task, s.time+1, a[1], a[2])], [1.0])
end
function observation(p::FPOMDP, a::Tuple{Int,Int}, sp::FState)
if sp.time > 2 && a[1] != DTEMP[sp.task] || a[2] != DFLOW[sp.task]
change = (DTEMP[sp.task], DFLOW[sp.task])
leave = (0,0)
return SparseCat([change, leave], [.9, 0.1]) # list of observations and associated probabilities/items
else
leave = (0,0)
return SparseCat([leave], [1.0])
end
end
function reward(p::FPOMDP, s::FState, a::Tuple{Int,Int})
if a[1] == DTEMP[s.task] && a[2] == DFLOW[s.task]
return 10.0
elseif a[1] == DTEMP[s.task]
return 5.0
elseif a[2] == DFLOW[s.task]
return 3.0
else
return -10.0
end
end
initial_state_distribution(p::FPOMDP) = SparseCat([FState(t, 0, 0, 0) for t in [DISH, HAND]], [0.5, 0.5])
# policy = RandomPolicy(p)
solver = POMCPSolver(c=100)
policy = solve(solver, p)
function my_policy(b::ParticleCollection)
s = rand(Base.GLOBAL_RNG, b)
return (DTEMP[s.task], DFLOW[s.task])
end
#policy = FunctionPolicy(my_policy)
#up = SIRParticleFilter(p, 1000)
for (b, s, a, r, o) in stepthrough(p, policy, "bsaro")
frac_hand = length(filter(s->s.task==HAND, particles(b)))/n_particles(b)
@show frac_hand
@show s
@show a
@show r
@show o
end
# inchrome(D3Tree(policy))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment