Skip to content

Instantly share code, notes, and snippets.

@staticfloat
Created April 10, 2019 00:04
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save staticfloat/1b2ea67b084765358841d74581294b03 to your computer and use it in GitHub Desktop.
Save staticfloat/1b2ea67b084765358841d74581294b03 to your computer and use it in GitHub Desktop.
struct CachedConv
conv::Conv
cache::Ref{Tuple}
end
CachedConv(c::Conv) = CachedConv(c, ())
Flux.@treelike CachedConv
function (m::CachedConv)(x::AbstractArray)
# Has the user changed batch size on us? If so, clear our cache and re-up!
if !isempty(m.cache[]) && size(m.cache[][2], 4) != size(x, 4)
m.cache[] = ()
end
if isempty(m.cache[])
cdims = DenseConvDims(x, m.conv.weight; stride=m.conv.stride, padding=m.conv.pad, dilation=m.conv.dilation)
y = similar(x, promote_type(eltype(x), eltype(m.conv.weight)),
NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x,ndims(x)))
dx = similar(x)
dw = similar(m.conv.weight)
col = similar(x, NNlib.im2col_dims(cdims))
m.cache[] = (
y,
dx,
dw,
col,
cdims,
)
end
y, dx, dw, col, cdims = m.cache[]
σ, b = m.conv.σ, reshape(m.conv.bias, map(_->1, NNlib.stride(cdims))..., :, 1)
return σ.(cached_conv!(y, x, m.conv.weight, cdims; col = col, dx = dx, dw = dw) .+ b)
end
cached_conv!(y, x, w, cdims; col=nothing, dx=nothing, dw=nothing, kwargs...) = NNlib.conv!(y, x, w, cdims; col=col, kwargs...)
Zygote.@adjoint function cached_conv!(y, x, w, cdims; col=nothing, dx=nothing, dw=nothing, kwargs...)
NNlib.conv!(y, x, w, cdims; col=col, kwargs...)
back = Δ -> begin
NNlib.∇conv_data!(dx, Δ, w, cdims; kwargs...)
NNlib.∇conv_filter!(dw, x, Δ, cdims; kwargs...)
return (
nothing,
dx,
dw,
nothing
)
end
return y, back
end
mutable struct CachedConvTranspose
conv::ConvTranspose
cache::Ref{Tuple}
end
CachedConvTranspose(c::ConvTranspose) = CachedConvTranspose(c, ())
Flux.@treelike CachedConvTranspose
function (m::CachedConvTranspose)(y::AbstractArray)
# Has the user changed batch size on us? If so, clear our cache and re-up!
if !isempty(m.cache[]) && size(m.cache[][1], 4) != size(y, 4)
m.cache[] = ()
end
if isempty(m.cache[])
cdims = Flux.conv_transpose_dims(m.conv, y)
dy = similar(y)
dx = similar(y, NNlib.input_size(cdims)..., NNlib.channels_in(cdims), size(y,ndims(y)))
dw = similar(m.conv.weight)
col = similar(dx, NNlib.im2col_dims(cdims))
m.cache[] = (
dy,
dx,
dw,
col,
cdims,
)
end
dy, dx, dw, col, cdims = m.cache[]
σ, b = m.conv.σ, reshape(m.conv.bias, map(_->1, NNlib.stride(cdims))..., :, 1)
return σ.(cached_∇conv_data!(dx, y, m.conv.weight, cdims; col = col, dy = dy, dw = dw) .+ b)
end
cached_∇conv_data!(dx, y, w, cdims; col = nothing, dy = nothing, dw = nothing, kwargs...) = NNlib.∇conv_data!(dx, y, w, cdims; kwargs...)
Zygote.@adjoint function cached_∇conv_data!(dx, y, w, cdims; col=nothing, dy=nothing, dw=nothing, kwargs...)
fwd = NNlib.∇conv_data!(dx, y, w, cdims; kwargs...)
back = Δ -> begin
NNlib.conv!(dy, Δ, w, cdims; kwargs...)
NNlib.∇conv_filter!(dw, Δ, y, cdims; kwargs...)
return (
nothing,
dy,
dw,
nothing,
)
end
return fwd, back
end
alloc_cache(conv::Conv) = CachedConv(conv)
alloc_cache(conv::ConvTranspose) = CachedConvTranspose(conv)
alloc_cache(model) = Flux.mapchildren(alloc_cache, model)
function clear_cache(m::Union{CachedConv,CachedConvTranspose})
m.cache[] = ()
return m
end
clear_cache(model) = Flux.mapchildren(clear_cache, model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment