Skip to content

Instantly share code, notes, and snippets.

@Jutho
Created November 4, 2014 22:24
Show Gist options
  • Save Jutho/e10e39a91e1a473315ab to your computer and use it in GitHub Desktop.
Save Jutho/e10e39a91e1a473315ab to your computer and use it in GitHub Desktop.
A staged function approach to permutedims!
using Base.Cartesian
function permutedimsnew!{T1,T2,N}(P::StridedArray{T1,N},B::StridedArray{T2,N},perm)
length(perm) == N || error("expected permutation of size $N, but length(perm)=$(length(perm))")
isperm(perm) || error("input is not a permutation")
dims = size(P)
for i = 1:N
dims[i] == size(B,perm[i]) || throw(DimensionMismatch("destination tensor of incorrect size"))
end
stridesB = strides(B)[perm]
stridesP = strides(P)
if isa(B, SubArray)
startB = B.first_index - 1
B = B.parent
else
startB = 1
end
Preturn=P
if isa(P, SubArray)
startP = P.first_index - 1
P = P.parent
else
startP = 1
end
permutedimsblock!(P,B,dims,startP,stridesP,startB,stridesB)
return Preturn
end
stagedfunction permutedimsblock!{T1,T2,N}(P::Array{T1,N},B::Array{T2,N},dims::NTuple{N,Int},startP::Int,stridesP::NTuple{N,Int},startB::Int,stridesB::NTuple{N,Int})
ex=macroexpand(quote
if prod(dims)<=1024
@nexprs 1 d->(indP_{$N} = startP)
@nexprs 1 d->(indB_{$N} = startB)
@nloops($N, i, d->1:dims[d],
d->(indB_{d-1} = indB_{d};indP_{d-1}=indP_{d}), # PRE
d->(indB_{d} += stridesB[d];indP_{d} += stridesP[d]), # POST
@inbounds P[indP_0]=B[indB_0])
else
dcut=1
dimcut=dims[1]
for d=2:$N
if dims[d]>dimcut
dcut=d
dimcut=dims[d]
end
end
newdim=dimcut>>1
newdims=@ntuple $N d->(d==dcut ? newdim : dims[d])
permutedimsblock!(P,B,newdims,startP,stridesP,startB,stridesB)
startP=startP+newdim*stridesP[dcut]
startB=startB+newdim*stridesB[dcut]
newdim=dimcut-newdim
newdims=@ntuple $N d->(d==dcut ? newdim : dims[d])
permutedimsblock!(P,B,newdims,startP,stridesP,startB,stridesB)
end
return P
end)
return ex
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment