Skip to content

Instantly share code, notes, and snippets.

Created June 23, 2023 18:11
Show Gist options
  • Save timdecode/2aa10535b65dab08df78655d560983fb to your computer and use it in GitHub Desktop.
Save timdecode/2aa10535b65dab08df78655d560983fb to your computer and use it in GitHub Desktop.
Metal implementation of subgroupPartitionNV
// Created by Timothy Davison on 2023-06-21.
// This is a Metal implementation of subgroupPartitionNV. You use it to find a mask of
// the other threads in a simd-group with the same value (a partition of the simd-group about
// a set of values).
// Feel free to use this in your code. Please share any fixes or ideas to make it faster.
// Khronos docs on subgroup partitioning:
// -
// Great talk on Vulkan subgroups by Daniel Koch @ NVIDIA
// -
// Vulkan subgroups are simd-group and quad-groups in Metal. The Metal library doesn't have a subgroup partitioning
// algorithm. However, we can build one using simd-shuffle and simd-ballot. Other platforms have hardware instructions,
// however, Apple GPUs don't have those hardware instructions (yet). This code requires n trips through the loop,
// where n is the number of partitions. For some classes of problems (e.g., coalescing device atomic-adds) even
// this looped version of subgroup partitioning can be a big win.
#include <metal_stdlib>
using namespace metal;
// Returns a mask of the other active threads in the simd-group with the same value.
// Requires a Metal GPU that supports simd-permutations. The types supported are limited
// by the types supported by simd_shuffle.
// This is a Metal implementation of subgroupPartitionNV.
// See:
// -
// -
template<typename T>
static inline uint simd_partitionTD(T value) {
// mask the active threads (we'll pop off unique values from
// this mask one at a time, synced)
uint unvisited = uint(simd_active_threads_mask().operator unsigned long());
uint result = value;
// Basic idea: loop through each thread in the group and see which other
// threads have the same value. Whenever we visit a thread, we remove it from
// the unvisited mask along with the other threads with the same value.
// We use shuffle to share the active value with the group, and ballot
// to find which threads also have that value.
while( unvisited != 0 ) {
const int activeLane = ctz(unvisited);
const T activeValue = simd_shuffle(value, activeLane);
const auto vote = simd_ballot(activeValue == value);
const uint v = uint(vote.operator unsigned long());
if( activeValue == value ) {
result = v;
// remove the vote mask from the mask
unvisited &= ~v;
return result;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment