Skip to content

Instantly share code, notes, and snippets.

@reachtarunhere
Created March 2, 2023 14:49
Show Gist options
  • Save reachtarunhere/899541077537db694638e7ce4c8620a2 to your computer and use it in GitHub Desktop.
Save reachtarunhere/899541077537db694638e7ce4c8620a2 to your computer and use it in GitHub Desktop.
using CUDA, Flux, BenchmarkTools
function block_wise_kernel_sm(input, output, col_size) # 1.675 ms with 256 threads
col = blockIdx().x
stride = blockDim().x
index = (col - 1) * col_size + threadIdx().x
tid = threadIdx().x
sa = CuStaticSharedArray(Float16, 1024)
#max capture
max_a = CuStaticSharedArray(Float16, 64)
sum_a = CuStaticSharedArray(Float16, 64)
for i in tid:stride:col_size
sa[i] = input[(col-1)*col_size + i]
end
# sa[threadIdx().x] = index
sync_threads()
for i in tid:stride:col_size
@inbounds e = sa[i]
if e < max_a[tid]
@inbounds max_a[tid] = e
end
end
sync_threads()
indexer = 2
while indexer <= 64
if tid % indexer == 0
max_a[tid] += max_a[tid - (indexer ÷ 2)]
end
sync_threads()
indexer *= 2
end
# real_sum = sum_a[64]
real_max = max_a[64]
sync_threads()
for i in tid:stride:col_size
@fastmath temp = exp(sa[i] - real_max)
@inbounds @fastmath sum_a[tid] += temp
@inbounds sa[i] = temp
end
sync_threads()
indexer = 2
while indexer <= 64
if tid % indexer == 0
sum_a[tid] += sum_a[tid - (indexer ÷ 2)]
end
sync_threads()
indexer *= 2
end
# real_sum = sum_a[64]
real_sum = sum_a[64]
for i in tid:stride:col_size
@inbounds @fastmath sa[i] = sa[i] / real_sum
end
sync_threads()
# output
for i in tid:stride:col_size
@inbounds @fastmath output[(col-1)*col_size + i] = sa[i]
end
sync_threads()
return nothing
end
function time_it(kernel, input, output, n_threads)
CUDA.@sync @cuda threads=n_threads blocks=(size(input, 2),) kernel(input, output, size(input, 1))
end
function naive_softmax(input)
temp = exp.(input .- max(input, dims=1))
temp ./ sum(temp, dims=1)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment