Skip to content

Instantly share code, notes, and snippets.

@maxwellpirtle
Last active August 8, 2022 02:13
Show Gist options
  • Save maxwellpirtle/b46f567559bd9c2314ea53e73daed568 to your computer and use it in GitHub Desktop.
Save maxwellpirtle/b46f567559bd9c2314ea53e73daed568 to your computer and use it in GitHub Desktop.
ushort
divceil(ushort a, ushort b) {
return a % b == 0 ? a / b : a / b + 1;
}
/// Quickly samples a (2N + 1) x (2N + 1) region for each
/// thread
///
/// This is a generalization of the sampling
/// described in "Discover Advances in Metal for A15 Bionic"
/// to reduce the number of texture samples required to
/// apply a 5x5 compute kernel. The sampling uses the
/// simdgroup shuffle instructions introduced in Metal 2.4
/// to share sample data among threads. Note that with a chunk
/// width of four the corresponding quadgroup instructions could
/// be used, but for simplicity we always use the simdgroup
/// variant instead
///
/// @param threads_per_threadgroup: The threadgroup execution width
/// @param thread_position_in_grid: The position of the thread in
/// the dispatch
/// @param: threads_per_simdgroup: The number of threads
/// in each simdgroup
/// @param texture: The texture to sample an (2N + 1) x (2N + 1)
/// region from
/// @param sampler: The sampler used to read values from the texture
///
/// @returns: A (2N + 1) x (2N + 1) matrix of `float` or `half` containing
/// the sampled texture value relative to the pixel at grid position
/// `thread_position_in_grid`, where the "center" (i.e. element at position
/// (N, N)) corresponds to the pixel at `thread_position_in_grid`, the value
/// at (N + 1, N) corresponds to the pixel to the right of
/// `thread_position_in_grid`, etc.
///
/// Invariant: The width of the threadgroup divides the number of threads in each
/// simdgroup and is at most one simdgroup wide; i.e.,
///
/// threads_per_simdgroup >= threads_per_threadgroup &&
/// threads_per_simdgroup % threads_per_threadgroup.x == 0
///
template<texture_channel sample_channel, typename xhalf, int N = 3>
matrix<xhalf, 2 * N + 1>
simd_fast_sample(ushort2 threads_per_threadgroup,
ushort2 thread_position_in_grid,
ushort threads_per_simdgroup,
texture2d<xhalf, access::sample> texture,
sampler sampler)
{
constexpr int K = 2 * N + 1;
matrix<xhalf, K> result;
// Note that if 2N / chunksize.width > 5,
// we'll have an index OOB. Ideally we'd
// want to prevent this with a compilation error,
// but a size of 5 suffices for most cases; e.g.,
// with a simdgroup chunk size of 4, a 5x5
// matrix would suffice up to N = 10, or a 21 x 21 read!
matrix<xhalf, 5> sampleMap;
constexpr int2 globalOffset = int2(N, N);
const ushort threads_per_simd_chunk = threads_per_threadgroup.x;
const ushort simd_chunks_per_simdgroup = threads_per_simdgroup / threads_per_simd_chunk;
const ushort simd_chunk_samplesX = 1 + divceil((ushort)2 * N, threads_per_simd_chunk);
const ushort simd_chunk_samplesY = 1 + divceil((ushort)2 * N, simd_chunks_per_simdgroup);
// Step 1: Fill the sample map based on the simd_chunk_size
for (int i = 0; i < simd_chunk_samplesX; i++) {
for (int j = 0; i < simd_chunk_samplesY; j++) {
const int2 chunkSampleOffset = int2(j * threads_per_simd_chunk, i * simd_chunks_per_simdgroup);
// Metal matrices are specified as column-major
sampleMap[j][i] = texture.sample(sampler, float2(thread_position_in_grid), chunkSampleOffset - globalOffset).r;
}
}
// Step 2. Fill the first row of the matrix
// for subsequent processing by the remainder
// of the algorithm
for(int j = 0; j < K; j++) {
const int sampleMapLocX = 2 * j / threads_per_simd_chunk;
result[j][0] = simd_shuffle_and_fill_down(sampleMap[sampleMapLocX][0],
sampleMap[sampleMapLocX + 1][0],
j, threads_per_simd_chunk);
}
for(int i = 1; i < K; i++) {
for(int j = 0; j < K; j++) {
const int sampleMapLocX = 2 * j / threads_per_simd_chunk;
const int sampleMapLocY = 2 * i / simd_chunks_per_simdgroup;
const xhalf fill = simd_shuffle_and_fill_down(sampleMap[sampleMapLocX][sampleMapLocY + 1],
sampleMap[sampleMapLocX + 1][sampleMapLocY + 1],
j, threads_per_simd_chunk);
result[j][i] = simd_shuffle_and_fill_down(result[j][i - 1], fill, threads_per_simd_chunk);
}
}
return result;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment