diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 07a3bc32e..670ba78a0 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2405,23 +2405,33 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { GGML_UNUSED(backend); } +#if (CUDART_VERSION >= 12000) +#define USE_CUDA_GRAPH +#endif + +#ifdef USE_CUDA_GRAPH +#define MAX_NODES_IN_CUDA_GRAPH 10000 struct ggml_cudaGraph { int count=0; cudaGraph_t graph = nullptr; cudaGraphExec_t instance = nullptr; size_t numNodes = 0; int softmax_ne0 = 0; + cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH]; + CUDA_KERNEL_NODE_PARAMS_v2 paramsDriver[MAX_NODES_IN_CUDA_GRAPH]; + cudaKernelNodeParams paramsRuntime[MAX_NODES_IN_CUDA_GRAPH]; }; +#endif GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ggml_cuda_set_device(cuda_ctx->device); +#ifdef USE_CUDA_GRAPH // Objects required for CUDA Graph -#define MAX_NODES_IN_CUDA_GRAPH 10000 - static ggml_cudaGraph cudaGraph; //TO DO move this to a suitable persistant location (and avoid use of static memory) - bool useCudaGraph = (cudaGraph.count>=2); //avoid CUDA graphs on first 2 steps due to incompatible initialisations. + static ggml_cudaGraph cudaGraph; + bool useCudaGraph = (cudaGraph.count>=7); //avoid CUDA graphs on first few steps due to incompatible initialisations. char** updatedKernelArg[MAX_NODES_IN_CUDA_GRAPH]; bool cudaGraphUpdateRequired = false; // pointer to CUDA cpy kernel, which is required to identify @@ -2458,6 +2468,11 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeGlobal)); } +#else + bool useCudaGraph = false; + bool cudaGraphUpdateRequired = false; +#endif + // Only perfom the graph exection if CUDA graphs are not enebled, or we are capturing the graph. // With use of CUDA graphs, the execution will be performed by the graph launch. if(!useCudaGraph || cudaGraphUpdateRequired) { @@ -2486,6 +2501,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t } } + #ifdef USE_CUDA_GRAPH if(useCudaGraph && (cudaGraphUpdateRequired)) { // End CUDA graph capture CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cudaGraph.graph)); } @@ -2498,26 +2514,26 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t // Perform update to graph (if required for this token), and change copy parameter (required for every token) - cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH]; - CUDA_KERNEL_NODE_PARAMS_v2 paramsDriver[MAX_NODES_IN_CUDA_GRAPH]; - cudaKernelNodeParams paramsRuntime[MAX_NODES_IN_CUDA_GRAPH]; - if(cudaGraphUpdateRequired) { // Extract nodes from graph if(cudaGraph.numNodes == 0) { CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, nullptr, &cudaGraph.numNodes)); } - CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, nodes, &cudaGraph.numNodes)); + CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, cudaGraph.nodes, &cudaGraph.numNodes)); // Loop over nodes, and extract kernel parameters fro each node for(size_t i=0; istream())); } cudaGraph.count++; +#endif return GGML_STATUS_SUCCESS; }