Skip to content

Instantly share code, notes, and snippets.

@mwarusz
Last active May 2, 2020 17:41
Show Gist options
  • Save mwarusz/fc0da42762d2737d0093fd15761ab144 to your computer and use it in GitHub Desktop.
Save mwarusz/fc0da42762d2737d0093fd15761ab144 to your computer and use it in GitHub Desktop.
GPUArrays and CuArrays patches for Broadcasted
diff --git a/src/mapreduce.jl b/src/mapreduce.jl
index 14bcfe1..ad2da8c 100644
--- a/src/mapreduce.jl
+++ b/src/mapreduce.jl
@@ -132,14 +132,15 @@ function partial_mapreduce_grid(f, op, neutral, Rreduce, Rother, shuffle, R, As.
end
## COV_EXCL_STOP
-
-NVTX.@range function GPUArrays.mapreducedim!(f, op, R::CuArray{T}, As::AbstractArray...; init=nothing) where T
+NVTX.@range function GPUArrays.mapreducedim!(f, op, R::CuArray{T}, As::Base.AbstractArrayOrBroadcasted...; init=nothing) where T
# TODO: Broadcast-semantics after JuliaLang-julia#31020
A = first(As)
all(B -> size(A) == size(B), As) || throw(DimensionMismatch("dimensions of containers must be identical"))
Base.check_reducedims(R, A)
- isempty(A) && return R
+
+ # FIXME: Fix isempty to not use scalar indexing for Broadcasted
+ #isempty(A) && return R
f = cufunc(f)
op = cufunc(op)
@@ -156,7 +157,7 @@ NVTX.@range function GPUArrays.mapreducedim!(f, op, R::CuArray{T}, As::AbstractA
# iteration domain, split in two: one part covers the dimensions that should
# be reduced, and the other covers the rest. combining both covers all values.
- Rall = CartesianIndices(A)
+ Rall = CartesianIndices(axes(A))
Rother = CartesianIndices(R)
Rreduce = CartesianIndices(ifelse.(axes(A) .== axes(R), Ref(Base.OneTo(1)), axes(A)))
# NOTE: we hard-code `OneTo` (`first.(axes(A))` would work too) or we get a
diff --git a/src/host/mapreduce.jl b/src/host/mapreduce.jl
index e8def4f..0ec91b2 100644
--- a/src/host/mapreduce.jl
+++ b/src/host/mapreduce.jl
@@ -1,9 +1,11 @@
# map-reduce
+const AbstractGPUArrayOrBroadcasted = Union{AbstractGPUArray, Base.AbstractBroadcasted}
+
# GPUArrays' mapreduce methods build on `Base.mapreducedim!`, but with an additional
# argument `init` value to avoid eager initialization of `R` (if set to something).
-mapreducedim!(f, op, R::AbstractGPUArray, As::AbstractArray...; init=nothing) = error("Not implemented") # COV_EXCL_LINE
-Base.mapreducedim!(f, op, R::AbstractGPUArray, A::AbstractArray) = mapreducedim!(f, op, R, A)
+mapreducedim!(f, op, R::AbstractGPUArray, As::Base.AbstractArrayOrBroadcasted...; init=nothing) = error("Not implemented") # COV_EXCL_LINE
+Base.mapreducedim!(f, op, R::AbstractGPUArray, A::Base.AbstractArrayOrBroadcasted) = mapreducedim!(f, op, R, A)
neutral_element(op, T) =
error("""GPUArrays.jl needs to know the neutral element for your operator `$op`.
@@ -18,7 +20,7 @@ neutral_element(::typeof(Base.mul_prod), T) = one(T)
neutral_element(::typeof(Base.min), T) = typemax(T)
neutral_element(::typeof(Base.max), T) = typemin(T)
-function Base.mapreduce(f, op, As::AbstractGPUArray...; dims=:, init=nothing)
+function Base.mapreduce(f, op, As::AbstractGPUArrayOrBroadcasted...; dims=:, init=nothing)
# figure out the destination container type by looking at the initializer element,
# or by relying on inference to reason through the map and reduce functions.
if init === nothing
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment