Skip to content

Instantly share code, notes, and snippets.

@mwarusz
Created August 23, 2019 18:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mwarusz/9d62fed08ae583cd62d963c24a535bd3 to your computer and use it in GitHub Desktop.
Save mwarusz/9d62fed08ae583cd62d963c24a535bd3 to your computer and use it in GitHub Desktop.
gpuify_cthulhu.patch
diff --git a/Project.toml b/Project.toml
index 104d043..d15c4ad 100644
--- a/Project.toml
+++ b/Project.toml
@@ -5,6 +5,7 @@ version = "0.2.8"
[deps]
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
+Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
diff --git a/src/GPUifyLoops.jl b/src/GPUifyLoops.jl
index b8c67bc..335cce7 100644
--- a/src/GPUifyLoops.jl
+++ b/src/GPUifyLoops.jl
@@ -19,6 +19,7 @@ export CPU, CUDA, Device
using StaticArrays
using Requires
+using Cthulhu
export @setup, @loop, @synchronize
export @scratch, @shmem
@@ -85,6 +86,7 @@ function split_kwargs(kwargs)
call_kws = [:blocks, :threads, :shmem, :stream, :config]
compiler_kwargs = []
call_kwargs = []
+ descend = false
for kwarg in kwargs
key, val = kwarg
if isa(key, Symbol)
@@ -92,6 +94,8 @@ function split_kwargs(kwargs)
push!(compiler_kwargs, kwarg)
elseif key in call_kws
push!(call_kwargs, kwarg)
+ elseif key == :descend
+ descend = true
else
throw(ArgumentError("unknown keyword argument '$key'"))
end
@@ -99,7 +103,7 @@ function split_kwargs(kwargs)
throw(ArgumentError("non-symbolic keyword '$key'"))
end
end
- return compiler_kwargs, call_kwargs
+ return compiler_kwargs, call_kwargs, descend
end
@init @require CUDAnative="be33ccc6-a3ff-5ff2-a52e-74243cff1e17" begin
@@ -116,7 +120,7 @@ end
global const CUDANativeVersion = version_check()
function launch(::CUDA, f::F, args...; kwargs...) where F
- compiler_kwargs, call_kwargs = split_kwargs(kwargs)
+ compiler_kwargs, call_kwargs, descend_into = split_kwargs(kwargs)
args = (ctx, f, args...)
GC.@preserve args begin
kernel_args = map(cudaconvert, args)
@@ -130,7 +134,11 @@ end
maxthreads = CUDAnative.maxthreads(kernel)
config = launch_config(f, maxthreads, args...; call_kwargs...)
- kernel(kernel_args...; config...)
+ if descend_into
+ @descend kernel(kernel_args...; config...)
+ else
+ kernel(kernel_args...; config...)
+ end
end
return nothing
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment