Skip to content

Instantly share code, notes, and snippets.

@niklasschmitz
Last active February 9, 2022 07:32
Show Gist options
  • Save niklasschmitz/b00223b9e9ba2a37ed09539a264bf423 to your computer and use it in GitHub Desktop.
Save niklasschmitz/b00223b9e9ba2a37ed09539a264bf423 to your computer and use it in GitHub Desktop.
NLsolve ChainRules implicit differentiation
using NLsolve
using Zygote
using ChainRulesCore
using IterativeSolvers
using LinearMaps
using SparseArrays
using LinearAlgebra
using BenchmarkTools
using Random
Random.seed!(1234)
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(nlsolve), f, x0; kwargs...)
result = nlsolve(f, x0; kwargs...)
function nlsolve_pullback(Δresult)
Δx = Δresult.zero
x = result.zero
_, f_pullback = rrule_via_ad(config, f, x)
JT(v) = f_pullback(v)[2] # w.r.t. x
# solve JT*Δfx = -Δx
L = LinearMap(JT, length(x0))
Δfx = gmres(L, -Δx)
∂f = f_pullback(Δfx)[1] # w.r.t. f itself (implicitly closed-over variables)
return (NoTangent(), ∂f, ZeroTangent())
end
return result, nlsolve_pullback
end
const N = 10000
const nonlin = 0.1
const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1))
const p0 = randn(N)
h(x, p) = A*x + nonlin*x.^2 - p
solve_x(p) = nlsolve(x -> h(x, p), zeros(N), method=:anderson, m=10).zero
obj(p) = sum(solve_x(p))
# need an rrule for h as Zygote otherwise densifies the sparse matrix A
# https://github.com/FluxML/Zygote.jl/issues/931
function ChainRulesCore.rrule(::typeof(h), x, p)
y = h(x, p)
function my_h_pullback(ȳ)
∂x = @thunk(A'ȳ + 2nonlin*x.*ȳ)
∂p = @thunk(-ȳ)
return (NoTangent(), ∂x, ∂p)
end
return y, my_h_pullback
end
g_auto = Zygote.gradient(obj, p0)[1]
g_analytic = gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N))
display(g_auto)
display(g_analytic)
@show sum(abs, g_auto - g_analytic) / N # 7.613631947123168e-17
@btime Zygote.gradient(obj, p0); # 11.730 ms (784 allocations: 19.87 MiB)
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N)); # 11.409 ms (626 allocations: 17.50 MiB)
import Pkg; Pkg.status()
# Status `/tmp/nlsolve/Project.toml`
# [6e4b80f9] BenchmarkTools v1.2.2
# [d360d2e6] ChainRulesCore v1.11.6
# [42fd0dbc] IterativeSolvers v0.9.2
# [7a12625a] LinearMaps v3.5.1
# [2774e3e8] NLsolve v4.5.1
# [e88e6eb3] Zygote v0.6.34
# [37e2e46d] LinearAlgebra
# [9a3f8284] Random
# [2f01184e] SparseArrays
using NLsolve
using Zygote
using ChainRulesCore
using SparseArrays
using LinearAlgebra
using Random
Random.seed!(1234)
using IterativeSolvers
using LinearMaps
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(nlsolve), f, x0; kwargs...)
result = nlsolve(f, x0; kwargs...)
function nlsolve_pullback(Δresult)
Δx = Δresult.zero
x = result.zero
_, f_pullback = rrule_via_ad(config, f, x)
JT(v) = f_pullback(v)[2] # w.r.t. x
# solve JT*Δfx = -Δx
Δfx = nlsolve(v -> JT(v) + Δx, zero(x); kwargs...).zero
∂f = f_pullback(Δfx)[1] # w.r.t. f itself (implicitly closed-over variables)
return (NoTangent(), ∂f, ZeroTangent())
end
return result, nlsolve_pullback
end
const N = 100
const nonlin = 0.1
const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1))
const p0 = randn(N)
h(x, p) = A*x + nonlin*x.^2 - p
solve_x(p) = nlsolve(x -> h(x, p), zeros(N), method=:anderson, m=10, show_trace=true).zero
obj(p) = sum(solve_x(p))
# need an rrule for h as Zygote otherwise densifies the sparse matrix A
# https://github.com/FluxML/Zygote.jl/issues/931
function ChainRulesCore.rrule(::typeof(h), x, p)
y = h(x, p)
function my_h_pullback(ȳ)
∂x = @thunk(A'ȳ + 2nonlin*x.*ȳ)
∂p = @thunk(-ȳ)
return (NoTangent(), ∂x, ∂p)
end
return y, my_h_pullback
end
g_auto = Zygote.gradient(obj, p0)[1]
g_analytic = gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N))
display(g_auto)
display(g_analytic)
@show sum(abs, g_auto - g_analytic) / N # 8.878502030795765e-11
@JulienPascal
Copy link

Thank you @niklasschmitz for the update. It works really well on Julia 1.7.1.

@ChrisRackauckas
Copy link

@antoine-levitt
Copy link

Cool! Where's the code? I couldn't find it in NonlinearSolve.jl (which doesn't even depend on ChainRulesCore)

@ChrisRackauckas
Copy link

https://github.com/SciML/DiffEqSensitivity.jl/blob/master/src/steadystate_adjoint.jl it uses the DiffEqSensitivity machinery (which should be renamed SciMLSensitivity.jl at this point, but I digress) to get all of the AD compatibility without Requires (it throws an error mentioning this if you differentiate without it). It has heuristics for switching between Jacobian-based and Jacobian-free based on size, does the same DiffEqSensitivity thing of Zygote/Enzyme/etc. for VJPs (though it needs to be improved for this case), etc.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment