Skip to content

Instantly share code, notes, and snippets.

@jrevels
Last active February 23, 2018 16:57
Show Gist options
  • Save jrevels/536dd3068d15f7c3def05de037f26563 to your computer and use it in GitHub Desktop.
Save jrevels/536dd3068d15f7c3def05de037f26563 to your computer and use it in GitHub Desktop.
recursive_ad.jl
# This script requires an up-to-date ForwardDiff, ReverseDiff, and Julia v0.6 installation.
using ForwardDiff, ReverseDiff
D_f(f) = x::Number -> ForwardDiff.derivative(f, x)
# ReverseDiff's API only supports array inputs, so we just wrap our scalar input in a
# 1-element array and extract our scalar derivative from the returned gradient array
D_r(f) = x::Number -> ReverseDiff.gradient(y -> f(y[1]), [x])[1]
function ff(x::Number)
if x > 0
return sin(x)
else
return D_f(ff)(x + 1)
end
end
function fr(x::Number)
if x > 0
return sin(x)
else
return D_r(fr)(x + 1)
end
end
# The below expression returns `true`, as it should. Note that it takes a while to
# run the first time it is called; this is just due to compilation time, since Julia
# has a method-invocation JIT. I've found nested differentiation can be quite hard on
# the compiler.
# Note also that Julia's inference optimally infers the concrete type of the output
# (`Float64`) in every case except for `D_f(ff)` (nference results can be checked by
# using Julia's `@code_typed` macro). For `D_f(ff)`, it infers an output type of `Any`
# which is still correct, but suboptimal as it could hypothetically induce dynamic dispatch
# in downstream code (which Julia handles well, but is obviously not as fast as a
# precomputed static dispatch). I suspect this `Any` result isn't because Julia's inference
# couldn't figure out the concrete output type given enough time, but rather because it
# chose to give up inferring when hitting an internally tuned heuristic limit (to avoid,
# e.g. overspecialization).
D_r(ff)(-1.0) == D_f(ff)(-1.0) == D_r(fr)(-1.0) == D_f(fr)(-1.0) == -cos(1.0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment