Skip to content

Instantly share code, notes, and snippets.

@sethaxen
Created April 29, 2020 22:32
Show Gist options
  • Save sethaxen/fa67e541c4a2a5e773b475349ed87fb9 to your computer and use it in GitHub Desktop.
Save sethaxen/fa67e541c4a2a5e773b475349ed87fb9 to your computer and use it in GitHub Desktop.
Zygote power series tests
using FiniteDifferences, LinearAlgebra, Zygote, Random, Test
# adapted from ChainRulesTestUtils.rrule_test
function pullback_test(
f,
ȳ,
xx̄s::Tuple{Any,Any}...;
rtol = 1e-9,
atol = 1e-9,
fkwargs = NamedTuple(),
fdm = central_fdm(5, 1),
kwargs...,
)
# Check correctness of evaluation.
xs, x̄s = collect(zip(xx̄s...))
y_ad, pullback = Zygote.pullback((xs...) -> f(xs...; fkwargs...), xs...)
y = f(xs...; fkwargs...)
if y_ad isa Tuple && y isa Tuple
for (yi_ad, yi) in zip(y_ad, y)
@test isapprox(yi_ad, yi; rtol = rtol, atol = atol, kwargs...)
end
else
@test isapprox(y_ad, y; rtol = rtol, atol = atol, kwargs...)
end
x̄s_ad = pullback(ȳ)
# Correctness testing via finite differencing.
return for (i, x̄_ad) in enumerate(x̄s_ad)
if x̄s[i] === nothing
@test x̄_ad === nothing
else
x̄_fd = j′vp(
fdm,
x -> f(xs[1:(i-1)]..., x, xs[(i+1):end]...; fkwargs...),
ȳ,
xs[i],
)[1]
if x̄_ad isa Tuple && x̄_fd isa Tuple
for (x̄i_ad, x̄i_fd) in zip(x̄_ad, x̄_fd)
@test isapprox(x̄i_ad, x̄i_fd; rtol = rtol, atol = atol, kwargs...)
end
else
@test isapprox(x̄_ad, x̄_fd; rtol = rtol, atol = atol, kwargs...)
end
end
end
end
N = 10
rng = MersenneTwister(87)
xr = randn(rng, N, N)
xc = randn(rng, ComplexF64, N, N)
x̄r = randn(rng, N, N)
x̄c = randn(rng, ComplexF64, N, N)
ȳr = randn(rng, N, N)
ȳr2 = randn(rng, N, N)
ȳc = randn(rng, ComplexF64, N, N)
ȳc2 = randn(rng, ComplexF64, N, N)
pint = 2
p̄int = nothing
pr = 0.5
pc = 0.5 + 0.25im
p̄r = randn(rng)
p̄c = randn(rng, ComplexF64)
λr = randn(rng, N)
λc = randn(rng, ComplexF64, N)
λ̄r = randn(rng, N)
λ̄c = randn(rng, ComplexF64, N)
# Do a few quick tests that we expect should pass
pullback_test(*, ȳc, (xc, x̄c), (xc, x̄c))
pullback_test(\, ȳc, (xc, x̄c), (xc, x̄c))
# Check Hermitian and Symmetric
pullback_test(Symmetric, ȳr, (xr, x̄r))
pullback_test(Symmetric, ȳc, (xc, x̄c))
pullback_test(Hermitian, ȳr, (xr, x̄r))
pullback_test(Hermitian, ȳc, (xc, x̄c))
# Check eigvals just because
pullback_test(eigvals ∘ Symmetric, λ̄r, (xr, x̄r))
pullback_test(eigvals ∘ Hermitian, λ̄r, (xr, x̄r))
pullback_test(eigvals ∘ Hermitian, λ̄r, (xc, x̄c))
@testset "$(nameof(f))" for f in (
exp,
log,
cos,
sin,
tan,
cosh,
sinh,
tanh,
acos,
asin,
atan,
acosh,
asinh,
#atanh, # skip because it has special domain requirements, but has no specialized adjoints
sqrt,
)
if Zygote._hasrealdomain(f, eigvals(Symmetric(xr))) # j′vp is real
pullback_test(f ∘ Symmetric, ȳr, (xr, x̄r))
pullback_test(f ∘ Hermitian, ȳr, (xr, x̄r))
else # j′vp is complex, test the real and imaginary parts separately
pullback_test(f ∘ Symmetric ∘ real, ȳc, (xc, x̄r))
pullback_test(f ∘ Symmetric ∘ imag, ȳc, (xc, x̄r))
pullback_test(f ∘ Hermitian ∘ real, ȳc, (xc, x̄r))
pullback_test(f ∘ Hermitian ∘ imag, ȳc, (xc, x̄r))
end
pullback_test(f ∘ Hermitian, ȳc, (xc, x̄c))
end
# integer powers
pullback_test((x, p) -> Symmetric(x)^p, ȳr, (xr, x̄r), (pint, p̄int))
pullback_test((x, p) -> Hermitian(x)^p, ȳr, (xr, x̄r), (pint, p̄int))
pullback_test((x, p) -> Symmetric(x)^p, ȳc, (xc, x̄c), (pint, p̄int))
pullback_test((x, p) -> Hermitian(x)^p, ȳc, (xc, x̄c), (pint, p̄int))
# float powers
# FD cannot handle this case, see https://github.com/JuliaDiff/FiniteDifferences.jl/issues/80
# pullback_test((x, p) -> Symmetric(real(x))^real(p), ȳr, (xc, x̄c), (pc, p̄c))
# pullback_test((x, p) -> Hermitian(real(x))^real(p), ȳr, (xc, x̄c), (pc, p̄c))
pullback_test((x, p) -> Hermitian(x)^real(p), ȳc, (xc, x̄c), (pc, p̄c))
pullback_test((x, p) -> Hermitian(x)^imag(p), ȳc, (xc, x̄c), (pc, p̄c))
# I don't know where the rule is defined for this, but it works too
pullback_test((x, p) -> Hermitian(x)^p, ȳc, (xc, x̄c), (pc, p̄c))
# sincos
pullback_test(sincos ∘ Symmetric, (ȳr, ȳr2), (xr, x̄r))
pullback_test(sincos ∘ Hermitian, (ȳr, ȳr2), (xr, x̄r))
pullback_test(sincos ∘ Hermitian, (ȳc, ȳc2), (xc, x̄c))
# TODO: check cases when eigenvalues are similar
# TODO: check low-rank cases
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment