Skip to content

Instantly share code, notes, and snippets.

@simeonschaub
Created September 2, 2019 18:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save simeonschaub/a6dfcd71336d863b3777093b3b8d9c97 to your computer and use it in GitHub Desktop.
Save simeonschaub/a6dfcd71336d863b3777093b3b8d9c97 to your computer and use it in GitHub Desktop.
using ChainRulesCore
import ChainRulesCore: wirtinger_primal, wirtinger_conjugate
using ChainRulesCore: AbstractDifferential
using FiniteDifferences
using Test
abs_to_pow(x, p) = abs(x)^p
@scalar_rule(abs_to_pow(x::Real, p),
(p == 0 ? Zero() : p * abs_to_pow(x, p-1) * sign(x), Ω * log(abs(x))))
@scalar_rule(
abs_to_pow(x::Complex, p),
@setup(u = abs(x)),
(
p == 0 ? Zero() : p * u^(p-1) * Wirtinger(x' / 2u, x / 2u),
Ω * log(abs(x))
)
)
function test_scalar(f, x, df; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1),
test_wirtinger=x isa Complex || extern(df) isa Wirtinger, kwargs...)
# Check that we get the derivative right:
if !test_wirtinger
@test isapprox(
extern(df), fdm(f, x);
rtol=rtol, atol=atol, kwargs...
)
else
# For complex arguments, also check if the wirtinger derivative is correct
∂Re = fdm(ϵ -> f(x + ϵ), 0)
∂Im = fdm(ϵ -> f(x + im*ϵ), 0)
∂ = 0.5(∂Re - im*∂Im)
∂̅ = 0.5(∂Re + im*∂Im)
@test isapprox(
wirtinger_primal(df), ∂;
rtol=rtol, atol=atol, kwargs...
)
@test isapprox(
wirtinger_conjugate(df), ∂̅;
rtol=rtol, atol=atol, kwargs...
)
end
end
function Base.isapprox(ad::Wirtinger, fd; kwargs...)
error("Finite differencing with Wirtinger rules not implemented")
end
function Base.isapprox(d_ad::Casted, d_fd; kwargs...)
return all(isapprox.(extern(d_ad), d_fd; kwargs...))
end
function Base.isapprox(d_ad::DNE, d_fd; kwargs...)
error("Tried to differentiate w.r.t. a DNE")
end
function Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...)
return isapprox(extern(d_ad), d_fd; kwargs...)
end
# TODO: PR in ChainRulesCore
wirtinger_primal(x::Union{Casted,Thunk}) = wirtinger_primal(extern(x))
wirtinger_conjugate(x::Union{Casted,Thunk}) = wirtinger_conjugate(extern(x))
const f = abs_to_pow
@testset "abs_to_pow" begin
@testset "f($(args[1]), $(args[2]))" for args in
Iterators.product((2, 3.4, -2.1, -10+0im, 2.3-2im), (0, 1, 2, 4.3, -2.1, 1+.2im))
res = frule(f, args...)
@test res !== nothing # Check the rule was defined
fx, df = res
@test fx == f(args...) # Check we still get the normal value, right
test_scalar(ϵ -> f(args[1] + ϵ, args[2]), 0, df(One(), Zero()))
test_scalar(ϵ -> f(args[1], args[2] + ϵ), 0, df(Zero(), One()))
res = rrule(f, args...)
@test res !== nothing # Check the rule was defined
fx, df = res
@test fx == f(args...) # Check we still get the normal value, right
test_scalar(ϵ -> f(args[1] + ϵ, args[2]), 0, df[1](One()))
test_scalar(ϵ -> f(args[1], args[2] + ϵ), 0, df[2](One()))
end
end
@simeonschaub
Copy link
Author

simeonschaub commented Sep 3, 2019

Probably also should include a test whether !isa(extern(df_dp), Wirtinger) is always true.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment