Skip to content

Instantly share code, notes, and snippets.

@niklasschmitz
Last active October 2, 2021 12:56
Show Gist options
  • Save niklasschmitz/8caa81e70c4458285367bd93908a6bcb to your computer and use it in GitHub Desktop.
Save niklasschmitz/8caa81e70c4458285367bd93908a6bcb to your computer and use it in GitHub Desktop.
Getting an rrule from an frule in ChainRules.jl
# rrule from frule (transposition)
using Zygote
using ChainRulesCore
using LinearAlgebra
function f(x)
a = sin.(x)
b = sum(a)
c = b * a
return c
end
function ChainRulesCore.frule((Δself, Δx,), ::typeof(f), x)
a, ȧ = sin.(x), cos.(x) .* Δx
b, ḃ = sum(a), sum(ȧ)
c, ċ = b * a, ḃ * a + b * ȧ
return c, ċ
end
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(f), x)
pushforward(Δfx...) = frule(Δfx, f, x)[2]
_, back = rrule_via_ad(config, pushforward, f, x)
f_pullback(Δy) = back(Δy)[2:end]
return f(x), f_pullback
end
let x = rand(3)
v = randn(3)
w = randn(3)
jvp(f, x, v) = frule((NoTangent(), v), f, x)[2]
vjp(f, x, w) = rrule_via_ad(Zygote.ZygoteRuleConfig(), f, x)[2](w)[2]
dot(w, jvp(f, x, v)) ≈ dot(vjp(f, x, w), v)
end
@niklasschmitz
Copy link
Author

This is inspired by JAX's approach to decomposing reverse-mode into (forward-mode) linearization + transposition (where transposition only needs rules for linear primitives)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment