Skip to content

Instantly share code, notes, and snippets.

@pkofod
Created August 18, 2016 11:22
Show Gist options
  • Save pkofod/41ef3bceab07586e62a2cd01590bc56e to your computer and use it in GitHub Desktop.
Save pkofod/41ef3bceab07586e62a2cd01590bc56e to your computer and use it in GitHub Desktop.
function LearnBase.learn!(solver::CrossEntropyMethod, env::AbstractEnvironment, doanim = false)
# !!! INIT:
# this is a mappable function of θ to reward
cem_episode = θ -> begin
π = cem_policy(env, θ)
R, T = episode!(env, π; maxiter = solver.options[:maxiter])
R
end
result = optimize(cem_episode, solver.μ, CrossEntropy())
# is solver.μ the paramter to be found?
solver.μ[:] = Optim.minimzer(result)
end
method_string(method::CrossEntropy) = "Nicely Formatted Method"
type CrossEntropyState{T}
# your variables if any
# default variables; need to document this, initial_x
end
initialize_state(method::Method, options, d, initial_x::Array) = initialize_state(method, options, d.f, initial_x)
function initialize_state(method::Method, options, f::Function, initial_x)
#something like
anim = doanim ? Animation() : nothing
n_elite = round(Int, solver.options[:cem_batch_size] * solver.options[:cem_elite_frac])
last_μ = similar(solver.μ)
tr = OptimizationTrace{typeof(Method)}()
CrossEntropy(anim, n_elite, last_μ, tr, default_variables...)
end
update!(d, state::ParticleSwarmState, method::ParticleSwarm) = update!(d.f, state, method)
function update!{T}(f::Function, state::ParticleSwarmState{T}, method::ParticleSwarm)
last_μ = copy(solver.μ)
# sample thetas from a multivariate normal distribution
N = MultivariateNormal(solver.μ, solver.σ)
θs = [rand(N) for k=1:solver.options[:cem_batch_size]]
# compute rewards and pick out an elite set
Rs = map(cem_episode, θs)
elite_indices = sortperm(Rs, rev=true)[1:n_elite]
elite_θs = θs[elite_indices]
info("Iteration $t. mean(R): $(mean(Rs)) max(R): $(maximum(Rs))")
# update the policy from the elite set
for j=1:length(solver.μ)
θj = [θ[j] for θ in elite_θs]
solver.μ[j] = mean(θj)
solver.Z[j] = solver.noise_func(t)
solver.σ[j] = sqrt(var(θj) + solver.Z[j])
end
@show solver.μ solver.σ solver.Z
state.R, state.T = episode!(
env,
cem_policy(env, solver.μ),
maxiter = solver.options[:maxiter],
stepfunc = myplot(t, hist_min, hist_mean, hist_max, anim)
)
false
end
function assess_convergence(state::CrossEntropy, options)
normdiff = norm(solver.μ - last_μ)
@show normdiff
if normdiff < state.options[:stopping_norm]
info("Converged after $(t*state.options[:cem_batch_size]) episodes.")
return true
end
false
end
function trace!(tr, state, iteration, method, options)
# save the three values in the extended trace using a dictionary
end
after_while!(d, state, method, options)
state.doanim && gif(state.anim)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment