Skip to content

Instantly share code, notes, and snippets.

@oxinabox
Created August 10, 2021 12:16
Show Gist options
  • Save oxinabox/c6ad25c468b3108f8a799bda66c147f8 to your computer and use it in GitHub Desktop.
Save oxinabox/c6ad25c468b3108f8a799bda66c147f8 to your computer and use it in GitHub Desktop.
Sketch: Extension of rrule to take in the activity (i.e. if your want to get the derivative wrt this)
abstract type ActivityMarked{T} end
struct Active{T} <: ActivityMarked{T}
val::T
end
struct Dead{T} <: ActivityMarked{T}
val::T
end
active(x) = false
active(x::Active) = true
strip_activity(x::ActivityMarked) = x.val
"""
`active_rrule` is like `rrule` but rather than passing in primals, you pass in either `Active(primal)` or `Dead(primal)` depending on if you want the be able to AD wrt it.
If it is Dead wthen for it's deriviate we return `NoTangent` (in examples that follow)
or perhaps some new `DidNotRequestTangent<:AbstractZero`
"""
function active_rrule end
# Fallback
# if in doubt all back to assuming all are active
# and using a plain rrule
active_rrule(f, args...) = rrule(f, strip_activity.(args)...)
# if all are dead then this is easy
all_dead_pullback(n) = _->ntuple(_->NoTangent(), n+1)
function active_rrule(f, args::Dead...)
@assert fieldcount(f) === 0 # ignoring functors for now
return f(args...), all_dead_pullback(length(args))
end
# Now define:
function active_rrule(::typeof(foo), a::Active, b)
# slow way to get cotangent for both a and b
end
function active_rrule(::typeof(foo), a::Dead, b::Active)
# fast way to get cotangent for b, and a returns NoTangent()
end
@oxinabox
Copy link
Author

Probably the fallback to rrule should also drop any tangents for anything that is Dead.
So that if they were thunks it can for sure avoid anyone ever unthunking them.

@willtebbutt
Copy link

👍

@oxinabox
Copy link
Author

To reduce ambiguities also should have a all active cast that hits rrule

@willtebbutt
Copy link

Not sure what you mean by that. Could you elaborate please?

@oxinabox
Copy link
Author

We could define:
active_rrule(args::Active...) = rrule(strip_activity.(args)...)
So that if everything is active if will just all the rrule, no matter what the function is.

@willtebbutt
Copy link

Ahhh I see. Yeah, that makes sense.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment