Created
July 3, 2020 06:09
-
-
Save samuela/d53088b8aa7403a77ec5b6e51166c0f3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
julia> include("HagerZhang_bug.jl") | |
Loss 25.880611 | |
Loss 34.07 | |
ERROR: LoadError: ArgumentError: Value and slope at step length = 0 must be finite. | |
Stacktrace: | |
[1] (::LineSearches.HagerZhang{Float64,Base.RefValue{Bool}})(::Function, ::LineSearches.var"#ϕdϕ#6"{Optim.ManifoldObjective{NLSolversBase.TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}},Array{Float32,1},Array{Float32,1},Array{Float32,1}}, ::Float32, ::Float32, ::Float32) at /Users/skainswo/.julia/packages/LineSearches/WrsMD/src/hagerzhang.jl:117 | |
[2] HagerZhang at /Users/skainswo/.julia/packages/LineSearches/WrsMD/src/hagerzhang.jl:101 [inlined] | |
[3] perform_linesearch!(::Optim.LBFGSState{Array{Float32,1},Array{Array{Float32,1},1},Array{Array{Float32,1},1},Float32,Array{Float32,1}}, ::LBFGS{Nothing,LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Optim.var"#19#21"}, ::Optim.ManifoldObjective{NLSolversBase.TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}}) at /Users/skainswo/.julia/packages/Optim/L5T76/src/utilities/perform_linesearch.jl:56 | |
[4] update_state!(::NLSolversBase.TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}, ::Optim.LBFGSState{Array{Float32,1},Array{Array{Float32,1},1},Array{Array{Float32,1},1},Float32,Array{Float32,1}}, ::LBFGS{Nothing,LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Optim.var"#19#21"}) at /Users/skainswo/.julia/packages/Optim/L5T76/src/multivariate/solvers/first_order/l_bfgs.jl:198 | |
[5] optimize(::NLSolversBase.TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}, ::Array{Float32,1}, ::LBFGS{Nothing,LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Optim.var"#19#21"}, ::Optim.Options{Float64,DiffEqFlux.var"#_cb#55"{var"#472#473",DataLoader}}, ::Optim.LBFGSState{Array{Float32,1},Array{Array{Float32,1},1},Array{Array{Float32,1},1},Float32,Array{Float32,1}}) at /Users/skainswo/.julia/packages/Optim/L5T76/src/multivariate/optimize/optimize.jl:57 | |
[6] optimize(::NLSolversBase.TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}, ::Array{Float32,1}, ::LBFGS{Nothing,LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Optim.var"#19#21"}, ::Optim.Options{Float64,DiffEqFlux.var"#_cb#55"{var"#472#473",DataLoader}}) at /Users/skainswo/.julia/packages/Optim/L5T76/src/multivariate/optimize/optimize.jl:33 | |
[7] sciml_train(::Function, ::Array{Float32,1}, ::LBFGS{Nothing,LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Optim.var"#19#21"}, ::DataLoader; cb::Function, maxiters::Int64, diffmode::DiffEqFlux.ZygoteDiffMode, kwargs::Base.Iterators.Pairs{Symbol,Real,NTuple{8,Symbol},NamedTuple{(:iterations, :allow_f_increases, :x_abstol, :x_reltol, :f_abstol, :f_reltol, :g_abstol, :g_reltol),Tuple{Int64,Bool,Float64,Float64,Float64,Float64,Float64,Float64}}}) at /Users/skainswo/.julia/packages/DiffEqFlux/vOgUc/src/train.jl:269 | |
[8] top-level scope at /Users/skainswo/dev/research/julia/odecontrol/pendulum_unconstrained.jl:101 | |
[9] include(::String) at ./client.jl:439 | |
[10] top-level scope at REPL[55]:1 | |
in expression starting at /Users/skainswo/dev/research/julia/odecontrol/pendulum_unconstrained.jl:101 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import DifferentialEquations: Tsit5 | |
import Flux: ADAM | |
import Flux.Data: DataLoader | |
import DiffEqFlux: | |
FastChain, FastDense, initial_params, sciml_train, ODEProblem, solve | |
import Random: seed!, randn | |
import Plots: plot | |
import Statistics: mean | |
import Zygote | |
using Optim: LBFGS | |
import DiffEqSensitivity: InterpolatingAdjoint | |
include("pendulum.jl") | |
seed!(1) | |
T = 5.0 | |
batch_size = 1 | |
num_hidden = 32 | |
policy = FastChain( | |
FastDense(4, num_hidden, tanh), | |
# FastDense(num_hidden, num_hidden, tanh), | |
FastDense(num_hidden, 1), | |
# (x, _) -> 2 * x, | |
) | |
# policy = FastDense(4, 1) # linear policy | |
# The model weights are destructured into a vector of parameters | |
init_policy_params = initial_params(policy) | |
dynamics, cost, sample_x0 = pendulum_env(1, 1, 9.8, 0) | |
function preprocess(x) | |
θ, θ_dot = x | |
sinθ, cosθ = sincos(θ) | |
[θ, θ_dot, sinθ, cosθ] | |
end | |
function aug_dynamics!(dz, z, policy_params, t) | |
x = z[2:end] | |
u = policy(preprocess(x), policy_params)[1] | |
dz[1] = cost(x, u) | |
# Note that dynamics!(dz[2:end], x, u) breaks Zygote :( | |
dz[2:end] = dynamics(x, u) | |
end | |
# @benchmark aug_dynamics!(rand(3), rand(3), init_policy_params, 0.0) | |
### Example rollout. | |
# x0 = [π - 0.1f0, 0f0]::Array{Float32} | |
# z0 = [0f0, x0...]::Array{Float32} | |
# rollout = solve( | |
# ODEProblem(aug_dynamics!, z0, (0, T)), | |
# Tsit5(), | |
# u0 = z0, | |
# p = init_policy_params, | |
# ) | |
# tspan = 0:0.05:T | |
# plot(tspan, hcat(rollout.(tspan)...)', label = ["cost" "θ" "θ dot"]) | |
function loss(policy_params, data...) | |
# TODO: use the ensemble thing | |
mean([ | |
begin | |
z0 = [0f0, x0...] | |
rollout = solve( | |
ODEProblem(aug_dynamics!, z0, (0, T), policy_params), | |
Tsit5(), | |
u0 = z0, | |
p = policy_params, | |
sensealg = InterpolatingAdjoint(), | |
) | |
Array(rollout)[1, end] | |
end for x0 in data | |
]) | |
end | |
callback = function (policy_params, loss_val) | |
println("Loss $loss_val") | |
z0_bottom = [0f0, 0f0, 0f0]::Array{Float32} | |
rollout = solve( | |
ODEProblem(aug_dynamics!, z0_bottom, (0, T), policy_params), | |
Tsit5(), | |
u0 = z0_bottom, | |
p = policy_params, | |
) | |
tspan = 0:0.05:T | |
rollout_arr = hcat(rollout.(tspan)...) | |
display(plot( | |
tspan, | |
[cos.(rollout_arr[2, :]), rollout_arr[3, :]], | |
label = ["cos(θ)" "θ dot"], | |
title = "Swing up cost: $(rollout_arr[1, end])", | |
)) | |
false | |
end | |
data = DataLoader([sample_x0() for _ = 1:1_000_000], batchsize = batch_size) | |
# res1 = sciml_train(loss, init_policy_params, ADAM(), data, cb = callback) | |
res1 = sciml_train( | |
loss, | |
init_policy_params, | |
LBFGS( | |
# alphaguess = LineSearches.InitialStatic(alpha = 0.01), | |
# linesearch = LineSearches.Static(), | |
# linesearch = LineSearches.BackTracking(order = 2), | |
), | |
data, | |
cb = callback, | |
iterations = length(data), | |
allow_f_increases = true, | |
# allow_outer_f_increases = true, | |
x_abstol = NaN, | |
x_reltol = NaN, | |
f_abstol = NaN, | |
f_reltol = NaN, | |
g_abstol = NaN, | |
g_reltol = NaN, | |
) | |
# @profile res1 = | |
# sciml_train(loss, init_policy_params, ADAM(), [first(data)], cb = callback) | |
# 1.180s median | |
# @benchmark Zygote.gradient(loss, init_policy_params, first(data)...) | |
# @benchmark Zygote.gradient(loss, init_policy_params, first(data)...) | |
# 1.192s median | |
# @benchmark sciml_train(loss, init_policy_params, ADAM(), data, maxiters = 1) | |
# begin | |
# import LineSearches | |
# opt = LBFGS( | |
# alphaguess = LineSearches.InitialStatic(alpha = 0.1), | |
# linesearch = LineSearches.Static(), | |
# ) | |
# @timev sciml_train(loss, init_policy_params, opt, data, maxiters = 1) | |
# @benchmark sciml_train(loss, init_policy_params, opt, data, maxiters = 1) | |
# end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment