Skip to content

Instantly share code, notes, and snippets.

@kririae
Created July 23, 2024 06:47
Show Gist options
  • Save kririae/0f554bca51d8d685178e7eb278faeeea to your computer and use it in GitHub Desktop.
Save kririae/0f554bca51d8d685178e7eb278faeeea to your computer and use it in GitHub Desktop.
// clang++ ./cujit.cpp -I/opt/cuda/include/ -std=c++20 -lcuda
#include <cuda.h>
#include <format>
#include <iostream>
#include <source_location>
#include <stdexcept>
#include <string>
#include <string_view>
// CUDA error checking function
template <bool bShouldThrow = true>
inline void
cudaCheck(CUresult result,
std::source_location loc =
std::source_location::current()) noexcept(not bShouldThrow) {
if (result != CUDA_SUCCESS) [[unlikely]] {
const char *err_name;
const char *err_str;
cuGetErrorName(result, &err_name);
cuGetErrorString(result, &err_str);
auto const str = std::format(
"cudaCheck(): CUDA API call error {:d} ({:s}): \"{:s}\" at {:s}:{:d}",
static_cast<int>(result), err_name, err_str, loc.file_name(),
loc.line());
if constexpr (bShouldThrow) {
throw std::runtime_error(str);
} else {
std::cerr << str << std::endl;
}
}
}
constexpr std::string_view func1_ptx = R"(
.version 8.5
.target sm_86
.address_size 64
// .globl func
.extern .func (.param .b32 func_retval0) vprintf
(
.param .b64 vprintf_param_0,
.param .b64 vprintf_param_1
)
;
.global .align 1 .b8 $str[18] = {72, 101, 108, 108, 111, 32, 102, 114, 111, 109, 32, 102, 117, 110, 99, 49, 10};
.visible .func func()
{
.reg .b32 %r<2>;
.reg .b64 %rd<3>;
mov.u64 %rd1, $str;
cvta.global.u64 %rd2, %rd1;
{ // callseq 0, 0
.reg .b32 temp_param_reg;
.param .b64 param0;
st.param.b64 [param0+0], %rd2;
.param .b64 param1;
st.param.b64 [param1+0], 0;
.param .b32 retval0;
call.uni (retval0),
vprintf,
(
param0,
param1
);
ld.param.b32 %r1, [retval0+0];
} // callseq 0
ret;
}
)";
constexpr std::string_view func2_ptx = R"(
.version 8.5
.target sm_86
.address_size 64
// .globl func
.extern .func (.param .b32 func_retval0) vprintf
(
.param .b64 vprintf_param_0,
.param .b64 vprintf_param_1
)
;
.global .align 1 .b8 $str[18] = {72, 101, 108, 108, 111, 32, 102, 114, 111, 109, 32, 102, 117, 110, 99, 50, 10};
.visible .func func()
{
.reg .b32 %r<2>;
.reg .b64 %rd<3>;
mov.u64 %rd1, $str;
cvta.global.u64 %rd2, %rd1;
{ // callseq 0, 0
.reg .b32 temp_param_reg;
.param .b64 param0;
st.param.b64 [param0+0], %rd2;
.param .b64 param1;
st.param.b64 [param1+0], 0;
.param .b32 retval0;
call.uni (retval0),
vprintf,
(
param0,
param1
);
ld.param.b32 %r1, [retval0+0];
} // callseq 0
ret;
}
)";
constexpr std::string_view kernel_ptx = R"(
.version 8.5
.target sm_86
.address_size 64
// .globl kernel
.extern .func func
()
;
.visible .entry kernel()
{
.reg .pred %p<2>;
.reg .b32 %r<6>;
mov.u32 %r1, %ctaid.x;
mov.u32 %r2, %ntid.x;
mul.lo.s32 %r3, %r1, %r2;
mov.u32 %r4, %tid.x;
neg.s32 %r5, %r4;
setp.ne.s32 %p1, %r3, %r5;
@%p1 bra $L__BB0_2;
{ // callseq 0, 0
.reg .b32 temp_param_reg;
call.uni
func,
(
);
} // callseq 0
$L__BB0_2:
ret;
}
)";
bool parseCommandLine(int argc, char *argv[]) {
for (int i = 1; i < argc; ++i) {
if (std::string_view(argv[i]) == "--use-func1") {
return true;
}
}
return false;
}
int main(int argc, char *argv[]) {
bool useFunction1 = parseCommandLine(argc, argv);
CUdevice device;
CUcontext context;
CUmodule module;
CUfunction kernel;
CUlinkState linkState;
try {
cudaCheck(cuInit(0));
cudaCheck(cuDeviceGet(&device, 0));
cudaCheck(cuCtxCreate(&context, 0, device));
// Create link state
cudaCheck(cuLinkCreate(0, nullptr, nullptr, &linkState));
// Add kernel PTX to link state
cudaCheck(cuLinkAddData(linkState, CU_JIT_INPUT_PTX,
const_cast<char *>(kernel_ptx.data()),
kernel_ptx.size(), "kernel", 0, nullptr, nullptr));
// Add function PTX to link state based on command line argument
if (useFunction1) {
cudaCheck(cuLinkAddData(linkState, CU_JIT_INPUT_PTX,
const_cast<char *>(func1_ptx.data()),
func1_ptx.size(), "func1", 0, nullptr, nullptr));
} else {
cudaCheck(cuLinkAddData(linkState, CU_JIT_INPUT_PTX,
const_cast<char *>(func2_ptx.data()),
func2_ptx.size(), "func2", 0, nullptr, nullptr));
}
void *cubin;
size_t cubinSize;
cudaCheck(cuLinkComplete(linkState, &cubin, &cubinSize));
cudaCheck(cuModuleLoadData(&module, cubin));
cudaCheck(cuModuleGetFunction(&kernel, module, "kernel"));
void *args[] = {};
cudaCheck(
cuLaunchKernel(kernel, 1, 1, 1, 1, 1, 1, 0, nullptr, args, nullptr));
cudaCheck(cuModuleUnload(module));
cudaCheck(cuLinkDestroy(linkState));
cudaCheck(cuCtxDestroy(context));
} catch (const std::exception &e) {
std::cerr << "Error: " << e.what() << std::endl;
return 1;
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment