Allow number of nodes in CUDA graph to change (#7738)

Previously the code would have failed to cope in the case that the
number of nodes changes in an existing CUDA graph. This fixes the
issue by removing an unnecessary conditional.
This commit is contained in:
agray3 2024-06-04 21:06:49 +01:00 committed by GitHub
parent 1442677f92
commit b90dc566c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2702,10 +2702,8 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
if (cuda_graph_update_required) { if (cuda_graph_update_required) {
// Extract nodes from graph // Extract nodes from graph
if (cuda_ctx->cuda_graph->num_nodes == 0) { // First call with null argument gets number of nodes in graph
// First call with null argument gets number of nodes in graph CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
}
// Subsequent call with non-null argument gets nodes // Subsequent call with non-null argument gets nodes
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes); cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes); cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);