Skip to content

Instantly share code, notes, and snippets.

@jaredhoberock
Created July 15, 2014 02:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jaredhoberock/b80c29ee556bfdf9ed59 to your computer and use it in GitHub Desktop.
Save jaredhoberock/b80c29ee556bfdf9ed59 to your computer and use it in GitHub Desktop.
CUB __host__ __device__ sort
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__>= 350 && defined(__CUDACC_RDC__))
# define CUB_CDP 1
#endif
#if defined(__CUDACC__)
# if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__>= 350 && defined(__CUDACC_RDC__))
# define HAS_CUDART 1
# else
# define HAS_CUDART 0
# endif
#else
# define HAS_CUDART 0
#endif
#include <thrust/device_vector.h>
#include <thrust/copy.h>
#include <cub/device/device_radix_sort.cuh>
#include <cassert>
#include <algorithm>
#include <cstdlib>
template<typename T>
class temporary_array
{
public:
__host__ __device__
temporary_array(size_t n)
: m_size(n)
{
#if HAS_CUDART
size_t num_bytes = n * sizeof(T);
cudaError_t error = cudaMalloc(&m_ptr, num_bytes);
if(error)
{
printf("CUDA error in temporary_array ctor: %s\n", cudaGetErrorString(error));
}
#endif
}
__host__ __device__
~temporary_array()
{
#if HAS_CUDART
cudaError_t error = cudaFree(m_ptr);
if(error)
{
printf("CUDA error in temporary_storage dtor: %s\n", cudaGetErrorString(error));
}
#endif
}
__host__ __device__
T* data() const
{
return m_ptr;
}
__host__ __device__
T* begin() const
{
return data();
}
__host__ __device__
T* end() const
{
return begin() + size();
}
__host__ __device__
size_t size() const
{
return m_size;
}
private:
T *m_ptr;
size_t m_size;
};
template<typename Key>
__host__ __device__
cudaError_t cub_sort_keys_wrapper(void *d_temp_storage,
size_t &temp_storage_bytes,
cub::DoubleBuffer<Key> &d_keys,
int num_items,
int begin_bit = 0,
int end_bit = sizeof(Key) * 8,
cudaStream_t stream = 0,
bool debug_synchronous = false)
{
struct workaround
{
__host__
static cudaError_t host_path(void *d_temp_storage,
size_t &temp_storage_bytes,
cub::DoubleBuffer<Key> &d_keys,
int num_items,
int begin_bit,
int end_bit,
cudaStream_t stream,
bool debug_synchronous)
{
return cub::DeviceRadixSort::SortKeys(d_temp_storage, temp_storage_bytes, d_keys, num_items, begin_bit, end_bit, stream, debug_synchronous);
}
__device__
static cudaError_t device_path(void *d_temp_storage,
size_t &temp_storage_bytes,
cub::DoubleBuffer<Key> &d_keys,
int num_items,
int begin_bit,
int end_bit,
cudaStream_t stream,
bool debug_synchronous)
{
#if HAS_CUDART
return cub::DeviceRadixSort::SortKeys(d_temp_storage, temp_storage_bytes, d_keys, num_items, begin_bit, end_bit, stream, debug_synchronous);
#else
return cudaErrorNotSupported;
#endif
}
};
#ifndef __CUDA_ARCH__
return workaround::host_path(d_temp_storage, temp_storage_bytes, d_keys, num_items, begin_bit, end_bit, stream, debug_synchronous);
#else
return workaround::device_path(d_temp_storage, temp_storage_bytes, d_keys, num_items, begin_bit, end_bit, stream, debug_synchronous);
#endif
}
template<typename T>
__host__ __device__
thrust::tuple<size_t, size_t, size_t> compute_temporary_storage_requirements_for_radix_sort_n(size_t n)
{
cub::DoubleBuffer<T> dummy;
// measure the number of additional temporary storage bytes required
size_t num_additional_temp_storage_bytes = 0;
cudaError_t error = cub_sort_keys_wrapper(0, num_additional_temp_storage_bytes, dummy, static_cast<int>(n));
if(error)
{
#if HAS_CUDART
printf("CUDA error after sort(0): %s\n", cudaGetErrorString(error));
#endif
}
// XXX the additional temporary storage bytes
// must be allocated on a 16b aligned address
struct __align__(16) aligned_type {};
size_t num_double_buffer_bytes = n * sizeof(T);
size_t num_aligned_double_buffer_bytes = thrust::detail::util::round_i(num_double_buffer_bytes, sizeof(aligned_type));
size_t num_aligned_total_temporary_storage_bytes = num_aligned_double_buffer_bytes + num_additional_temp_storage_bytes;
return thrust::make_tuple(num_aligned_total_temporary_storage_bytes, num_aligned_double_buffer_bytes, num_additional_temp_storage_bytes);
}
template<typename T>
__host__ __device__
cudaError_t cub_sort_n(T* first, size_t n)
{
cudaError_t result = cudaErrorNotSupported;
if(n > 1)
{
// compute temporary storage requirements
size_t num_temporary_storage_bytes = 0;
size_t offset_to_additional_temp_storage = 0;
size_t num_additional_temp_storage_bytes = 0;
thrust::tie(num_temporary_storage_bytes, offset_to_additional_temp_storage, num_additional_temp_storage_bytes) =
compute_temporary_storage_requirements_for_radix_sort_n<T>(n);
// allocate storage
temporary_array<char> temporary_storage(num_temporary_storage_bytes);
// set up double buffer
cub::DoubleBuffer<T> double_buffer;
double_buffer.d_buffers[0] = thrust::raw_pointer_cast(&*first);
double_buffer.d_buffers[1] = reinterpret_cast<T*>(temporary_storage.data());
result = cub_sort_keys_wrapper(thrust::raw_pointer_cast(temporary_storage.data() + offset_to_additional_temp_storage),
num_additional_temp_storage_bytes,
double_buffer,
static_cast<int>(n));
if(result != cudaSuccess) return result;
if(double_buffer.selector != 0)
{
T* temp_ptr = reinterpret_cast<T*>(double_buffer.d_buffers[1]);
#ifndef __CUDA_ARCH__
result = cudaMemcpy(first, temp_ptr, sizeof(T) * n, cudaMemcpyDeviceToDevice);
#else
result = cudaDeviceSynchronize();
memcpy(first, temp_ptr, sizeof(T) * n);
#endif
}
}
return result;
}
template<typename T>
__global__ void sort_kernel(T* data, size_t n)
{
cudaError_t error = cub_sort_n(data, n);
if(error)
{
printf("CUDA error: %s\n", cudaGetErrorString(error));
}
}
int main()
{
size_t n = 3;
std::vector<int> ref(n);
// generate unsorted values
std::generate(ref.begin(), ref.end(), rand);
thrust::device_vector<int> data = ref;
cudaError_t error = cub_sort_n(thrust::raw_pointer_cast(data.data()), data.size());
if(error)
{
std::cerr << "CUDA error: " << cudaGetErrorString(error) << std::endl;
std::exit(-1);
}
std::vector<int> h_data(n);
thrust::copy(data.begin(), data.end(), h_data.begin());
std::sort(ref.begin(), ref.end());
assert(ref == h_data);
std::cout << "Host sort OK!" << std::endl;
std::generate(ref.begin(), ref.end(), rand);
data = ref;
sort_kernel<<<1,1>>>(thrust::raw_pointer_cast(data.data()), data.size());
thrust::copy(data.begin(), data.end(), h_data.begin());
std::sort(ref.begin(), ref.end());
assert(ref == h_data);
std::cout << "Everything OK!" << std::endl;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment