Skip to content

Instantly share code, notes, and snippets.

@oxinabox
Created August 10, 2021 12:16
Show Gist options
  • Save oxinabox/c6ad25c468b3108f8a799bda66c147f8 to your computer and use it in GitHub Desktop.
Save oxinabox/c6ad25c468b3108f8a799bda66c147f8 to your computer and use it in GitHub Desktop.
Sketch: Extension of rrule to take in the activity (i.e. if your want to get the derivative wrt this)
abstract type ActivityMarked{T} end
struct Active{T} <: ActivityMarked{T}
val::T
end
struct Dead{T} <: ActivityMarked{T}
val::T
end
active(x) = false
active(x::Active) = true
strip_activity(x::ActivityMarked) = x.val
"""
`active_rrule` is like `rrule` but rather than passing in primals, you pass in either `Active(primal)` or `Dead(primal)` depending on if you want the be able to AD wrt it.
If it is Dead wthen for it's deriviate we return `NoTangent` (in examples that follow)
or perhaps some new `DidNotRequestTangent<:AbstractZero`
"""
function active_rrule end
# Fallback
# if in doubt all back to assuming all are active
# and using a plain rrule
active_rrule(f, args...) = rrule(f, strip_activity.(args)...)
# if all are dead then this is easy
all_dead_pullback(n) = _->ntuple(_->NoTangent(), n+1)
function active_rrule(f, args::Dead...)
@assert fieldcount(f) === 0 # ignoring functors for now
return f(args...), all_dead_pullback(length(args))
end
# Now define:
function active_rrule(::typeof(foo), a::Active, b)
# slow way to get cotangent for both a and b
end
function active_rrule(::typeof(foo), a::Dead, b::Active)
# fast way to get cotangent for b, and a returns NoTangent()
end
@oxinabox
Copy link
Author

We could define:
active_rrule(args::Active...) = rrule(strip_activity.(args)...)
So that if everything is active if will just all the rrule, no matter what the function is.

@willtebbutt
Copy link

Ahhh I see. Yeah, that makes sense.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment