Last active
May 2, 2020 17:41
-
-
Save mwarusz/fc0da42762d2737d0093fd15761ab144 to your computer and use it in GitHub Desktop.
GPUArrays and CuArrays patches for Broadcasted
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/src/mapreduce.jl b/src/mapreduce.jl | |
index 14bcfe1..ad2da8c 100644 | |
--- a/src/mapreduce.jl | |
+++ b/src/mapreduce.jl | |
@@ -132,14 +132,15 @@ function partial_mapreduce_grid(f, op, neutral, Rreduce, Rother, shuffle, R, As. | |
end | |
## COV_EXCL_STOP | |
- | |
-NVTX.@range function GPUArrays.mapreducedim!(f, op, R::CuArray{T}, As::AbstractArray...; init=nothing) where T | |
+NVTX.@range function GPUArrays.mapreducedim!(f, op, R::CuArray{T}, As::Base.AbstractArrayOrBroadcasted...; init=nothing) where T | |
# TODO: Broadcast-semantics after JuliaLang-julia#31020 | |
A = first(As) | |
all(B -> size(A) == size(B), As) || throw(DimensionMismatch("dimensions of containers must be identical")) | |
Base.check_reducedims(R, A) | |
- isempty(A) && return R | |
+ | |
+ # FIXME: Fix isempty to not use scalar indexing for Broadcasted | |
+ #isempty(A) && return R | |
f = cufunc(f) | |
op = cufunc(op) | |
@@ -156,7 +157,7 @@ NVTX.@range function GPUArrays.mapreducedim!(f, op, R::CuArray{T}, As::AbstractA | |
# iteration domain, split in two: one part covers the dimensions that should | |
# be reduced, and the other covers the rest. combining both covers all values. | |
- Rall = CartesianIndices(A) | |
+ Rall = CartesianIndices(axes(A)) | |
Rother = CartesianIndices(R) | |
Rreduce = CartesianIndices(ifelse.(axes(A) .== axes(R), Ref(Base.OneTo(1)), axes(A))) | |
# NOTE: we hard-code `OneTo` (`first.(axes(A))` would work too) or we get a |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/src/host/mapreduce.jl b/src/host/mapreduce.jl | |
index e8def4f..0ec91b2 100644 | |
--- a/src/host/mapreduce.jl | |
+++ b/src/host/mapreduce.jl | |
@@ -1,9 +1,11 @@ | |
# map-reduce | |
+const AbstractGPUArrayOrBroadcasted = Union{AbstractGPUArray, Base.AbstractBroadcasted} | |
+ | |
# GPUArrays' mapreduce methods build on `Base.mapreducedim!`, but with an additional | |
# argument `init` value to avoid eager initialization of `R` (if set to something). | |
-mapreducedim!(f, op, R::AbstractGPUArray, As::AbstractArray...; init=nothing) = error("Not implemented") # COV_EXCL_LINE | |
-Base.mapreducedim!(f, op, R::AbstractGPUArray, A::AbstractArray) = mapreducedim!(f, op, R, A) | |
+mapreducedim!(f, op, R::AbstractGPUArray, As::Base.AbstractArrayOrBroadcasted...; init=nothing) = error("Not implemented") # COV_EXCL_LINE | |
+Base.mapreducedim!(f, op, R::AbstractGPUArray, A::Base.AbstractArrayOrBroadcasted) = mapreducedim!(f, op, R, A) | |
neutral_element(op, T) = | |
error("""GPUArrays.jl needs to know the neutral element for your operator `$op`. | |
@@ -18,7 +20,7 @@ neutral_element(::typeof(Base.mul_prod), T) = one(T) | |
neutral_element(::typeof(Base.min), T) = typemax(T) | |
neutral_element(::typeof(Base.max), T) = typemin(T) | |
-function Base.mapreduce(f, op, As::AbstractGPUArray...; dims=:, init=nothing) | |
+function Base.mapreduce(f, op, As::AbstractGPUArrayOrBroadcasted...; dims=:, init=nothing) | |
# figure out the destination container type by looking at the initializer element, | |
# or by relying on inference to reason through the map and reduce functions. | |
if init === nothing |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment