Skip to content

Instantly share code, notes, and snippets.

@gasagna
Created August 21, 2016 15:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gasagna/c440339903190147757a69b5b092fba3 to your computer and use it in GitHub Desktop.
Save gasagna/c440339903190147757a69b5b092fba3 to your computer and use it in GitHub Desktop.
Example implementation of ND array with periodic indexing
import Base: getindex,
size,
linearindexing,
LinearSlow
# Type parameters are
# T : element type
# N : array dimensions
# P : tuple of booleans for periodic/non-periodic
# S : tuple of array size, precomputed so avoid repeated run-time calculation
# A : parent array type
immutable PeriodicArray{T, N, P, S, A} <: AbstractArray{T, N}
data::A
end
# Constructor:
# Pass the array `data` and the periodic dimensions `ps`
function PeriodicArray{T, N}(data::AbstractArray{T, N}, ps::Int...)
# check periodic dimensions if any are provided
if !isempty(ps)
psmin, psmax = extrema(ps)
psmin ≥ 1 || throw(ArgumentError("invalid periodic dimension $psmin"))
psmax ≤ N || throw(ArgumentError("invalid periodic dimension $psmax"))
# check all ps are distinct
length(unique(ps)) == length(ps) ||
throw(ArgumentError("repeated periodic dimension"))
end
# construct boolean tuple
P = ntuple(i->i ∈ ps ? true : false, N)
# instantiate
PeriodicArray{T, N, P, size(data), typeof(data)}(data)
end
# ~~~ Array interface ~~~
size(p::PeriodicArray) = size(p.data)
linearindexing(p::PeriodicArray) = LinearSlow()
# Build a linear index from `N` indices `I` to index `ind`
@generated function getindex{T, N, P, S}(p::PeriodicArray{T, N, P, S}, I::Vararg{Int})
expr = quote end
# ~~~ Bound checking commented out. Adding bound check results in huge
# ~~~ memory allocations... why?
# Bounds check for non-periodic dimensions. Out-of-bounds indexing
# a periodic dims never result in an error.
# for i = 1:N
# if P[i] == false
# push!(expr.args, :(I[$i] < 1 || I[$i] > $(S[i]) && throw(BoundsError())))
# end
# end
# ~~~ end bound checking
# Convert to p.data[ind]
if P[1] == true # if it is periodic compute index using function
push!(expr.args, :(ind = per2ind($(S[1]), I[1])))
else # if not periodic, just use I[i]
push!(expr.args, :(ind = I[1]))
end
for i = 2:N
if P[i] == true
push!(expr.args, :(ind += $(prod(S[1:(i-1)]))*(per2ind($(S[i]), I[$i])-1)))
else
push!(expr.args, :(ind += $(prod(S[1:(i-1)]))*(I[$i]-1)))
end
end
push!(expr.args, :(p.data[ind]))
println(expr)
expr
end
# Transform index i
# Removing @noinline results in huge memory allocations; why?
@noinline function per2ind(Si::Int, i::Int)
i < 1 && return Si+i % Si
i > Si && return (i-1) % Si + 1
return i
end
# benchmarking function
function foo(x::AbstractArray)
y = zero(eltype(x))
for rep = 1:1000
for i = 1:size(x, 1)
y += x[i, 1, 1]
end
end
y
end
# main benchmark
using BenchmarkTools
a = randn(100000, 2, 4)
p = PeriodicArray(a, 1)
println(@benchmark foo($a))
println(@benchmark foo($p))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment