Skip to content

Instantly share code, notes, and snippets.

@samuela
Created July 3, 2020 06:09
Show Gist options
  • Save samuela/d53088b8aa7403a77ec5b6e51166c0f3 to your computer and use it in GitHub Desktop.
Save samuela/d53088b8aa7403a77ec5b6e51166c0f3 to your computer and use it in GitHub Desktop.
julia> include("HagerZhang_bug.jl")
Loss 25.880611
Loss 34.07
ERROR: LoadError: ArgumentError: Value and slope at step length = 0 must be finite.
Stacktrace:
[1] (::LineSearches.HagerZhang{Float64,Base.RefValue{Bool}})(::Function, ::LineSearches.var"#ϕdϕ#6"{Optim.ManifoldObjective{NLSolversBase.TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}},Array{Float32,1},Array{Float32,1},Array{Float32,1}}, ::Float32, ::Float32, ::Float32) at /Users/skainswo/.julia/packages/LineSearches/WrsMD/src/hagerzhang.jl:117
[2] HagerZhang at /Users/skainswo/.julia/packages/LineSearches/WrsMD/src/hagerzhang.jl:101 [inlined]
[3] perform_linesearch!(::Optim.LBFGSState{Array{Float32,1},Array{Array{Float32,1},1},Array{Array{Float32,1},1},Float32,Array{Float32,1}}, ::LBFGS{Nothing,LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Optim.var"#19#21"}, ::Optim.ManifoldObjective{NLSolversBase.TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}}) at /Users/skainswo/.julia/packages/Optim/L5T76/src/utilities/perform_linesearch.jl:56
[4] update_state!(::NLSolversBase.TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}, ::Optim.LBFGSState{Array{Float32,1},Array{Array{Float32,1},1},Array{Array{Float32,1},1},Float32,Array{Float32,1}}, ::LBFGS{Nothing,LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Optim.var"#19#21"}) at /Users/skainswo/.julia/packages/Optim/L5T76/src/multivariate/solvers/first_order/l_bfgs.jl:198
[5] optimize(::NLSolversBase.TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}, ::Array{Float32,1}, ::LBFGS{Nothing,LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Optim.var"#19#21"}, ::Optim.Options{Float64,DiffEqFlux.var"#_cb#55"{var"#472#473",DataLoader}}, ::Optim.LBFGSState{Array{Float32,1},Array{Array{Float32,1},1},Array{Array{Float32,1},1},Float32,Array{Float32,1}}) at /Users/skainswo/.julia/packages/Optim/L5T76/src/multivariate/optimize/optimize.jl:57
[6] optimize(::NLSolversBase.TwiceDifferentiable{Float32,Array{Float32,1},Array{Float32,2},Array{Float32,1}}, ::Array{Float32,1}, ::LBFGS{Nothing,LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Optim.var"#19#21"}, ::Optim.Options{Float64,DiffEqFlux.var"#_cb#55"{var"#472#473",DataLoader}}) at /Users/skainswo/.julia/packages/Optim/L5T76/src/multivariate/optimize/optimize.jl:33
[7] sciml_train(::Function, ::Array{Float32,1}, ::LBFGS{Nothing,LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},Optim.var"#19#21"}, ::DataLoader; cb::Function, maxiters::Int64, diffmode::DiffEqFlux.ZygoteDiffMode, kwargs::Base.Iterators.Pairs{Symbol,Real,NTuple{8,Symbol},NamedTuple{(:iterations, :allow_f_increases, :x_abstol, :x_reltol, :f_abstol, :f_reltol, :g_abstol, :g_reltol),Tuple{Int64,Bool,Float64,Float64,Float64,Float64,Float64,Float64}}}) at /Users/skainswo/.julia/packages/DiffEqFlux/vOgUc/src/train.jl:269
[8] top-level scope at /Users/skainswo/dev/research/julia/odecontrol/pendulum_unconstrained.jl:101
[9] include(::String) at ./client.jl:439
[10] top-level scope at REPL[55]:1
in expression starting at /Users/skainswo/dev/research/julia/odecontrol/pendulum_unconstrained.jl:101
import DifferentialEquations: Tsit5
import Flux: ADAM
import Flux.Data: DataLoader
import DiffEqFlux:
FastChain, FastDense, initial_params, sciml_train, ODEProblem, solve
import Random: seed!, randn
import Plots: plot
import Statistics: mean
import Zygote
using Optim: LBFGS
import DiffEqSensitivity: InterpolatingAdjoint
include("pendulum.jl")
seed!(1)
T = 5.0
batch_size = 1
num_hidden = 32
policy = FastChain(
FastDense(4, num_hidden, tanh),
# FastDense(num_hidden, num_hidden, tanh),
FastDense(num_hidden, 1),
# (x, _) -> 2 * x,
)
# policy = FastDense(4, 1) # linear policy
# The model weights are destructured into a vector of parameters
init_policy_params = initial_params(policy)
dynamics, cost, sample_x0 = pendulum_env(1, 1, 9.8, 0)
function preprocess(x)
θ, θ_dot = x
sinθ, cosθ = sincos(θ)
[θ, θ_dot, sinθ, cosθ]
end
function aug_dynamics!(dz, z, policy_params, t)
x = z[2:end]
u = policy(preprocess(x), policy_params)[1]
dz[1] = cost(x, u)
# Note that dynamics!(dz[2:end], x, u) breaks Zygote :(
dz[2:end] = dynamics(x, u)
end
# @benchmark aug_dynamics!(rand(3), rand(3), init_policy_params, 0.0)
### Example rollout.
# x0 = [π - 0.1f0, 0f0]::Array{Float32}
# z0 = [0f0, x0...]::Array{Float32}
# rollout = solve(
# ODEProblem(aug_dynamics!, z0, (0, T)),
# Tsit5(),
# u0 = z0,
# p = init_policy_params,
# )
# tspan = 0:0.05:T
# plot(tspan, hcat(rollout.(tspan)...)', label = ["cost" "θ" "θ dot"])
function loss(policy_params, data...)
# TODO: use the ensemble thing
mean([
begin
z0 = [0f0, x0...]
rollout = solve(
ODEProblem(aug_dynamics!, z0, (0, T), policy_params),
Tsit5(),
u0 = z0,
p = policy_params,
sensealg = InterpolatingAdjoint(),
)
Array(rollout)[1, end]
end for x0 in data
])
end
callback = function (policy_params, loss_val)
println("Loss $loss_val")
z0_bottom = [0f0, 0f0, 0f0]::Array{Float32}
rollout = solve(
ODEProblem(aug_dynamics!, z0_bottom, (0, T), policy_params),
Tsit5(),
u0 = z0_bottom,
p = policy_params,
)
tspan = 0:0.05:T
rollout_arr = hcat(rollout.(tspan)...)
display(plot(
tspan,
[cos.(rollout_arr[2, :]), rollout_arr[3, :]],
label = ["cos(θ)" "θ dot"],
title = "Swing up cost: $(rollout_arr[1, end])",
))
false
end
data = DataLoader([sample_x0() for _ = 1:1_000_000], batchsize = batch_size)
# res1 = sciml_train(loss, init_policy_params, ADAM(), data, cb = callback)
res1 = sciml_train(
loss,
init_policy_params,
LBFGS(
# alphaguess = LineSearches.InitialStatic(alpha = 0.01),
# linesearch = LineSearches.Static(),
# linesearch = LineSearches.BackTracking(order = 2),
),
data,
cb = callback,
iterations = length(data),
allow_f_increases = true,
# allow_outer_f_increases = true,
x_abstol = NaN,
x_reltol = NaN,
f_abstol = NaN,
f_reltol = NaN,
g_abstol = NaN,
g_reltol = NaN,
)
# @profile res1 =
# sciml_train(loss, init_policy_params, ADAM(), [first(data)], cb = callback)
# 1.180s median
# @benchmark Zygote.gradient(loss, init_policy_params, first(data)...)
# @benchmark Zygote.gradient(loss, init_policy_params, first(data)...)
# 1.192s median
# @benchmark sciml_train(loss, init_policy_params, ADAM(), data, maxiters = 1)
# begin
# import LineSearches
# opt = LBFGS(
# alphaguess = LineSearches.InitialStatic(alpha = 0.1),
# linesearch = LineSearches.Static(),
# )
# @timev sciml_train(loss, init_policy_params, opt, data, maxiters = 1)
# @benchmark sciml_train(loss, init_policy_params, opt, data, maxiters = 1)
# end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment