Skip to content

Instantly share code, notes, and snippets.

@oxinabox
Created August 8, 2019 19:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save oxinabox/d6e07dc58b1e0b10b5e15b23a5b0346a to your computer and use it in GitHub Desktop.
Save oxinabox/d6e07dc58b1e0b10b5e15b23a5b0346a to your computer and use it in GitHub Desktop.
A Simple scalar ForwardDiff using ChainRules + DualNumbers
#==
A Simple scalar ForwardDiff using ChainRules + DualNumbers
---
No promises are made to its correctness or safty.
Infact it probably errors for super standard cases.
But this is just to explain how it can work
==#
## Setup
using Pkg: Pkg, @pkg_str
Pkg.activate(@__DIR__)
pkg"add ChainRules"
pkg"add DualNumbers"
pkg"add Cassette"
## Main Code
using DualNumbers
using ChainRules
using Cassette
Cassette.@context DiffCtx2
const diffctx2 = DiffCtx2()
Cassette.overdub(::DiffCtx, f, args...) = duel_based_grad(f, args...)
function duel_based_grad(f, x...)
@show f
xr = realpart.(x)
xd = dualpart.(x)
rule = ChainRules.frule(f, xr...)
if rule === nothing
@show "no rule"
# No rule, need to do nontrival AD
# x has duel parts, so calling it does AD
# Would just do `y = f(x...)`, but want to substitute for
# contained calls so need to do a Cassette.recurse
y = Cassette.recurse(diffctx2, f, x...)
@show y
yr = realpart.(y)
∂y = dualpart.(y)
yd = xd.*∂y # is this math right?
else
@show "hit rule"
yr, yd_rule = rule
yd = yd_rule(xd...)
end
@show yr
@show yd
return Dual.(yr, yd)
end
function grad(f, x)
println("-"^40)
y = duel_based_grad(f, Dual(x, 1.0))
@assert length(y)==1
y = first(y)
return (;result=realpart(y), derivative=dualpart(y))
end
## Demo
@show grad(-, 20.1)
#== Output
grad(-, 20.1) = (result = -20.1, derivative = -1.0)
==#
@show grad(x->3*x, 20.1)
#== Output
grad(x->3x, 20.1) = (result = 60.300000000000004, derivative = 3.0)
==#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment