Skip to content

Instantly share code, notes, and snippets.

@sethaxen
Last active December 26, 2020 20:59
Show Gist options
  • Save sethaxen/4071b401b9b4ff4f5421136cec2fa7da to your computer and use it in GitHub Desktop.
Save sethaxen/4071b401b9b4ff4f5421136cec2fa7da to your computer and use it in GitHub Desktop.
Chain rules for the action of the matrix exponential
# rrule for expv(t, A, v) = exp(t * A) * v
# Use OrdinaryDiffEq to solve adjoint system for pullback
# Since OrdinaryDiffEq depends on ExponentialUtilities, it doesn't make sense
# to include this code there
using Pkg
Pkg.add(["ChainRulesCore", "ChainRulesTestUtils", "ExponentialUtilities", "FiniteDifferences", "OrdinaryDiffEq", "Test", "LinearAlgebra"])
using ChainRulesCore, ChainRulesTestUtils, FiniteDifferences, OrdinaryDiffEq, Test, LinearAlgebra, Random
using FiniteDifferences: rand_tangent
using ExponentialUtilities: ExponentialUtilities
function expv!(w, t, A, v)
ExponentialUtilities.expv!(w, t, ExponentialUtilities.arnoldi(A, v))
return w
end
function expv!(w, t, A::Diagonal, v)
w .= exp.(t .* A.diag) .* v
return w
end
function expv(t, A, v; kwargs...)
w = similar(v, Base.promote_eltypeof(t, A, v))
isa(A, AbstractMatrix) && isdiag(A) && (expv!(w, t, Diagonal(A), v); return w)
expv!(w, t, A, v)
return w
end
function ChainRulesCore.rrule(
::typeof(expv),
t,
A,
v;
adjoint_solver = Tsit5(),
adjoint_solve_kwargs = (; abstol=1e-8, reltol=1e-8),
kwargs...
)
w = expv(t, A, v; kwargs...)
function expv_pullback(Δw)
∂t = @thunk expv_rev_t(t, A, w, Δw)
∂A = @thunk expv_rev_A(t, A, w, Δw, adjoint_solver, adjoint_solve_kwargs)
# NOTE: ∂v is computed as part of ∂A, but because ∂A is much more expensive, we
# recompute here in case the user doesn't need ∂A
∂v = @thunk expv_rev_v(t, A, w, Δw; kwargs...)
return NO_FIELDS, ∂t, ∂A, ∂v
end
return w, expv_pullback
end
expv_rev_t(t, A, w, Δw) = conj(mayberealify!(t, dot(Δw, A, w)))
expv_rev_v(t, A, w, Δw; kwargs...) = expv(conj(t), A', Δw; kwargs...)
expv_rev_A(t, A::Diagonal, w, Δw, args...) = outer!(similar(A), Δw, w, conj(t), false)
function expv_rev_A(t, A, w, Δw, adjoint_solver, adjoint_solve_kwargs)
∂A = similar(A)
if isdiag(A)
copyto!(∂A, expv_rev_A(t, Diagonal(A), w, Δw))
return ∂A
end
# solve system backwards, augmented to evolve adjoints to parameters
# based on Algorithm 1 of https://arxiv.org/abs/1806.07366, though the approach
# is older
n = length(w)
u0 = Matrix{Base.promote_eltypeof(w, Δw, A)}(undef, n, n + 2)
u0[:, 1] .= w
u0[:, 2] .= Δw
fill!(@view(u0[:, 3:n+2]), false)
solver_kwargs = merge(adjoint_solve_kwargs, (; save_everystep=false))
# include time as parameter, which allows us to handle t::Complex
Tt = real(typeof(t))
tspan = (one(Tt), zero(Tt))
params = (A, n, t)
problem = ODEProblem(f_expv_rev_A!, u0, tspan, params)
u1 = last(solve(problem, adjoint_solver; solver_kwargs...))
# preserve sparsity pattern of A
# NOTE: inspect efficiency
broadcast!(∂A, A, @view(u1[:, 3:n+2])) do Aij, ∂Aij
mayberealify!(Aij, conj(t) * ifelse(iszero(Aij), zero(∂Aij), ∂Aij))
end
return ∂A
end
function f_expv_rev_A!(du, u, (A, n, c), t)
z, a = @views u[:, 1], u[:, 2]
dz, da, dμ = @views du[:, 1], du[:, 2], du[:, 3:n+2]
mul!(dz, A, z, c, false)
mul!(da, A', a, -conj(c), false)
outer!(dμ, a, z, -1, false)
end
# in-place outer product Z = x α y' + Z β, for vectors x,y and scalars α,β,
# respecting sparsity pattern of Z and avoiding unnecessary computation if possible
function outer!(Z, x, y, α, β)
broadcast!(Z, Z, x, y', B) do Zij, xi, yj
β * Zij + α * ifelse(iszero(Zij), zero(Base.promote_typeof(xi, yj)), xi * yj)
end
return Z
end
outer!(Z::StridedMatrix, x, y, α, β) = mul!(Z, x, y', α, β)
function outer!(Z::Diagonal, x, y, α, β)
Z.diag .= Z.diag .* β .+ x .* α .* conj.(y)
return Z
end
# this is essentially the pullback of `promote` for `Real` and `Complex`,
# i.e. a projection to ℝ if the primal was in ℝ
mayberealify!(x::Real, y) = real(y)
mayberealify!(x::Number, y) = y
function mayberealify!(x, y)
if eltype(x) <: Real && !(eltype(y) <: Real)
x .= real.(y)
else
copyto!(x, y)
end
return x
end
# test the rule using finite differences
# WARNING: type-piracy
function FiniteDifferences.to_vec(x::Tridiagonal)
xvec, back = to_vec(Matrix(x))
Tridiagonal_back(xvec) = Tridiagonal(back(xvec))
return xvec, Tridiagonal_back
end
function FiniteDifferences.rand_tangent(rng::AbstractRNG, x::Tridiagonal)
return Tridiagonal(rand_tangent(rng, Matrix(x)))
end
@testset "$(TA{T}), n=$n" for TA in (Matrix, Diagonal, Tridiagonal), T in (Float64, ComplexF64), n in (10,)
t, A, v = rand(T), TA(randn(T, n, n)), randn(T, n)
∂t, ∂A, ∂v = rand_tangent(t), rand_tangent(A), rand_tangent(v)
w = expv(t, A, v)
Δw = rand_tangent(w)
rrule_test(expv, Δw, (t, ∂t), (A, ∂A), (v, ∂v))
if TA <: Complex
rrule_test(expv, Δw, (real(t), real(∂t)), (A, ∂A), (v, ∂v))
end
# check type-stable
@inferred expv(t, A, v)
@inferred rrule(expv, t, A, v)
_, back = rrule(expv, t, A, v)
@test_broken @inferred back(Δw)
_, ∂t, ∂A, ∂v = back(Δw)
@inferred unthunk(∂t)
@inferred unthunk(∂A)
@inferred unthunk(∂v)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment