diff --git a/ggml-cuda.cu b/ggml-cuda.cu index a63b9b554..2977902bd 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2409,6 +2409,14 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { #define USE_CUDA_GRAPH #endif +struct ggml_graph_node_properties { + void * node_address; + int node_op; + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; + void * src_address[GGML_MAX_SRC]; +}; + #ifdef USE_CUDA_GRAPH #define MAX_NODES_IN_CUDA_GRAPH 10000 struct ggml_cuda_graph { @@ -2416,15 +2424,42 @@ struct ggml_cuda_graph { cudaGraph_t graph = nullptr; cudaGraphExec_t instance = nullptr; size_t num_nodes = 0; - int softmax_ne0 = 0; cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH]; cudaKernelNodeParams params[MAX_NODES_IN_CUDA_GRAPH]; bool disable_due_to_gpu_arch = false; + bool disable_due_to_too_many_updates = false; + int number_consecutive_updates = 0; + ggml_graph_node_properties ggml_graph_properties[MAX_NODES_IN_CUDA_GRAPH]; }; #endif const bool disable_cuda_graphs = (getenv("LLAMACPP_DISABLE_CUDA_GRAPHS") != nullptr); +GGML_CALL static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { + graph_node_properties->node_address = node; + graph_node_properties->node_op = node->op; + for(int i=0; ine[i] = node->ne[i]; + graph_node_properties->nb[i] = node->nb[i]; + } + for(int i=0; isrc_address[i] = node->src[i]; + } +} + +GGML_CALL static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { + if(node != graph_node_properties->node_address) return false; + if(node->op != graph_node_properties->node_op) return false; + for(int i=0; ine[i] != graph_node_properties->ne[i]) return false; + if(node->nb[i] != graph_node_properties->nb[i]) return false; + } + for(int i=0; isrc[i] != graph_node_properties->src_address[i]) return false; + } + return true; +} + GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; @@ -2446,9 +2481,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t } } - // Disable CUDA graphs in presence of env var or old GPU. + // Disable CUDA graphs in presence of env var, old GPU or use-case which is changing too rapidly. // Also disable for multi-gpu for now. TO DO investigate - if(disable_cuda_graphs || cuda_graph.disable_due_to_gpu_arch || ggml_backend_cuda_get_device_count() > 1){ + if(disable_cuda_graphs || cuda_graph.disable_due_to_gpu_arch || cuda_graph.disable_due_to_too_many_updates || + ggml_backend_cuda_get_device_count() > 1){ use_cuda_graph = false; } @@ -2456,20 +2492,25 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t if(cuda_graph.instance == nullptr) cuda_graph_update_required=true; + // Loop over nodes in GGML graph to determine if CUDA graph update is required + // and store properties to allow this comparison for the next token + for (int i = 0; i < cgraph->n_nodes; i++) { + bool has_matching_properties = true; + if(!cuda_graph_update_required) { + has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_graph.ggml_graph_properties[i]); + } + if(!has_matching_properties) cuda_graph_update_required = true; + set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_graph.ggml_graph_properties[i]); + } + // Loop over nodes in GGML graph to obtain info needed for CUDA graph int k=0; for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; - // Identify if the graph needs to be updated for this token due to the number of elements changing - // (identified by inspecting soft max op parameters) if(node->op == GGML_OP_SOFT_MAX) { if(node->src[1]->ne[1] > 1){ use_cuda_graph = false; // disable CUDA graphs for batch size > 1 for now. TO DO investigate } - if(node->src[0]->ne[0] != cuda_graph.softmax_ne0) { - cuda_graph_update_required = true; - cuda_graph.softmax_ne0 = node->src[0]->ne[0]; - } } if(node->op == GGML_OP_CPY) { // store the copy op parameter which changes with each token. @@ -2480,6 +2521,15 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t } } } + + // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. + if(cuda_graph_update_required) { + cuda_graph.number_consecutive_updates++; + } + else { + cuda_graph.number_consecutive_updates = 0; + } + if (cuda_graph.number_consecutive_updates >= 4) cuda_graph.disable_due_to_too_many_updates = true; } if(use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture