Skip to content

Instantly share code, notes, and snippets.

@Keno
Created March 10, 2021 00:00
Show Gist options
  • Save Keno/d8faa85dd64c878e0985e25942bce450 to your computer and use it in GitHub Desktop.
Save Keno/d8faa85dd64c878e0985e25942bce450 to your computer and use it in GitHub Desktop.
# ∂⃖rrule has a 4-recurrence - we model this as 4 separate structs that we
# cycle between. N.B.: These names match the names that these variables
# have in Snippet 19 of the terminology guide. They are probably not ideal,
# but if you rename them here, please update the terminology guide also.
struct ∂⃖rruleA{N, O}; ∂; ȳ; ȳ̄ ; end
struct ∂⃖rruleB{N, O}; ᾱ; ȳ̄ ; end
struct ∂⃖rruleC{N, O}; ȳ̄ ; Δ′′′; β̄ ; end
struct ∂⃖rruleD{N, O}; γ̄; β̄ ; end
function (a::∂⃖rruleA{N, O})(Δ) where {N, O}
@destruct (α, ᾱ) = a.∂(a.ȳ, Δ)
(α, ∂⃖rruleB{N, O}(ᾱ, a.ȳ̄))
end
function (b::∂⃖rruleB{N, O})(Δ′...) where {N, O}
@destruct ((Δ′′′, β), β̄) = b.ᾱ(Δ′)
(β, ∂⃖rruleC{N, O}(b.ȳ̄, Δ′′′, β̄))
end
function (c::∂⃖rruleC{N, O})(Δ′′) where {N, O}
@destruct (γ, γ̄) = c.ȳ̄((Δ′′, c.Δ′′′))
(Base.tail(γ), ∂⃖rruleD{N, O}(γ̄, c.β̄))
end
function (d::∂⃖rruleD{N, O})(Δ⁴...) where {N, O}
(δ₁, δ₂), δ̄ = d.γ̄(Zero(), Δ⁴...)
(δ₁, ∂⃖rruleA{N, O+1}(d.β̄ , δ₂, δ̄ ))
end
# Terminal cases
function (c::∂⃖rruleB{N, N})(Δ′...) where {N}
@destruct (Δ′′′, β) = c.ᾱ(Δ′)
(β, ∂⃖rruleC{N, N}(c.ȳ̄, Δ′′′, nothing))
end
(c::∂⃖rruleC{N, N})(Δ′′) where {N} = Base.tail(c.ȳ̄((Δ′′, c.Δ′′′)))
(::∂⃖rruleD{N, N})(Δ...) where {N} = error("Should not be reached")
# ∂⃖rrule
@Base.pure term_depth(N) = 2^(N-2)
function (::∂⃖rrule{N})(z, z̄) where {N}
@destruct (y, ȳ) = z
y, ∂⃖rruleA{term_depth(N), 1}(∂⃖{minus1(N)}(), ȳ, z̄)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment