As I was trying to understand compute shaders, I had trouble grasping how the control flow was handled but I think I finally got it, so in case my understanding might help someone here it is. That should be mostly API agnostic.
For me it clicked when I tried to think about how I would write what a single wave/warp does with pseudo SIMD instructions and execution masks.
If you need some refresher on the confusing vocabulary, I'd recommend the Compute Shader 101 Glossary.
Let's start with a simple if
/else
.
if x == y {
z *= 2;
} else {
z *= 3;
}
You need of course a mask that corresponds to the result of the if's condition.
And if there is an else branch, that branch's mask is the reverse of the if's condition masked with the mask active when entering the if.
// `xs`, `ys`, and `zs` contain the values for `x`, `y`, and `z` for the lanes active here.
// `start_mask` is the execution mask at the start.
// First we do the comparison.
// This `simd_equal` will return a mask that is false on inactive lanes the resulting mask is false.
// (if it was not, you would need an additional `and_mask`)
if_mask = simd_equal(xs, ys, mask: start_mask)
else_mask = and_mask(invert_mask(if_mask), start_mask)
// If the comparison was false for all active lanes, we can skip the if branch.
// (for a very simple `if` branch, not branching might be best though)
if mask_all_false(if_mask) {
goto else_start
}
// If branch
// Here everything is executed with a `if_mask` execution mask.
// Do the multiplication on the lanes for which the comparison was true.
simd_set(Cs, simd_mul(Cs, 2, mask: if_mask), mask: if_mask)
// If the comparison was true for all active lanes, we can skip the else branch.
if mask_all_false(else_mask) {
goto else_end
}
else_start:
// Else branch
// Here everything is executed with a `else_mask` execution mask.
// Do the multiplication on the lanes for which the comparison was false.
simd_set(Cs, simd_mul(Cs, 2, mask: else_mask), mask: else_mask)
else_end:
// Here everything should once again be executed with a `start_mask` execution mask.
And now loops. Simple loops with a fixed number of iterations should probably be unrolled it, but here were are interested in more complex ones.
loop {
x *= 2;
if x < 42 {
y -= 2;
continue;
}
if y == 99 {
break;
}
y += 1
}
You need to keep two masks around:
- one that expresses which lanes have not broken out of the loop yet.
- one that expresses which lanes are still active in the current iteration of the loop. Lanes that broke out of the loop are not iterating anymore so this is a subset of the other mask.
// `xs` and `ys` contain the values for a and b for all active lanes.
// start_mask is the execution mask at the start.
loop_mask = start_mask
loop_start:
current_iteration_mask = loop_mask
// Starting multiplication.
simd_set(xs, simd_mul(xs, 2, mask: current_iteration_mask), mask: current_iteration_mask)
// First `if` - Comparing each lane's `x` to 42.
if1_mask = simd_less_than(xs, 42, mask: current_iteration_mask)
if mask_all_false(comparison1_mask) {
goto if1_end
}
// Subtraction.
simd_set(ys, simd_sub(ys, 2, mask: if1_mask), mask: if1_mask)
// The first tricky part: handling the `continue`.
// The active lanes should stop taking part in the current iteration.
// So we remove them from `current_iteration_mask`.
current_iteration_mask = mask_and(current_iteration_mask, mask_not(if1_mask))
// If no lane is taking part in the current iteration anymore, we can loop.
if mask_all_false(current_iteration_mask) {
goto loop_start
}
if1_end:
// Second `if` - Comparing each lane's `y` to 99.
if2_mask = simd_eq(ys, 99, mask: current_iteration_mask)
// The second tricky part: how to handle the `break`.
// The active lanes should stop taking part in the whole loop.
// So we remove them from both `loop_mask` and `current_iteration_mask`.
loop_mask = mask_and(loop_mask, mask_not(if2_mask))
// If no lane is taking part in the loop anymore, we can leave the loop.
if mask_all_false(loop_mask) {
goto loop_end
}
current_iteration_mask = mask_and(current_iteration_mask, mask_not(if2_mask))
// If no lane is taking part in the current iteration anymore, we can loop.
if mask_all_false(current_iteration_mask) {
goto loop_start
}
if2_end:
// Addition.
simd_set(ys, simd_add(ys, 1, mask: current_iteration_mask), mask: current_iteration_mask)
goto loop_start // Loop
loop_end:
// Here everything is once again executed with a start_mask execution mask.
Functions are executed more or less as if their code was inline where they were called, with the call site's current active mask.
The most complex part is handling return
s. Especially as they can be in an if
in nested loops. The idea is close to loop's break
handling, but when you hit the return
you have to modify and check the whole stack of active masks inside the function.
In the examples above, the lower level version is only handling one wave/warp at a time. If you were to use SIMD on a CPU, that probably would not be a good use of your cache. The GPU being able to freely switch between wave/warp when waiting for memory or a barrier is what makes SIMT different from SIMD on you CPU.
I said this post was mostly API agnostic, but in fact CUDA on recent Nvidia hardware has a significant difference: independent thread scheduling.
You can see a detailed explanation on Nvidia's blog – search for "Independent thread scheduling" –, but I will still try explaining it with my own words.
Instead of having one program counter per warp, you have one per lane. At each cycle the GPU runs the warp, it chooses one of the program counters of that warp not waiting and executes the instruction for all the lanes that have that specific program counter.
When encountering a condition where not all lanes agree, the lanes' program counter will diverge (though they might not diverge for very simple branches). When lanes of a warp diverge, the only way to make them reconverge for sure is with a specific barrier. Even if the if
and else
branches of an if
seem the same length, that does not mean they will take the same time. In assembly they might look more different, and some memory access might take longer anyway.
If your data processing does not require interaction between lanes (via memory or subgroup operations), having a good understand of the execution model might not look that important, but optimization will require some knowledge. And if you want to write more complex shaders, it is probably going to be invaluable.
I am still a beginner in this domain but I hope this writing might be of some help.