Created
August 21, 2016 15:54
-
-
Save gasagna/c440339903190147757a69b5b092fba3 to your computer and use it in GitHub Desktop.
Example implementation of ND array with periodic indexing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Base: getindex, | |
size, | |
linearindexing, | |
LinearSlow | |
# Type parameters are | |
# T : element type | |
# N : array dimensions | |
# P : tuple of booleans for periodic/non-periodic | |
# S : tuple of array size, precomputed so avoid repeated run-time calculation | |
# A : parent array type | |
immutable PeriodicArray{T, N, P, S, A} <: AbstractArray{T, N} | |
data::A | |
end | |
# Constructor: | |
# Pass the array `data` and the periodic dimensions `ps` | |
function PeriodicArray{T, N}(data::AbstractArray{T, N}, ps::Int...) | |
# check periodic dimensions if any are provided | |
if !isempty(ps) | |
psmin, psmax = extrema(ps) | |
psmin ≥ 1 || throw(ArgumentError("invalid periodic dimension $psmin")) | |
psmax ≤ N || throw(ArgumentError("invalid periodic dimension $psmax")) | |
# check all ps are distinct | |
length(unique(ps)) == length(ps) || | |
throw(ArgumentError("repeated periodic dimension")) | |
end | |
# construct boolean tuple | |
P = ntuple(i->i ∈ ps ? true : false, N) | |
# instantiate | |
PeriodicArray{T, N, P, size(data), typeof(data)}(data) | |
end | |
# ~~~ Array interface ~~~ | |
size(p::PeriodicArray) = size(p.data) | |
linearindexing(p::PeriodicArray) = LinearSlow() | |
# Build a linear index from `N` indices `I` to index `ind` | |
@generated function getindex{T, N, P, S}(p::PeriodicArray{T, N, P, S}, I::Vararg{Int}) | |
expr = quote end | |
# ~~~ Bound checking commented out. Adding bound check results in huge | |
# ~~~ memory allocations... why? | |
# Bounds check for non-periodic dimensions. Out-of-bounds indexing | |
# a periodic dims never result in an error. | |
# for i = 1:N | |
# if P[i] == false | |
# push!(expr.args, :(I[$i] < 1 || I[$i] > $(S[i]) && throw(BoundsError()))) | |
# end | |
# end | |
# ~~~ end bound checking | |
# Convert to p.data[ind] | |
if P[1] == true # if it is periodic compute index using function | |
push!(expr.args, :(ind = per2ind($(S[1]), I[1]))) | |
else # if not periodic, just use I[i] | |
push!(expr.args, :(ind = I[1])) | |
end | |
for i = 2:N | |
if P[i] == true | |
push!(expr.args, :(ind += $(prod(S[1:(i-1)]))*(per2ind($(S[i]), I[$i])-1))) | |
else | |
push!(expr.args, :(ind += $(prod(S[1:(i-1)]))*(I[$i]-1))) | |
end | |
end | |
push!(expr.args, :(p.data[ind])) | |
println(expr) | |
expr | |
end | |
# Transform index i | |
# Removing @noinline results in huge memory allocations; why? | |
@noinline function per2ind(Si::Int, i::Int) | |
i < 1 && return Si+i % Si | |
i > Si && return (i-1) % Si + 1 | |
return i | |
end | |
# benchmarking function | |
function foo(x::AbstractArray) | |
y = zero(eltype(x)) | |
for rep = 1:1000 | |
for i = 1:size(x, 1) | |
y += x[i, 1, 1] | |
end | |
end | |
y | |
end | |
# main benchmark | |
using BenchmarkTools | |
a = randn(100000, 2, 4) | |
p = PeriodicArray(a, 1) | |
println(@benchmark foo($a)) | |
println(@benchmark foo($p)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment