Last active
December 26, 2020 20:59
-
-
Save sethaxen/4071b401b9b4ff4f5421136cec2fa7da to your computer and use it in GitHub Desktop.
Chain rules for the action of the matrix exponential
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# rrule for expv(t, A, v) = exp(t * A) * v | |
# Use OrdinaryDiffEq to solve adjoint system for pullback | |
# Since OrdinaryDiffEq depends on ExponentialUtilities, it doesn't make sense | |
# to include this code there | |
using Pkg | |
Pkg.add(["ChainRulesCore", "ChainRulesTestUtils", "ExponentialUtilities", "FiniteDifferences", "OrdinaryDiffEq", "Test", "LinearAlgebra"]) | |
using ChainRulesCore, ChainRulesTestUtils, FiniteDifferences, OrdinaryDiffEq, Test, LinearAlgebra, Random | |
using FiniteDifferences: rand_tangent | |
using ExponentialUtilities: ExponentialUtilities | |
function expv!(w, t, A, v) | |
ExponentialUtilities.expv!(w, t, ExponentialUtilities.arnoldi(A, v)) | |
return w | |
end | |
function expv!(w, t, A::Diagonal, v) | |
w .= exp.(t .* A.diag) .* v | |
return w | |
end | |
function expv(t, A, v; kwargs...) | |
w = similar(v, Base.promote_eltypeof(t, A, v)) | |
isa(A, AbstractMatrix) && isdiag(A) && (expv!(w, t, Diagonal(A), v); return w) | |
expv!(w, t, A, v) | |
return w | |
end | |
function ChainRulesCore.rrule( | |
::typeof(expv), | |
t, | |
A, | |
v; | |
adjoint_solver = Tsit5(), | |
adjoint_solve_kwargs = (; abstol=1e-8, reltol=1e-8), | |
kwargs... | |
) | |
w = expv(t, A, v; kwargs...) | |
function expv_pullback(Δw) | |
∂t = @thunk expv_rev_t(t, A, w, Δw) | |
∂A = @thunk expv_rev_A(t, A, w, Δw, adjoint_solver, adjoint_solve_kwargs) | |
# NOTE: ∂v is computed as part of ∂A, but because ∂A is much more expensive, we | |
# recompute here in case the user doesn't need ∂A | |
∂v = @thunk expv_rev_v(t, A, w, Δw; kwargs...) | |
return NO_FIELDS, ∂t, ∂A, ∂v | |
end | |
return w, expv_pullback | |
end | |
expv_rev_t(t, A, w, Δw) = conj(mayberealify!(t, dot(Δw, A, w))) | |
expv_rev_v(t, A, w, Δw; kwargs...) = expv(conj(t), A', Δw; kwargs...) | |
expv_rev_A(t, A::Diagonal, w, Δw, args...) = outer!(similar(A), Δw, w, conj(t), false) | |
function expv_rev_A(t, A, w, Δw, adjoint_solver, adjoint_solve_kwargs) | |
∂A = similar(A) | |
if isdiag(A) | |
copyto!(∂A, expv_rev_A(t, Diagonal(A), w, Δw)) | |
return ∂A | |
end | |
# solve system backwards, augmented to evolve adjoints to parameters | |
# based on Algorithm 1 of https://arxiv.org/abs/1806.07366, though the approach | |
# is older | |
n = length(w) | |
u0 = Matrix{Base.promote_eltypeof(w, Δw, A)}(undef, n, n + 2) | |
u0[:, 1] .= w | |
u0[:, 2] .= Δw | |
fill!(@view(u0[:, 3:n+2]), false) | |
solver_kwargs = merge(adjoint_solve_kwargs, (; save_everystep=false)) | |
# include time as parameter, which allows us to handle t::Complex | |
Tt = real(typeof(t)) | |
tspan = (one(Tt), zero(Tt)) | |
params = (A, n, t) | |
problem = ODEProblem(f_expv_rev_A!, u0, tspan, params) | |
u1 = last(solve(problem, adjoint_solver; solver_kwargs...)) | |
# preserve sparsity pattern of A | |
# NOTE: inspect efficiency | |
broadcast!(∂A, A, @view(u1[:, 3:n+2])) do Aij, ∂Aij | |
mayberealify!(Aij, conj(t) * ifelse(iszero(Aij), zero(∂Aij), ∂Aij)) | |
end | |
return ∂A | |
end | |
function f_expv_rev_A!(du, u, (A, n, c), t) | |
z, a = @views u[:, 1], u[:, 2] | |
dz, da, dμ = @views du[:, 1], du[:, 2], du[:, 3:n+2] | |
mul!(dz, A, z, c, false) | |
mul!(da, A', a, -conj(c), false) | |
outer!(dμ, a, z, -1, false) | |
end | |
# in-place outer product Z = x α y' + Z β, for vectors x,y and scalars α,β, | |
# respecting sparsity pattern of Z and avoiding unnecessary computation if possible | |
function outer!(Z, x, y, α, β) | |
broadcast!(Z, Z, x, y', B) do Zij, xi, yj | |
β * Zij + α * ifelse(iszero(Zij), zero(Base.promote_typeof(xi, yj)), xi * yj) | |
end | |
return Z | |
end | |
outer!(Z::StridedMatrix, x, y, α, β) = mul!(Z, x, y', α, β) | |
function outer!(Z::Diagonal, x, y, α, β) | |
Z.diag .= Z.diag .* β .+ x .* α .* conj.(y) | |
return Z | |
end | |
# this is essentially the pullback of `promote` for `Real` and `Complex`, | |
# i.e. a projection to ℝ if the primal was in ℝ | |
mayberealify!(x::Real, y) = real(y) | |
mayberealify!(x::Number, y) = y | |
function mayberealify!(x, y) | |
if eltype(x) <: Real && !(eltype(y) <: Real) | |
x .= real.(y) | |
else | |
copyto!(x, y) | |
end | |
return x | |
end | |
# test the rule using finite differences | |
# WARNING: type-piracy | |
function FiniteDifferences.to_vec(x::Tridiagonal) | |
xvec, back = to_vec(Matrix(x)) | |
Tridiagonal_back(xvec) = Tridiagonal(back(xvec)) | |
return xvec, Tridiagonal_back | |
end | |
function FiniteDifferences.rand_tangent(rng::AbstractRNG, x::Tridiagonal) | |
return Tridiagonal(rand_tangent(rng, Matrix(x))) | |
end | |
@testset "$(TA{T}), n=$n" for TA in (Matrix, Diagonal, Tridiagonal), T in (Float64, ComplexF64), n in (10,) | |
t, A, v = rand(T), TA(randn(T, n, n)), randn(T, n) | |
∂t, ∂A, ∂v = rand_tangent(t), rand_tangent(A), rand_tangent(v) | |
w = expv(t, A, v) | |
Δw = rand_tangent(w) | |
rrule_test(expv, Δw, (t, ∂t), (A, ∂A), (v, ∂v)) | |
if TA <: Complex | |
rrule_test(expv, Δw, (real(t), real(∂t)), (A, ∂A), (v, ∂v)) | |
end | |
# check type-stable | |
@inferred expv(t, A, v) | |
@inferred rrule(expv, t, A, v) | |
_, back = rrule(expv, t, A, v) | |
@test_broken @inferred back(Δw) | |
_, ∂t, ∂A, ∂v = back(Δw) | |
@inferred unthunk(∂t) | |
@inferred unthunk(∂A) | |
@inferred unthunk(∂v) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment