NLsolve ChainRules implicit differentiation
using NLsolve
using Zygote
using ChainRulesCore
using IterativeSolvers
using LinearMaps
using SparseArrays
using LinearAlgebra
using BenchmarkTools
using Random
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(nlsolve), f, x0; kwargs...)
result = nlsolve(f, x0; kwargs...)
function nlsolve_pullback(Δresult)
Δx = Δ
x =
_, f_pullback = rrule_via_ad(config, f, x)
JT(v) = f_pullback(v)[2] # w.r.t. x
# solve JT*Δfx = -Δx
L = LinearMap(JT, length(x0))
Δfx = gmres(L, -Δx)
∂f = f_pullback(Δfx)[1] # w.r.t. f itself (implicitly closed-over variables)
return (NoTangent(), ∂f, ZeroTangent())
return result, nlsolve_pullback
const N = 10000
const nonlin = 0.1
const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1))
const p0 = randn(N)
h(x, p) = A*x + nonlin*x.^2 - p
solve_x(p) = nlsolve(x -> h(x, p), zeros(N), method=:anderson, m=10).zero
obj(p) = sum(solve_x(p))
# need an rrule for h as Zygote otherwise densifies the sparse matrix A
function ChainRulesCore.rrule(::typeof(h), x, p)
y = h(x, p)
function my_h_pullback(ȳ)
∂x = @thunk(A'ȳ + 2nonlin*x.*ȳ)
∂p = @thunk(-ȳ)
return (NoTangent(), ∂x, ∂p)
return y, my_h_pullback
g_auto = Zygote.gradient(obj, p0)[1]
g_analytic = gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N))
@show sum(abs, g_auto - g_analytic) / N # 7.613631947123168e-17
@btime Zygote.gradient(obj, p0); # 11.730 ms (784 allocations: 19.87 MiB)
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N)); # 11.409 ms (626 allocations: 17.50 MiB)
import Pkg; Pkg.status()
# Status `/tmp/nlsolve/Project.toml`
# [6e4b80f9] BenchmarkTools v1.2.2
# [d360d2e6] ChainRulesCore v1.11.6
# [42fd0dbc] IterativeSolvers v0.9.2
# [7a12625a] LinearMaps v3.5.1
# [2774e3e8] NLsolve v4.5.1
# [e88e6eb3] Zygote v0.6.34
# [37e2e46d] LinearAlgebra
# [9a3f8284] Random
# [2f01184e] SparseArrays
using NLsolve
using Zygote
using ChainRulesCore
using SparseArrays
using LinearAlgebra
using Random
using IterativeSolvers
using LinearMaps
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(nlsolve), f, x0; kwargs...)
result = nlsolve(f, x0; kwargs...)
function nlsolve_pullback(Δresult)
Δx = Δ
x =
_, f_pullback = rrule_via_ad(config, f, x)
JT(v) = f_pullback(v)[2] # w.r.t. x
# solve JT*Δfx = -Δx
Δfx = nlsolve(v -> JT(v) + Δx, zero(x); kwargs...).zero
∂f = f_pullback(Δfx)[1] # w.r.t. f itself (implicitly closed-over variables)
return (NoTangent(), ∂f, ZeroTangent())
return result, nlsolve_pullback
const N = 100
const nonlin = 0.1
const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1))
const p0 = randn(N)
h(x, p) = A*x + nonlin*x.^2 - p
solve_x(p) = nlsolve(x -> h(x, p), zeros(N), method=:anderson, m=10, show_trace=true).zero
obj(p) = sum(solve_x(p))
# need an rrule for h as Zygote otherwise densifies the sparse matrix A
function ChainRulesCore.rrule(::typeof(h), x, p)
y = h(x, p)
function my_h_pullback(ȳ)
∂x = @thunk(A'ȳ + 2nonlin*x.*ȳ)
∂p = @thunk(-ȳ)
return (NoTangent(), ∂x, ∂p)
return y, my_h_pullback
g_auto = Zygote.gradient(obj, p0)[1]
g_analytic = gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N))
@show sum(abs, g_auto - g_analytic) / N # 8.878502030795765e-11
Thank you @niklasschmitz for the update. It works really well on Julia 1.7.1.

Cool! Where's the code? I couldn't find it in NonlinearSolve.jl (which doesn't even depend on ChainRulesCore)

Copy link it uses the DiffEqSensitivity machinery (which should be renamed SciMLSensitivity.jl at this point, but I digress) to get all of the AD compatibility without Requires (it throws an error mentioning this if you differentiate without it). It has heuristics for switching between Jacobian-based and Jacobian-free based on size, does the same DiffEqSensitivity thing of Zygote/Enzyme/etc. for VJPs (though it needs to be improved for this case), etc.

