Skip to content

Instantly share code, notes, and snippets.

@alexmorley
Last active December 13, 2017 16:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alexmorley/e585df0d8d857d7c9e4a5af75df43d00 to your computer and use it in GitHub Desktop.
Save alexmorley/e585df0d8d857d7c9e4a5af75df43d00 to your computer and use it in GitHub Desktop.
Mapslices with more than one argument
import Base.Slice
import Base._unsafe_getindex!
function mapslices2(f, A::AbstractArray, B::AbstractArray, dims::AbstractVector)
if isempty(dims)
return map(f,A,B)
end
@assert size(A) == size(B)
dimsA = [indices(A)...]
ndimsA = ndims(A)
alldims = [1:ndimsA;]
otherdims = setdiff(alldims, dims)
idx = Any[first(ind) for ind in indices(A)]
itershape = tuple(dimsA[otherdims]...)
for d in dims
idx[d] = Slice(indices(A, d))
end
# Apply the function to the first slice in order to determine the next steps
Aslice = A[idx...]
Bslice = B[idx...]
r1 = f(Aslice, Bslice)
# In some cases, we can re-use the first slice for a dramatic performance
# increase. The slice itself must be mutable and the result cannot contain
# any mutable containers. The following errs on the side of being overly
# strict (#18570 & #21123).
safe_for_reuse = isa(Aslice, StridedArray) &&
(isa(r1, Number) || (isa(r1, AbstractArray) && eltype(r1) <: Number))
# determine result size and allocate
Rsize = copy(dimsA)
# TODO: maybe support removing dimensions
if !isa(r1, AbstractArray) || ndims(r1) == 0
r1 = [r1]
end
nextra = max(0, length(dims)-ndims(r1))
if eltype(Rsize) == Int
Rsize[dims] = [size(r1)..., ntuple(d->1, nextra)...]
else
Rsize[dims] = [indices(r1)..., ntuple(d->OneTo(1), nextra)...]
end
R = similar(r1, tuple(Rsize...,))
ridx = Any[map(first, indices(R))...]
for d in dims
ridx[d] = indices(R,d)
end
R[ridx...] = r1
nidx = length(otherdims)
indexes = Iterators.drop(CartesianRange(itershape), 1)
inner_mapslices2!(safe_for_reuse, indexes, nidx, idx, otherdims, ridx, Aslice,
Bslice, A, B, f, R)
end
function inner_mapslices2!(safe_for_reuse, indexes, nidx, idx, otherdims, ridx, Aslice,
Bslice, A, B, f, R)
if safe_for_reuse
# when f returns an array, R[ridx...] = f(Aslice) line copies elements,
# so we can reuse Aslice
for I in indexes # skip the first element, we already handled it
replace_tuples2!(nidx, idx, ridx, otherdims, I)
_unsafe_getindex!(Aslice, A, idx...)
_unsafe_getindex!(Bslice, B, idx...)
R[ridx...] = f(Aslice, Bslice)
end
else
# we can't guarantee safety (#18524), so allocate new storage for each slice
for I in indexes
replace_tuples!(nidx, idx, ridx, otherdims, I)
R[ridx...] = f(A[idx...], B[idx...])
end
end
return R
end
function mapslicesN(f, AN::Array{T,1}, dims::AbstractVector) where T<:AbstractArray
if isempty(dims)
return map(f,AN...)
end
A = AN[1]
dimsA = [indices(A)...]
ndimsA = ndims(A)
alldims = [1:ndimsA;]
otherdims = setdiff(alldims, dims)
idx = Any[first(ind) for ind in indices(A)]
itershape = tuple(dimsA[otherdims]...)
for d in dims
idx[d] = Slice(indices(A, d))
end
# Apply the function to the first slice in order to determine the next steps
ANslices = [x[idx...] for x in AN]
r1 = f(ANslices...)
# In some cases, we can re-use the first slice for a dramatic performance
# increase. The slice itself must be mutable and the result cannot contain
# any mutable containers. The following errs on the side of being overly
# strict (#18570 & #21123).
safe_for_reuse = isa(ANslices[1], StridedArray) &&
(isa(r1, Number) || (isa(r1, AbstractArray) && eltype(r1) <: Number))
# determine result size and allocate
Rsize = copy(dimsA)
# TODO: maybe support removing dimensions
if !isa(r1, AbstractArray) || ndims(r1) == 0
r1 = [r1]
end
nextra = max(0, length(dims)-ndims(r1))
if eltype(Rsize) == Int
Rsize[dims] = [size(r1)..., ntuple(d->1, nextra)...]
else
Rsize[dims] = [indices(r1)..., ntuple(d->OneTo(1), nextra)...]
end
R = similar(r1, tuple(Rsize...,))
ridx = Any[map(first, indices(R))...]
for d in dims
ridx[d] = indices(R,d)
end
R[ridx...] = r1
nidx = length(otherdims)
indexes = Iterators.drop(CartesianRange(itershape), 1)
inner_mapslicesN!(safe_for_reuse, indexes, nidx, idx, otherdims, ridx, ANslices,
AN, f, R)
end
function inner_mapslicesN!(safe_for_reuse, indexes, nidx, idx, otherdims, ridx, ANslices,
AN, f, R)
if safe_for_reuse
# when f returns an array, R[ridx...] = f(Aslice) line copies elements,
# so we can reuse Aslice
for I in indexes # skip the first element, we already handled it
replace_tuples2!(nidx, idx, ridx, otherdims, I)
for (ind,slice) in enumerate(ANslices)
_unsafe_getindex!(slice, AN[ind], idx...)
end
R[ridx...] = f(ANslices...)
end
else
# we can't guarantee safety (#18524), so allocate new storage for each slice
for I in indexes
replace_tuples!(nidx, idx, ridx, otherdims, I)
R[ridx...] = f([x[idx...] for x in AN]...)
end
end
return R
end
function replace_tuples!(nidx, idx, ridx, otherdims, I)
for i in 1:nidx
idx[otherdims[i]] = ridx[otherdims[i]] = I.I[i]
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment