-
-
Save maxwindiff/768a9f1ac532e549ff9b012891b23eb7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Revise, Metal, BenchmarkTools | |
a = fill(Float32(1.0), 8192 * 8192); | |
da = MtlArray(a); | |
b = fill(Float32(1.0), 8192, 8192); | |
db = MtlArray(b); | |
# julia> @btime sum(a) | |
# 5.574 ms (1 allocation: 16 bytes) | |
# 6.7108864f7 | |
# julia> @btime sum(da) | |
# 6.042 ms (754 allocations: 20.80 KiB) | |
# 6.7108864f7 | |
# julia> @btime myreduce(+, da) | |
# 1.749 ms (579 allocations: 16.38 KiB) | |
# 6.7108864f7 | |
@inline function reduce_warp(op, val) | |
offset = 0x00000001 | |
while offset < 32 | |
val = op(val, simd_shuffle_down(val, offset)) | |
offset <<= 1 | |
end | |
return val | |
end | |
function reduce_group(op, in::MtlDeviceArray{T}, out::MtlDeviceArray{T}, ::Val{stride}) where {T, stride} | |
tid = thread_position_in_threadgroup_1d() | |
blockSize = threads_per_threadgroup_1d() | |
shared = MtlThreadGroupArray(T, 1024) | |
@inbounds begin | |
# Read and reduce multiple values per thread | |
shared[tid] = 0 | |
base = (thread_position_in_grid_1d() - 1) * stride | |
i = base + 1 | |
while i <= base+stride | |
shared[tid] = op(shared[tid], in[i]) | |
i += 1 | |
end | |
threadgroup_barrier(Metal.MemoryFlagThreadGroup) | |
offset::UInt32 = 512 | |
while offset > 16 | |
if blockSize >= 2 * offset | |
if tid <= offset | |
shared[tid] += shared[tid + offset] | |
end | |
threadgroup_barrier(Metal.MemoryFlagThreadGroup) | |
end | |
offset = offset >> 1 | |
end | |
if simdgroup_index_in_threadgroup() == 1 | |
shared[tid] = reduce_warp(op, shared[tid]) | |
end | |
if tid == 1 | |
out[threadgroup_position_in_grid_1d()] = shared[tid] | |
end | |
end | |
return | |
end | |
function myreduce(op, a::MtlArray{T}, stride=4) where {T} | |
@assert length(a) % stride == 0 "Not suppported yet" | |
threads = min(length(a), 1024) | |
groups = cld(length(a), 1024 * stride) | |
b = similar(a, groups) | |
@metal threads=threads grid=groups reduce_group(op, a, b, Val(stride)) | |
if groups == 1 | |
return b[1] | |
elseif groups < 32 | |
return sum(b) | |
else | |
return myreduce(op, b) | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment