Skip to content

Instantly share code, notes, and snippets.

@maxwindiff
Last active February 26, 2023 07:18
Show Gist options
  • Save maxwindiff/768a9f1ac532e549ff9b012891b23eb7 to your computer and use it in GitHub Desktop.
Save maxwindiff/768a9f1ac532e549ff9b012891b23eb7 to your computer and use it in GitHub Desktop.
using Revise, Metal, BenchmarkTools
a = fill(Float32(1.0), 8192 * 8192);
da = MtlArray(a);
b = fill(Float32(1.0), 8192, 8192);
db = MtlArray(b);
# julia> @btime sum(a)
# 5.574 ms (1 allocation: 16 bytes)
# 6.7108864f7
# julia> @btime sum(da)
# 6.042 ms (754 allocations: 20.80 KiB)
# 6.7108864f7
# julia> @btime myreduce(+, da)
# 1.749 ms (579 allocations: 16.38 KiB)
# 6.7108864f7
@inline function reduce_warp(op, val)
offset = 0x00000001
while offset < 32
val = op(val, simd_shuffle_down(val, offset))
offset <<= 1
end
return val
end
function reduce_group(op, in::MtlDeviceArray{T}, out::MtlDeviceArray{T}, ::Val{stride}) where {T, stride}
tid = thread_position_in_threadgroup_1d()
blockSize = threads_per_threadgroup_1d()
shared = MtlThreadGroupArray(T, 1024)
@inbounds begin
# Read and reduce multiple values per thread
shared[tid] = 0
base = (thread_position_in_grid_1d() - 1) * stride
i = base + 1
while i <= base+stride
shared[tid] = op(shared[tid], in[i])
i += 1
end
threadgroup_barrier(Metal.MemoryFlagThreadGroup)
offset::UInt32 = 512
while offset > 16
if blockSize >= 2 * offset
if tid <= offset
shared[tid] += shared[tid + offset]
end
threadgroup_barrier(Metal.MemoryFlagThreadGroup)
end
offset = offset >> 1
end
if simdgroup_index_in_threadgroup() == 1
shared[tid] = reduce_warp(op, shared[tid])
end
if tid == 1
out[threadgroup_position_in_grid_1d()] = shared[tid]
end
end
return
end
function myreduce(op, a::MtlArray{T}, stride=4) where {T}
@assert length(a) % stride == 0 "Not suppported yet"
threads = min(length(a), 1024)
groups = cld(length(a), 1024 * stride)
b = similar(a, groups)
@metal threads=threads grid=groups reduce_group(op, a, b, Val(stride))
if groups == 1
return b[1]
elseif groups < 32
return sum(b)
else
return myreduce(op, b)
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment