Skip to content

Instantly share code, notes, and snippets.

@timdecode
Created June 23, 2023 18:11
Show Gist options
  • Save timdecode/2aa10535b65dab08df78655d560983fb to your computer and use it in GitHub Desktop.
Save timdecode/2aa10535b65dab08df78655d560983fb to your computer and use it in GitHub Desktop.
Metal implementation of subgroupPartitionNV
// Created by Timothy Davison on 2023-06-21.
//
// This is a Metal implementation of subgroupPartitionNV. You use it to find a mask of
// the other threads in a simd-group with the same value (a partition of the simd-group about
// a set of values).
//
// Feel free to use this in your code. Please share any fixes or ideas to make it faster.
//
// Khronos docs on subgroup partitioning:
// - https://github.com/KhronosGroup/GLSL/blob/master/extensions/nv/GL_NV_shader_subgroup_partitioned.txt
//
// Great talk on Vulkan subgroups by Daniel Koch @ NVIDIA
// - https://www.youtube.com/watch?v=fP1Af0u097o&list=FLY3QhNiXc0I1GK_0eM-8J0w&index=2
//
// Vulkan subgroups are simd-group and quad-groups in Metal. The Metal library doesn't have a subgroup partitioning
// algorithm. However, we can build one using simd-shuffle and simd-ballot. Other platforms have hardware instructions,
// however, Apple GPUs don't have those hardware instructions (yet). This code requires n trips through the loop,
// where n is the number of partitions. For some classes of problems (e.g., coalescing device atomic-adds) even
// this looped version of subgroup partitioning can be a big win.
#include <metal_stdlib>
using namespace metal;
// Returns a mask of the other active threads in the simd-group with the same value.
// Requires a Metal GPU that supports simd-permutations. The types supported are limited
// by the types supported by simd_shuffle.
//
// This is a Metal implementation of subgroupPartitionNV.
//
// See:
// - https://github.com/KhronosGroup/GLSL/blob/master/extensions/nv/GL_NV_shader_subgroup_partitioned.txt
// - https://www.youtube.com/watch?v=fP1Af0u097o&list=FLY3QhNiXc0I1GK_0eM-8J0w&index=2
template<typename T>
static inline uint simd_partitionTD(T value) {
// mask the active threads (we'll pop off unique values from
// this mask one at a time, synced)
uint unvisited = uint(simd_active_threads_mask().operator unsigned long());
uint result = value;
// Basic idea: loop through each thread in the group and see which other
// threads have the same value. Whenever we visit a thread, we remove it from
// the unvisited mask along with the other threads with the same value.
//
// We use shuffle to share the active value with the group, and ballot
// to find which threads also have that value.
while( unvisited != 0 ) {
const int activeLane = ctz(unvisited);
const T activeValue = simd_shuffle(value, activeLane);
const auto vote = simd_ballot(activeValue == value);
const uint v = uint(vote.operator unsigned long());
if( activeValue == value ) {
result = v;
}
// remove the vote mask from the mask
unvisited &= ~v;
}
return result;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment