Created
August 23, 2019 18:46
-
-
Save mwarusz/9d62fed08ae583cd62d963c24a535bd3 to your computer and use it in GitHub Desktop.
gpuify_cthulhu.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/Project.toml b/Project.toml | |
index 104d043..d15c4ad 100644 | |
--- a/Project.toml | |
+++ b/Project.toml | |
@@ -5,6 +5,7 @@ version = "0.2.8" | |
[deps] | |
Cassette = "7057c7e9-c182-5462-911a-8362d720325c" | |
+Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f" | |
Requires = "ae029012-a4dd-5104-9daa-d747884805df" | |
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" | |
diff --git a/src/GPUifyLoops.jl b/src/GPUifyLoops.jl | |
index b8c67bc..335cce7 100644 | |
--- a/src/GPUifyLoops.jl | |
+++ b/src/GPUifyLoops.jl | |
@@ -19,6 +19,7 @@ export CPU, CUDA, Device | |
using StaticArrays | |
using Requires | |
+using Cthulhu | |
export @setup, @loop, @synchronize | |
export @scratch, @shmem | |
@@ -85,6 +86,7 @@ function split_kwargs(kwargs) | |
call_kws = [:blocks, :threads, :shmem, :stream, :config] | |
compiler_kwargs = [] | |
call_kwargs = [] | |
+ descend = false | |
for kwarg in kwargs | |
key, val = kwarg | |
if isa(key, Symbol) | |
@@ -92,6 +94,8 @@ function split_kwargs(kwargs) | |
push!(compiler_kwargs, kwarg) | |
elseif key in call_kws | |
push!(call_kwargs, kwarg) | |
+ elseif key == :descend | |
+ descend = true | |
else | |
throw(ArgumentError("unknown keyword argument '$key'")) | |
end | |
@@ -99,7 +103,7 @@ function split_kwargs(kwargs) | |
throw(ArgumentError("non-symbolic keyword '$key'")) | |
end | |
end | |
- return compiler_kwargs, call_kwargs | |
+ return compiler_kwargs, call_kwargs, descend | |
end | |
@init @require CUDAnative="be33ccc6-a3ff-5ff2-a52e-74243cff1e17" begin | |
@@ -116,7 +120,7 @@ end | |
global const CUDANativeVersion = version_check() | |
function launch(::CUDA, f::F, args...; kwargs...) where F | |
- compiler_kwargs, call_kwargs = split_kwargs(kwargs) | |
+ compiler_kwargs, call_kwargs, descend_into = split_kwargs(kwargs) | |
args = (ctx, f, args...) | |
GC.@preserve args begin | |
kernel_args = map(cudaconvert, args) | |
@@ -130,7 +134,11 @@ end | |
maxthreads = CUDAnative.maxthreads(kernel) | |
config = launch_config(f, maxthreads, args...; call_kwargs...) | |
- kernel(kernel_args...; config...) | |
+ if descend_into | |
+ @descend kernel(kernel_args...; config...) | |
+ else | |
+ kernel(kernel_args...; config...) | |
+ end | |
end | |
return nothing | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment