-
-
Save under-Peter/d2318d937304950702d91e787305d0b9 to your computer and use it in GitHub Desktop.
Sketch of how a general fixedpoint-method interface with a custom adjoint could be implemented to get more efficient gradients.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Zygote, IterTools, LinearAlgebra | |
using BenchmarkTools | |
function fixedpoint(f, guess, n, stopfun) | |
# force at least one iteration | |
for v in Iterators.drop(IterTools.iterated(x -> f(x,n), guess),1) | |
stopfun(v) && return v | |
end | |
end | |
struct StopFunction{T,S} | |
o::Ref{T} | |
tol::S | |
end | |
(st::StopFunction)(v) = norm(v - st.o[]) < st.tol ? true : (st.o[] = v; false) | |
function fixedpointbackward(next, r, n) | |
@assert next(r,n) ≈ r || next(r,n) ≈ -r # to work with eigenvectors, should be more general | |
_, back = Zygote.forward(next,r,n) | |
back1 = x -> back(x)[1] | |
back2 = x -> back(x)[2] | |
function backΔ(Δ) # as in 'Differential Programming Tensor Networks', arXiv 1903.09650 | |
grad = back2(Δ) | |
for g in IterTools.imap(back2,Iterators.drop(IterTools.iterated(back1, Δ),1)) | |
grad += g | |
norm(g) < 1e-12 && break | |
end | |
grad | |
end | |
return backΔ | |
end | |
fixedpointAD(f, g, n, sf) = fixedpoint(f, g, n ,sf) # define adjoint for another function to be able to compare with/without custom adjoint | |
@Zygote.adjoint function fixedpointAD(f, guess, n, stopfun) | |
r = fixedpoint(f, guess, n, stopfun) | |
return r, Δ -> (nothing, nothing, fixedpointbackward(f, r, n)(Δ), nothing) | |
end | |
# squareroot with fixed=point iteration | |
next(guess, n) = 1/2*(guess + n/guess) | |
stopfun = StopFunction(Ref(Inf), 1e-9) | |
fixedpoint(next, 9, 9, stopfun) ≈ sqrt(9) | |
fixedpointbackward(next, 3, 9)(1) ≈ Zygote.gradient(fixedpoint, next, 3.0001, 9, stopfun)[3] | |
Zygote.gradient(fixedpoint, next, 3, 9, stopfun)[3] | |
Zygote.gradient(fixedpointAD, next, 3, 9, stopfun)[3] | |
@btime Zygote.gradient(fixedpoint, next, 3, 9, stopfun)[3] # 36.883 μs (215 allocations: 7.42 KiB) | |
@btime Zygote.gradient(fixedpointAD, next, 3, 9, stopfun)[3] # 4.405 μs (30 allocations: 768 bytes) | |
# Check that it works | |
all(fixedpointbackward(next, sqrt(n), n)(1) ≈ 1/(2sqrt(n)) for n in abs.(randn(10))) | |
# Works for scalars! | |
# getting an eigenvector with the power-method | |
# adjust stopfun since eigenvectors are teh same up to global sign | |
(st::StopFunction)(v) = norm(v - st.o[]) < st.tol || norm(v + st.o[]) < st.tol ? true : (st.o[] = v; false) | |
mynormalize(x) = x ./ norm(x) | |
function next(x,m) | |
y = mynormalize(m * x) | |
y *= sign(y[1]) | |
y | |
end | |
stopfun = StopFunction(Ref([Inf,Inf]), 1e-9) | |
m = randn(2,2) | |
evs, evecs = eigen(m) | |
evec = evecs[:, argmax(abs.(evs))] | |
evecfp = fixedpoint(next, rand(2), m, stopfun) | |
evecfp ≈ evec || evecfp ≈ -evec | |
gradzyg = Zygote.gradient(sum ∘ fixedpoint,(x,m) -> mynormalize(m * x), rand(2), m, StopFunction(Ref([Inf,Inf]), 1e-14))[3] | |
gradfp = Zygote.gradient(sum ∘ fixedpointAD,(x,m) -> mynormalize(m * x), rand(2), m, StopFunction(Ref([Inf,Inf]), 1e-14))[3] | |
gradzyg ≈ gradfp | |
@btime Zygote.gradient(sum ∘ fixedpoint,(x,m) -> mynormalize(m * x), $(rand(2)), $m, $(StopFunction(Ref([Inf,Inf]), 1e-9))) | |
# 673.281 μs (4817 allocations: 174.05 KiB) | |
@btime Zygote.gradient(sum ∘ fixedpointAD,(x,m) -> mynormalize(m * x), $(rand(2)), $m, $(StopFunction(Ref([Inf,Inf]), 1e-9))) | |
# 200.665 μs (2839 allocations: 106.69 KiB) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment