Last active
September 27, 2017 08:38
-
-
Save meggart/3539a06f65f82d5149efcf90c37276b7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
broadcast_reduce(f,op,v0,A,Bs...) | |
Should behave like mapreduce, with the only difference that singleton dimension are expanded like in broadcast. | |
""" | |
function broadcast_reduce(f,op,v0,A,Bs...) | |
shape = Base.Broadcast.broadcast_indices(A,Bs...) | |
iter = CartesianRange(shape) | |
keeps, Idefaults = Base.Broadcast.map_newindexer(shape, A, Bs) | |
_broadcast_reduce(f,op,v0,keeps,Idefaults,A, Bs,Val{length(Bs)}(),iter) | |
end | |
@generated function _broadcast_reduce(f, op,v0 , keeps::K, Idefaults::ID, A::AT, Bs::BT, ::Val{N}, iter) where {K,ID,AT,BT,N} | |
nargs = N + 1 | |
quote | |
$(Expr(:meta, :inline)) | |
# destructure the keeps and As tuples | |
A_1 = A | |
@nexprs $N i->(A_{i+1} = Bs[i]) | |
@nexprs $nargs i->(keep_i = keeps[i]) | |
@nexprs $nargs i->(Idefault_i = Idefaults[i]) | |
for I in iter | |
# reverse-broadcast the indices | |
@nexprs $nargs i->(I_i = Base.Broadcast.newindex(I, keep_i, Idefault_i)) | |
# extract array values | |
@nexprs $nargs i->(@inbounds val_i = Base.Broadcast._broadcast_getindex(A_i, I_i)) | |
# call function and reduce | |
v0 = op(v0,@ncall $nargs f val) | |
end | |
v0 | |
end | |
end | |
""" | |
An iterator type that iterates over multiple arrays like zip, | |
but with singleton dimensions expanded, according to the bradcast rules | |
""" | |
immutable BroadIter{N,K,ID,AT} | |
keeps::K | |
Idefaults::ID | |
A::AT | |
iter::CartesianRange{CartesianIndex{N}} | |
end | |
using Base.Cartesian | |
@generated function Base.next{N,K,ID,AT}(broit::BroadIter{N,K,ID,AT}, s) | |
NC = length(AT.parameters) | |
quote | |
$(Expr(:meta, :inline)) | |
@nexprs $NC i->(I_i = Base.Broadcast.newindex(s, broit.keeps[i], broit.Idefaults[i])) | |
@inbounds vals = @ntuple $NC i->(Base.Broadcast._broadcast_getindex(broit.A[i], I_i)) | |
vals,next(broit.iter,s)[2] | |
end | |
end | |
Base.start(ii::BroadIter)=start(ii.iter) | |
Base.done(ii::BroadIter,s)=done(ii.iter,s) | |
function BroadIter(A,Bs...) | |
shape = Base.Broadcast.broadcast_indices(A,Bs...) | |
iter = CartesianRange(shape) | |
keeps, Idefaults = Base.Broadcast.map_newindexer(shape, A, Bs) | |
BroadIter(keeps,Idefaults,(A,Bs...),iter) | |
end | |
using BenchmarkTools | |
x1 = rand(200,5000,30) | |
x2 = rand(1,5000,1) | |
x3 = rand(1,1,1) | |
@btime sum(x1.*x2.*x3) | |
it = BroadIter(x1,x2,x3) | |
@btime mapreduce(x->x[1]*x[2]*x[3],+,it) | |
function sum_iter(x1,x2,x3) | |
it = BroadIter(x1,x2,x3) | |
s = 0.0 | |
for (a1,a2,a3) in it | |
s+=a1*a2*a3 | |
end | |
s | |
end | |
@btime sum_iter(x1,x2,x3) | |
@btime broadcast_reduce(*,+,0.0,x1,x2,x3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment