diff --git a/examples/metal/metal.cpp b/examples/metal/metal.cpp index 50a651fa1..c05a4fa93 100644 --- a/examples/metal/metal.cpp +++ b/examples/metal/metal.cpp @@ -52,7 +52,7 @@ int main(int argc, char ** argv) { ggml_metal_set_tensor(ctx_metal, input); // warmup - ggml_metal_graph_compute(ctx_metal, &gf, false); + ggml_metal_graph_compute(ctx_metal, &gf); const int n_iter = 16; @@ -60,7 +60,7 @@ int main(int argc, char ** argv) { // the actual inference happens here for (int i = 0; i < n_iter; ++i) { - ggml_metal_graph_compute(ctx_metal, &gf, false); + ggml_metal_graph_compute(ctx_metal, &gf); } const int64_t t1 = ggml_time_us(); diff --git a/ggml-metal.h b/ggml-metal.h index 4e36cc129..fca28d37e 100644 --- a/ggml-metal.h +++ b/ggml-metal.h @@ -77,7 +77,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx); // same as ggml_graph_compute but uses Metal // creates gf->n_threads command buffers in parallel -void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool concurrent); +void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf); #ifdef __cplusplus } diff --git a/ggml-metal.m b/ggml-metal.m index 7ec31c21b..b438b83f9 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -610,15 +610,14 @@ void ggml_metal_graph_find_concurrency( void ggml_metal_graph_compute( struct ggml_metal_context * ctx, - struct ggml_cgraph * gf, - bool concurrent) { + struct ggml_cgraph * gf) { @autoreleasepool { // if there is ctx->concur_list, dispatch concurrently // else fallback to serial dispatch MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; - const bool has_concur = concurrent && ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR; + const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR; const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes; edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial; diff --git a/llama.cpp b/llama.cpp index 80b9993f4..907d130f9 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3012,7 +3012,7 @@ static bool llama_eval_internal( #ifdef GGML_USE_METAL if (lctx.ctx_metal) { ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); - ggml_metal_graph_compute(lctx.ctx_metal, gf, n_tokens > 1); + ggml_metal_graph_compute(lctx.ctx_metal, gf); } else { ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); }