Skip to content

Instantly share code, notes, and snippets.

@ChrisRackauckas
Created June 22, 2021 13:54
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ChrisRackauckas/1665257903b62462f6f970682636c5a8 to your computer and use it in GitHub Desktop.
Save ChrisRackauckas/1665257903b62462f6f970682636c5a8 to your computer and use it in GitHub Desktop.
hasbranching for automatically specializing ReverseDiff tape compilation
using Cassette, DiffRules
using Core: CodeInfo, SlotNumber, SSAValue, ReturnNode, GotoIfNot
const printbranch = true
Cassette.@context HasBranchingCtx
function Cassette.overdub(ctx::HasBranchingCtx, f, args...)
if Cassette.canrecurse(ctx, f, args...)
return Cassette.recurse(ctx, f, args...)
else
return Cassette.fallback(ctx, f, args...)
end
end
for (mod, f, n) in DiffRules.diffrules()
isdefined(@__MODULE__, mod) || continue
@eval Cassette.overdub(::HasBranchingCtx, f::Core.Typeof($mod.$f), x::Vararg{Any, $n}) = f(x...)
end
function _pass(::Type{<:HasBranchingCtx}, reflection::Cassette.Reflection)
ir = reflection.code_info
if any(x -> isa(x, GotoIfNot), ir.code)
printbranch && ccall(:jl_safe_printf, Cvoid, (Cstring,), "GotoIfNot detected in $(reflection.method)\nir = $ir\n")
Cassette.insert_statements!(
ir.code, ir.codelocs,
(stmt, i) -> i == 1 ? 3 : nothing,
(stmt, i) -> Any[
Expr(:call, Expr(:nooverdub, GlobalRef(Base, :getfield)), Expr(:contextslot), QuoteNode(:metadata)),
Expr(:call, Expr(:nooverdub, GlobalRef(Base, :setindex!)), SSAValue(1), true, QuoteNode(:has_branching)),
stmt,
],
)
Cassette.insert_statements!(
ir.code, ir.codelocs,
(stmt, i) -> i > 2 && isa(stmt, Expr) ? 1 : nothing,
(stmt, i) -> begin
callstmt = Meta.isexpr(stmt, :(=)) ? stmt.args[2] : stmt
Meta.isexpr(stmt, :call) || Meta.isexpr(stmt, :invoke) || return Any[stmt]
callstmt = Expr(callstmt.head, Expr(:nooverdub, callstmt.args[1]), callstmt.args[2:end]...)
return Any[
Meta.isexpr(stmt, :(=)) ? Expr(:(=), stmt.args[1], callstmt) : callstmt,
]
end,
)
end
return ir
end
const pass = Cassette.@pass _pass
function hasbranching(f, x...)
metadata = Dict(:has_branching => false)
Cassette.overdub(Cassette.disablehooks(HasBranchingCtx(; pass, metadata)), f, x...)
return metadata[:has_branching]
end
Cassette.overdub(::HasBranchingCtx, ::typeof(+), x...) = +(x...)
Cassette.overdub(::HasBranchingCtx, ::typeof(*), x...) = *(x...)
Cassette.overdub(::HasBranchingCtx, ::typeof(Base.materialize), x...) = Base.materialize(x...)
Cassette.overdub(::HasBranchingCtx, ::typeof(Base.literal_pow), x...) = Base.literal_pow(x...)
Cassette.overdub(::HasBranchingCtx, ::typeof(Base.getindex), x...) = Base.getindex(x...)
Cassette.overdub(::HasBranchingCtx, ::typeof(Core.Typeof), x...) = Core.Typeof(x...)
hasbranching(1, 2) do x, y
(x < 0 ? -x : x) + exp(y)
end
hasbranching(1, 2) do x, y
ifelse(x < 0, -x, x) + exp(y)
end
using DiffEqFlux
# Override FastDense to exclude the branch from the check
function Cassette.overdub(ctx::HasBranchingCtx, f::FastDense, x, p)
y = reshape(p[1:(f.out*f.in)],f.out,f.in)*x
Cassette.@overdub ctx f.σ.(y)
end
u0 = Float32[2.0; 0.0]
dudt2 = FastChain((x, p) -> x.^3,
FastDense(2, 50, tanh),
FastDense(50, 2))
p = initial_params(dudt2)
hasbranching(dudt2,u0,p)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment