Created
September 2, 2019 18:34
-
-
Save simeonschaub/a6dfcd71336d863b3777093b3b8d9c97 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using ChainRulesCore | |
import ChainRulesCore: wirtinger_primal, wirtinger_conjugate | |
using ChainRulesCore: AbstractDifferential | |
using FiniteDifferences | |
using Test | |
abs_to_pow(x, p) = abs(x)^p | |
@scalar_rule(abs_to_pow(x::Real, p), | |
(p == 0 ? Zero() : p * abs_to_pow(x, p-1) * sign(x), Ω * log(abs(x)))) | |
@scalar_rule( | |
abs_to_pow(x::Complex, p), | |
@setup(u = abs(x)), | |
( | |
p == 0 ? Zero() : p * u^(p-1) * Wirtinger(x' / 2u, x / 2u), | |
Ω * log(abs(x)) | |
) | |
) | |
function test_scalar(f, x, df; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), | |
test_wirtinger=x isa Complex || extern(df) isa Wirtinger, kwargs...) | |
# Check that we get the derivative right: | |
if !test_wirtinger | |
@test isapprox( | |
extern(df), fdm(f, x); | |
rtol=rtol, atol=atol, kwargs... | |
) | |
else | |
# For complex arguments, also check if the wirtinger derivative is correct | |
∂Re = fdm(ϵ -> f(x + ϵ), 0) | |
∂Im = fdm(ϵ -> f(x + im*ϵ), 0) | |
∂ = 0.5(∂Re - im*∂Im) | |
∂̅ = 0.5(∂Re + im*∂Im) | |
@test isapprox( | |
wirtinger_primal(df), ∂; | |
rtol=rtol, atol=atol, kwargs... | |
) | |
@test isapprox( | |
wirtinger_conjugate(df), ∂̅; | |
rtol=rtol, atol=atol, kwargs... | |
) | |
end | |
end | |
function Base.isapprox(ad::Wirtinger, fd; kwargs...) | |
error("Finite differencing with Wirtinger rules not implemented") | |
end | |
function Base.isapprox(d_ad::Casted, d_fd; kwargs...) | |
return all(isapprox.(extern(d_ad), d_fd; kwargs...)) | |
end | |
function Base.isapprox(d_ad::DNE, d_fd; kwargs...) | |
error("Tried to differentiate w.r.t. a DNE") | |
end | |
function Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...) | |
return isapprox(extern(d_ad), d_fd; kwargs...) | |
end | |
# TODO: PR in ChainRulesCore | |
wirtinger_primal(x::Union{Casted,Thunk}) = wirtinger_primal(extern(x)) | |
wirtinger_conjugate(x::Union{Casted,Thunk}) = wirtinger_conjugate(extern(x)) | |
const f = abs_to_pow | |
@testset "abs_to_pow" begin | |
@testset "f($(args[1]), $(args[2]))" for args in | |
Iterators.product((2, 3.4, -2.1, -10+0im, 2.3-2im), (0, 1, 2, 4.3, -2.1, 1+.2im)) | |
res = frule(f, args...) | |
@test res !== nothing # Check the rule was defined | |
fx, df = res | |
@test fx == f(args...) # Check we still get the normal value, right | |
test_scalar(ϵ -> f(args[1] + ϵ, args[2]), 0, df(One(), Zero())) | |
test_scalar(ϵ -> f(args[1], args[2] + ϵ), 0, df(Zero(), One())) | |
res = rrule(f, args...) | |
@test res !== nothing # Check the rule was defined | |
fx, df = res | |
@test fx == f(args...) # Check we still get the normal value, right | |
test_scalar(ϵ -> f(args[1] + ϵ, args[2]), 0, df[1](One())) | |
test_scalar(ϵ -> f(args[1], args[2] + ϵ), 0, df[2](One())) | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Probably also should include a test whether
!isa(extern(df_dp), Wirtinger)
is always true.