Skip to content

Instantly share code, notes, and snippets.

@femtomc
Created April 22, 2020 14:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save femtomc/7735f959322a0d96b75da27dd96bf098 to your computer and use it in GitHub Desktop.
Save femtomc/7735f959322a0d96b75da27dd96bf098 to your computer and use it in GitHub Desktop.
module ReverseModeADwithContexts
using Cassette
Cassette.@context J; # J is the notation for the function which generates pullbacks (the lambda terms you see below)
# Equivalent to gradient tape.
mutable struct ReverseTrace
gradient_tape::Array{Function, 1}
ReverseTrace() = new([])
end
function apply(arr::Array{Function, 1}, arg)
arg = arg
while length(arr) > 0
func = pop!(arr)
arg = func(arg)
end
return arg
end
gradient(tr::ReverseTrace) = apply(tr.gradient_tape, 1)
# Trace with context.
function Cassette.overdub(ctx::J, func::typeof(sin), arg::Float64)
result = sin(arg)
lambda = ȳ-> ȳ * cos(arg)
push!(ctx.metadata.gradient_tape, lambda)
return result
end
function Cassette.overdub(ctx::J, func::typeof(cos), arg::Float64)
result = cos(arg)
lambda = ȳ-> ȳ * -sin(arg)
push!(ctx.metadata.gradient_tape, lambda)
return result
end
# Test.
function foo(x::Float64)
y = sin(x)
z = cos(y)
return z
end
tr = ReverseTrace()
Cassette.overdub(Cassette.disablehooks(J(metadata = tr)), foo, 5.0)
println(tr.gradient_tape)
println(gradient(tr))
end #module
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment