Skip to content

Instantly share code, notes, and snippets.

@samuela
Last active July 27, 2020 19:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save samuela/8d7cf55cf921decfffc7559691bfad12 to your computer and use it in GitHub Desktop.
Save samuela/8d7cf55cf921decfffc7559691bfad12 to your computer and use it in GitHub Desktop.
import DifferentialEquations: Tsit5
import DiffEqFlux: FastChain, FastDense, initial_params, ODEProblem, solve
import Random: seed!
import DiffEqSensitivity: InterpolatingAdjoint, BacksolveAdjoint, ODEAdjointProblem
using BenchmarkTools
seed!(123)
x_dim = 64
T = 25.0
x0 = ones(x_dim)
num_hidden = 32
policy = FastChain(FastDense(x_dim, num_hidden, tanh), FastDense(num_hidden, x_dim))
function aug_dynamics(z, policy_params, t)
x = @view z[2:end]
u = policy(x, policy_params)
# dz[1] = x' * x + u' * u
# # Note that dynamics!(dz[2:end], x, u) breaks Zygote :(
# dz[2:end] = u
vcat(x' * x + u' * u, u)
# [x' * x + u' * u; u]
end
function loss_pullback(x0, policy_params)
z0 = vcat(0.0, x0)
fwd_sol = solve(
ODEProblem(aug_dynamics, z0, (0, T), policy_params),
Tsit5(),
u0 = z0,
p = policy_params,
)
function _adjoint_solve(g_zT, sensealg; kwargs...)
solve(
ODEAdjointProblem(fwd_sol, sensealg, (out, x, p, t, i) -> (out[:] = g_zT), [T]),
Tsit5();
kwargs...,
)
end
# This is the pullback using the augmented system and a discrete
# gradient input at time T. Alternatively one could use the continuous
# adjoints on the non-augmented system although that seems to be slower.
function pullback(g_zT, sensealg::BacksolveAdjoint)
_adjoint_solve(
g_zT,
sensealg,
dense = false,
save_everystep = false,
save_start = false,
)
# Not bothering to slice out the gradient from the results of the
# adjoint solve; just trying to measure performance.
end
function pullback(g_zT, sensealg::InterpolatingAdjoint)
_adjoint_solve(
g_zT,
sensealg,
dense = false,
save_everystep = false,
save_start = false,
)
end
function pullback(g_zT, sensealg::QuadratureAdjoint)
# See https://github.com/SciML/DiffEqSensitivity.jl/blob/master/src/local_sensitivity/quadrature_adjoint.jl#L173.
# This is 75% of the time and allocs of the pullback. quadgk is
# actually lightweight relatively speaking.
_adjoint_solve(g_zT, sensealg, save_everystep = true, save_start = true)
# Skip the whole quadrature bit, only measuring the adjoint solve.
end
fwd_sol, pullback
end
policy_params = initial_params(policy)
@info "forward"
@btime loss_pullback(x0, policy_params)
fwd_sol, vjp = loss_pullback(x0, policy_params)
g = vcat(1, zero(x0))
@info "BacksolveAdjoint"
@btime vjp(g, BacksolveAdjoint())
@info "InterpolatingAdjoint"
@btime vjp(g, InterpolatingAdjoint())
@info "QuadratureAdjoint"
@btime vjp(g, QuadratureAdjoint())
nothing
[ Info: forward
1.443 ms (9420 allocations: 7.27 MiB)
[ Info: BacksolveAdjoint
50.525 ms (115033 allocations: 113.32 MiB)
[ Info: InterpolatingAdjoint
30.859 ms (81025 allocations: 75.65 MiB)
[ Info: QuadratureAdjoint
21.833 ms (50868 allocations: 48.77 MiB)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment