Skip to content

Instantly share code, notes, and snippets.

@ardfork
Created April 22, 2024 15:03
Show Gist options
  • Save ardfork/a223a10d20961707e7b5f3ee0b76c7d5 to your computer and use it in GitHub Desktop.
Save ardfork/a223a10d20961707e7b5f3ee0b76c7d5 to your computer and use it in GitHub Desktop.
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 7da06124..778ae0be 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -2405,9 +2405,7 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
GGML_UNUSED(backend);
}
-#if (CUDART_VERSION >= 12000)
#define USE_CUDA_GRAPH
-#endif
#ifdef USE_CUDA_GRAPH
#define MAX_NODES_IN_CUDA_GRAPH 10000
diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh
index 481065b2..2142679d 100644
--- a/ggml-cuda/common.cuh
+++ b/ggml-cuda/common.cuh
@@ -117,6 +117,29 @@
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
+#define CUDA_KERNEL_NODE_PARAMS_v2 hipKernelNodeParams
+#define CUresult hipError_t
+#define cuGetErrorString hipDrvGetErrorString
+#define cuGraphKernelNodeGetParams hipGraphKernelNodeGetParams
+#define cuGraphKernelNodeSetParams hipGraphKernelNodeSetParams
+#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
+#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
+#define cudaGraphExecUpdateResult hipGraphExecUpdateResult
+#define cudaGraphExec_t hipGraphExec_t
+#define cudaGraphGetNodes hipGraphGetNodes
+#define cudaGraphInstantiate hipGraphInstantiate
+#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
+#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
+#define cudaGraphLaunch hipGraphLaunch
+#define cudaGraphNodeGetType hipGraphNodeGetType
+#define cudaGraphNodeType hipGraphNodeType
+#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
+#define cudaGraphNode_t hipGraphNode_t
+#define cudaGraph_t hipGraph_t
+#define cudaKernelNodeParams hipKernelNodeParams
+#define cudaStreamBeginCapture hipStreamBeginCapture
+#define cudaStreamCaptureModeGlobal hipStreamCaptureModeGlobal
+#define cudaStreamEndCapture hipStreamEndCapture
#else
#include <cuda_runtime.h>
#include <cuda.h>
@@ -208,14 +231,12 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
-#if !defined(GGML_USE_HIPBLAS)
static const char * cu_get_error_str(CUresult err) {
const char * err_str;
cuGetErrorString(err, &err_str);
return err_str;
}
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
-#endif
#if CUDART_VERSION >= 11100
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
@@ -389,6 +410,16 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
#endif
return c;
}
+
+struct cudaGraphExecUpdateResultInfo {
+ cudaGraphNode_t errorFromNode;
+ cudaGraphNode_t errorNode;
+ cudaGraphExecUpdateResult result;
+};
+
+static __host__ __forceinline__ cudaError_t cudaGraphExecUpdate(cudaGraphExec_t hGraphExec, cudaGraph_t hGraph, cudaGraphExecUpdateResultInfo* resultInfo ) {
+ return hipGraphExecUpdate(hGraphExec, hGraph, &resultInfo->errorNode, &resultInfo->result);
+}
#endif // defined(GGML_USE_HIPBLAS)
// TODO: move to ggml-common.h
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment