Created
March 2, 2023 14:49
-
-
Save reachtarunhere/899541077537db694638e7ce4c8620a2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using CUDA, Flux, BenchmarkTools | |
function block_wise_kernel_sm(input, output, col_size) # 1.675 ms with 256 threads | |
col = blockIdx().x | |
stride = blockDim().x | |
index = (col - 1) * col_size + threadIdx().x | |
tid = threadIdx().x | |
sa = CuStaticSharedArray(Float16, 1024) | |
#max capture | |
max_a = CuStaticSharedArray(Float16, 64) | |
sum_a = CuStaticSharedArray(Float16, 64) | |
for i in tid:stride:col_size | |
sa[i] = input[(col-1)*col_size + i] | |
end | |
# sa[threadIdx().x] = index | |
sync_threads() | |
for i in tid:stride:col_size | |
@inbounds e = sa[i] | |
if e < max_a[tid] | |
@inbounds max_a[tid] = e | |
end | |
end | |
sync_threads() | |
indexer = 2 | |
while indexer <= 64 | |
if tid % indexer == 0 | |
max_a[tid] += max_a[tid - (indexer ÷ 2)] | |
end | |
sync_threads() | |
indexer *= 2 | |
end | |
# real_sum = sum_a[64] | |
real_max = max_a[64] | |
sync_threads() | |
for i in tid:stride:col_size | |
@fastmath temp = exp(sa[i] - real_max) | |
@inbounds @fastmath sum_a[tid] += temp | |
@inbounds sa[i] = temp | |
end | |
sync_threads() | |
indexer = 2 | |
while indexer <= 64 | |
if tid % indexer == 0 | |
sum_a[tid] += sum_a[tid - (indexer ÷ 2)] | |
end | |
sync_threads() | |
indexer *= 2 | |
end | |
# real_sum = sum_a[64] | |
real_sum = sum_a[64] | |
for i in tid:stride:col_size | |
@inbounds @fastmath sa[i] = sa[i] / real_sum | |
end | |
sync_threads() | |
# output | |
for i in tid:stride:col_size | |
@inbounds @fastmath output[(col-1)*col_size + i] = sa[i] | |
end | |
sync_threads() | |
return nothing | |
end | |
function time_it(kernel, input, output, n_threads) | |
CUDA.@sync @cuda threads=n_threads blocks=(size(input, 2),) kernel(input, output, size(input, 1)) | |
end | |
function naive_softmax(input) | |
temp = exp.(input .- max(input, dims=1)) | |
temp ./ sum(temp, dims=1) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment