Skip to content

Instantly share code, notes, and snippets.

@under-Peter
Created July 5, 2019 08:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save under-Peter/d2318d937304950702d91e787305d0b9 to your computer and use it in GitHub Desktop.
Save under-Peter/d2318d937304950702d91e787305d0b9 to your computer and use it in GitHub Desktop.
Sketch of how a general fixedpoint-method interface with a custom adjoint could be implemented to get more efficient gradients.
using Zygote, IterTools, LinearAlgebra
using BenchmarkTools
function fixedpoint(f, guess, n, stopfun)
# force at least one iteration
for v in Iterators.drop(IterTools.iterated(x -> f(x,n), guess),1)
stopfun(v) && return v
end
end
struct StopFunction{T,S}
o::Ref{T}
tol::S
end
(st::StopFunction)(v) = norm(v - st.o[]) < st.tol ? true : (st.o[] = v; false)
function fixedpointbackward(next, r, n)
@assert next(r,n) r || next(r,n) -r # to work with eigenvectors, should be more general
_, back = Zygote.forward(next,r,n)
back1 = x -> back(x)[1]
back2 = x -> back(x)[2]
function backΔ(Δ) # as in 'Differential Programming Tensor Networks', arXiv 1903.09650
grad = back2(Δ)
for g in IterTools.imap(back2,Iterators.drop(IterTools.iterated(back1, Δ),1))
grad += g
norm(g) < 1e-12 && break
end
grad
end
return backΔ
end
fixedpointAD(f, g, n, sf) = fixedpoint(f, g, n ,sf) # define adjoint for another function to be able to compare with/without custom adjoint
@Zygote.adjoint function fixedpointAD(f, guess, n, stopfun)
r = fixedpoint(f, guess, n, stopfun)
return r, Δ -> (nothing, nothing, fixedpointbackward(f, r, n)(Δ), nothing)
end
# squareroot with fixed=point iteration
next(guess, n) = 1/2*(guess + n/guess)
stopfun = StopFunction(Ref(Inf), 1e-9)
fixedpoint(next, 9, 9, stopfun) sqrt(9)
fixedpointbackward(next, 3, 9)(1) Zygote.gradient(fixedpoint, next, 3.0001, 9, stopfun)[3]
Zygote.gradient(fixedpoint, next, 3, 9, stopfun)[3]
Zygote.gradient(fixedpointAD, next, 3, 9, stopfun)[3]
@btime Zygote.gradient(fixedpoint, next, 3, 9, stopfun)[3] # 36.883 μs (215 allocations: 7.42 KiB)
@btime Zygote.gradient(fixedpointAD, next, 3, 9, stopfun)[3] # 4.405 μs (30 allocations: 768 bytes)
# Check that it works
all(fixedpointbackward(next, sqrt(n), n)(1) 1/(2sqrt(n)) for n in abs.(randn(10)))
# Works for scalars!
# getting an eigenvector with the power-method
# adjust stopfun since eigenvectors are teh same up to global sign
(st::StopFunction)(v) = norm(v - st.o[]) < st.tol || norm(v + st.o[]) < st.tol ? true : (st.o[] = v; false)
mynormalize(x) = x ./ norm(x)
function next(x,m)
y = mynormalize(m * x)
y *= sign(y[1])
y
end
stopfun = StopFunction(Ref([Inf,Inf]), 1e-9)
m = randn(2,2)
evs, evecs = eigen(m)
evec = evecs[:, argmax(abs.(evs))]
evecfp = fixedpoint(next, rand(2), m, stopfun)
evecfp evec || evecfp -evec
gradzyg = Zygote.gradient(sum fixedpoint,(x,m) -> mynormalize(m * x), rand(2), m, StopFunction(Ref([Inf,Inf]), 1e-14))[3]
gradfp = Zygote.gradient(sum fixedpointAD,(x,m) -> mynormalize(m * x), rand(2), m, StopFunction(Ref([Inf,Inf]), 1e-14))[3]
gradzyg gradfp
@btime Zygote.gradient(sum fixedpoint,(x,m) -> mynormalize(m * x), $(rand(2)), $m, $(StopFunction(Ref([Inf,Inf]), 1e-9)))
# 673.281 μs (4817 allocations: 174.05 KiB)
@btime Zygote.gradient(sum fixedpointAD,(x,m) -> mynormalize(m * x), $(rand(2)), $m, $(StopFunction(Ref([Inf,Inf]), 1e-9)))
# 200.665 μs (2839 allocations: 106.69 KiB)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment