Created
September 3, 2018 09:54
-
-
Save ssz66666/904e3355f88092070d441dcffa57b4d7 to your computer and use it in GitHub Desktop.
Partial code of implementing more generic wrapper support with reflection
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
_collect_union(t::Type,set) = push!(set,t) | |
_collect_union(t::Union,set) = _collect_union(t.b, push!(set,t.a)) | |
collect_union(t) = _collect_union(t,Set{Type}()) | |
_unpack_abstract(::Type{T}) where {T} = isabstracttype(T) ? Union{map(_unpack_abstract,subtypes(T))...} : T | |
# query Base.parent methods to find potential wrapper types | |
const supported_wrappers = collect_union( | |
Union{map(_unpack_abstract,filter(t->t!=AbstractArray,map(m->m.sig.parameters[2], methods(Base.parent))))...} | |
) | |
# dummy type to help guess the type parameter representing | |
# the parent array type | |
struct _TV{M} end | |
_unwind_unionall(t::Type{T}, n) where {T} = n | |
_unwind_unionall(t::UnionAll, n) = _unwind_unionall(t.body, n+1) | |
_rewind_unionall!(t, tvs) = begin | |
while !(isempty(tvs)) | |
t = UnionAll(pop!(tvs), t) | |
end | |
t | |
end | |
for wp in supported_wrappers | |
# get number of parameters of UnionAll type | |
local nparam = _unwind_unionall(wp,0) | |
# populate the UnionAll by applying type vars with upper bound _TV{i} | |
# e.g. SubArray{T1,T2,T3,T4,T5} where {T1<:_TV{1},T2<:_TV{2},T3<:_TV{3},T4<:_TV{4},T5<:_TV{5}} | |
local typevs = [TypeVar(Symbol("T$i"),Union{},_TV{i}) for i = 1:nparam] | |
local populated = Core.apply_type(wp,typevs...) | |
# rewrap the populated type back to a UnionAll | |
local rewrapped = _rewind_unionall!(populated, typevs) | |
local m = first(methods(Base.parent,[wp])) | |
# figure out the relationship between the parent type and the type parameters | |
# here we do a return type inference on Base.parent to guess which type parameter of the wrapper | |
# represents the parent array type. | |
local pt = Core.Compiler.typeinf_type(m,Tuple{rewrapped},Core.svec(),Core.Compiler.Params(UInt(m.min_world))) | |
if pt <: _TV | |
# success! | |
local n = first(pt.parameters) | |
# used by other macro calls to generate appropriate signatures | |
# to allow overriding functions directly dispatching on wrapper types | |
@eval parent_type_param_pos(::Type{$wp}) = $n | |
# generate functions of the form: | |
# parent_type(::Type{<:wp{[(n - 1) * <:Any] , T}}) = T | |
@eval parent_type(::Type{<:$(Expr( | |
:curly, | |
:($wp), | |
fill(:(<:Any),n-1)..., | |
:PT | |
))}) where PT = PT | |
#= | |
@eval _trecurse(::Type{<:$(Expr( | |
:curly, | |
:($wp), | |
fill(:(<:Any),n-1)..., | |
:PT | |
))}) where PT = _trecurse(PT) | |
=# | |
else | |
# we can't handle it automatically, abort | |
pop!(supported_wrappers,wp) | |
end | |
end | |
# produce a hacky Union that is large enough to cover all wrapper types we support | |
# has to be smaller than AbstractArray | |
# we will use traits to identify and redirect incorrectly overridden functions back | |
const GPUDestArray = Union{GPUArray,supported_wrappers...} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment