Skip to content

Instantly share code, notes, and snippets.

@Roger-luo
Created March 2, 2020 18:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Roger-luo/9ef4c4e745668598f6b9f8297c217ed8 to your computer and use it in GitHub Desktop.
Save Roger-luo/9ef4c4e745668598f6b9f8297c217ed8 to your computer and use it in GitHub Desktop.
AutoPreallocation patch
using Zygote: @adjoint, _pullback, Context, cache
using AutoPreallocation
using Cassette
export expect, ∇expect, exact_expect, ∇exact_expect
using AutoPreallocation: RecordingCtx, ReplayCtx
# https://github.com/oxinabox/AutoPreallocation.jl/pull/9
@inline Cassette.overdub(ctx::RecordingCtx, ::typeof(Base.haskey), collection, key) = haskey(collection, key)
@inline Cassette.overdub(ctx::ReplayCtx, ::typeof(Base.haskey), collection, key) = haskey(collection, key)
@inline Cassette.overdub(ctx::RecordingCtx, ::typeof(Broadcast.combine_axes), xs...) = Broadcast.combine_axes(xs...)
@inline Cassette.overdub(ctx::ReplayCtx, ::typeof(Broadcast.combine_axes), xs...) = Broadcast.combine_axes(xs...)
@inline Cassette.overdub(ctx::RecordingCtx, ::typeof(Zygote.trim), xs...) = Zygote.trim(xs...)
@inline Cassette.overdub(ctx::ReplayCtx, ::typeof(Zygote.trim), xs...) = Zygote.trim(xs...)
@inline Cassette.overdub(ctx::RecordingCtx, ::Type{Val}, x) = Val(x)
@inline Cassette.overdub(ctx::ReplayCtx, ::Type{Val}, x) = Val(x)
@inline Cassette.overdub(ctx::RecordingCtx, ::typeof(Base.reduced_indices), xs...) = Base.reduced_indices(xs...)
@inline Cassette.overdub(ctx::ReplayCtx, ::typeof(Base.reduced_indices), xs...) = Base.reduced_indices(xs...)
@inline Cassette.overdub(ctx::RecordingCtx, f::typeof(LinearAlgebra.gemv!), xs...) = f(xs...)
@inline Cassette.overdub(ctx::ReplayCtx, f::typeof(LinearAlgebra.gemv!), xs...) = f(xs...)
@inline Cassette.overdub(ctx::RecordingCtx, ::typeof(getindex), x::IdDict, key) = getindex(x, key)
@inline Cassette.overdub(ctx::ReplayCtx, ::typeof(getindex), x::IdDict, key) = getindex(x, key)
function reset_cx!(cx::Context, ps)::Context
for p in ps
cache(cx)[p] = nothing
end
return cx
end
@inline function Cassette.overdub(ctx::RecordingCtx, ::typeof(Zygote.gradient), f, ps::Params)
cx = Context()
y, back = Cassette.overdub(ctx, _pullback, cx, f)
reset_cx!(cx, ps)
Cassette.overdub(ctx, back, Zygote.sensitivity(y))
return Zygote.Grads(cx.cache)
end
@inline function Cassette.overdub(ctx::ReplayCtx, ::typeof(Zygote.gradient), f, ps::Params)
cx = Context()
y, back = Cassette.overdub(ctx, _pullback, cx, f)
reset_cx!(cx, ps)
Cassette.overdub(ctx, back, Zygote.sensitivity(y))
return Zygote.Grads(cx.cache)
end
@inline function Cassette.overdub(ctx::RecordingCtx, ::typeof(_accum_param), cx::Context, x, Δ)
haskey(cache(cx), x) || return
x_cache = cache(cx)[x]
new_x = Cassette.overdub(ctx, Zygote.accum, x_cache,Δ)
cache(cx)[x] = new_x
return
end
@inline function Cassette.overdub(ctx::ReplayCtx, ::typeof(_accum_param), cx::Context, x, Δ)
haskey(cache(cx), x) || return
x_cache = cache(cx)[x]
new_x = Cassette.overdub(ctx, Zygote.accum, x_cache,Δ)
cache(cx)[x] = new_x
return
end
# preallocation patch for Flux
@inline function Cassette.overdub(ctx::RecordingCtx, ::typeof(Flux.applychain), layers, x)
for l in layers
x = Cassette.overdub(ctx, l, x)
end
return x
end
@inline function Cassette.overdub(ctx::ReplayCtx, ::typeof(Flux.applychain), layers, x)
for l in layers
x = Cassette.overdub(ctx, l, x)
end
return x
end
@inline function Cassette.overdub(ctx::RecordingCtx, m::Dense, x::AbstractArray)
W, b, σ = m.W, m.b, m.σ
T = LinearAlgebra.promote_op(*, eltype(W), eltype(x))
y1 = mul!(similar(b, T), W, x)
y2 = broadcast!(similar(b, T), y1, b) do x, y
σ(x + y)
end
AutoPreallocation.record_alloc!(ctx, y1)
AutoPreallocation.record_alloc!(ctx, y2)
return y2
end
@inline function Cassette.overdub(ctx::ReplayCtx, m::Dense, x::AbstractArray)
W, b, σ = m.W, m.b, m.σ
y1 = AutoPreallocation.next_scheduled_alloc!(ctx)
y2 = AutoPreallocation.next_scheduled_alloc!(ctx)
mul!(y1, W, x)
broadcast!(y2, y1, b) do x, y
σ(x + y)
end
return y2
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment