-
-
Save ardfork/a223a10d20961707e7b5f3ee0b76c7d5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/ggml-cuda.cu b/ggml-cuda.cu | |
index 7da06124..778ae0be 100644 | |
--- a/ggml-cuda.cu | |
+++ b/ggml-cuda.cu | |
@@ -2405,9 +2405,7 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { | |
GGML_UNUSED(backend); | |
} | |
-#if (CUDART_VERSION >= 12000) | |
#define USE_CUDA_GRAPH | |
-#endif | |
#ifdef USE_CUDA_GRAPH | |
#define MAX_NODES_IN_CUDA_GRAPH 10000 | |
diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh | |
index 481065b2..2142679d 100644 | |
--- a/ggml-cuda/common.cuh | |
+++ b/ggml-cuda/common.cuh | |
@@ -117,6 +117,29 @@ | |
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED | |
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR | |
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED | |
+#define CUDA_KERNEL_NODE_PARAMS_v2 hipKernelNodeParams | |
+#define CUresult hipError_t | |
+#define cuGetErrorString hipDrvGetErrorString | |
+#define cuGraphKernelNodeGetParams hipGraphKernelNodeGetParams | |
+#define cuGraphKernelNodeSetParams hipGraphKernelNodeSetParams | |
+#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure | |
+#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction | |
+#define cudaGraphExecUpdateResult hipGraphExecUpdateResult | |
+#define cudaGraphExec_t hipGraphExec_t | |
+#define cudaGraphGetNodes hipGraphGetNodes | |
+#define cudaGraphInstantiate hipGraphInstantiate | |
+#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams | |
+#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams | |
+#define cudaGraphLaunch hipGraphLaunch | |
+#define cudaGraphNodeGetType hipGraphNodeGetType | |
+#define cudaGraphNodeType hipGraphNodeType | |
+#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel | |
+#define cudaGraphNode_t hipGraphNode_t | |
+#define cudaGraph_t hipGraph_t | |
+#define cudaKernelNodeParams hipKernelNodeParams | |
+#define cudaStreamBeginCapture hipStreamBeginCapture | |
+#define cudaStreamCaptureModeGlobal hipStreamCaptureModeGlobal | |
+#define cudaStreamEndCapture hipStreamEndCapture | |
#else | |
#include <cuda_runtime.h> | |
#include <cuda.h> | |
@@ -208,14 +231,12 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in | |
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str) | |
-#if !defined(GGML_USE_HIPBLAS) | |
static const char * cu_get_error_str(CUresult err) { | |
const char * err_str; | |
cuGetErrorString(err, &err_str); | |
return err_str; | |
} | |
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str) | |
-#endif | |
#if CUDART_VERSION >= 11100 | |
#define GGML_CUDA_ASSUME(x) __builtin_assume(x) | |
@@ -389,6 +410,16 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { | |
#endif | |
return c; | |
} | |
+ | |
+struct cudaGraphExecUpdateResultInfo { | |
+ cudaGraphNode_t errorFromNode; | |
+ cudaGraphNode_t errorNode; | |
+ cudaGraphExecUpdateResult result; | |
+}; | |
+ | |
+static __host__ __forceinline__ cudaError_t cudaGraphExecUpdate(cudaGraphExec_t hGraphExec, cudaGraph_t hGraph, cudaGraphExecUpdateResultInfo* resultInfo ) { | |
+ return hipGraphExecUpdate(hGraphExec, hGraph, &resultInfo->errorNode, &resultInfo->result); | |
+} | |
#endif // defined(GGML_USE_HIPBLAS) | |
// TODO: move to ggml-common.h |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment