Skip to content

Instantly share code, notes, and snippets.

@Roger-luo
Created August 16, 2019 20:39
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Roger-luo/1dc955595eb40cef9a77b0e145ece153 to your computer and use it in GitHub Desktop.
Save Roger-luo/1dc955595eb40cef9a77b0e145ece153 to your computer and use it in GitHub Desktop.
Alloc.jl in Cassette
module HoistMem
export hoist_alloc, Buffer
using Cassette, LinearAlgebra
using Cassette: @context, overdub
@context BuffCtx
mutable struct Buffer
buf::Vector{UInt8}
offset::UInt
end
Buffer(n::Int) = Buffer(Vector{UInt8}(undef, n), 0)
Base.copy(b::Buffer) = Buffer(copy(b.buf), b.offset)
Cassette.prehook(::BuffCtx, f, xs...) = nothing
function alloc(b::Buffer, ::Type{Array{T,N}}, d::NTuple{N,Int}) where {T,N}
# @info "Allocating $(prod(d)) * $(T)"
ptr = Base.unsafe_convert(Ptr{UInt8}, b.buf) + b.offset
b.offset += sizeof(T) * prod(d)
b.offset > length(b.buf) && error("Alloc: Out of memory")
unsafe_wrap(Array, convert(Ptr{T}, ptr), d)
end
function clear!(b::Buffer)
b.offset = 0
return b
end
function hoist_alloc(f, b::Buffer)
clear!(b)
return overdub(BuffCtx(metadata=b), f)
end
for F in [bmm!, LinearAlgebra.mul!, TNFilters.batched_tr!, Base.promote_op, size, Base.to_shape,
Broadcast.broadcasted, Broadcast.instantiate, Broadcast.preprocess,
Broadcast.combine_eltypes, copyto!, Broadcast.copyto_nonleaf!,
Broadcast.axes, Base.getindex, Base.setindex!, Base.fill!]
@eval @inline Cassette.overdub(ctx::BuffCtx, f::typeof($F), xs...) = f(xs...)
end
@inline function Cassette.overdub(ctx::BuffCtx, ::typeof(similar), bc::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{N}}, ::Type{T}) where {T, N}
alloc(ctx.metadata, Array{T, N}, length.(Broadcast.axes(bc)))
end
@inline function Cassette.overdub(ctx::BuffCtx, ::Type{Array{T, N}}, ::UndefInitializer, d::Vararg{Int, N}) where {T, N}
return alloc(ctx.metadata, Array{T, N}, d)
end
export mprofile
@context ProfileCtx
# Cassette.prehook(cx::ProfileCtx, f::typeof(Core.apply_type), xs...) = nothing
function Cassette.prehook(cx::ProfileCtx, ::Type{Array{T, N}}, ::UndefInitializer, d::Vararg{Int, N}) where {T, N}
T === Any && return
cx.metadata[] += sizeof(T) * prod(d)
return
end
function mprofile(f)
ctx = ProfileCtx(metadata = Ref(0))
x = overdub(ctx, f)
@info "allocated $(ctx.metadata[]) bytes"
return x
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment