Skip to content

Instantly share code, notes, and snippets.

@Tokazama
Created August 20, 2021 05:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Tokazama/c396eb25bfb2dc3ad57cfb9df150f4ee to your computer and use it in GitHub Desktop.
Save Tokazama/c396eb25bfb2dc3ad57cfb9df150f4ee to your computer and use it in GitHub Desktop.
abstract type AccessStyle end
struct LinearElement <: AccessStyle end
struct CartesianElement <: AccessStyle end
struct LinearCollection <: AccessStyle end
struct CartesianCollection <: AccessStyle end
## ArrayIndex
const LinearStrideIndex{S,O} = StrideIndex{1,(1,),Nothing,Tuple{S},Tuple{O}}
LinearStrideIndex(stride::CanonicalInt, offset::CanonicalInt) = StrideIndex{1,(1,),Nothing}((stride,), (offset,))
struct PermutedIndex{N,I1,I2} <: ArrayIndex{N}
PermutedIndex{N,I1,I2}() where {N,I1,I2} = new{N,I1::NTuple{N,Int},I2::NTuple{N,Int}}()
PermutedIndex(a::PermutedDimsArray{T,N,I1,I2}) where {T,N,I1,I2} = PermutedIndex{N,I1,I2}()
PermutedIndex(::MatAdjTrans) = PermutedIndex{2,(2,1),(2,1)}()
end
struct SubIndex{N,I} <: ArrayIndex{N}
indices::I
SubIndex{N}(inds::Tuple) where {N} = new{N,typeof(inds)}(inds)
SubIndex(x::SubArray{T,N}) where {T,N} = SubIndex{N}(x.indices)
end
struct ComposedIndex{N,I1,I2} <: ArrayIndex{N}
i1::I1
i2::I2
ComposedIndex(i1::I1, i2::I2) where {I1,I2} = new{ndims(I1),I1,I2}(i1, i2)
end
@inline function Base.getindex(x::PermutedIndex{N,I1,I2}, i::AbstractCartesianIndex{N}) where {N,I1,I2}
return NDIndex(permute(Tuple(i), Val(I2)))
end
@inline function Base.getindex(x::LinearStrideIndex, i::CanonicalInt)
getfield(offsets(x), 1) + i * getfield(strides(x), 1)
end
Base.getindex(x::ConjugateIndex, i::AbstractCartesianIndex{2}) = getfield(Tuple(i), 2)
@propagate_inbounds function Base.getindex(x::ComposedIndex, i::CanonicalInt)
return @inbounds(getfield(x, :i2)[getfield(x, :i1)[i]])
end
@propagate_inbounds function Base.getindex(x::ComposedIndex, i::AbstractCartesianIndex)
return @inbounds(getfield(x, :i2)[getfield(x, :i1)[i]])
end
## composed
Base.:(∘)(x::ArrayIndex, y::ArrayIndex) = ComposedIndex(y, x)
@inline function Base.:(∘)(x::StrideIndex{N,R,C}, y::PermutedIndex{N,perm,iperm}) where {N,R,C,perm,iperm}
if C === nothing || C === -1
c2 = C
else
c2 = getfield(iperm, C)
end
return StrideIndex{N,permute(R, Val(perm)),c2}(
permute(strides(x), Val(perm)),
permute(offsets(x), Val(perm)),
)
end
@inline function Base.:(∘)(x::StrideIndex{N,R,C}, y::SubIndex{Ns,I}) where {N,R,C,Ns,I<:Tuple{Vararg{Any,N}}}
c = static(C)
if _get_tuple(I, c) <: AbstractUnitRange
c2 = known(getfield(_from_sub_dims(static(N), I), C))
elseif (_get_tuple(I, c) <: AbstractArray) && (_get_tuple(I, c) <: Integer)
c2 = -1
else
c2 = nothing
end
pdims = _to_sub_dims(I)
return StrideIndex{Ns,permute(R, pdims),c2}(
eachop(getmul, pdims, map(maybe_static_step, y.indices), strides(x)),
permute(offsets(x), pdims),
)
end
## layouts
layout(A::Array, ::LinearElement) = A, LinearStrideIndex(static(1), static(0))
layout(A::Array, ::CartesianElement) = A, StrideIndex(A)
@inline function layout(A::PermutedIndex, ::CartesianElement)
buffer, index = layout(parent(A), CartesianElement())
return buffer, (index ∘ PermutedIndex(A))
end
@inline function layout(A::Base.FastSubArray, ::LinearElement)
buffer, index = layout(parent(A), LinearElement())
return buffer, (index ∘ LinearStrideIndex(x.stride1, x.offset1))
end
function layout(A::Base.FastContiguousSubArray, ::LinearElement)
buffer, index = layout(parent(A), LinearElement())
return buffer, (index ∘ LinearStrideIndex(static(1), x.offset1))
end
function layout(A::SubArray, ::CartesianElement)
buffer, index = layout(parent(A), CartesianElement())
return buffer, (index ∘ SubIndex(A))
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment