Skip to content

Instantly share code, notes, and snippets.

@meggart
Last active September 27, 2017 08:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save meggart/3539a06f65f82d5149efcf90c37276b7 to your computer and use it in GitHub Desktop.
Save meggart/3539a06f65f82d5149efcf90c37276b7 to your computer and use it in GitHub Desktop.
"""
broadcast_reduce(f,op,v0,A,Bs...)
Should behave like mapreduce, with the only difference that singleton dimension are expanded like in broadcast.
"""
function broadcast_reduce(f,op,v0,A,Bs...)
shape = Base.Broadcast.broadcast_indices(A,Bs...)
iter = CartesianRange(shape)
keeps, Idefaults = Base.Broadcast.map_newindexer(shape, A, Bs)
_broadcast_reduce(f,op,v0,keeps,Idefaults,A, Bs,Val{length(Bs)}(),iter)
end
@generated function _broadcast_reduce(f, op,v0 , keeps::K, Idefaults::ID, A::AT, Bs::BT, ::Val{N}, iter) where {K,ID,AT,BT,N}
nargs = N + 1
quote
$(Expr(:meta, :inline))
# destructure the keeps and As tuples
A_1 = A
@nexprs $N i->(A_{i+1} = Bs[i])
@nexprs $nargs i->(keep_i = keeps[i])
@nexprs $nargs i->(Idefault_i = Idefaults[i])
for I in iter
# reverse-broadcast the indices
@nexprs $nargs i->(I_i = Base.Broadcast.newindex(I, keep_i, Idefault_i))
# extract array values
@nexprs $nargs i->(@inbounds val_i = Base.Broadcast._broadcast_getindex(A_i, I_i))
# call function and reduce
v0 = op(v0,@ncall $nargs f val)
end
v0
end
end
"""
An iterator type that iterates over multiple arrays like zip,
but with singleton dimensions expanded, according to the bradcast rules
"""
immutable BroadIter{N,K,ID,AT}
keeps::K
Idefaults::ID
A::AT
iter::CartesianRange{CartesianIndex{N}}
end
using Base.Cartesian
@generated function Base.next{N,K,ID,AT}(broit::BroadIter{N,K,ID,AT}, s)
NC = length(AT.parameters)
quote
$(Expr(:meta, :inline))
@nexprs $NC i->(I_i = Base.Broadcast.newindex(s, broit.keeps[i], broit.Idefaults[i]))
@inbounds vals = @ntuple $NC i->(Base.Broadcast._broadcast_getindex(broit.A[i], I_i))
vals,next(broit.iter,s)[2]
end
end
Base.start(ii::BroadIter)=start(ii.iter)
Base.done(ii::BroadIter,s)=done(ii.iter,s)
function BroadIter(A,Bs...)
shape = Base.Broadcast.broadcast_indices(A,Bs...)
iter = CartesianRange(shape)
keeps, Idefaults = Base.Broadcast.map_newindexer(shape, A, Bs)
BroadIter(keeps,Idefaults,(A,Bs...),iter)
end
using BenchmarkTools
x1 = rand(200,5000,30)
x2 = rand(1,5000,1)
x3 = rand(1,1,1)
@btime sum(x1.*x2.*x3)
it = BroadIter(x1,x2,x3)
@btime mapreduce(x->x[1]*x[2]*x[3],+,it)
function sum_iter(x1,x2,x3)
it = BroadIter(x1,x2,x3)
s = 0.0
for (a1,a2,a3) in it
s+=a1*a2*a3
end
s
end
@btime sum_iter(x1,x2,x3)
@btime broadcast_reduce(*,+,0.0,x1,x2,x3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment