Skip to content

Instantly share code, notes, and snippets.

@dlazares
Forked from kieber-emmons/ParallelRadixSort.metal
Created January 7, 2024 21:02
Show Gist options
  • Save dlazares/8871f19726b49f55b1d185bb4b8ad6cf to your computer and use it in GitHub Desktop.
Save dlazares/8871f19726b49f55b1d185bb4b8ad6cf to your computer and use it in GitHub Desktop.
This gist is for an article I wrote on Medium (https://medium.com/p/4f4590cfd5d3).
//
// ParallelRadixSort.metal
//
// Created by Matthew Kieber-Emmons on 08/29/22.
// Copyright © 2022 Matthew Kieber-Emmons. All rights reserved.
// This work is for educational purposes only and cannot be used without consent.
//
#include <metal_stdlib>
using namespace metal;
////////////////////////////////////////////////////////////////
// MARK: - Compilation Constants
// these constants are typically provided at library generation but we have sensible defaults here
////////////////////////////////////////////////////////////////
#ifndef THREADS_PER_THREADGROUP
#define THREADS_PER_THREADGROUP (256)
#endif
#ifndef VALUES_PER_THREAD
#define VALUES_PER_THREAD (4)
#endif
#ifndef EXECUTION_WIDTH
#define EXECUTION_WIDTH (32)
#endif
#ifndef LIBRARY_RADIX
#define LIBRARY_RADIX (32)
#endif
////////////////////////////////////////////////////////////////
// MARK: - Functions Constants
// these constants control the code paths at pipeline creation
////////////////////////////////////////////////////////////////
constant int LOCAL_ALGORITHM [[function_constant(0)]];
constant int GLOBAL_ALGORITHM [[function_constant(1)]];
#define SORT_GLOBAL_ALGORITHM_LSD (0)
constant bool DISABLE_BOUNDS_CHECK [[function_constant(2)]];
constant bool ASCENDING [[function_constant(3)]];
#define SORT_DIRECTION_ASCENDING (0)
#define SORT_DIRECTION_DESCENDING (1)
constant int SORT_OPTIONS [[function_constant(4)]];
////////////////////////////////////////////////////////////////
// MARK: - Helpers
////////////////////////////////////////////////////////////////
static constexpr bool IsPowerOfTwo(uint32_t x){
return ((x != 0) && !(x & (x - 1)));
}
static constexpr ushort RadixToBits(ushort n) {
return (n-1<2)?1:
(n-1<4)?2:
(n-1<8)?3:
(n-1<16)?4:
(n-1<32)?5:
(n-1<64)?6:
(n-1<128)?7:
(n-1<256)?8:
(n-1<512)?9:
(n-1<1024)?10:
(n-1<2048)?11:
(n-1<4096)?12:
(n-1<8192)?13:
(n-1<16384)?14:
(n-1<32768)?15:0;
}
template <ushort RADIX, typename T> static inline ushort
ValueToKeyAtBit(T value, ushort current_bit){
return (value >> current_bit) & (RADIX - 1);
}
template <ushort RADIX> static inline ushort
ValueToKeyAtBit(int32_t value, ushort current_bit){
return ( (as_type<uint32_t>(value) ^ (1U << 31)) >> current_bit) & (RADIX - 1);
}
template <ushort RADIX, typename T> static inline ushort
ValueToKeyAtDigit(T value, ushort current_digit){
ushort bits_to_shift = RadixToBits(RADIX) * current_digit;
return ValueToKeyAtBit<RADIX>(value, bits_to_shift);
}
///////////////////////////////////////////////////////////////////////////////
// MARK: - Load and Store Functions
///////////////////////////////////////////////////////////////////////////////
// blocked read into registers i.e. ABCDEFGH -> AB, CD, EF, GH
template<ushort LENGTH, typename T> static void
LoadBlockedLocalFromGlobal(thread T (&value)[LENGTH],
const device T* input_data,
const ushort local_id) {
for (ushort i = 0; i < LENGTH; i++){
value[i] = input_data[local_id * LENGTH + i];
}
}
// blocked read into registers with bounds checking
template<ushort LENGTH, typename T> static void
LoadBlockedLocalFromGlobal(thread T (&value)[LENGTH],
const device T* input_data,
const ushort local_id,
const uint n,
const T substitution_value) {
for (ushort i = 0; i < LENGTH; i++){
value[i] = (local_id * LENGTH + i < n) ? input_data[local_id * LENGTH + i] : substitution_value;
}
}
// striped read into registers i.e. ABCDEFGH -> AE, BF, CG, EH
template<ushort LENGTH, typename T> static void
LoadStripedLocalFromGlobal(thread T (&value)[LENGTH],
const device T* input_data,
const ushort local_id,
const ushort local_size) {
for (ushort i = 0; i < LENGTH; i++){
value[i] = input_data[local_id + i * local_size];
}
}
// striped read into registers with bounds checking
template<ushort LENGTH, typename T> static void
LoadStripedLocalFromGlobal(thread T (&value)[LENGTH],
const device T* input_data,
const ushort local_id,
const ushort local_size,
const uint n,
const T substitution_value){
// this is a blocked read into registers
for (ushort i = 0; i < LENGTH; i++){
value[i] = (local_id + i * local_size < n) ? input_data[local_id + i * local_size] : substitution_value;
}
}
///////////////////////////////////////////////////////////////////////////////
// MARK: - Prefix Scan Functions
///////////////////////////////////////////////////////////////////////////////
template <typename T>
struct SumOp {
inline T operator()(thread const T& a, thread const T& b) const{return a + b;}
inline T operator()(threadgroup const T& a, thread const T& b) const{return a + b;}
inline T operator()(threadgroup const T& a, threadgroup const T& b) const{return a + b;}
inline T operator()(volatile threadgroup const T& a, volatile threadgroup const T& b) const{return a + b;}
constexpr T identity(){return static_cast<T>(0);}
};
template <typename T>
struct MaxOp {
inline T operator()(thread const T& a, thread const T& b) const{return max(a,b);}
inline T operator()(threadgroup const T& a, thread const T& b) const{return max(a,b);}
inline T operator()(threadgroup const T& a, threadgroup const T& b) const{return max(a,b);}
inline T operator()(volatile threadgroup const T& a, volatile threadgroup const T& b) const{return max(a,b);}
constexpr T identity(){ return metal::numeric_limits<T>::min(); }
};
#define SCAN_TYPE_INCLUSIVE (0)
#define SCAN_TYPE_EXCLUSIVE (1)
template<ushort LENGTH, int SCAN_TYPE, typename BinaryOp, typename T>
static inline T ThreadScan(threadgroup T* values, BinaryOp Op){
for (ushort i = 1; i < LENGTH; i++){
values[i] = Op(values[i],values[i - 1]);
}
T result = values[LENGTH - 1];
if (SCAN_TYPE == SCAN_TYPE_EXCLUSIVE){
for (ushort i = LENGTH - 1; i > 0; i--){
values[i] = values[i - 1];
}
values[0] = 0;
}
return result;
}
template<ushort LENGTH, typename BinaryOp, typename T> static inline void
ThreadUniformApply(thread T* values, T uni, BinaryOp Op){
for (ushort i = 0; i < LENGTH; i++){
values[i] = Op(values[i],uni);
}
}
template<ushort LENGTH, typename BinaryOp, typename T> static inline void
ThreadUniformApply(threadgroup T* values, T uni, BinaryOp Op){
for (ushort i = 0; i < LENGTH; i++){
values[i] = Op(values[i],uni);
}
}
template <int SCAN_TYPE, typename BinaryOp, typename T> static inline T
SimdgroupScan(T value, ushort local_id, BinaryOp Op){
const ushort lane_id = local_id % 32;
T temp = simd_shuffle_up(value, 1);
if (lane_id >= 1) value = Op(value,temp);
temp = simd_shuffle_up(value, 2);
if (lane_id >= 2) value = Op(value,temp);
temp = simd_shuffle_up(value, 4);
if (lane_id >= 4) value = Op(value,temp);
temp = simd_shuffle_up(value, 8);
if (lane_id >= 8) value = Op(value,temp);
temp = simd_shuffle_up(value, 16);
if (lane_id >= 16) value = Op(value,temp);
if (SCAN_TYPE == SCAN_TYPE_EXCLUSIVE){
temp = simd_shuffle_up(value, 1);
value = (lane_id == 0) ? 0 : temp;
}
return value;
}
template<ushort BLOCK_SIZE, int SCAN_TYPE, typename BinaryOp, typename T> static T
ThreadgroupPrefixScanStoreSum(T value, thread T& inclusive_sum, threadgroup T* shared, const ushort local_id, BinaryOp Op) {
shared[local_id] = value;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (local_id < 32){
T partial_sum = ThreadScan<BLOCK_SIZE / 32, SCAN_TYPE>(&shared[local_id * (BLOCK_SIZE / 32)], Op);
T prefix = SimdgroupScan<SCAN_TYPE_EXCLUSIVE>(partial_sum, local_id, Op);
ThreadUniformApply<BLOCK_SIZE / 32>(&shared[local_id * (BLOCK_SIZE / 32)], prefix, Op);
if (local_id == 31) shared[0] = prefix + partial_sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (SCAN_TYPE == SCAN_TYPE_INCLUSIVE) value = (local_id == 0) ? value : shared[local_id];
else value = (local_id == 0) ? 0 : shared[local_id];
inclusive_sum = shared[0];
threadgroup_barrier(mem_flags::mem_threadgroup);
return value;
}
template<ushort BLOCK_SIZE, int SCAN_TYPE, typename BinaryOp, typename T> static T
ThreadgroupPrefixScan(T value, threadgroup T* shared, const ushort local_id, BinaryOp Op) {
// load values into shared memory
shared[local_id] = value;
threadgroup_barrier(mem_flags::mem_threadgroup);
// rake over shared mem
if (local_id < 32){
T partial_sum = ThreadScan<BLOCK_SIZE / 32, SCAN_TYPE>(&shared[local_id * (BLOCK_SIZE / 32)], Op);
T prefix = SimdgroupScan<SCAN_TYPE_EXCLUSIVE>(partial_sum, local_id, Op);
ThreadUniformApply<BLOCK_SIZE / 32>(&shared[local_id * (BLOCK_SIZE / 32)], prefix, Op);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
value = shared[local_id];
threadgroup_barrier(mem_flags::mem_threadgroup);
return value;
}
///////////////////////////////////////////////////////////////////////////////
// MARK: - Discontinuous regions functions
///////////////////////////////////////////////////////////////////////////////
template <ushort BLOCK_SIZE, typename T> static uchar
FlagHeadDiscontinuity(const T value, threadgroup T* shared, const ushort local_id){
shared[local_id] = value;
threadgroup_barrier(mem_flags::mem_threadgroup);
uchar result = (local_id == 0) ? 1 : shared[local_id] != shared[local_id - 1];
threadgroup_barrier(mem_flags::mem_threadgroup);
return result;
}
template <ushort BLOCK_SIZE, typename T> static uchar
FlagTailDiscontinuity(const T value, threadgroup T* shared, const ushort local_id){
shared[local_id] = value;
threadgroup_barrier(mem_flags::mem_threadgroup);
uchar result = (local_id == BLOCK_SIZE - 1) ? 1 : shared[local_id] != shared[local_id + 1];
threadgroup_barrier(mem_flags::mem_threadgroup);
return result;
}
///////////////////////////////////////////////////////////////////////////////
// MARK: - Sorting Functions
///////////////////////////////////////////////////////////////////////////////
template <ushort BLOCK_SIZE, typename T> static T
SortByBit(const T value, threadgroup uint* shared, const ushort local_id, const uchar current_bit){
// extract the value of the digit
uchar mask = ValueToKeyAtBit<2>(value, current_bit);
// 2-way scan
uchar2 partial_sum;
uchar2 scan = {0};
scan[mask] = 1;
scan = ThreadgroupPrefixScanStoreSum<BLOCK_SIZE, SCAN_TYPE_EXCLUSIVE>(scan,
partial_sum,
reinterpret_cast<threadgroup uchar2*>(shared),
local_id,
SumOp<uchar2>());
// make offsets from the partial sums
ushort2 offset;
offset[0] = 0;
offset[1] = offset[0] + partial_sum[0];
shared[scan[mask] + offset[mask]] = value;
threadgroup_barrier(mem_flags::mem_threadgroup);
// read new value from shared
T result = shared[local_id];
threadgroup_barrier(mem_flags::mem_threadgroup);
return result;
}
template <ushort BLOCK_SIZE, typename T> static T
SortByTwoBits(const T value, threadgroup uint* shared, const ushort local_id, const uchar current_bit){
uchar mask = ValueToKeyAtBit<4>(value, current_bit);
// 4-way scan
uchar4 partial_sum;
uchar4 scan = {0};
scan[mask] = 1;
scan = ThreadgroupPrefixScanStoreSum<BLOCK_SIZE, SCAN_TYPE_EXCLUSIVE>(scan,
partial_sum,
reinterpret_cast<threadgroup uchar4*>(shared),
local_id,
SumOp<uchar4>());
// make offsets from the partial sums
ushort4 offset;
offset[0] = 0;
offset[1] = offset[0] + partial_sum[0];
offset[2] = offset[1] + partial_sum[1];
offset[3] = offset[2] + partial_sum[2];
shared[scan[mask] + offset[mask]] = value;
threadgroup_barrier(mem_flags::mem_threadgroup);
// read new value from shared
T result = shared[local_id];
threadgroup_barrier(mem_flags::mem_threadgroup);
return result;
}
template <ushort BLOCK_SIZE, typename T, ushort RADIX> static T
PartialRadixSort(const T value, threadgroup uint* shared, const ushort local_id, const ushort current_digit){
T result = value;
ushort current_bit = current_digit * RadixToBits(RADIX);
const ushort last_bit = min(current_bit + RadixToBits(RADIX), (ushort)sizeof(T) * 8);
while (current_bit < last_bit){
if (last_bit - current_bit > 1){
result = SortByTwoBits<BLOCK_SIZE>(result, shared, local_id, current_bit);
current_bit += 2;
}else{
result = SortByBit<BLOCK_SIZE>(result, shared, local_id, current_bit);
current_bit += 1;
}
}
return result;
}
///////////////////////////////////////////////////////////////////////////////
// MARK: - Kernels
///////////////////////////////////////////////////////////////////////////////
template<ushort BLOCK_SIZE, ushort GRAIN_SIZE, ushort RADIX, typename T> kernel void
MakeHistogramOfPlaceValuesKernel(device uint* output_data,
device const T* input_data,
constant uint& n,
constant uint& current_digit,
uint group_id [[threadgroup_position_in_grid]],
uint grid_size [[threadgroups_per_grid]],
ushort local_id [[thread_position_in_threadgroup]]) {
static_assert((BLOCK_SIZE % 32) == 0, "ERROR - BLOCK_SIZE must be a multiple of the execution width");
static_assert(IsPowerOfTwo(RADIX), "ERROR - RADIX must be a power of 2");
uint base_id = group_id * BLOCK_SIZE * GRAIN_SIZE;
// load data into registers
T values[GRAIN_SIZE];
if (DISABLE_BOUNDS_CHECK){
LoadBlockedLocalFromGlobal(values, &input_data[base_id], local_id);
} else {
LoadBlockedLocalFromGlobal(values, &input_data[base_id], local_id, n - base_id, numeric_limits<T>::max());
}
// zero out the shared memory
threadgroup uint histogram[RADIX];
for (ushort i = 0; i < (RADIX + BLOCK_SIZE - 1) / BLOCK_SIZE; i++){
if (local_id + i * BLOCK_SIZE < RADIX) histogram[local_id + i * BLOCK_SIZE] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// iterate over values to update the histogram using an atomic add operation
volatile threadgroup atomic_uint* atomic_histogram = reinterpret_cast<volatile threadgroup atomic_uint*>(histogram);
for (ushort i = 0; i < GRAIN_SIZE; i++){
uchar key = ValueToKeyAtDigit<RADIX>(values[i], current_digit);
if (DISABLE_BOUNDS_CHECK){
atomic_fetch_add_explicit(&atomic_histogram[key], 1, memory_order_relaxed);
} else {
uint32_t predicate = (base_id + local_id * GRAIN_SIZE + i < n ) ? 1 : 0; // for blocked reading
atomic_fetch_add_explicit(&atomic_histogram[key], predicate, memory_order_relaxed);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// store histogram to global in column major format (striped)
for (ushort i = 0; i < (RADIX + BLOCK_SIZE - 1) / BLOCK_SIZE; i++){
if (local_id + i * BLOCK_SIZE < RADIX){
output_data[grid_size * (local_id + i * BLOCK_SIZE) + group_id] = histogram[local_id + i * BLOCK_SIZE];
}
}
}
template [[host_name("make_histogram_int32")]] kernel void
MakeHistogramOfPlaceValuesKernel<THREADS_PER_THREADGROUP,VALUES_PER_THREAD,LIBRARY_RADIX,int>(device uint*, device const int*,constant uint&,constant uint&, uint,uint,ushort);
MakeHistogramOfPlaceValuesKernel<THREADS_PER_THREADGROUP,VALUES_PER_THREAD,LIBRARY_RADIX,ushort>(device uint*, device const ushort*,constant uint&,constant uint&, uint,uint,ushort);
template [[host_name("make_histogram_uint32")]] kernel void
MakeHistogramOfPlaceValuesKernel<THREADS_PER_THREADGROUP,VALUES_PER_THREAD,LIBRARY_RADIX,uint>(device uint*, device const uint*,constant uint&,constant uint&, uint,uint,ushort);
template<ushort BLOCK_SIZE, ushort GRAIN_SIZE, ushort RADIX, typename T> kernel void
ReorderByPlaceValuesKernel(device T* output_data,
device const T* input_data,
constant uint& n,
device const uint* offsets_data,
constant uint& current_digit,
uint group_id [[threadgroup_position_in_grid]],
uint grid_size [[threadgroups_per_grid]],
ushort local_id [[thread_position_in_threadgroup]]) {
uint base_id = group_id * BLOCK_SIZE * GRAIN_SIZE;
// load data into registers
T values[GRAIN_SIZE];
if (DISABLE_BOUNDS_CHECK){
LoadStripedLocalFromGlobal(values, &input_data[base_id], local_id, BLOCK_SIZE);
} else {
LoadStripedLocalFromGlobal(values, &input_data[base_id], local_id, BLOCK_SIZE, n - base_id, metal::numeric_limits<T>::max());
}
// sort striped values by threadgroup
threadgroup uint shared_data[BLOCK_SIZE];
for (ushort i = 0; i < GRAIN_SIZE; i++){
values[i] = PartialRadixSort<BLOCK_SIZE, T, RADIX>(values[i], shared_data, local_id, current_digit);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup uint global_offset[RADIX];
if (local_id < RADIX){
global_offset[local_id] = offsets_data[grid_size * local_id + group_id];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// write to global using offsets in the histogram
uint indexes[GRAIN_SIZE];
for (ushort i = 0; i < GRAIN_SIZE; i++){
// get local offset by scan of head flags of the range of digits
uchar key = ValueToKeyAtDigit<RADIX>(values[i], current_digit);
uchar flag = FlagHeadDiscontinuity<BLOCK_SIZE>(key, reinterpret_cast<threadgroup uchar*>(shared_data), local_id);
ushort local_offset = local_id - ThreadgroupPrefixScan<BLOCK_SIZE, SCAN_TYPE_INCLUSIVE>(flag ? (ushort)local_id : (ushort)0,
reinterpret_cast<threadgroup ushort*>(shared_data),
local_id,
MaxOp<T>());
indexes[i] = local_offset + global_offset[key];
threadgroup_barrier(mem_flags::mem_threadgroup);
// update the global offsets - put flags into registers, then update indexes and offsets
flag = FlagTailDiscontinuity<BLOCK_SIZE>(key, reinterpret_cast<threadgroup uchar*>(shared_data), local_id);
if (flag){
global_offset[key] += local_offset + 1;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// scatter to global
if (DISABLE_BOUNDS_CHECK){
for (ushort i = 0; i < GRAIN_SIZE; i++){
output_data[indexes[i]] = values[i];
}
} else {
for (ushort i = 0; i < GRAIN_SIZE; i++){
if (indexes[i] < n) {
output_data[indexes[i]] = values[i];
}
}
}
}
template [[host_name("reorder_int32")]] kernel void
ReorderByPlaceValuesKernel<THREADS_PER_THREADGROUP,VALUES_PER_THREAD,LIBRARY_RADIX,int> (device int*, device const int*, constant uint&, device const uint*, constant uint& t, uint, uint, ushort);
template [[host_name("reorder_uint32")]] kernel void
ReorderByPlaceValuesKernel<THREADS_PER_THREADGROUP,VALUES_PER_THREAD,LIBRARY_RADIX,uint> (device uint*, device const uint*, constant uint&, device const uint*, constant uint& t, uint, uint, ushort);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment