Created
July 27, 2020 20:43
-
-
Save samuela/9a1daca41fd46ce5e67f5df150933373 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| In-place: | |
| [ Info: forward | |
| 20.009 ms (23413 allocations: 58.50 MiB) | |
| [ Info: BacksolveAdjoint | |
| 5.056 s (28059882 allocations: 5.93 GiB) | |
| [ Info: InterpolatingAdjoint | |
| 1.596 s (9013271 allocations: 1.89 GiB) | |
| [ Info: QuadratureAdjoint | |
| 473.475 ms (2782989 allocations: 599.98 MiB) | |
| Out-of-place: | |
| [ Info: forward | |
| 22.901 ms (45777 allocations: 60.26 MiB) | |
| [ Info: BacksolveAdjoint | |
| 7.333 s (25147876 allocations: 3.42 GiB) | |
| [ Info: InterpolatingAdjoint | |
| 5.475 s (18755710 allocations: 2.54 GiB) | |
| [ Info: QuadratureAdjoint | |
| 1.659 s (5810587 allocations: 807.58 MiB) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import DifferentialEquations: Tsit5 | |
| import DiffEqFlux: FastChain, FastDense, initial_params, ODEProblem, solve | |
| import Random: seed! | |
| import DiffEqSensitivity: InterpolatingAdjoint, BacksolveAdjoint, ODEAdjointProblem | |
| using BenchmarkTools | |
| seed!(123) | |
| module QuadrotorEnv | |
| import StaticArrays: SA | |
| function env(floatT, gravity, mass, Ix, Iy, Iz) | |
| twopi = convert(floatT, 2π) | |
| function dynamics(state, u) | |
| # See Eq 2.25. | |
| x, y, z, ψ, θ, ϕ, ẋ, ẏ, ż, p, q, r = state | |
| sinψ, cosψ = sincos(ψ) | |
| sinθ, cosθ = sincos(θ) | |
| sinϕ, cosϕ = sincos(ϕ) | |
| tanθ = sinθ / cosθ | |
| g_1_7 = -1 / mass * (sinϕ * sinψ + cosϕ * cosψ * sinθ) | |
| g_1_8 = -1 / mass * (cosψ * sinϕ - cosϕ * sinψ * sinθ) | |
| g_1_9 = -1 / mass * cosϕ * cosθ | |
| SA[ | |
| ẋ, | |
| ẏ, | |
| ż, | |
| q*sinϕ/cosθ+r*cosϕ/cosθ, | |
| q*cosϕ-r*sinϕ, | |
| p+q * sinϕ * tanθ+r * cosϕ * tanθ, | |
| g_1_7*u[1], | |
| g_1_8*u[1], | |
| gravity+g_1_9*u[1], | |
| (Iy - Iz) / Ix*q*r+u[2]/Ix, | |
| (Iz - Ix) / Iy*p*r+u[3]/Iy, | |
| (Ix - Iy) / Iz*p*q+u[4]/Iz, | |
| ] | |
| end | |
| function cost(state, u) | |
| x, y, z, ψ, θ, ϕ, ẋ, ẏ, ż, p, q, r = state | |
| (x^2 + y^2 + z^2) + (ẋ^2 + ẏ^2 + ż^2) + 0.01 * u' * u | |
| end | |
| function sample_x0() | |
| x = rand() * 5 - 2.5 | |
| y = rand() * 5 - 2.5 | |
| z = rand() * 5 - 2.5 | |
| ψ = randn() * 0.1 | |
| θ = randn() * 0.1 | |
| ϕ = randn() * 0.1 | |
| ẋ = 0 | |
| ẏ = 0 | |
| ż = 0 | |
| p = randn() * 0.1 | |
| q = randn() * 0.1 | |
| r = randn() * 0.1 | |
| convert(Array{floatT}, [x, y, z, ψ, θ, ϕ, ẋ, ẏ, ż, p, q, r]) | |
| end | |
| function observation(state) | |
| x, y, z, ψ, θ, ϕ, ẋ, ẏ, ż, p, q, r = state | |
| sinψ, cosψ = sincos(ψ) | |
| sinθ, cosθ = sincos(θ) | |
| sinϕ, cosϕ = sincos(ϕ) | |
| tanψ = sinψ / cosψ | |
| tanθ = sinθ / cosθ | |
| tanϕ = sinϕ / cosϕ | |
| sinp, cosp = sincos(p) | |
| sinq, cosq = sincos(q) | |
| sinr, cosr = sincos(r) | |
| tanp = sinp / cosp | |
| tanq = sinq / cosq | |
| tanr = sinr / cosr | |
| # Using a StaticArray here doesn't work with FastChain and the rest. | |
| [ | |
| sinψ, | |
| cosψ, | |
| tanψ, | |
| sinθ, | |
| cosθ, | |
| tanθ, | |
| sinϕ, | |
| cosϕ, | |
| tanϕ, | |
| sinp, | |
| cosp, | |
| tanp, | |
| sinq, | |
| cosq, | |
| tanq, | |
| sinr, | |
| cosr, | |
| tanr, | |
| x^2 + y^2 + z^2, | |
| ẋ^2 + ẏ^2 + ż^2, | |
| x, | |
| y, | |
| z, | |
| ψ % twopi, | |
| θ % twopi, | |
| ϕ % twopi, | |
| ẋ, | |
| ẏ, | |
| ż, | |
| p % twopi, | |
| q % twopi, | |
| r % twopi, | |
| ] | |
| end | |
| dynamics, cost, sample_x0, observation | |
| end | |
| end | |
| T = 25.0 | |
| dynamics, cost, sample_x0, obs = QuadrotorEnv.env(floatT, 9.8f0, 1, 1, 1, 1) | |
| x0 = sample_x0() | |
| num_hidden = 64 | |
| policy = FastChain( | |
| (x, _) -> obs(x), | |
| FastDense(32, num_hidden, tanh), | |
| FastDense(num_hidden, num_hidden, tanh), | |
| FastDense(num_hidden, 4), | |
| ) | |
| function aug_dynamics(z, policy_params, t) | |
| x = @view z[2:end] | |
| u = policy(x, policy_params) | |
| vcat(cost(x, u), dynamics(x, u)) | |
| # [x' * x + u' * u; u] | |
| end | |
| function aug_dynamics!(dz, z, policy_params, t) | |
| x = @view z[2:end] | |
| u = policy(x, policy_params) | |
| dz[1] = cost(x, u) | |
| # Note that dynamics!(dz[2:end], x, u) breaks Zygote :( | |
| dz[2:end] = dynamics(x, u) | |
| end | |
| function loss_pullback(x0, policy_params) | |
| z0 = vcat(0.0, x0) | |
| fwd_sol = solve( | |
| ODEProblem(aug_dynamics, z0, (0, T), policy_params), # change for in-place vs out-of-place... | |
| Tsit5(), | |
| u0 = z0, | |
| p = policy_params, | |
| ) | |
| function _adjoint_solve(g_zT, sensealg; kwargs...) | |
| solve( | |
| ODEAdjointProblem(fwd_sol, sensealg, (out, x, p, t, i) -> (out[:] = g_zT), [T]), | |
| Tsit5(); | |
| kwargs..., | |
| ) | |
| end | |
| # This is the pullback using the augmented system and a discrete | |
| # gradient input at time T. Alternatively one could use the continuous | |
| # adjoints on the non-augmented system although that seems to be slower. | |
| function pullback(g_zT, sensealg::BacksolveAdjoint) | |
| _adjoint_solve( | |
| g_zT, | |
| sensealg, | |
| dense = false, | |
| save_everystep = false, | |
| save_start = false, | |
| ) | |
| # Not bothering to slice out the gradient from the results of the | |
| # adjoint solve; just trying to measure performance. | |
| end | |
| function pullback(g_zT, sensealg::InterpolatingAdjoint) | |
| _adjoint_solve( | |
| g_zT, | |
| sensealg, | |
| dense = false, | |
| save_everystep = false, | |
| save_start = false, | |
| ) | |
| end | |
| function pullback(g_zT, sensealg::QuadratureAdjoint) | |
| # See https://github.com/SciML/DiffEqSensitivity.jl/blob/master/src/local_sensitivity/quadrature_adjoint.jl#L173. | |
| # This is 75% of the time and allocs of the pullback. quadgk is | |
| # actually lightweight relatively speaking. | |
| _adjoint_solve(g_zT, sensealg, save_everystep = true, save_start = true) | |
| # Skip the whole quadrature bit, only measuring the adjoint solve. | |
| end | |
| fwd_sol, pullback | |
| end | |
| policy_params = initial_params(policy) | |
| @info "forward" | |
| @btime loss_pullback(x0, policy_params) | |
| fwd_sol, vjp = loss_pullback(x0, policy_params) | |
| g = vcat(1, zero(x0)) | |
| @info "BacksolveAdjoint" | |
| @btime vjp(g, BacksolveAdjoint()) | |
| @info "InterpolatingAdjoint" | |
| @btime vjp(g, InterpolatingAdjoint()) | |
| @info "QuadratureAdjoint" | |
| @btime vjp(g, QuadratureAdjoint()) | |
| nothing |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment