Created
August 7, 2021 22:54
-
-
Save jenkspt/459c9712257bec9544a7c149b7a33ba4 to your computer and use it in GitHub Desktop.
Hack to get a sortperm working for CUDA
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Base: zero, one | |
function uint_type(t) | |
s = sizeof(t) | |
if s == 2 | |
Int16 | |
elseif s == 4 | |
Int32 | |
elseif s == 8 | |
Int64 | |
else | |
throw("Invalid size") | |
end | |
end | |
# CUDA Quicksort kernel needs these to be defined to sort tuples | |
zero(::Type{Tuple{T1,T2}}) where {T1,T2} = (T1(0), T2(0)) | |
one(::Type{Tuple{T1,T2}}) where {T1,T2} = (T1(1), T2(1)) | |
function _sortperm(a::AbstractVector{<:Number}) | |
T1, T2 = eltype(a), uint_type(eltype(a)) | |
N = length(a) | |
indexed = similar(a, T2, (2, N)) | |
indexed[1, :] .= reinterpret(T2, a) | |
indexed[2, :] .= T2(1):T2(N) | |
indexed = reshape(reinterpret(Tuple{T1,T2}, indexed), N) | |
sort!(indexed) | |
@views reinterpret(T2, reshape(indexed, 1, N))[2, :] | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment