Skip to content

Instantly share code, notes, and snippets.

@Roger-luo
Created August 14, 2019 02:02
Show Gist options
  • Save Roger-luo/a7cddd8cf5902c5b43e09ea16acc74da to your computer and use it in GitHub Desktop.
Save Roger-luo/a7cddd8cf5902c5b43e09ea16acc74da to your computer and use it in GitHub Desktop.
using Test
using IRTools, LinearAlgebra, InteractiveUtils
using IRTools: IR, Branch, BasicBlock, return!, blocks, block,
Pipe, var, arguments, xcall, finish, argnames!,
slots!, pis!, inlineable!
# NOTE: do not restrict ElmentType, since it can be either Number/Array/Any
struct VecArray{T, D, ElmentType, N, S <: AbstractArray{T, N}} <: AbstractVector{ElmentType}
storage::S
function VecArray(::Type{E}, storage::AbstractArray{T, N}) where {T, E, N}
new{eltype(E), ndims(E), E, N, typeof(storage)}(storage)
end
function VecArray{T}(f, sz, n) where T
storage = f(T, sz..., n)
data = similar(storage, ntuple(k->size(storage, k), ndims(storage)-1)) # replace this with type inference when inferable?
new{T, ndims(data), typeof(data), ndims(storage), typeof(storage)}(storage)
end
end
function VecArray(f, sz, n)
data = f(sz...)
storage = f(sz..., n)
VecArray(typeof(data), storage)
end
Base.size(X::VecArray{T, D, E, N, S}) where {T, D, E, N, S} = ntuple(k->size(X.storage, k+D), N-D)
Base.getindex(X::VecArray{T, D}, idx::Int) where {T, D} = view(X.storage, ntuple(_->:, D)..., idx)
xgetindex(x, i...) = xcall(Base, :getindex, x, i...)
check_batchsize(x::VecArray{T, D, E, N}) where {T, D, E, N} = size(x.storage, N)
_check_batchsize(B, x::VecArray, xs...) = B == check_batchsize(x) ? B : throw(DimensionMismatch("batch dimension mismatch"))
_check_batchsize(B, x, xs...) = _check_batchsize(B, xs...)
_check_batchsize(B) = B
check_batchsize(x::VecArray, xs...) = _check_batchsize(check_batchsize(x), xs...)
check_batchsize(x, xs...) = check_batchsize(xs...)
check_batchsize() = 1
function transform(::Type{<:Number}, ir::IR, batched_args)
lane = Pipe(ir)
self = IRTools.argument!(lane, at = 1)
batch_idx = IRTools.argument!(lane, at = 3)
args = arguments(ir)
for (v, stmt) in lane
ex = stmt.expr
if IRTools.isexpr(ex, :call) && ex.args[2] in batched_args
lane[v] = xcall(Main, :spmd_lane, ex.args[1], batch_idx, ex.args[2:end]...)
# storage = insert!(lane, v, xcall(Core, :getfield, args[2], :storage))
# lane[v] = xgetindex(storage, ex.args[3:end]..., batch_idx)
end
end
return finish(lane)
end
# function spmd!(out::AbstractArray{T, N}, f, ::Type{AT}, xs...) where {T, D, N, AT <: AbstractArray{T, D}}
# end
@inline @generated function spmd_lane(f, k, xs...) where E
batched_args = Int[]
element_types = []
for (k, x) in enumerate(xs)
if x <: VecArray
push!(batched_args, k)
push!(element_types, eltype(x))
else
push!(element_types, x)
end
end
T = Tuple{f, element_types...}
ret_T = Core.Compiler.return_type(f.instance, Tuple{element_types...})
m = IRTools.meta(T)
m === nothing && return :(error("cannot find signature $($T)"))
ir = transform(ret_T, IR(m), var.(batched_args.+1))
argnames!(m, Symbol("#self#"), :f, :k, :xs)
ir = IRTools.varargs!(m, ir, 3)
ir = slots!(pis!(inlineable!(ir)))
return IRTools.update!(m.code, ir)
end
@inline spmd_lane(::typeof(getindex), k, x::VecArray, idx...) = getindex(x.storage, idx..., k)
@inline spmd_lane(::typeof(LinearAlgebra.checksquare), k, x::VecArray) = LinearAlgebra.checksquare(view(x.storage, :, :, k))
function spmd!(out::AbstractVector{T}, f, ::Type{T}, xs...) where T
for k in eachindex(out)
out[k] = spmd_lane(f, k, xs...)
end
return VecArray(T, out)
end
@generated function spmd(f, xs...)
element_types = []
for (k, x) in enumerate(xs)
if x <: VecArray
push!(element_types, eltype(x))
else
push!(element_types, x)
end
end
ret_T = Core.Compiler.return_type(f.instance, Tuple{element_types...})
quote
B = check_batchsize(xs...)
spmd!(Vector{$ret_T}(undef, B), f, $ret_T, xs...)
end
end
@test check_batchsize(1, 1) == 1
@test check_batchsize(1, VecArray(rand, (2, 2), 10)) == 10
vA = VecArray{Float64}(rand, (2, 2), 10)
T = Tuple{typeof(tr), eltype(vA)}
m = IRTools.meta(T)
batched_args = [var(2)]
ir = transform(Float64, IR(m), batched_args)
IRTools.varargs!(m, ir, 3)
using BenchmarkTools
@benchmark spmd(tr, A) setup=(A=VecArray(Matrix{Float64}, rand(2, 2, 1000)))
batched_tr(A::AbstractArray{T, 3}) where T = batched_tr!(A, fill!(similar(A, (size(A, 3), )), 0))
function batched_tr!(A::AbstractArray{T, 3}, B::AbstractVector{T}) where T
@boundscheck size(A, 1) == size(A, 2) || error("Expect a square matrix")
@boundscheck size(A, 3) == length(B) || error("Batch size mismatch")
for k in 1:size(A, 3)
@inbounds for i in 1:size(A, 1)
B[k] += A[i, i, k]
end
end
return B
end
function m_spmd_lane(::typeof(tr), k, A::VecArray{T}) where T
n = spmd_lane(LinearAlgebra.checksquare, k, A)
out = zero(T)
for i in 1:n
out += spmd_lane(Base.getindex, k, A, i, i)
end
return out
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment