-
-
Save fredzett/2f2fc5613f8b760d58225b9149adaaad to your computer and use it in GitHub Desktop.
Jack Car rental
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Distributions | |
using LinearAlgebra | |
using Plots | |
const Amounts = Array{Int,1} # Type alias for car numbers for both states and actions | |
abstract type MDP end # Abstract type for Markov Decision Problem (TODO: move to separate file) | |
struct CarProblem <: MDP # concrete mdp problem (here: Sutton/Barto, p. 81) | |
γ::Float64 | |
S::Array{Amounts} | |
A::Array{Amounts} | |
R::Function | |
P::Function | |
function CarProblem(mcars, mret, γ, R, P) | |
S = reshape([[i,j] for i in 0:mcars, j in 0:mcars ], (mcars+1)^2, ) | |
A = [[i,-i] for i in -mret:1:mret] | |
return new(γ, S, A, R, P) | |
end | |
end | |
# Assumptions | |
MAX_CARS = 20 # Maximum number of cars at each location | |
MAX_MOVE = 5 # Maximum number of cars moved overnight | |
COST_MOVE = -2 # costs for moving one cars | |
PRICE_RENT = 10 # sales from renting one car | |
CUT_OFF = 50 # cut off point for poission distribution (note: 10 to 15 likely suffices) | |
λs = [3,3,4,2] # paramaters for poisson distribution | |
# moves cars using current state and action a | |
function _move(s,a) | |
# Determine move in, move out | |
i_out = findfirst(x -> x < 0, a) | |
# a = [0,0] | |
if isnothing(i_out) | |
s′, n_move = s, 0 | |
return (s′, n_move) | |
end | |
i_in = i_out == 1 ? 2 : 1 | |
# Determine next state (i.e. # of cars in loc1/loc2 after moving) | |
s′ = [0,0] | |
n_move = min(s[i_out], abs(a[i_out])) | |
s′[i_out] = s[i_out] - n_move | |
s′[i_in] = min(MAX_CARS, s[i_in] + n_move) | |
return (s′, n_move) | |
end | |
# Convenience function: sample from poisson distribution with given λ | |
Pλ(λ) = rand(Poisson(λ)) | |
# Handles customer arrival and moving of cars | |
function update(s, a) | |
# Move Cars | |
s′, n_move = _move(s,a) | |
# Sample # customers at loc1 / loc2 | |
requests₁ = min(s′[1], Pλ(3)) | |
requests₂ = min(s′[2], Pλ(4)) | |
# Determine reward from renting/moving | |
sales = (requests₁ + requests₂)*PRICE_RENT | |
costs = COST_MOVE * n_move | |
reward = sales + costs # Cost is negative | |
# Determine next state (state after moving - rented cars + returned cars) | |
returns₁, returns₂ = Pλ(3), Pλ(2) | |
s′[1] = min(MAX_CARS, s′[1] - requests₁ + returns₁) | |
s′[2] = min(MAX_CARS, s′[2] - requests₂ + returns₂) | |
return s′, reward | |
end | |
function R(s,a) | |
s′, r = update(s,a) | |
return r | |
end | |
function make_probs(λs, cutoff) | |
l1, l2, l3, l4 = λs | |
# Determine probability of 0 to CUT_OFF being requested / returned | |
p_req₁ = pdf(Poisson(l1), 0:cutoff) | |
p_req₁[end] += 1 - cdf(Poisson(l1), CUT_OFF) | |
p_ret₁ = pdf(Poisson(l2), 0:cutoff) | |
p_ret₁[end] += 1 - cdf(Poisson(l2), CUT_OFF) | |
p_req₂ = pdf(Poisson(l3), 0:cutoff) | |
p_req₂[end] += 1 -cdf(Poisson(l3), CUT_OFF) | |
p_ret₂ = pdf(Poisson(l4), 0:cutoff) | |
p_ret₂[end] += 1 - cdf(Poisson(l4), CUT_OFF) | |
# Calculate all combinations of joint probabilities e.g. request = 2, return = 5 | |
joint_p₁ = [p₁*p₂ for p₁ in p_req₁, p₂ in p_ret₁] | |
joint_p₂ = [p₁*p₂ for p₁ in p_req₂, p₂ in p_ret₂] | |
return joint_p₁, joint_p₂ | |
end | |
P_JOINT₁, P_JOINT₂ = make_probs(λs, CUT_OFF) | |
function T(s, a, s′) | |
# Determine state after moving (intermediate step) | |
sₘ, _ = _move(s,a) | |
# Determine Δ in # of cars from final state s′ and intermediate state sm | |
# Example: | |
# res₁ = -3; sum of cars requested + cars returned = - 3, i.e. 3 more cars requested than returned | |
# res₂ = +2; sum of cars requested + cars returned = 2, i.e. 2 more cars returned than requested | |
res₁, res₂ = s′ - sₘ | |
# Calculate possible combinations how to get res₁ / res₂ | |
idx₁ = [-i+j == res₁ for i in 0:CUT_OFF, j in 0:CUT_OFF] | |
idx₂ = [-i+j == res₂ for i in 0:CUT_OFF, j in 0:CUT_OFF] | |
if sum(idx₁) == 0 | |
println(s, s′) | |
end | |
if sum(idx₂) == 0 | |
println(s, s′) | |
end | |
# Calculate joint probability of occurence in loc1 and loc2 | |
p = sum([p1*p2 for p1 in P_JOINT₁[idx₁], p2 in P_JOINT₂[idx₂]]) | |
#p = sum(P_JOINT₁[idx₁]) * sum(P_JOINT₂[idx₂]) | |
return p | |
end | |
function lookahead(P::CarProblem, V, s, a) | |
S, R, T, γ = P.S, P.R, P.P, P.γ | |
return R(s,a) + γ*sum(T(s,a,s′)*V[i] for (i, s′) in enumerate(S)) | |
end | |
# Policy Evaluation via solving system of linear equations | |
function policy_evaluation(P, π) | |
S, R, T, γ = P.S, P.R, P.P, P.γ | |
R′ = [R(s, π(s)) for s in S] | |
T′ = [T(s, π(s), s′) for s in S, s′ in S] | |
return (I - γ*T′)\R′ | |
end | |
# Policy definitions | |
struct ValueFunctionPolicy | |
P::CarProblem | |
V::Array{Any} | |
end | |
# Helper function findmax (see Kochenderfer) | |
function Base.findmax(F::Function, xs) | |
f_max = -Inf | |
x_max = first(xs) | |
for x in xs | |
v = F(x) | |
if v > f_max | |
f_max = v | |
x_max = x | |
end | |
end | |
return f_max, x_max | |
end | |
function greedy(P::CarProblem, V, s) | |
v, a = findmax(a -> lookahead(P, V, s, a), P.A) | |
return (a=a, v=v) | |
end | |
function (π::ValueFunctionPolicy)(s) | |
P, V = π.P, π.V | |
return greedy(P, V,s).a | |
end | |
struct PolicyIteration | |
π::ValueFunctionPolicy | |
k_max::Int | |
end | |
function solve(M::PolicyIteration, P::CarProblem) | |
π, S = M.π, P.S | |
for k in 1:M.k_max | |
println("Iteration: ", k) | |
V = policy_evaluation(P, π,) | |
π′ = ValueFunctionPolicy(P, V) | |
if all(π(s) == π′(s) for s in S) | |
break | |
end | |
π = π′ | |
end | |
return π | |
end | |
# Define problem | |
p = CarProblem(MAX_CARS, MAX_MOVE, 0.9, R, T) | |
# Initialize Policy | |
V₀ = zeros(size(p.S)) | |
π = ValueFunctionPolicy(p, V₀) | |
# Initialize Policy Iteration | |
M = PolicyIteration(π, 10) | |
# Solve | |
π_best = solve(M,p) | |
V = π_best.V | |
heatmap(reshape(V, MAX_CARS +1, MAX_CARS + 1)) | |
# Above implementation is wrong and very slow | |
# wrong: sum(T(s,a,s´) for s´ in p.S) != 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment