-
-
Save simeonschaub/b8c6c233fbde79ad7cde0d4bc490d676 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Cassette, Test | |
using Cassette: @context, enabletagging, @overdub, overdub, recurse, | |
hasmetadata, metadata, tag, untag | |
using ChainRules: frule, Zero, extern | |
@context DiffCtx | |
Cassette.metadatatype(::Type{<:DiffCtx}, ::Type{T}) where {T<:Real} = T | |
function D(f, x) | |
ctx = enabletagging(DiffCtx(), f) | |
result = overdub(ctx, f, tag(x, ctx, oftype(x, 1.0))) | |
return metadata(result, ctx) | |
end | |
function Cassette.overdub(ctx::DiffCtx, f, x) | |
vx = untag(x, ctx) | |
rule = frule(f, vx) | |
if !(rule isa Nothing) | |
fx, df = rule | |
hasmetadata(x, ctx) || return fx | |
return tag_outputs(ctx, fx, df, metadata(x, ctx)) | |
end | |
return Cassette.recurse(ctx, f, x) | |
end | |
function Cassette.overdub(ctx::DiffCtx, f, x, y) | |
vx = untag(x, ctx) | |
vy = untag(y, ctx) | |
rule = frule(f, vx, vy) | |
if !(rule isa Nothing) | |
fxy, df = rule | |
!(hasmetadata(x, ctx) || hasmetadata(y, ctx)) && return fxy | |
dx = hasmetadata(x, ctx) ? metadata(x, ctx) : Zero() | |
dy = hasmetadata(y, ctx) ? metadata(y, ctx) : Zero() | |
return tag_outputs(ctx, fxy, df, dx, dy) | |
end | |
return Cassette.recurse(ctx, f, x, y) | |
end | |
tag_outputs(ctx, Ω, dΩ, Δ...) = tag(Ω, ctx, extern(dΩ(Δ...))) | |
function tag_outputs(ctx, Ω::NTuple{2,Any}, dΩ::NTuple{2,Any}, Δ...) | |
Ω₁, Ω₂ = Ω | |
dΩ₁, dΩ₂ = dΩ | |
t₁ = tag(Ω₁, ctx, extern(dΩ₁(Δ...))) | |
t₂ = tag(Ω₂, ctx, extern(dΩ₂(Δ...))) | |
return @overdub ctx tuple(t₁, t₂) | |
end | |
@test D(sin, 1) === cos(1) | |
@test D(x -> D(sin, x), 1) === -sin(1) | |
@test D(x -> sin(x) * cos(x), 1) === cos(1)^2 - sin(1)^2 | |
@test D(x -> x * D(y -> x * y, 1), 2) === 4 | |
@test D(x -> x * D(y -> x * y, 2), 1) === 2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment