Skip to content

Instantly share code, notes, and snippets.

@Tokazama
Created June 3, 2021 09:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Tokazama/8e0017a0084cfcbfaa556d754aebfa5a to your computer and use it in GitHub Desktop.
Save Tokazama/8e0017a0084cfcbfaa556d754aebfa5a to your computer and use it in GitHub Desktop.
Partial conceptualization of layouts interface
if VERSION < v"1.6"
struct ComposedFunction{O,I} <: Function
outer::O
inner::I
ComposedFunction{O, I}(outer, inner) where {O, I} = new{O, I}(outer, inner)
ComposedFunction(outer, inner) = new{Core.Typeof(outer),Core.Typeof(inner)}(outer, inner)
end
(c::ComposedFunction)(x...; kw...) = c.outer(c.inner(x...; kw...))
end
""" LinkedIndex """
struct LinkedIndex{N,I<:Tuple} <: ArrayIndex{N}
index::I
LinkedIndex(index::Tuple{A,Vararg{Any}}) where {A} = new{ndims(A),typeof(index)}(index)
end
link_index(f::Tuple) = Base.Fix2(link_index, f)
link_index(x, f::Tuple) = LinkedIndex(_link_index(x, f))
_link_index(x, f::Tuple{Any,Vararg{Any}}) = (first(f)(x), _link_index(x, tail(f))...)
_link_index(x, f::Tuple{Any}) = (first(f)(x),)
_link_index(x, f::Tuple{}) = ()
"""
PermutedIndex{N,perm}()
Subtypes of `ArrayIndex` that is responsible for permuting each index prior to accessing
parent indices.
"""
struct PermutedIndex{N,perm} <: ArrayIndex{N} end
""" CartesianSubIndex """
struct CartesianSubIndex{N,I} <: ArrayIndex{N}
indices::I
end
"""
LinearNDIndex
A linear representation of a multidimensional index. Unlike `CartesianIndices`, indexing
`LinearNDIndex` by an `StaticInt` can produce a static type (e.g., `NDIndex`).
"""
struct LinearNDIndex{O1,O<:Tuple{Vararg{CanonicalInt,N}},S<:Tuple{Vararg{CanonicalInt,N}}} <: VectorIndex
offset1::O1
offsets::O
size::S
LinearNDIndex(a) = LinearNDIndex(offset1(a), offsets(a), size(a))
end
"""
NDLinearIndex
A multidimensional representation of a linear index.
"""
struct NDLinearIndex{N,O1,O<:Tuple{Vararg{CanonicalInt,N}},S<:Tuple{Vararg{CanonicalInt,N}}} <: ArrayIndex{N}
offset1::O1
offsets::O
size::S
NDLinearIndex(a) = NDLinearIndex(offset1(a), offsets(a), size(a))
end
""" OffsetIndex """
struct OffsetIndex{O<:CanonicalInt} <: VectorIndex
offset1::O
end
""" LinearStrideIndex """
struct LinearStrideIndex{S1<:CanonicalInt} <: VectorIndex
stride1::S1
end
const LinearSubIndex{S1,O1} = LinkedIndex{1,Tuple{StrideIndex{S1},OffsetIndex{O1}}}
""" TransposedVectorIndex """
struct TransposedVectorIndex <: ArrayIndex{2} end
"""
AccessStyle
"""
abstract type AccessStyle end
struct LinearAccess <: AccessStyle end
struct CartesianAccess <: AccessStyle end
struct UnorderedAccess <: AccessStyle end
struct UnknownAccess <: AccessStyle end
struct MultiAccess{U} <: AccessStyle end
const LinearOrCartesianAccessAccess = MultiAccess{Union{LinearAccess,CartesianAccess}}
"""
access_output(::Type{T}, ::AccessStyle) -> AccessStyle
"""
access_output(::Type{T}, ::AccessStyle) = UnknownAccess()
function access_output(::Type{T}, ::LinearAccess) where {T}
return _index2access(IndexStyle(IndexLinear(), IndexStyle(T)))
end
access_output(::Type{<:VecAdjTrans}, ::CartesianAccess) where {T} = LinearAccess()
function access_output(::Type{T}, ::CartesianAccess) where {T}
return _index2access(IndexStyle(IndexCartesian(), IndexStyle(T)))
end
access_output(::Type{<:Union{PermutedDimsArray,Adjoint,Transpose}}, ::UnorderedAccess) = UnorderedAccess()
access_output(::Type{T}, ::UnorderedAccess) where {T} = _index2access(IndexStyle(T))
access_output(::Type{<:}, ::UnorderedAccess) = UnorderedAccess()
access_output(::Type{<:Transpose}, ::UnorderedAccess) = UnorderedAccess()
_index2access(::IndexLinear) = LinearAccess()
_index2access(::IndexCartesian) = CartesianAccess()
_index2access(::IndexStyle) = UnknownAccess()
access_output(::Type{<:SubArray{<:Any,<:Any,<:Any,<:Tuple{Vararg{Slice}}}}, ::UnorderedAccess) = UnorderedAccess()
access_output(::Type{<:SubArray}, ::CartesianAccess) = CartesianAccess()
access_output(::Type{<:SubArray{<:Any,<:Any,<:Any,I}}, ::LinearAccess) where {I} = _view_access(I)
_view_access(::Type{Tuple{}}) = LinearOrCartesianAccess()
_view_access(::Type{I}) where {I<:Tuple{Real, Vararg{Any}}} = _view_access(Base.tuple_type_tail(I))
_view_access(::Type{I}) where {I<:Tuple{Slice, Slice, Vararg{Any}}} = _view_access(Base.tuple_type_tail(I))
_view_access(::Type{I}) where {I<:Tuple{Slice, AbstractUnitRange, Vararg{Real}}} = LinearOrCartesianAccess()
_view_access(::Type{I}) where {I<:Tuple{Slice, Slice, Vararg{Real}}} = LinearOrCartesianAccess()
_view_access(::Type{I}) where {I<:Tuple{AbstractRange, Vararg{Real}}} = LinearOrCartesianAccess()
_view_access(::Type{I}) where {I<:Tuple{Vararg{Any}}} = CartesianAccess()
_view_access(::Type{I}) where {I<:Tuple{AbstractArray,Vararg{Any}}} = CartesianAccess()
"""
is_native_access(::Type{T}, s::AccessStyle) -> StaticBool
Returns `static(true)` if the access style `s` is the native access style for `T`.
If `static(false)` is returned then accessing an instance of `T` in the style of `s` is
either completely incompatible or requires an initial transformation the acccessor.
"""
is_native_access(::A, ::A) where {A<:AccessStyle} = static(true)
is_native_access(::A1, ::A2) where {A1<:AccessStyle,A2<:AccessStyle} = static(false)
is_native_access(::A1, ::A2) where {A1<:AccessStyle,A2<:AccessStyle} = static(false)
function is_native_access(::Type{T}, s::AccessStyle) where {T}
return is_native_access(access_input(T), s)
end
################
### getindex ###
################
function Base.getindex(x::PermutedIndex{N,perm}, i::AbstractCartesianIndex{N}) where {N,perm}
return NDIndex(permute(Tuple(i), Val(perm)))
end
function Base.getindex(x::CartesianSubIndex{N}, i::AbstractCartesianIndex{N}) where {N}
return _reindex(x.indices, Tuple(i))
end
@generated function _reindex(subinds::S, inds::I) where {P,S,I}
inds_i = 1
subinds_i = 1
NS = known_length(S)
NI = known_length(I)
out = Expr(:tuple)
while (inds_i <= NI) && (subinds_i <= NS)
subinds_type = S.parameters[subinds_i]
if subinds_type <: Integer
push!(out.args, :(getfield(subinds, $subinds_i)))
subinds_i += 1
elseif subinds_type <: Slice
push!(out.args, :(getfield(inds, $inds_i)))
inds_i += 1
subinds_i += 1
else
T_i = eltype(subinds_type)
if T_i <: AbstractCartesianIndex
push!(out.args, :(Tuple(@inbounds(getfield(subinds, $subinds_i)[getfield(subinds, $inds_i)]))...))
inds_i += 1
subinds_i += 1
else
push!(out.args, :(Tuple(@inbounds(getfield(subinds, $subinds_i)[getfield(subinds, $inds_i)]))))
inds_i += 1
subinds_i += 1
end
end
end
return Expr(:block, Expr(:meta, :inline), :($out))
end
@inline function Base.getindex(x::LinearNDIndex, i::CanonicalInt)
return NDIndex(_lin2subs(offsets(x), size(x), i - offset1(x)))
end
@inline function _lin2subs(o::Tuple{Any,Vararg{Any}}, s::Tuple{Any,Vararg{Any}}, i::CanonicalInt)
len = first(s)
inext = div(i, len)
return (i - len * inext + first(o), _lin2subs(tail(o), tail(s), inext)...)
end
_lin2subs(o::Tuple{Any}, s::Tuple{Any}, i::CanonicalInt) = i + first(o)
@inline function Base.getindex(x::CartesianToLinearIndex{N}, i::AbstractCartesianIndex{N}) where {N}
inds = Tuple(arg)
o = offsets(x)
s = size(x)
return first(inds) + (offset1(x) - first(o)) + _subs2lin(first(s), tail(s), tail(o), tail(inds))
end
@inline function _subs2lin(str, s::Tuple{Any,Vararg}, o::Tuple{Any,Vararg}, i::Tuple{Any,Vararg})
return ((first(i) - first(o)) * str) + _subs2lin(str * first(s), tail(s), tail(o), tail(i))
end
_subs2lin(str, s::Tuple{Any}, o::Tuple{Any}, i::Tuple{Any}) = (first(i) - first(o)) * str
# trailing inbounds can only be 1 or 1:1
_subs2lin(str, ::Tuple{}, ::Tuple{}, ::Tuple{Any}) = static(0)
Base.getindex(x::OffsetIndex, i::CanonicalInt) = offset1(x) + i
Base.getindex(x::TransposedVectorIndex, i::AbstractCartesianIndex{2}) = last(Tuple(i))
##############
### layout ###
##############
function layout(::Type{T}, ::UnorderedAccess) where {T}
if parent_type(T) <: T
return nothing
else
access = access_output(T, UnorderedAccess())
if access === UnorderedAccess()
return combine_layouts(parent, layout(parent_type(T), access))
else
return combine_layouts(T, layout(parent_type(T), access))
end
end
end
function layout(::Type{T}, s::AccessStyle) where {T}
if parent_type(T) <: T
return nothing
else
return _maybe_combine_layouts(T, layout(parent_type(T), access_output(T, s)))
end
end
_maybe_combine_layouts(::Type{T}, index) where {T} = combine_layouts(T, index)
# don't bother combining if we have nothing
_maybe_combine_layouts(::Type{T}, ::Nothing) where {T} = nothing
# expose the outermost layout so that `T` can try to combine
function _maybe_combine_layouts(::Type{T}, lyt::ComposedFunction) where {T}
return ComposedFunction(_maybe_combine_layouts(T, lyt.outer), lyt.inner)
end
""" combine_layouts """
combine_layouts(::typeof(parent), lyt) = ComposedFunction(parent, lyt)
function combine_layouts(::Type{X}, ::Type{Y}) where {X<:ArrayIndex,Y<:ArrayIndex}
return _compose_if_compatible_access(is_native_access(Y, access_output(X)), X, Y)
end
_compose_if_compatible_access(::True, x, y) = ComposedFunction(x, y)
_compose_if_compatible_access(::False, x, y) = nothing
combine_layouts(::Type{<:VecAdjTrans}, ::Type{Y}) where {Y<:StrideIndex} = Y
function combine_layouts(::Type{<:VecAdjTrans}, ::Type{Y}) where {Y<:ArrayIndex}
return link_index((TransposedVectorIndex, ComposedFunction(parent, Y)))
end
combine_layouts(::Type{<:MatAdjTrans}, ::Type{Y}) where {Y<:StrideIndex} = Y
function combine_layouts(::Type{<:MatAdjTrans}, ::Type{Y}) where {Y<:ArrayIndex}
return link_index((PermutedIndex{2,(2,1)}, ComposedFunction(parent, Y)))
end
combine_layouts(::Type{<:PermutedDimsArray}, ::Type{Y}) where {Y<:StrideIndex} = Y
function combine_layouts(::Type{<:PermutedDimsArray{<:Any,N,I}}, ::Type{Y}) where {N,I,Y<:ArrayIndex}
return link_index((PermutedIndex{N,I}, ComposedFunction(parent, Y)))
end
combine_layouts(::Type{<:ReshapedArray}, ::Type{Y}) where {X,Y<:StrideIndex} = Y
function combine_layouts(::Type{X}, ::Type{Y}) where {N,I,L,X<:SubArray{<:Any,N,<:Any,I,L},Y<:ArrayIndex}
if ndims(Y) === 1 && L
if X <: Base.FastContiguousSubArray
return link_index((OffsetIndex{Int}, ComposedFunction(parent, Y)))
else
return link_index((LinkedIndex{1,Tuple{StrideIndex{Int},OffsetIndex{Int}}}, ComposedFunction(parent, Y)))
end
else
return link_index((CartesianSubIndex{N,I}, ComposedFunction(parent, Y)))
end
end
function combine_layouts(::Type{<:SubArray{<:Any,N,<:Any,I}}, ::Type{Y}) where {N,I,Y<:StrideIndex}
if known(stride_preserving_index(I))
return Y
else
return ComposedFunction(CartesianSubIndex{N,I}, Y)
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment