Skip to content

Instantly share code, notes, and snippets.

@zsunberg
Last active December 6, 2019 05:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zsunberg/67fb8fabda30c7bb6228277ada98885f to your computer and use it in GitHub Desktop.
Save zsunberg/67fb8fabda30c7bb6228277ada98885f to your computer and use it in GitHub Desktop.

Here are the two ways that I was referring to about augmenting the state space (these are illustrative rather than efficient or complete implementations):

  1. Add a single new terminal state
struct VariableDiscountWrapper1{S, A, F<:Function} <: MDP{Union{S, TerminalState}, A}
    m::MDP{S, A}
    discount::F
end

function transition(m::VariableDiscountWrapper1, s, a)
    td = transition(m.m, s, a)
    disc = m.discount(s, a)
    return CombinedDistribution([disc=>td, (1-disc)=>Deterministic(terminalstate)])
end

states(m::VariableDiscountWrapper1) = push!(collect(states(m.m)), terminalstate)

# forward other methods to m - we can make something that does this

(CombinedDistribution hasn't been implemented yet, if it doesn't make sense, let me know)

  1. Keep track of the discount factor
struct VariableDiscountWrapper2{S, A, F<:Function} <: MDP{Tuple{S, Float64}, A}
    m::MDP{S, A}
    discount::F
end

initialstate(m::VariableDiscountWrapper2, rng) = (initialstate(m.m), 1.0)

function gen(::DDNNode{:sp}, m::VariableDiscountWrapper2, sd, a, rng)
    s, d = sd
    sp = gen(DDNNode(:sp), m.m, s, a, rng)
    return (sp, d*discount(s, a))
end

function reward(m::VariableDiscountWrapper2, sd, a, sdp)
    s, d = sd
    sp = sdp[1]
    return d*reward(m.m, s, a, sp)
end

# forward other methods to m
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment