Created
June 3, 2021 09:39
-
-
Save Tokazama/8e0017a0084cfcbfaa556d754aebfa5a to your computer and use it in GitHub Desktop.
Partial conceptualization of layouts interface
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if VERSION < v"1.6" | |
struct ComposedFunction{O,I} <: Function | |
outer::O | |
inner::I | |
ComposedFunction{O, I}(outer, inner) where {O, I} = new{O, I}(outer, inner) | |
ComposedFunction(outer, inner) = new{Core.Typeof(outer),Core.Typeof(inner)}(outer, inner) | |
end | |
(c::ComposedFunction)(x...; kw...) = c.outer(c.inner(x...; kw...)) | |
end | |
""" LinkedIndex """ | |
struct LinkedIndex{N,I<:Tuple} <: ArrayIndex{N} | |
index::I | |
LinkedIndex(index::Tuple{A,Vararg{Any}}) where {A} = new{ndims(A),typeof(index)}(index) | |
end | |
link_index(f::Tuple) = Base.Fix2(link_index, f) | |
link_index(x, f::Tuple) = LinkedIndex(_link_index(x, f)) | |
_link_index(x, f::Tuple{Any,Vararg{Any}}) = (first(f)(x), _link_index(x, tail(f))...) | |
_link_index(x, f::Tuple{Any}) = (first(f)(x),) | |
_link_index(x, f::Tuple{}) = () | |
""" | |
PermutedIndex{N,perm}() | |
Subtypes of `ArrayIndex` that is responsible for permuting each index prior to accessing | |
parent indices. | |
""" | |
struct PermutedIndex{N,perm} <: ArrayIndex{N} end | |
""" CartesianSubIndex """ | |
struct CartesianSubIndex{N,I} <: ArrayIndex{N} | |
indices::I | |
end | |
""" | |
LinearNDIndex | |
A linear representation of a multidimensional index. Unlike `CartesianIndices`, indexing | |
`LinearNDIndex` by an `StaticInt` can produce a static type (e.g., `NDIndex`). | |
""" | |
struct LinearNDIndex{O1,O<:Tuple{Vararg{CanonicalInt,N}},S<:Tuple{Vararg{CanonicalInt,N}}} <: VectorIndex | |
offset1::O1 | |
offsets::O | |
size::S | |
LinearNDIndex(a) = LinearNDIndex(offset1(a), offsets(a), size(a)) | |
end | |
""" | |
NDLinearIndex | |
A multidimensional representation of a linear index. | |
""" | |
struct NDLinearIndex{N,O1,O<:Tuple{Vararg{CanonicalInt,N}},S<:Tuple{Vararg{CanonicalInt,N}}} <: ArrayIndex{N} | |
offset1::O1 | |
offsets::O | |
size::S | |
NDLinearIndex(a) = NDLinearIndex(offset1(a), offsets(a), size(a)) | |
end | |
""" OffsetIndex """ | |
struct OffsetIndex{O<:CanonicalInt} <: VectorIndex | |
offset1::O | |
end | |
""" LinearStrideIndex """ | |
struct LinearStrideIndex{S1<:CanonicalInt} <: VectorIndex | |
stride1::S1 | |
end | |
const LinearSubIndex{S1,O1} = LinkedIndex{1,Tuple{StrideIndex{S1},OffsetIndex{O1}}} | |
""" TransposedVectorIndex """ | |
struct TransposedVectorIndex <: ArrayIndex{2} end | |
""" | |
AccessStyle | |
""" | |
abstract type AccessStyle end | |
struct LinearAccess <: AccessStyle end | |
struct CartesianAccess <: AccessStyle end | |
struct UnorderedAccess <: AccessStyle end | |
struct UnknownAccess <: AccessStyle end | |
struct MultiAccess{U} <: AccessStyle end | |
const LinearOrCartesianAccessAccess = MultiAccess{Union{LinearAccess,CartesianAccess}} | |
""" | |
access_output(::Type{T}, ::AccessStyle) -> AccessStyle | |
""" | |
access_output(::Type{T}, ::AccessStyle) = UnknownAccess() | |
function access_output(::Type{T}, ::LinearAccess) where {T} | |
return _index2access(IndexStyle(IndexLinear(), IndexStyle(T))) | |
end | |
access_output(::Type{<:VecAdjTrans}, ::CartesianAccess) where {T} = LinearAccess() | |
function access_output(::Type{T}, ::CartesianAccess) where {T} | |
return _index2access(IndexStyle(IndexCartesian(), IndexStyle(T))) | |
end | |
access_output(::Type{<:Union{PermutedDimsArray,Adjoint,Transpose}}, ::UnorderedAccess) = UnorderedAccess() | |
access_output(::Type{T}, ::UnorderedAccess) where {T} = _index2access(IndexStyle(T)) | |
access_output(::Type{<:}, ::UnorderedAccess) = UnorderedAccess() | |
access_output(::Type{<:Transpose}, ::UnorderedAccess) = UnorderedAccess() | |
_index2access(::IndexLinear) = LinearAccess() | |
_index2access(::IndexCartesian) = CartesianAccess() | |
_index2access(::IndexStyle) = UnknownAccess() | |
access_output(::Type{<:SubArray{<:Any,<:Any,<:Any,<:Tuple{Vararg{Slice}}}}, ::UnorderedAccess) = UnorderedAccess() | |
access_output(::Type{<:SubArray}, ::CartesianAccess) = CartesianAccess() | |
access_output(::Type{<:SubArray{<:Any,<:Any,<:Any,I}}, ::LinearAccess) where {I} = _view_access(I) | |
_view_access(::Type{Tuple{}}) = LinearOrCartesianAccess() | |
_view_access(::Type{I}) where {I<:Tuple{Real, Vararg{Any}}} = _view_access(Base.tuple_type_tail(I)) | |
_view_access(::Type{I}) where {I<:Tuple{Slice, Slice, Vararg{Any}}} = _view_access(Base.tuple_type_tail(I)) | |
_view_access(::Type{I}) where {I<:Tuple{Slice, AbstractUnitRange, Vararg{Real}}} = LinearOrCartesianAccess() | |
_view_access(::Type{I}) where {I<:Tuple{Slice, Slice, Vararg{Real}}} = LinearOrCartesianAccess() | |
_view_access(::Type{I}) where {I<:Tuple{AbstractRange, Vararg{Real}}} = LinearOrCartesianAccess() | |
_view_access(::Type{I}) where {I<:Tuple{Vararg{Any}}} = CartesianAccess() | |
_view_access(::Type{I}) where {I<:Tuple{AbstractArray,Vararg{Any}}} = CartesianAccess() | |
""" | |
is_native_access(::Type{T}, s::AccessStyle) -> StaticBool | |
Returns `static(true)` if the access style `s` is the native access style for `T`. | |
If `static(false)` is returned then accessing an instance of `T` in the style of `s` is | |
either completely incompatible or requires an initial transformation the acccessor. | |
""" | |
is_native_access(::A, ::A) where {A<:AccessStyle} = static(true) | |
is_native_access(::A1, ::A2) where {A1<:AccessStyle,A2<:AccessStyle} = static(false) | |
is_native_access(::A1, ::A2) where {A1<:AccessStyle,A2<:AccessStyle} = static(false) | |
function is_native_access(::Type{T}, s::AccessStyle) where {T} | |
return is_native_access(access_input(T), s) | |
end | |
################ | |
### getindex ### | |
################ | |
function Base.getindex(x::PermutedIndex{N,perm}, i::AbstractCartesianIndex{N}) where {N,perm} | |
return NDIndex(permute(Tuple(i), Val(perm))) | |
end | |
function Base.getindex(x::CartesianSubIndex{N}, i::AbstractCartesianIndex{N}) where {N} | |
return _reindex(x.indices, Tuple(i)) | |
end | |
@generated function _reindex(subinds::S, inds::I) where {P,S,I} | |
inds_i = 1 | |
subinds_i = 1 | |
NS = known_length(S) | |
NI = known_length(I) | |
out = Expr(:tuple) | |
while (inds_i <= NI) && (subinds_i <= NS) | |
subinds_type = S.parameters[subinds_i] | |
if subinds_type <: Integer | |
push!(out.args, :(getfield(subinds, $subinds_i))) | |
subinds_i += 1 | |
elseif subinds_type <: Slice | |
push!(out.args, :(getfield(inds, $inds_i))) | |
inds_i += 1 | |
subinds_i += 1 | |
else | |
T_i = eltype(subinds_type) | |
if T_i <: AbstractCartesianIndex | |
push!(out.args, :(Tuple(@inbounds(getfield(subinds, $subinds_i)[getfield(subinds, $inds_i)]))...)) | |
inds_i += 1 | |
subinds_i += 1 | |
else | |
push!(out.args, :(Tuple(@inbounds(getfield(subinds, $subinds_i)[getfield(subinds, $inds_i)])))) | |
inds_i += 1 | |
subinds_i += 1 | |
end | |
end | |
end | |
return Expr(:block, Expr(:meta, :inline), :($out)) | |
end | |
@inline function Base.getindex(x::LinearNDIndex, i::CanonicalInt) | |
return NDIndex(_lin2subs(offsets(x), size(x), i - offset1(x))) | |
end | |
@inline function _lin2subs(o::Tuple{Any,Vararg{Any}}, s::Tuple{Any,Vararg{Any}}, i::CanonicalInt) | |
len = first(s) | |
inext = div(i, len) | |
return (i - len * inext + first(o), _lin2subs(tail(o), tail(s), inext)...) | |
end | |
_lin2subs(o::Tuple{Any}, s::Tuple{Any}, i::CanonicalInt) = i + first(o) | |
@inline function Base.getindex(x::CartesianToLinearIndex{N}, i::AbstractCartesianIndex{N}) where {N} | |
inds = Tuple(arg) | |
o = offsets(x) | |
s = size(x) | |
return first(inds) + (offset1(x) - first(o)) + _subs2lin(first(s), tail(s), tail(o), tail(inds)) | |
end | |
@inline function _subs2lin(str, s::Tuple{Any,Vararg}, o::Tuple{Any,Vararg}, i::Tuple{Any,Vararg}) | |
return ((first(i) - first(o)) * str) + _subs2lin(str * first(s), tail(s), tail(o), tail(i)) | |
end | |
_subs2lin(str, s::Tuple{Any}, o::Tuple{Any}, i::Tuple{Any}) = (first(i) - first(o)) * str | |
# trailing inbounds can only be 1 or 1:1 | |
_subs2lin(str, ::Tuple{}, ::Tuple{}, ::Tuple{Any}) = static(0) | |
Base.getindex(x::OffsetIndex, i::CanonicalInt) = offset1(x) + i | |
Base.getindex(x::TransposedVectorIndex, i::AbstractCartesianIndex{2}) = last(Tuple(i)) | |
############## | |
### layout ### | |
############## | |
function layout(::Type{T}, ::UnorderedAccess) where {T} | |
if parent_type(T) <: T | |
return nothing | |
else | |
access = access_output(T, UnorderedAccess()) | |
if access === UnorderedAccess() | |
return combine_layouts(parent, layout(parent_type(T), access)) | |
else | |
return combine_layouts(T, layout(parent_type(T), access)) | |
end | |
end | |
end | |
function layout(::Type{T}, s::AccessStyle) where {T} | |
if parent_type(T) <: T | |
return nothing | |
else | |
return _maybe_combine_layouts(T, layout(parent_type(T), access_output(T, s))) | |
end | |
end | |
_maybe_combine_layouts(::Type{T}, index) where {T} = combine_layouts(T, index) | |
# don't bother combining if we have nothing | |
_maybe_combine_layouts(::Type{T}, ::Nothing) where {T} = nothing | |
# expose the outermost layout so that `T` can try to combine | |
function _maybe_combine_layouts(::Type{T}, lyt::ComposedFunction) where {T} | |
return ComposedFunction(_maybe_combine_layouts(T, lyt.outer), lyt.inner) | |
end | |
""" combine_layouts """ | |
combine_layouts(::typeof(parent), lyt) = ComposedFunction(parent, lyt) | |
function combine_layouts(::Type{X}, ::Type{Y}) where {X<:ArrayIndex,Y<:ArrayIndex} | |
return _compose_if_compatible_access(is_native_access(Y, access_output(X)), X, Y) | |
end | |
_compose_if_compatible_access(::True, x, y) = ComposedFunction(x, y) | |
_compose_if_compatible_access(::False, x, y) = nothing | |
combine_layouts(::Type{<:VecAdjTrans}, ::Type{Y}) where {Y<:StrideIndex} = Y | |
function combine_layouts(::Type{<:VecAdjTrans}, ::Type{Y}) where {Y<:ArrayIndex} | |
return link_index((TransposedVectorIndex, ComposedFunction(parent, Y))) | |
end | |
combine_layouts(::Type{<:MatAdjTrans}, ::Type{Y}) where {Y<:StrideIndex} = Y | |
function combine_layouts(::Type{<:MatAdjTrans}, ::Type{Y}) where {Y<:ArrayIndex} | |
return link_index((PermutedIndex{2,(2,1)}, ComposedFunction(parent, Y))) | |
end | |
combine_layouts(::Type{<:PermutedDimsArray}, ::Type{Y}) where {Y<:StrideIndex} = Y | |
function combine_layouts(::Type{<:PermutedDimsArray{<:Any,N,I}}, ::Type{Y}) where {N,I,Y<:ArrayIndex} | |
return link_index((PermutedIndex{N,I}, ComposedFunction(parent, Y))) | |
end | |
combine_layouts(::Type{<:ReshapedArray}, ::Type{Y}) where {X,Y<:StrideIndex} = Y | |
function combine_layouts(::Type{X}, ::Type{Y}) where {N,I,L,X<:SubArray{<:Any,N,<:Any,I,L},Y<:ArrayIndex} | |
if ndims(Y) === 1 && L | |
if X <: Base.FastContiguousSubArray | |
return link_index((OffsetIndex{Int}, ComposedFunction(parent, Y))) | |
else | |
return link_index((LinkedIndex{1,Tuple{StrideIndex{Int},OffsetIndex{Int}}}, ComposedFunction(parent, Y))) | |
end | |
else | |
return link_index((CartesianSubIndex{N,I}, ComposedFunction(parent, Y))) | |
end | |
end | |
function combine_layouts(::Type{<:SubArray{<:Any,N,<:Any,I}}, ::Type{Y}) where {N,I,Y<:StrideIndex} | |
if known(stride_preserving_index(I)) | |
return Y | |
else | |
return ComposedFunction(CartesianSubIndex{N,I}, Y) | |
end | |
end | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment