Last active
August 31, 2018 11:11
-
-
Save ssz66666/e175488374bc52ecd0b8a7c53b655d8f to your computer and use it in GitHub Desktop.
Toy example with traits
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using SimpleTraits | |
abstract type MyInterface{T} end # like AbstractArray | |
abstract type MyGPUStruct{T} <: MyInterface{T} end # like GPUArray | |
struct MyStruct{T} <: MyGPUStruct{T} # like CuArray | |
name::String | |
x::T | |
end | |
struct MyStruct2{T} <: MyGPUStruct{T} # like JLArray, perhaps | |
name::String | |
x::T | |
end | |
struct MyWrapper{T,AT <: MyInterface{T}} <: MyInterface{T} # like LinearAlgebra.Transpose | |
parent::AT | |
end | |
struct MyWrapper2{T,AT <: MyInterface{T}} <: MyInterface{T} # any Wrapper that we want it to fall back to base | |
parent2::AT | |
end | |
# methods in Base to create wrappers, like view, transpose and adjoint | |
wrap1(x::MyInterface) = MyWrapper(x) | |
wrap2(x::MyInterface) = MyWrapper2(x) | |
function get_name(::MyInterface) | |
println("base fallback") | |
"fallback" | |
end | |
# a method that does expect a wrapper type directly | |
get_name(o::MyStruct) = begin println("MyStruct"); o.name end | |
get_name(o::MyStruct2) = begin println("MyStruct2"); o.name end | |
get_name(o::MyWrapper) = begin println("Wrapper"); get_name(o.parent) end | |
get_name(o::MyWrapper2) = begin println("Wrapper2"); get_name(o.parent2) end | |
# a method that does not expect a wrapper type directly | |
foo(o::MyInterface) = println("generalised foo") | |
# A trait that marks an array as GPU-compatible | |
@traitdef IsOnGPU{A} | |
# helpers | |
## define GPU backend arrays and supported wrappers | |
bottoms = Set{Type}([MyStruct, MyStruct2]) | |
wrappers = Set{Type}([MyWrapper]) | |
## generate `is_wrapper` and `is_bottom` functions here | |
## for `digtype` | |
unpack_unionall(x::DataType) = x | |
unpack_unionall(t::UnionAll) = unpack_unionall(t.body) | |
# TODO: learn julia metaprogramming and see | |
# if I can make this shorter | |
# generate methods that look like p(Type{t{params...}}) = true | |
function make_type_predicate(p::Symbol,set::Set{Type}) | |
for t in set | |
local wt = unpack_unionall(t) | |
local params = map(Symbol,wt.parameters) | |
eval(Expr( | |
:(=), | |
Expr( | |
:where, | |
Expr( | |
:call, | |
p, | |
Expr( | |
:(::), | |
Expr( | |
:curly, | |
:Type, | |
Expr( | |
:curly, | |
Symbol(wt.name), | |
params... | |
) | |
) | |
) | |
), | |
params... | |
), | |
true | |
)) | |
end | |
end | |
function gen_type_predicates() | |
make_type_predicate(:is_wrapper,wrappers) | |
make_type_predicate(:is_bottom,bottoms) | |
end | |
is_wrapper(::Type{T}) where {T} = false | |
is_bottom(::Type{T}) where {T} = false | |
is_bottom(::TypeVar) = false | |
gen_type_predicates() | |
@inline unwrap_type(x::Type{MyWrapper{T,AT}}) where {T,AT} = AT | |
@inline unwrap_type(x::Type{MyWrapper2{T,AT}}) where {T,AT} = AT | |
# if we want to change wrappers later, we need to | |
# mark digtype as impure, leading to slower (but correct) dynamic dispatch. | |
# changing bottoms (GPU backend) is fine because it's handled in SimpleTraits.trait | |
# making digtype a pure function improves the performance a lot | |
Base.@pure @inline function digtype(::Type{X}) where {X} | |
local Y = X | |
while is_wrapper(Y) | |
Y = unwrap_type(Y) | |
end | |
Y | |
end | |
SimpleTraits.trait(::Type{IsOnGPU{X}}) where {X} = begin | |
is_bottom(digtype(X)) ? IsOnGPU{X} : Not{IsOnGPU{X}} | |
end | |
# probably change this to use more sensible type | |
backend(::Type{X}) where {X} = digtype(X) | |
# GPUTarget type is very important to be able to get dispatch to work | |
# should be generated using `wrappers` set | |
macro union_param(n,union_type_name,types) | |
ftvars = [Symbol("TypeVar",i) for i = 1:n] | |
buf = IOBuffer() | |
esc(Expr(:const, Expr( | |
:(=), | |
Expr( | |
:curly, | |
union_type_name, | |
ftvars... | |
), | |
begin | |
typs = map(t -> Core.apply_type(t,map(TypeVar,ftvars)...),eval(types)) | |
# TODO: fix this | |
# there must be a better way!!!! | |
typs_expr = [ begin | |
seekstart(buf) | |
show(buf,t) | |
str = String(take!(buf)) | |
Base.Meta.parse(str) | |
end for t in typs ] | |
Expr( | |
:curly, | |
:Union, | |
typs_expr... | |
) | |
end | |
))) | |
end | |
# should be equivalent to | |
# GPUTarget{T} = Union{MyGPUStruct{T},MyWrapper{T}} | |
@union_param(1,GPUTarget,[MyGPUStruct,wrappers...]) | |
@traitfn foo(x::A) where {T, A<:GPUTarget{T}; IsOnGPU{A}} = begin | |
foo_gpu(backend(A),x) | |
end | |
@traitfn foo(x::A) where {T, A<:GPUTarget{T}; !IsOnGPU{A}} = begin | |
invoke(foo,Tuple{MyInterface},x) | |
end | |
@traitfn get_name(x::A) where {T, AT<: GPUTarget{T}, A<:MyWrapper{T,AT}; IsOnGPU{A}} = begin | |
get_name_gpu(backend(A),x) | |
end | |
@traitfn get_name(x::A) where {T, AT<: GPUTarget{T}, A<:MyWrapper{T,AT}; !IsOnGPU{A}} = begin | |
invoke(get_name,Tuple{MyWrapper{_T,_AT} where {_T, _AT<:MyInterface{_T}} },x) | |
end | |
get_name_gpu(::Type{<:GPUTarget{T}},x::MyWrapper{T,AT}) where {T,AT} = begin | |
println("GPU specialised get_name") | |
get_name(x.parent) | |
end | |
foo_gpu(::Type{<:GPUTarget{T}},x) where {T} = begin | |
println("GPU specialised foo") | |
end | |
foo_gpu(::Type{MyStruct{T}},x) where {T} = begin | |
println("MyStruct specialised foo") | |
end | |
#= | |
function f(n;x=MyStruct("foo",1)) | |
for i = 1:n | |
x = MyWrapper(x) | |
end | |
x | |
end | |
=# | |
function test() | |
begin | |
x = MyStruct("x",1) # uses gpu method | |
y = wrap1(x) # dispatch to gpu method | |
z = wrap1(y) # dispatch to gpu method | |
w = wrap2(z) # fall back to base for non-compatible wrappers | |
v = wrap1(w) # fall back to base again | |
u = wrap1(v) | |
get_name(x) | |
println() | |
get_name(y) | |
println() | |
get_name(z) | |
println() | |
get_name(w) | |
println() | |
get_name(v) | |
println() | |
get_name(u) | |
println() | |
foo(x) | |
foo(y) | |
foo(z) | |
foo(w) | |
foo(v) | |
end | |
begin | |
x = MyStruct2("x",1) # uses gpu method | |
y = wrap1(x) # dispatch to gpu method | |
z = wrap1(y) # dispatch to gpu method | |
w = wrap2(z) # fall back to base for non-compatible wrappers | |
v = wrap1(w) # fall back to base again | |
get_name(x) | |
println() | |
get_name(y) | |
println() | |
get_name(z) | |
println() | |
get_name(w) | |
println() | |
get_name(v) | |
println() | |
get_name(u) | |
println() | |
foo(x) | |
foo(y) | |
foo(z) | |
foo(w) | |
foo(v) | |
end | |
end | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment