Skip to content

Instantly share code, notes, and snippets.

@JKrehl
Last active March 3, 2017 23:11
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JKrehl/a697c17d00a927766b9631c5a68bb8db to your computer and use it in GitHub Desktop.
Save JKrehl/a697c17d00a927766b9631c5a68bb8db to your computer and use it in GitHub Desktop.
an in-place circshift! implementation
using Base.Cartesian
function _bifurcated_loop(i, ex, reverse, splits=:splits, half=reverse)
if i == 0
ex
else
ix = Symbol("x_", i)
iy = Symbol("y_", i)
isize = :(size(dest, $i))
isplit = :($splits[$i])
ishift = :(shifts[$i])
loop1_en = (half ? :(div($isplit,2)) : isplit)
loop2_en = (half ? :(div($isplit+$isize,2)) : isize)
quote
for $ix in 1:$loop1_en
$(reverse ? :($iy = $isplit-$ix+1) : :($iy = $ishift + $ix))
$(_bifurcated_loop(i-1, ex, reverse, splits, false))
end
if $half && isodd($isplit)
$ix = $iy = $loop1_en + 1
$(_bifurcated_loop(i-1, ex, reverse, splits, true))
end
for $ix in $isplit+1:$loop2_en
$(reverse ? :($iy = $isplit+$isize-$ix+1) : :($iy = $ix - $isplit))
$(_bifurcated_loop(i-1, ex, reverse, splits, false))
end
if $half && isodd($isize-$isplit)
$ix = $iy = $loop2_en + 1
$(_bifurcated_loop(i-1, ex, reverse, splits, true))
end
end
end
end
@generated function _circshift!{T,N}(dest::AbstractArray{T,N}, src::AbstractArray{T,N}, shifts::Base.DimsInteger{N})
ex = :(@nref($N,dest,y) = @nref($N,src,x))
quote
splits = map(-, size(dest), shifts)
@inbounds $(_bifurcated_loop(N, ex, false))
dest
end
end
@generated function _circshift!{T,N}(dest::AbstractArray{T,N}, shifts::Base.DimsInteger{N})
ex = :((@nref($N,dest,x), @nref($N,dest,y)) = (@nref($N,dest,y), @nref($N,dest,x)))
quote
splits = map(-, size(dest), shifts)
@inbounds $(_bifurcated_loop(N, ex, true))
splits0 = @ntuple($N,i->0)
@inbounds $(_bifurcated_loop(N, ex, true, :splits0))
dest
end
end
function circshift!(dest, src, shifts::Tuple)
@assert ndims(dest) == length(shifts)
_shifts = map(mod, shifts, size(dest))
if dest === src
@inbounds _circshift!(dest, _shifts)
else
@assert indices(dest) == indices(src)
@inbounds _circshift!(dest, src, _shifts)
end
dest
end
circshift!(dest, src, shifts::Int) = circshift!(dest, src, (fill(shifts, ndims(dest))...))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment