Skip to content

Instantly share code, notes, and snippets.

@dmikushin
Created March 23, 2021 12:19
Show Gist options
  • Save dmikushin/7871337859128b111333a2025a7df5a5 to your computer and use it in GitHub Desktop.
Save dmikushin/7871337859128b111333a2025a7df5a5 to your computer and use it in GitHub Desktop.
CUDA device function address retrieval from a PTX jump table, which shows that the address is just a relative offset
#include <cstdio>
#include <cstdlib>
#include <cuda.h>
#ifndef __USE_POSIX
#define __USE_POSIX // HOST_NAME_MAX
#endif
#include <limits.h>
#include <pthread.h>
#include <sys/types.h>
#include <unistd.h>
#define CU_ERR_CHECK(x) \
do { CUresult err = x; if (err != CUDA_SUCCESS) { \
char hostname[HOST_NAME_MAX] = ""; \
gethostname(hostname, HOST_NAME_MAX); \
fprintf(stderr, "CUDA driver error %d on %s at " \
"%s:%d\n", (int)err, hostname, __FILE__, __LINE__); \
exit(-1); \
}} while (0)
int main(int argc, char* argv[])
{
CU_ERR_CHECK(cuInit(0));
CUcontext ctx;
CU_ERR_CHECK(cuCtxCreate(&ctx, 0, 0));
CUmodule module;
CU_ERR_CHECK(cuModuleLoad(&module, "jumptable.ptx"));
// Get the function pointer retrieval kernel handle.
CUfunction kernel;
CU_ERR_CHECK(cuModuleGetFunction(&kernel, module, "kernel"));
// Create device array.
CUdeviceptr dfptrs;
CU_ERR_CHECK(cuMemAlloc(&dfptrs, sizeof(void*)));
// Launch kernel.
int zero = 0;
void* params[] = { (void*)&dfptrs, &zero };
CU_ERR_CHECK(cuLaunchKernel(kernel, 1, 1, 1, 1, 1, 1, 0, 0, params, NULL));
CU_ERR_CHECK(cuCtxSynchronize());
// Copy device functions addresses on host and free device array.
void* fptrs[1];
CU_ERR_CHECK(cuMemcpyDtoH(&fptrs[0], dfptrs, sizeof(void*)));
CU_ERR_CHECK(cuMemFree(dfptrs));
printf("fptr = %p\n", fptrs[0]); // prints 0x90
return 0;
}
.version 7.1
.target sm_61
.address_size 64
.visible .func (.reg .b32 rv) foo (.reg .b32 r) { mov.b32 rv, 1; ret; }
.global .u64 jmptbl[1] = { foo };
.visible .entry kernel(.param .u64 ptr, .param .b32 r)
{
.reg .b32 %b32r<3>;
.reg .u64 %u64r<2>;
ld.param.b32 %b32r0, [r];
ld.global.u64 %u64r0, [jmptbl];
call (%b32r2), %u64r0, (%b32r0), jmptbl;
ld.param.u64 %u64r1, [ptr];
st.global.u64 [%u64r1], %u64r0;
}
all: jumptable
jumptable.cubin: jumptable.ptx
nvcc -arch=sm_61 $< -cubin
jumptable: jumptable.cu jumptable.cubin
nvcc -arch=sm_61 $< -o $@ -lcuda
clean:
rm -rf jumptable jumptable.cubin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment