Created
December 7, 2017 08:30
-
-
Save WilliamJou/344e8cc879f4d51629bde3dcf166c99b to your computer and use it in GitHub Desktop.
Updated Faucet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
workspace() | |
importall POMDPs | |
using POMDPToolbox | |
using StatsBase | |
#using SARSOP | |
using BasicPOMCP | |
using D3Trees | |
using ParticleFilters | |
const DISH = 1 | |
const HAND = 2 | |
const POT = 3 | |
const TEMPS = 27:3:52 #temperature range in celsius | |
const TINDEX = Dict{Int, Int}(t=>i for (i,t) in enumerate(TEMPS)) | |
const FLOWS = 5:10:95 | |
const FINDEX = Dict{Int, Int}(t=>i for (i,t) in enumerate(FLOWS)) | |
const USERS = 1:1:4 | |
struct FState | |
task::Int | |
time::Int | |
prev_temp::Int | |
prev_flow::Int #not sure if this is necessary | |
user::Int | |
end | |
struct FPOMDP <: POMDP{FState, Tuple{Int,Int}, Tuple{Int, Int, Float64, Int}} # POMDP{State, Action, Observation} | |
max_time::Int | |
end | |
p = FPOMDP(10) | |
const DTEMP = Dict{Int, Int}(DISH=>40, HAND=>27, POT=>46) | |
const DFLOW = Dict{Int, Int}(DISH=>65, HAND=>55, POT=>75) #desired states of flow for each of these tasks | |
const METAL = [.05,.1,.75,.85] | |
const U_WEIGHTS = Dict{Int, Any}(1=>pweights([.4,.2,.2,.2]), 2=>pweights([.2,.4,.2,.2]), 3=>pweights([.2,.2,.4,.2]), 4=>pweights([.2,.2,.2,.4])) | |
isterminal(p::FPOMDP, s::FState) = s.time > p.max_time | |
states(p::FPOMDP) = vec(collect(FState(task, time, pt, pf, u) for task in [DISH, HAND, POT], time in 0:p.max_time, pt in TEMPS, pf in FLOWS, u in USERS)) | |
#user is an integer vector of [picky, resource-conscious, patient, and doubting] | |
n_states(p::FPOMDP) = length(TEMPS)*(p.max_time+1)*3*length(FLOWS)*length(USERS) | |
const SINDEX = Dict{FState, Int}(s=>i for (i,s) in enumerate(states(p))) | |
state_index(p::FPOMDP, s::FState) = SINDEX[s] | |
actions(p::FPOMDP) = vec(collect((t,f) for t in TEMPS, f in FLOWS)) | |
n_actions(p::FPOMDP) = length(TEMPS)*length(FLOWS) | |
const AINDEX = Dict(a=>i for (i,a) in enumerate(actions(p))) | |
action_index(p::FPOMDP, a::Int) = AINDEX[a] | |
observations(p::FPOMDP) = vec(collect((t,f,m,w) for t in TEMPS, f in FLOWS, m in METAL, w in USERS)) | |
n_observations(p::FPOMDP) = length(TEMPS)*length(FLOWS)*length(METAL)*length(USERS) | |
const OINDEX = Dict(o=>i for (i,o) in enumerate(observations(p))) | |
obs_index(p::FPOMDP, o::Tuple{Int,Int,Float64, Int}) = OINDEX[o] | |
function transition(p::FPOMDP, s::FState, a::Tuple{Int,Int}) | |
transition = FState(s.task, s.time+1, a[1], a[2], s.user) | |
stay = FState(s.task, s.time+1, s.prev_temp, s.prev_flow, s.user) | |
SparseCat([transition, stay], [.9,.1]) | |
#SparseCat([FState(s.task, s.time+1, a[1], a[2], s.user)],[1.0]) | |
end | |
function observation(p::FPOMDP, a::Tuple{Int,Int}, sp::FState) | |
#output of metal sensor dependent on if task is pot, dish, or hand | |
if sp.task == 3 | |
m_weight = pweights([.05,.05,.45,.45]) #weight for if it is a pot | |
elseif sp.task == 2 | |
m_weight = pweights([.3,.3,.3,.1]) #weight for dishwashing | |
else | |
m_weight = pweights([.45,.45,.05,.05]) #weight for handwashing | |
end | |
m = sample(METAL,m_weight) | |
#user 1 = Picky, only adjust reward function | |
#user 2 = Resource Conscious, will change if outputs are higher than they want | |
u_val = sample([1,2,3,4], U_WEIGHTS[sp.user]) | |
if sp.user == 3 | |
if sp.time > 4 && (a[1] != DTEMP[sp.task] || a[2] != DFLOW[sp.task]) | |
change = (DTEMP[sp.task], DFLOW[sp.task], m,u_val) | |
leave = (0,0,m,u_val) | |
return SparseCat([change, leave], [.95, 0.05]) # list of observations and associated probabilities/items | |
end | |
elseif sp.user == 4 | |
if sp.time > 1 && (a[1] != DTEMP[sp.task] || a[2] != DFLOW[sp.task]) | |
change = (DTEMP[sp.task], DFLOW[sp.task], m,u_val) | |
leave = (0,0,m,u_val) | |
return SparseCat([change, leave], [.75, 0.25]) # list of observations and associated probabilities/items | |
end | |
else | |
if sp.time > 2 && (a[1] != DTEMP[sp.task] || a[2] != DFLOW[sp.task]) | |
change = (DTEMP[sp.task], DFLOW[sp.task], m,u_val) | |
leave = (0,0,m,u_val) | |
return SparseCat([change, leave], [.95, 0.05]) # list of observations and associated probabilities/items | |
end | |
leave = (0,0,m,u_val) | |
return SparseCat([leave], [1.0]) | |
end | |
end | |
function reward(p::FPOMDP, s::FState, a::Tuple{Int,Int}) | |
if a[1] == DTEMP[s.task] && a[2] == DFLOW[s.task] | |
return 4.0 | |
elseif a[1] == DTEMP[s.task] | |
return 2.0 | |
elseif a[2] == DFLOW[s.task] | |
return 1.0 | |
else | |
return -10.0 | |
end | |
end | |
initial_user = sample([1,2,3,4], pweights([.25,.25,.25,.25])) | |
#initial_state_distribution(p::FPOMDP) = SparseCat([FState(t, 0, 0, 0,u) for t in [DISH, HAND, POT], u in USERS], [0.083, 0.083,.083, 0.083, 0.083,.083, 0.083, 0.083,.083, 0.083, 0.083,.083]) | |
initial_state_distribution(p::FPOMDP) = SparseCat([FState(t, 0, 0, 0,initial_user) for t in [DISH, HAND, POT]], [.3,.3,.3]) | |
# policy = RandomPolicy(p) | |
solver = POMCPSolver(c=100) | |
policy = solve(solver, p) | |
#function my_policy(b::ParticleCollection) | |
# s = rand(Base.GLOBAL_RNG, b) | |
# return (DTEMP[s.task], DFLOW[s.task]) | |
#end | |
#policy = FunctionPolicy(my_policy) | |
#up = SIRParticleFilter(p, 1000) | |
for (b, s, a, r, o) in stepthrough(p, policy, "bsaro") | |
frac_hand = length(filter(s->s.task==HAND, particles(b)))/n_particles(b) | |
frac_dish = length(filter(s->s.task==DISH, particles(b)))/n_particles(b) | |
frac_pot = length(filter(s->s.task==POT, particles(b)))/n_particles(b) | |
@show frac_hand | |
@show frac_dish | |
@show frac_pot | |
@show s | |
@show a | |
@show r | |
@show o | |
end | |
# inchrome(D3Tree(policy)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment