Skip to content

Instantly share code, notes, and snippets.

@dmikushin
Created December 18, 2023 09:43
Show Gist options
  • Save dmikushin/2b4faba0d160ce40074bc53f6f063a0a to your computer and use it in GitHub Desktop.
Save dmikushin/2b4faba0d160ce40074bc53f6f063a0a to your computer and use it in GitHub Desktop.
Example of using .branchtargets in PTX

Example of using .branchtargets in PTX

This code snippet demonstrates how to jump to a label by its index with brx.idx instruction.

Unlike the regular if (a == 2) { ... } else { ... } code (as well as switch (...) { ... }), the brx.idx solution avoids multiple setp.ne.s32 comparisons.

Building

> make
nvcc -arch=sm_50 branchtargets.cu -o branchtargets

Usage

> ./branchtargets 0
Testing tgt = 0
tgt = 0
BLK0
> ./branchtargets 1
Testing tgt = 1
tgt = 1
BLK1
> ./branchtargets 2
Testing tgt = 2
tgt = 2
BEXIT
#include <cstdio>
#include <cstdint>
#if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)
#define __LDG_PTR "l"
#else
#define __LDG_PTR "r"
#endif /*(defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)*/
__global__ void kernel(uint32_t* tgt)
{
printf("tgt = %d\n", *tgt);
asm volatile(".reg .u32 r_tgt;");
asm volatile("ld.u32 r_tgt, [%0];" :: __LDG_PTR(tgt));
asm volatile("ts: .branchtargets BLK0, BLK1, BEXIT;");
asm volatile("brx.idx r_tgt, ts;");
asm volatile("BLK0:");
printf("BLK0\n");
asm volatile("ret;\n");
asm volatile("BLK1:");
printf("BLK1\n");
asm volatile("ret;\n");
asm volatile("BEXIT:");
printf("BEXIT\n");
asm volatile("ret;\n");
}
__global__ void kernel2(uint32_t* tgt)
{
if (*tgt == 0)
printf("BLK0\n");
else if (*tgt == 1)
printf("BLK1\n");
else if (*tgt == 2)
printf("BLK2\n");
}
__global__ void kernel3(uint32_t* tgt)
{
switch (*tgt)
{
case 0:
printf("BLK0\n");
case 1:
printf("BLK1\n");
case 2:
printf("BEXIT\n");
}
}
int main(int argc, char* argv[])
{
uint32_t* tgt = nullptr;
cudaMalloc(&tgt, sizeof(uint32_t));
uint32_t val = atoi(argv[1]);
printf("Testing tgt = %d\n", val);
cudaMemcpy(tgt, &val, sizeof(int), cudaMemcpyHostToDevice);
kernel<<<1, 1>>>(tgt);
auto err = cudaDeviceSynchronize();
if (err != cudaSuccess)
{
fprintf(stderr, "CUDA error: code = %d\n", err);
exit(-1);
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment