Skip to content

Instantly share code, notes, and snippets.

@ssz66666
Last active August 31, 2018 11:11
Show Gist options
  • Save ssz66666/e175488374bc52ecd0b8a7c53b655d8f to your computer and use it in GitHub Desktop.
Save ssz66666/e175488374bc52ecd0b8a7c53b655d8f to your computer and use it in GitHub Desktop.
Toy example with traits
using SimpleTraits
abstract type MyInterface{T} end # like AbstractArray
abstract type MyGPUStruct{T} <: MyInterface{T} end # like GPUArray
struct MyStruct{T} <: MyGPUStruct{T} # like CuArray
name::String
x::T
end
struct MyStruct2{T} <: MyGPUStruct{T} # like JLArray, perhaps
name::String
x::T
end
struct MyWrapper{T,AT <: MyInterface{T}} <: MyInterface{T} # like LinearAlgebra.Transpose
parent::AT
end
struct MyWrapper2{T,AT <: MyInterface{T}} <: MyInterface{T} # any Wrapper that we want it to fall back to base
parent2::AT
end
# methods in Base to create wrappers, like view, transpose and adjoint
wrap1(x::MyInterface) = MyWrapper(x)
wrap2(x::MyInterface) = MyWrapper2(x)
function get_name(::MyInterface)
println("base fallback")
"fallback"
end
# a method that does expect a wrapper type directly
get_name(o::MyStruct) = begin println("MyStruct"); o.name end
get_name(o::MyStruct2) = begin println("MyStruct2"); o.name end
get_name(o::MyWrapper) = begin println("Wrapper"); get_name(o.parent) end
get_name(o::MyWrapper2) = begin println("Wrapper2"); get_name(o.parent2) end
# a method that does not expect a wrapper type directly
foo(o::MyInterface) = println("generalised foo")
# A trait that marks an array as GPU-compatible
@traitdef IsOnGPU{A}
# helpers
## define GPU backend arrays and supported wrappers
bottoms = Set{Type}([MyStruct, MyStruct2])
wrappers = Set{Type}([MyWrapper])
## generate `is_wrapper` and `is_bottom` functions here
## for `digtype`
unpack_unionall(x::DataType) = x
unpack_unionall(t::UnionAll) = unpack_unionall(t.body)
# TODO: learn julia metaprogramming and see
# if I can make this shorter
# generate methods that look like p(Type{t{params...}}) = true
function make_type_predicate(p::Symbol,set::Set{Type})
for t in set
local wt = unpack_unionall(t)
local params = map(Symbol,wt.parameters)
eval(Expr(
:(=),
Expr(
:where,
Expr(
:call,
p,
Expr(
:(::),
Expr(
:curly,
:Type,
Expr(
:curly,
Symbol(wt.name),
params...
)
)
)
),
params...
),
true
))
end
end
function gen_type_predicates()
make_type_predicate(:is_wrapper,wrappers)
make_type_predicate(:is_bottom,bottoms)
end
is_wrapper(::Type{T}) where {T} = false
is_bottom(::Type{T}) where {T} = false
is_bottom(::TypeVar) = false
gen_type_predicates()
@inline unwrap_type(x::Type{MyWrapper{T,AT}}) where {T,AT} = AT
@inline unwrap_type(x::Type{MyWrapper2{T,AT}}) where {T,AT} = AT
# if we want to change wrappers later, we need to
# mark digtype as impure, leading to slower (but correct) dynamic dispatch.
# changing bottoms (GPU backend) is fine because it's handled in SimpleTraits.trait
# making digtype a pure function improves the performance a lot
Base.@pure @inline function digtype(::Type{X}) where {X}
local Y = X
while is_wrapper(Y)
Y = unwrap_type(Y)
end
Y
end
SimpleTraits.trait(::Type{IsOnGPU{X}}) where {X} = begin
is_bottom(digtype(X)) ? IsOnGPU{X} : Not{IsOnGPU{X}}
end
# probably change this to use more sensible type
backend(::Type{X}) where {X} = digtype(X)
# GPUTarget type is very important to be able to get dispatch to work
# should be generated using `wrappers` set
macro union_param(n,union_type_name,types)
ftvars = [Symbol("TypeVar",i) for i = 1:n]
buf = IOBuffer()
esc(Expr(:const, Expr(
:(=),
Expr(
:curly,
union_type_name,
ftvars...
),
begin
typs = map(t -> Core.apply_type(t,map(TypeVar,ftvars)...),eval(types))
# TODO: fix this
# there must be a better way!!!!
typs_expr = [ begin
seekstart(buf)
show(buf,t)
str = String(take!(buf))
Base.Meta.parse(str)
end for t in typs ]
Expr(
:curly,
:Union,
typs_expr...
)
end
)))
end
# should be equivalent to
# GPUTarget{T} = Union{MyGPUStruct{T},MyWrapper{T}}
@union_param(1,GPUTarget,[MyGPUStruct,wrappers...])
@traitfn foo(x::A) where {T, A<:GPUTarget{T}; IsOnGPU{A}} = begin
foo_gpu(backend(A),x)
end
@traitfn foo(x::A) where {T, A<:GPUTarget{T}; !IsOnGPU{A}} = begin
invoke(foo,Tuple{MyInterface},x)
end
@traitfn get_name(x::A) where {T, AT<: GPUTarget{T}, A<:MyWrapper{T,AT}; IsOnGPU{A}} = begin
get_name_gpu(backend(A),x)
end
@traitfn get_name(x::A) where {T, AT<: GPUTarget{T}, A<:MyWrapper{T,AT}; !IsOnGPU{A}} = begin
invoke(get_name,Tuple{MyWrapper{_T,_AT} where {_T, _AT<:MyInterface{_T}} },x)
end
get_name_gpu(::Type{<:GPUTarget{T}},x::MyWrapper{T,AT}) where {T,AT} = begin
println("GPU specialised get_name")
get_name(x.parent)
end
foo_gpu(::Type{<:GPUTarget{T}},x) where {T} = begin
println("GPU specialised foo")
end
foo_gpu(::Type{MyStruct{T}},x) where {T} = begin
println("MyStruct specialised foo")
end
#=
function f(n;x=MyStruct("foo",1))
for i = 1:n
x = MyWrapper(x)
end
x
end
=#
function test()
begin
x = MyStruct("x",1) # uses gpu method
y = wrap1(x) # dispatch to gpu method
z = wrap1(y) # dispatch to gpu method
w = wrap2(z) # fall back to base for non-compatible wrappers
v = wrap1(w) # fall back to base again
u = wrap1(v)
get_name(x)
println()
get_name(y)
println()
get_name(z)
println()
get_name(w)
println()
get_name(v)
println()
get_name(u)
println()
foo(x)
foo(y)
foo(z)
foo(w)
foo(v)
end
begin
x = MyStruct2("x",1) # uses gpu method
y = wrap1(x) # dispatch to gpu method
z = wrap1(y) # dispatch to gpu method
w = wrap2(z) # fall back to base for non-compatible wrappers
v = wrap1(w) # fall back to base again
get_name(x)
println()
get_name(y)
println()
get_name(z)
println()
get_name(w)
println()
get_name(v)
println()
get_name(u)
println()
foo(x)
foo(y)
foo(z)
foo(w)
foo(v)
end
end
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment