-
-
Save oxinabox/c6ad25c468b3108f8a799bda66c147f8 to your computer and use it in GitHub Desktop.
abstract type ActivityMarked{T} end | |
struct Active{T} <: ActivityMarked{T} | |
val::T | |
end | |
struct Dead{T} <: ActivityMarked{T} | |
val::T | |
end | |
active(x) = false | |
active(x::Active) = true | |
strip_activity(x::ActivityMarked) = x.val | |
""" | |
`active_rrule` is like `rrule` but rather than passing in primals, you pass in either `Active(primal)` or `Dead(primal)` depending on if you want the be able to AD wrt it. | |
If it is Dead wthen for it's deriviate we return `NoTangent` (in examples that follow) | |
or perhaps some new `DidNotRequestTangent<:AbstractZero` | |
""" | |
function active_rrule end | |
# Fallback | |
# if in doubt all back to assuming all are active | |
# and using a plain rrule | |
active_rrule(f, args...) = rrule(f, strip_activity.(args)...) | |
# if all are dead then this is easy | |
all_dead_pullback(n) = _->ntuple(_->NoTangent(), n+1) | |
function active_rrule(f, args::Dead...) | |
@assert fieldcount(f) === 0 # ignoring functors for now | |
return f(args...), all_dead_pullback(length(args)) | |
end | |
# Now define: | |
function active_rrule(::typeof(foo), a::Active, b) | |
# slow way to get cotangent for both a and b | |
end | |
function active_rrule(::typeof(foo), a::Dead, b::Active) | |
# fast way to get cotangent for b, and a returns NoTangent() | |
end |
Yeah, this seems like a nice solution. It's particularly nice in that it would be opt-in for AD systems -- if they don't care about activity, and just want regular rrule
s, they can avoid this infrastructure entirely.
It would also be really easy to test these -- you just check for consistency with the regular rrule
(probably we'd want to insist that active_rrule
s have a fallback rrule
for the all-active case to prevent accidentally create active_rrule
s without rrule
s, which would prevent Zygote and Diffractor from using them).
Probably the fallback to rrule
should also drop any tangents for anything that is Dead
.
So that if they were thunks it can for sure avoid anyone ever unthunking them.
👍
To reduce ambiguities also should have a all active cast that hits rrule
Not sure what you mean by that. Could you elaborate please?
We could define:
active_rrule(args::Active...) = rrule(strip_activity.(args)...)
So that if everything is active if will just all the rrule
, no matter what the function is.
Ahhh I see. Yeah, that makes sense.
Probably would want to integrate this with
@non_differentiable
some howdefinitely want to integrate if with scalar_rules for multi-input functions that have no setup step.