Created
May 29, 2024 11:03
-
-
Save lw/e0cfe0b5b1d0effd74f0d45ffb8bd6bc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cuda.h> | |
#include <cuda_runtime.h> | |
#include <cudaTypedefs.h> | |
template <typename T> | |
T get_cuda_driver_symbol(const char* name) { | |
void* fn = nullptr; | |
#if CUDA_VERSION >= 12000 | |
enum cudaDriverEntryPointQueryResult query_result; | |
C10_CUDA_CHECK( | |
cudaGetDriverEntryPoint(name, &fn, cudaEnableDefault, &query_result)); | |
TORCH_CHECK( | |
query_result == cudaDriverEntryPointSuccess, | |
"Querying the ", | |
name, | |
" symbol from the CUDA driver failed with error ", | |
query_result); | |
#else // CUDA_VERSION < 12000 | |
C10_CUDA_CHECK(cudaGetDriverEntryPoint(name, &fn, cudaEnableDefault)); | |
#endif // CUDA_VERSION | |
TORCH_CHECK( | |
fn != nullptr, | |
"Querying the ", | |
name, | |
" symbol from the CUDA driver returned a null pointer"); | |
return reinterpret_cast<T>(fn); | |
} | |
void raise_cuda_driver_error(CUresult result, const char* fnName) { | |
static PFN_cuGetErrorName my_cu_get_error_name = | |
get_cuda_driver_symbol<PFN_cuGetErrorName>("cuGetErrorName"); | |
static PFN_cuGetErrorString my_cu_get_error_string = | |
get_cuda_driver_symbol<PFN_cuGetErrorString>("cuGetErrorString"); | |
const char* ptr; | |
CUresult sub_result = my_cu_get_error_name(result, &ptr); | |
std::string error_name = sub_result == CUDA_SUCCESS ? ptr : "UNKNOWN"; | |
sub_result = my_cu_get_error_string(result, &ptr); | |
std::string error_string = sub_result == CUDA_SUCCESS ? ptr : "???"; | |
TORCH_CHECK( | |
result == CUDA_SUCCESS, | |
"Calling ", | |
fnName, | |
" from the CUDA driver failed with error ", | |
error_name, | |
" (code ", | |
result, | |
"): ", | |
error_string); | |
} | |
template<size_t tile_a, size_t tile_b, typename dtype> | |
CUtensorMap make_tensor_map(void* ptr, int prob_a, int prob_b) { | |
static PFN_cuTensorMapEncodeTiled my_cu_tensor_map_encode_tiled = | |
get_cuda_driver_symbol<PFN_cuTensorMapEncodeTiled>("cuTensorMapEncodeTiled"); | |
CUtensorMap tensor_map = {0}; | |
CUresult res = my_cu_tensor_map_encode_tiled( | |
&tensor_map, | |
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, | |
2, | |
ptr, | |
dims, | |
strides, | |
box, | |
elem_strides, | |
CU_TENSOR_MAP_INTERLEAVE_NONE, | |
swizzle_mode, | |
CU_TENSOR_MAP_L2_PROMOTION_L2_128B, | |
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); | |
if (res != CUDA_SUCCESS) { | |
raise_cuda_driver_error(res, "cuTensorMapEncodeTiles"); | |
} | |
return tensor_map; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment