Skip to content

Instantly share code, notes, and snippets.

@ssz66666
Created September 3, 2018 09:54
Show Gist options
  • Save ssz66666/904e3355f88092070d441dcffa57b4d7 to your computer and use it in GitHub Desktop.
Save ssz66666/904e3355f88092070d441dcffa57b4d7 to your computer and use it in GitHub Desktop.
Partial code of implementing more generic wrapper support with reflection
_collect_union(t::Type,set) = push!(set,t)
_collect_union(t::Union,set) = _collect_union(t.b, push!(set,t.a))
collect_union(t) = _collect_union(t,Set{Type}())
_unpack_abstract(::Type{T}) where {T} = isabstracttype(T) ? Union{map(_unpack_abstract,subtypes(T))...} : T
# query Base.parent methods to find potential wrapper types
const supported_wrappers = collect_union(
Union{map(_unpack_abstract,filter(t->t!=AbstractArray,map(m->m.sig.parameters[2], methods(Base.parent))))...}
)
# dummy type to help guess the type parameter representing
# the parent array type
struct _TV{M} end
_unwind_unionall(t::Type{T}, n) where {T} = n
_unwind_unionall(t::UnionAll, n) = _unwind_unionall(t.body, n+1)
_rewind_unionall!(t, tvs) = begin
while !(isempty(tvs))
t = UnionAll(pop!(tvs), t)
end
t
end
for wp in supported_wrappers
# get number of parameters of UnionAll type
local nparam = _unwind_unionall(wp,0)
# populate the UnionAll by applying type vars with upper bound _TV{i}
# e.g. SubArray{T1,T2,T3,T4,T5} where {T1<:_TV{1},T2<:_TV{2},T3<:_TV{3},T4<:_TV{4},T5<:_TV{5}}
local typevs = [TypeVar(Symbol("T$i"),Union{},_TV{i}) for i = 1:nparam]
local populated = Core.apply_type(wp,typevs...)
# rewrap the populated type back to a UnionAll
local rewrapped = _rewind_unionall!(populated, typevs)
local m = first(methods(Base.parent,[wp]))
# figure out the relationship between the parent type and the type parameters
# here we do a return type inference on Base.parent to guess which type parameter of the wrapper
# represents the parent array type.
local pt = Core.Compiler.typeinf_type(m,Tuple{rewrapped},Core.svec(),Core.Compiler.Params(UInt(m.min_world)))
if pt <: _TV
# success!
local n = first(pt.parameters)
# used by other macro calls to generate appropriate signatures
# to allow overriding functions directly dispatching on wrapper types
@eval parent_type_param_pos(::Type{$wp}) = $n
# generate functions of the form:
# parent_type(::Type{<:wp{[(n - 1) * <:Any] , T}}) = T
@eval parent_type(::Type{<:$(Expr(
:curly,
:($wp),
fill(:(<:Any),n-1)...,
:PT
))}) where PT = PT
#=
@eval _trecurse(::Type{<:$(Expr(
:curly,
:($wp),
fill(:(<:Any),n-1)...,
:PT
))}) where PT = _trecurse(PT)
=#
else
# we can't handle it automatically, abort
pop!(supported_wrappers,wp)
end
end
# produce a hacky Union that is large enough to cover all wrapper types we support
# has to be smaller than AbstractArray
# we will use traits to identify and redirect incorrectly overridden functions back
const GPUDestArray = Union{GPUArray,supported_wrappers...}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment