Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
using Cassette, Test
using Cassette: @context, enabletagging, @overdub, overdub, recurse,
hasmetadata, metadata, tag, untag
using ChainRules: frule, Zero, extern
@context DiffCtx
Cassette.metadatatype(::Type{<:DiffCtx}, ::Type{T}) where {T<:Real} = T
function D(f, x)
ctx = enabletagging(DiffCtx(), f)
result = overdub(ctx, f, tag(x, ctx, oftype(x, 1.0)))
return metadata(result, ctx)
end
function Cassette.overdub(ctx::DiffCtx, f, x)
vx = untag(x, ctx)
rule = frule(f, vx)
if !(rule isa Nothing)
fx, df = rule
hasmetadata(x, ctx) || return fx
return tag_outputs(ctx, fx, df, metadata(x, ctx))
end
return Cassette.recurse(ctx, f, x)
end
function Cassette.overdub(ctx::DiffCtx, f, x, y)
vx = untag(x, ctx)
vy = untag(y, ctx)
rule = frule(f, vx, vy)
if !(rule isa Nothing)
fxy, df = rule
!(hasmetadata(x, ctx) || hasmetadata(y, ctx)) && return fxy
dx = hasmetadata(x, ctx) ? metadata(x, ctx) : Zero()
dy = hasmetadata(y, ctx) ? metadata(y, ctx) : Zero()
return tag_outputs(ctx, fxy, df, dx, dy)
end
return Cassette.recurse(ctx, f, x, y)
end
tag_outputs(ctx, Ω, dΩ, Δ...) = tag(Ω, ctx, extern(...)))
function tag_outputs(ctx, Ω::NTuple{2,Any}, dΩ::NTuple{2,Any}, Δ...)
Ω₁, Ω₂ = Ω
dΩ₁, dΩ₂ =
t₁ = tag(Ω₁, ctx, extern(dΩ₁...)))
t₂ = tag(Ω₂, ctx, extern(dΩ₂...)))
return @overdub ctx tuple(t₁, t₂)
end
@test D(sin, 1) === cos(1)
@test D(x -> D(sin, x), 1) === -sin(1)
@test D(x -> sin(x) * cos(x), 1) === cos(1)^2 - sin(1)^2
@test D(x -> x * D(y -> x * y, 1), 2) === 4
@test D(x -> x * D(y -> x * y, 2), 1) === 2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment