Skip to content

Instantly share code, notes, and snippets.

@oxinabox
Created August 10, 2021 12:16
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save oxinabox/c6ad25c468b3108f8a799bda66c147f8 to your computer and use it in GitHub Desktop.
Save oxinabox/c6ad25c468b3108f8a799bda66c147f8 to your computer and use it in GitHub Desktop.
Sketch: Extension of rrule to take in the activity (i.e. if your want to get the derivative wrt this)
abstract type ActivityMarked{T} end
struct Active{T} <: ActivityMarked{T}
val::T
end
struct Dead{T} <: ActivityMarked{T}
val::T
end
active(x) = false
active(x::Active) = true
strip_activity(x::ActivityMarked) = x.val
"""
`active_rrule` is like `rrule` but rather than passing in primals, you pass in either `Active(primal)` or `Dead(primal)` depending on if you want the be able to AD wrt it.
If it is Dead wthen for it's deriviate we return `NoTangent` (in examples that follow)
or perhaps some new `DidNotRequestTangent<:AbstractZero`
"""
function active_rrule end
# Fallback
# if in doubt all back to assuming all are active
# and using a plain rrule
active_rrule(f, args...) = rrule(f, strip_activity.(args)...)
# if all are dead then this is easy
all_dead_pullback(n) = _->ntuple(_->NoTangent(), n+1)
function active_rrule(f, args::Dead...)
@assert fieldcount(f) === 0 # ignoring functors for now
return f(args...), all_dead_pullback(length(args))
end
# Now define:
function active_rrule(::typeof(foo), a::Active, b)
# slow way to get cotangent for both a and b
end
function active_rrule(::typeof(foo), a::Dead, b::Active)
# fast way to get cotangent for b, and a returns NoTangent()
end
@oxinabox
Copy link
Author

Probably would want to integrate this with @non_differentiable some how
definitely want to integrate if with scalar_rules for multi-input functions that have no setup step.

@willtebbutt
Copy link

Yeah, this seems like a nice solution. It's particularly nice in that it would be opt-in for AD systems -- if they don't care about activity, and just want regular rrules, they can avoid this infrastructure entirely.

It would also be really easy to test these -- you just check for consistency with the regular rrule (probably we'd want to insist that active_rrules have a fallback rrule for the all-active case to prevent accidentally create active_rrules without rrules, which would prevent Zygote and Diffractor from using them).

@oxinabox
Copy link
Author

Probably the fallback to rrule should also drop any tangents for anything that is Dead.
So that if they were thunks it can for sure avoid anyone ever unthunking them.

@willtebbutt
Copy link

👍

@oxinabox
Copy link
Author

To reduce ambiguities also should have a all active cast that hits rrule

@willtebbutt
Copy link

Not sure what you mean by that. Could you elaborate please?

@oxinabox
Copy link
Author

We could define:
active_rrule(args::Active...) = rrule(strip_activity.(args)...)
So that if everything is active if will just all the rrule, no matter what the function is.

@willtebbutt
Copy link

Ahhh I see. Yeah, that makes sense.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment