Skip to content

Instantly share code, notes, and snippets.

@willtebbutt
Last active April 16, 2019 20:46
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save willtebbutt/39205ab845b22e6452a42705eac8d254 to your computer and use it in GitHub Desktop.
Save willtebbutt/39205ab845b22e6452a42705eac8d254 to your computer and use it in GitHub Desktop.
Toy tape-based reverse-mode AD with minimal Cassette usage.
#
# This uses the Nabla.jl-style interception mechanism whereby
# we wrap things that are to be differentiated w.r.t. in a
# thin wrapper. There are lots of thing that you can't
# propoagate derivative information through with this kind of
# approach without quite a lot of extra machinery, but the
# examples at the bottom do work.
#
using ChainRules, Cassette
using Cassette: @context
using ChainRules: rrule, extern, Zero
##############################
# Types for tracking objects #
##############################
abstract type Node{T} end
struct Leaf{Ty} <: Node{Ty}
y::Ty
pos::Int
tape::Vector{Any}
end
struct Branch{Ty, Tf, Txs, TΔxs} <: Node{Ty}
y::Ty
f::Tf
xs::Txs
Δxs::TΔxs
pos::Int
tape::Vector{Any}
end
# Helper functions.
is_tagged(x::Node) = true
is_tagged(x) = false
untag(x::Node) = x.y
untag(x) = x
get_tape(x...) = x[findfirst(is_tagged, x)].tape
######################################################
# Use Cassette to define the interception mechanisms #
######################################################
@context DiffCtx
function Cassette.overdub(ctx::DiffCtx, f, x)
is_tagged(x) || return f(x)
rule = rrule(f, untag(x))
if !(rule isa Nothing)
y, Δx = rule
tape = x.tape
y_br = Branch(y, f, (x,), (Δx,), length(tape) + 1, tape)
push!(tape, y_br)
return y_br
end
return Cassette.recurse(ctx, f, x)
end
function Cassette.overdub(ctx::DiffCtx, f, x...)
any(is_tagged, x) || return f(x...)
rule = rrule(f, map(untag, x)...)
if !(rule isa Nothing)
y, Δxs = rule
tape = get_tape(x...)
y_br = Branch(y, f, x, Δxs, length(tape) + 1, tape)
push!(tape, y_br)
return y_br
end
return Cassette.recurse(ctx, f, x...)
end
#############################################
# Implement reverse-mode AD in not many LoC #
#############################################
function forward(f, x...)
tape = Vector{Any}()
leaves = map(((n, x),)->Leaf(x, n, tape), enumerate(x))
map(leaf->push!(tape, leaf), leaves)
y = Cassette.overdub(DiffCtx(), f, leaves...)
return y.y, function(ȳ)
back_tape = Vector{Any}(undef, length(y.tape))
fill!(back_tape, Zero())
back_tape[end] = ȳ
for n in reverse(eachindex(back_tape))
if tape[n] isa Branch
for (p, x) in enumerate(tape[n].xs)
if is_tagged(x)
back_tape[x.pos] = ChainRules.accumulate(
back_tape[x.pos],
tape[n].Δxs[p],
back_tape[n],
)
end
end
end
end
return (extern.(back_tape[1:length(x)])...,)
end
end
foo(x) = sin(x) + cos(x)
y, back = forward(foo, 5.0);
back(1)
bar(x, y) = sin(x) + cos(y)
y, back = forward(bar, 5.0, 4.0);
back(1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment