Created
December 10, 2017 06:18
-
-
Save WilliamJou/9cc7293b3437ae7d7e91f888219637ad to your computer and use it in GitHub Desktop.
Faucet Code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
workspace() | |
importall POMDPs | |
using POMDPToolbox | |
using BasicPOMCP | |
using D3Trees | |
using ParticleFilters | |
using StatsBase | |
const DISH = 1 | |
const HAND = 2 | |
const POT = 3 | |
const TEMPS = 27:3:51 #temperature range in celsius | |
const TINDEX = Dict{Int, Int}(t=>i for (i,t) in enumerate(TEMPS)) | |
const FLOWS = 5:10:95 | |
const FINDEX = Dict{Int, Int}(t=>i for (i,t) in enumerate(FLOWS)) | |
const USERS = 1:1:4 | |
struct FState | |
task::Int | |
time::Int | |
prev_temp::Int | |
prev_flow::Int | |
user::Int | |
end | |
struct FPOMDP <: POMDP{FState, Tuple{Int,Int}, Tuple{Int, Int, Float64, Int}} # POMDP{State, Action, Observation} | |
max_time::Int | |
end | |
p = FPOMDP(10) | |
const DTEMP = Dict{Int, Int}(DISH=>36, HAND=>27, POT=>39) | |
const DFLOW = Dict{Int, Int}(DISH=>65, HAND=>45, POT=>85) #desired states of flow for each of these tasks | |
const METAL = [.05,.1,.85,.95] | |
const U_WEIGHTS = Dict{Int, Any}(1=>pweights([.1,.3,.3,.3]), 2=>pweights([.3,.1,.3,.3]), 3=>pweights([.3,.3,.1,.3]), 4=>pweights([.3,.3,.3,.1])) | |
#const U_WEIGHTS = Dict{Int, Any}(1=>pweights([.33,.22,.22,.22]), 2=>pweights([.22,.33,.22,.22]), 3=>pweights([.22,.22,.33,.22]), 4=>pweights([.22,.22,.22,.33])) | |
isterminal(p::FPOMDP, s::FState) = s.time > p.max_time | |
states(p::FPOMDP) = vec(collect(FState(task, time, pt, pf, u) for task in [DISH, HAND, POT], time in 0:p.max_time, pt in TEMPS, pf in FLOWS, u in USERS)) | |
#user is an integer vector of [picky, resource-conscious, patient, and doubting] | |
n_states(p::FPOMDP) = length(TEMPS)*(p.max_time+1)*3*length(FLOWS)*length(USERS) | |
const SINDEX = Dict{FState, Int}(s=>i for (i,s) in enumerate(states(p))) | |
state_index(p::FPOMDP, s::FState) = SINDEX[s] | |
actions(p::FPOMDP) = vec(collect((t,f) for t in TEMPS, f in FLOWS)) | |
n_actions(p::FPOMDP) = length(TEMPS)*length(FLOWS) | |
const AINDEX = Dict(a=>i for (i,a) in enumerate(actions(p))) | |
action_index(p::FPOMDP, a::Int) = AINDEX[a] | |
observations(p::FPOMDP) = vec(collect((t,f,m,w) for t in TEMPS, f in FLOWS, m in METAL, w in USERS)) | |
n_observations(p::FPOMDP) = length(TEMPS)*length(FLOWS)*length(METAL)*length(USERS) | |
const OINDEX = Dict(o=>i for (i,o) in enumerate(observations(p))) | |
obs_index(p::FPOMDP, o::Tuple{Int,Int,Float64, Int}) = OINDEX[o] | |
function transition(p::FPOMDP, s::FState, a::Tuple{Int,Int}) | |
transition = FState(s.task, s.time+1, a[1], a[2], s.user) | |
tasks = [1,2,3] | |
deleteat!(tasks, s.task) #removes current task | |
change_task = FState(reshape(rand(tasks,1),1)[1], s.time+1, s.prev_temp, s.prev_flow, s.user) #transitions to a new task with probability of 80% | |
SparseCat([transition, change_task], [.8,.2]) | |
#SparseCat([FState(s.task, s.time+1, a[1], a[2], s.user)],[1.0]) | |
end | |
function observation(p::FPOMDP, a::Tuple{Int,Int}, sp::FState) | |
#output of metal sensor dependent on if task is pot, dish, or hand | |
if sp.task == 3 | |
m_weight = pweights([.05,.05,.45,.45]) #weight for if it is a pot | |
elseif sp.task == 2 | |
m_weight = pweights([.4,.4,.1,.1]) #weight for dishwashing | |
else | |
m_weight = pweights([.45,.45,.05,.05]) #weight for handwashing | |
end | |
m = sample(METAL,m_weight) | |
#user 1 = Picky, only adjust reward function | |
#user 2 = Resource Conscious, will change if outputs are higher than they want | |
u_val = sample([1,2,3,4], U_WEIGHTS[sp.user]) | |
if sp.user == 2 #Resource-Conscious User | |
if a[1] > DTEMP[sp.task] || a[2] > DFLOW[sp.task] | |
change = (DTEMP[sp.task], DFLOW[sp.task], m,u_val) | |
return SparseCat([change], [1.0]) | |
else | |
leave = (0,0,m,u_val) | |
return SparseCat([leave], [1.0]) | |
end | |
elseif sp.user == 3 #Patient User | |
if sp.time > 4 && (a[1] != DTEMP[sp.task] || a[2] != DFLOW[sp.task]) | |
change = (DTEMP[sp.task], DFLOW[sp.task], m,u_val) | |
leave = (0,0,m,u_val) | |
return SparseCat([change, leave], [.95, 0.05]) # list of observations and associated probabilities/items | |
else | |
leave = (0,0,m,u_val) | |
return SparseCat([leave], [1.0]) | |
end | |
elseif sp.user == 4 | |
leave = (0,0,m,u_val) | |
return SparseCat([leave], [1.0]) | |
else | |
if sp.time > 2 && (a[1] != DTEMP[sp.task] || a[2] != DFLOW[sp.task]) | |
change = (DTEMP[sp.task], DFLOW[sp.task], m,u_val) | |
leave = (0,0,m,u_val) | |
return SparseCat([change, leave], [.90, 0.10]) # list of observations and associated probabilities/items | |
else | |
leave = (0,0,m,u_val) | |
return SparseCat([leave], [1.0]) | |
end | |
end | |
end | |
function reward(p::FPOMDP, s::FState, a::Tuple{Int,Int}) | |
if s.user == 1 | |
if a[1] == DTEMP[s.task] && a[2] == DFLOW[s.task] | |
return 5.0 | |
else | |
return -5.0 | |
end | |
elseif s.user == 2 | |
if a[1] == DTEMP[s.task] && a[2] == DFLOW[s.task] | |
return 1.0 | |
elseif (0<(DTEMP[s.task]-a[1])<=6 ) || (0<(DFLOW[s.task]-a[2])<=20) | |
return 5.0 | |
elseif a[1]>=DTEMP[s.task] || a[2]>=DFLOW[s.task] | |
return -3.0 | |
else | |
return -5.0 | |
end | |
elseif s.user == 4 | |
if a[1] == DTEMP[s.task] && a[2] == DFLOW[s.task] | |
return 2.0 | |
else | |
return -2.0 | |
end | |
else | |
if a[1] == DTEMP[s.task] && a[2] == DFLOW[s.task] | |
return 5.0 | |
elseif a[1] == DTEMP[s.task] | |
return -2.0 | |
elseif a[2] == DFLOW[s.task] | |
return -3.0 | |
else | |
return -5.0 | |
end | |
end | |
end | |
#initial_user = sample([1,2,3,4], pweights([.25,.25,.25,.25])) | |
initial_state_distribution(p::FPOMDP) = SparseCat([FState(t, 0, 0, 0,u) for t in [DISH, HAND, POT], u in USERS], [0.083, 0.083,.083, 0.083, 0.083,.083, 0.083, 0.083,.083, 0.083, 0.083,.083]) | |
#initial_state_distribution(p::FPOMDP) = SparseCat([FState(t, 0, 0, 0,initial_user) for t in [DISH, HAND, POT]], [.3,.3,.3]) | |
# policy = RandomPolicy(p) | |
solver = POMCPSolver(c=5, rng = MersenneTwister(7), tree_queries = 500) | |
policy = solve(solver, p) | |
function my_policy(b::ParticleCollection) | |
s = rand(Base.GLOBAL_RNG, b) | |
return (DTEMP[s.task], DFLOW[s.task]) | |
end | |
pomcp_r = 0 | |
aggpomcp_r = fill(0.0,100) | |
Reward = fill(0.0,100) | |
total = 0 | |
for i in 1:100 | |
pomcp_r = 0 | |
@show i | |
for (b, s, a, r, o) in stepthrough(p, policy, "bsaro") | |
#frac_hand = length(filter(s->s.task==HAND, particles(b)))/n_particles(b) | |
#frac_dish = length(filter(s->s.task==DISH, particles(b)))/n_particles(b) | |
#frac_pot = length(filter(s->s.task==POT, particles(b)))/n_particles(b) | |
#calculate pomcp_r | |
pomcp_r = pomcp_r + r | |
total = total + r | |
end | |
Reward[i]= pomcp_r | |
aggpomcp_r[i] = total | |
end | |
@show aggpomcp_r | |
avg_pomcp = fill(0.0,100) | |
for i in 1:100 | |
avg_pomcp[i] = aggpomcp_r[i]/i | |
end | |
mean90 = mean(Reward) | |
std90 = std(Reward) | |
@show mean90 | |
@show std90 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment