Created
April 29, 2020 22:32
-
-
Save sethaxen/fa67e541c4a2a5e773b475349ed87fb9 to your computer and use it in GitHub Desktop.
Zygote power series tests
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using FiniteDifferences, LinearAlgebra, Zygote, Random, Test | |
# adapted from ChainRulesTestUtils.rrule_test | |
function pullback_test( | |
f, | |
ȳ, | |
xx̄s::Tuple{Any,Any}...; | |
rtol = 1e-9, | |
atol = 1e-9, | |
fkwargs = NamedTuple(), | |
fdm = central_fdm(5, 1), | |
kwargs..., | |
) | |
# Check correctness of evaluation. | |
xs, x̄s = collect(zip(xx̄s...)) | |
y_ad, pullback = Zygote.pullback((xs...) -> f(xs...; fkwargs...), xs...) | |
y = f(xs...; fkwargs...) | |
if y_ad isa Tuple && y isa Tuple | |
for (yi_ad, yi) in zip(y_ad, y) | |
@test isapprox(yi_ad, yi; rtol = rtol, atol = atol, kwargs...) | |
end | |
else | |
@test isapprox(y_ad, y; rtol = rtol, atol = atol, kwargs...) | |
end | |
x̄s_ad = pullback(ȳ) | |
# Correctness testing via finite differencing. | |
return for (i, x̄_ad) in enumerate(x̄s_ad) | |
if x̄s[i] === nothing | |
@test x̄_ad === nothing | |
else | |
x̄_fd = j′vp( | |
fdm, | |
x -> f(xs[1:(i-1)]..., x, xs[(i+1):end]...; fkwargs...), | |
ȳ, | |
xs[i], | |
)[1] | |
if x̄_ad isa Tuple && x̄_fd isa Tuple | |
for (x̄i_ad, x̄i_fd) in zip(x̄_ad, x̄_fd) | |
@test isapprox(x̄i_ad, x̄i_fd; rtol = rtol, atol = atol, kwargs...) | |
end | |
else | |
@test isapprox(x̄_ad, x̄_fd; rtol = rtol, atol = atol, kwargs...) | |
end | |
end | |
end | |
end | |
N = 10 | |
rng = MersenneTwister(87) | |
xr = randn(rng, N, N) | |
xc = randn(rng, ComplexF64, N, N) | |
x̄r = randn(rng, N, N) | |
x̄c = randn(rng, ComplexF64, N, N) | |
ȳr = randn(rng, N, N) | |
ȳr2 = randn(rng, N, N) | |
ȳc = randn(rng, ComplexF64, N, N) | |
ȳc2 = randn(rng, ComplexF64, N, N) | |
pint = 2 | |
p̄int = nothing | |
pr = 0.5 | |
pc = 0.5 + 0.25im | |
p̄r = randn(rng) | |
p̄c = randn(rng, ComplexF64) | |
λr = randn(rng, N) | |
λc = randn(rng, ComplexF64, N) | |
λ̄r = randn(rng, N) | |
λ̄c = randn(rng, ComplexF64, N) | |
# Do a few quick tests that we expect should pass | |
pullback_test(*, ȳc, (xc, x̄c), (xc, x̄c)) | |
pullback_test(\, ȳc, (xc, x̄c), (xc, x̄c)) | |
# Check Hermitian and Symmetric | |
pullback_test(Symmetric, ȳr, (xr, x̄r)) | |
pullback_test(Symmetric, ȳc, (xc, x̄c)) | |
pullback_test(Hermitian, ȳr, (xr, x̄r)) | |
pullback_test(Hermitian, ȳc, (xc, x̄c)) | |
# Check eigvals just because | |
pullback_test(eigvals ∘ Symmetric, λ̄r, (xr, x̄r)) | |
pullback_test(eigvals ∘ Hermitian, λ̄r, (xr, x̄r)) | |
pullback_test(eigvals ∘ Hermitian, λ̄r, (xc, x̄c)) | |
@testset "$(nameof(f))" for f in ( | |
exp, | |
log, | |
cos, | |
sin, | |
tan, | |
cosh, | |
sinh, | |
tanh, | |
acos, | |
asin, | |
atan, | |
acosh, | |
asinh, | |
#atanh, # skip because it has special domain requirements, but has no specialized adjoints | |
sqrt, | |
) | |
if Zygote._hasrealdomain(f, eigvals(Symmetric(xr))) # j′vp is real | |
pullback_test(f ∘ Symmetric, ȳr, (xr, x̄r)) | |
pullback_test(f ∘ Hermitian, ȳr, (xr, x̄r)) | |
else # j′vp is complex, test the real and imaginary parts separately | |
pullback_test(f ∘ Symmetric ∘ real, ȳc, (xc, x̄r)) | |
pullback_test(f ∘ Symmetric ∘ imag, ȳc, (xc, x̄r)) | |
pullback_test(f ∘ Hermitian ∘ real, ȳc, (xc, x̄r)) | |
pullback_test(f ∘ Hermitian ∘ imag, ȳc, (xc, x̄r)) | |
end | |
pullback_test(f ∘ Hermitian, ȳc, (xc, x̄c)) | |
end | |
# integer powers | |
pullback_test((x, p) -> Symmetric(x)^p, ȳr, (xr, x̄r), (pint, p̄int)) | |
pullback_test((x, p) -> Hermitian(x)^p, ȳr, (xr, x̄r), (pint, p̄int)) | |
pullback_test((x, p) -> Symmetric(x)^p, ȳc, (xc, x̄c), (pint, p̄int)) | |
pullback_test((x, p) -> Hermitian(x)^p, ȳc, (xc, x̄c), (pint, p̄int)) | |
# float powers | |
# FD cannot handle this case, see https://github.com/JuliaDiff/FiniteDifferences.jl/issues/80 | |
# pullback_test((x, p) -> Symmetric(real(x))^real(p), ȳr, (xc, x̄c), (pc, p̄c)) | |
# pullback_test((x, p) -> Hermitian(real(x))^real(p), ȳr, (xc, x̄c), (pc, p̄c)) | |
pullback_test((x, p) -> Hermitian(x)^real(p), ȳc, (xc, x̄c), (pc, p̄c)) | |
pullback_test((x, p) -> Hermitian(x)^imag(p), ȳc, (xc, x̄c), (pc, p̄c)) | |
# I don't know where the rule is defined for this, but it works too | |
pullback_test((x, p) -> Hermitian(x)^p, ȳc, (xc, x̄c), (pc, p̄c)) | |
# sincos | |
pullback_test(sincos ∘ Symmetric, (ȳr, ȳr2), (xr, x̄r)) | |
pullback_test(sincos ∘ Hermitian, (ȳr, ȳr2), (xr, x̄r)) | |
pullback_test(sincos ∘ Hermitian, (ȳc, ȳc2), (xc, x̄c)) | |
# TODO: check cases when eigenvalues are similar | |
# TODO: check low-rank cases |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment