Skip to content

Instantly share code, notes, and snippets.

@simeonschaub
Created September 2, 2019 18:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save simeonschaub/a6dfcd71336d863b3777093b3b8d9c97 to your computer and use it in GitHub Desktop.
Save simeonschaub/a6dfcd71336d863b3777093b3b8d9c97 to your computer and use it in GitHub Desktop.
using ChainRulesCore
import ChainRulesCore: wirtinger_primal, wirtinger_conjugate
using ChainRulesCore: AbstractDifferential
using FiniteDifferences
using Test
abs_to_pow(x, p) = abs(x)^p
@scalar_rule(abs_to_pow(x::Real, p),
(p == 0 ? Zero() : p * abs_to_pow(x, p-1) * sign(x), Ω * log(abs(x))))
@scalar_rule(
abs_to_pow(x::Complex, p),
@setup(u = abs(x)),
(
p == 0 ? Zero() : p * u^(p-1) * Wirtinger(x' / 2u, x / 2u),
Ω * log(abs(x))
)
)
function test_scalar(f, x, df; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1),
test_wirtinger=x isa Complex || extern(df) isa Wirtinger, kwargs...)
# Check that we get the derivative right:
if !test_wirtinger
@test isapprox(
extern(df), fdm(f, x);
rtol=rtol, atol=atol, kwargs...
)
else
# For complex arguments, also check if the wirtinger derivative is correct
∂Re = fdm(ϵ -> f(x + ϵ), 0)
∂Im = fdm(ϵ -> f(x + im*ϵ), 0)
∂ = 0.5(∂Re - im*∂Im)
∂̅ = 0.5(∂Re + im*∂Im)
@test isapprox(
wirtinger_primal(df), ∂;
rtol=rtol, atol=atol, kwargs...
)
@test isapprox(
wirtinger_conjugate(df), ∂̅;
rtol=rtol, atol=atol, kwargs...
)
end
end
function Base.isapprox(ad::Wirtinger, fd; kwargs...)
error("Finite differencing with Wirtinger rules not implemented")
end
function Base.isapprox(d_ad::Casted, d_fd; kwargs...)
return all(isapprox.(extern(d_ad), d_fd; kwargs...))
end
function Base.isapprox(d_ad::DNE, d_fd; kwargs...)
error("Tried to differentiate w.r.t. a DNE")
end
function Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...)
return isapprox(extern(d_ad), d_fd; kwargs...)
end
# TODO: PR in ChainRulesCore
wirtinger_primal(x::Union{Casted,Thunk}) = wirtinger_primal(extern(x))
wirtinger_conjugate(x::Union{Casted,Thunk}) = wirtinger_conjugate(extern(x))
const f = abs_to_pow
@testset "abs_to_pow" begin
@testset "f($(args[1]), $(args[2]))" for args in
Iterators.product((2, 3.4, -2.1, -10+0im, 2.3-2im), (0, 1, 2, 4.3, -2.1, 1+.2im))
res = frule(f, args...)
@test res !== nothing # Check the rule was defined
fx, df = res
@test fx == f(args...) # Check we still get the normal value, right
test_scalar(ϵ -> f(args[1] + ϵ, args[2]), 0, df(One(), Zero()))
test_scalar(ϵ -> f(args[1], args[2] + ϵ), 0, df(Zero(), One()))
res = rrule(f, args...)
@test res !== nothing # Check the rule was defined
fx, df = res
@test fx == f(args...) # Check we still get the normal value, right
test_scalar(ϵ -> f(args[1] + ϵ, args[2]), 0, df[1](One()))
test_scalar(ϵ -> f(args[1], args[2] + ϵ), 0, df[2](One()))
end
end
@oxinabox
Copy link

oxinabox commented Sep 3, 2019

Awesome thank you

Here is my modified version of that that works for after JuliaDiff/ChainRulesCore.jl#30

using FiniteDifferences
using Test
using ChainRulesCore
using Random


abs_to_pow(x, p) = abs(x)^p
@scalar_rule(
    abs_to_pow(x::Real, p),
    (
        p == 0 ? Zero() : p * abs_to_pow(x, p-1) * sign(x),
        Ω * log(abs(x))
    )
)

@scalar_rule(
    abs_to_pow(x::Complex, p),
    @setup(u = abs(x)),
    (
        p == 0 ? Zero() : p * u^(p-1) * Wirtinger(x' / 2u, x / 2u),
        Ω * log(abs(x))
    )
)


function test_scalar(f, x, df; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1),
                     test_wirtinger=x isa Complex || extern(df) isa Wirtinger, kwargs...)
    # Check that we get the derivative right:
    if !test_wirtinger
        @test isapprox(
            extern(df), fdm(f, x);
            rtol=rtol, atol=atol, kwargs...
        )
    else
        # For complex arguments, also check if the wirtinger derivative is correct
        ∂Re = fdm-> f(x + ϵ), 0)
        ∂Im = fdm-> f(x + im*ϵ), 0)
        ∂ = 0.5(∂Re - im*∂Im)
        ∂̅ = 0.5(∂Re + im*∂Im)
        @test isapprox(
            wirtinger_primal(df), ∂;
            rtol=rtol, atol=atol, kwargs...
        )
        @test isapprox(
            wirtinger_conjugate(df), ∂̅;
            rtol=rtol, atol=atol, kwargs...
        )
    end
end

function Base.isapprox(ad::Wirtinger, fd; kwargs...)
    error("Finite differencing with Wirtinger rules not implemented")
end
function Base.isapprox(d_ad::Casted, d_fd; kwargs...)
    return all(isapprox.(extern(d_ad), d_fd; kwargs...))
end
function Base.isapprox(d_ad::DNE, d_fd; kwargs...)
    error("Tried to differentiate w.r.t. a DNE")
end
function Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...)
    return isapprox(extern(d_ad), d_fd; kwargs...)
end

# TODO: PR in ChainRulesCore
ChainRulesCore.wirtinger_primal(x::Union{Casted,Thunk}) = wirtinger_primal(extern(x))
ChainRulesCore.wirtinger_conjugate(x::Union{Casted,Thunk}) = wirtinger_conjugate(extern(x))

const f = abs_to_pow


@testset "f=abs_to_pow" begin
    f = abs_to_pow
    @testset "f($x, $p)" for (x, p) in Iterators.product(
        (2, 3.4, -2.1, -10+0im, 2.3-2im),
        (0, 1, 2, 4.3, -2.1, 1+.2im)
    )

        res = frule(f, x, p)
        @test res !== nothing  # Check the rule was defined
        fx, f_pushforward = res
        df(Δx, Δp) = f_pushforward(NamedTuple(), Δx, Δp)

        df_dx, = df(One(), Zero())
        df_dp,= df(Zero(), One())
        @test fx == f(x, p)  # Check we still get the normal value, right
        test_scalar-> f(x + ϵ, p), 0, df_dx)
        test_scalar-> f(x, p + ϵ), 0, df_dp)

        res = rrule(f, x, p)
        @test res !== nothing  # Check the rule was defined
        fx, f_pullback = res
        _, df_dx, df_dp = f_pullback(One())
        @test fx == f(x, p)  # Check we still get the normal value, right
        test_scalar-> f(x + ϵ, p), 0, df_dx)
        test_scalar-> f(x, p + ϵ), 0, df_dp)
    end
end

@simeonschaub
Copy link
Author

simeonschaub commented Sep 3, 2019

Probably also should include a test whether !isa(extern(df_dp), Wirtinger) is always true.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment