metal : revert the concurrnecy change because it was wrong

This commit is contained in:
Georgi Gerganov 2023-09-14 18:00:03 +03:00
parent 336afbcb76
commit e343b8b4d8
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 6 additions and 7 deletions

View File

@ -52,7 +52,7 @@ int main(int argc, char ** argv) {
ggml_metal_set_tensor(ctx_metal, input); ggml_metal_set_tensor(ctx_metal, input);
// warmup // warmup
ggml_metal_graph_compute(ctx_metal, &gf, false); ggml_metal_graph_compute(ctx_metal, &gf);
const int n_iter = 16; const int n_iter = 16;
@ -60,7 +60,7 @@ int main(int argc, char ** argv) {
// the actual inference happens here // the actual inference happens here
for (int i = 0; i < n_iter; ++i) { for (int i = 0; i < n_iter; ++i) {
ggml_metal_graph_compute(ctx_metal, &gf, false); ggml_metal_graph_compute(ctx_metal, &gf);
} }
const int64_t t1 = ggml_time_us(); const int64_t t1 = ggml_time_us();

View File

@ -77,7 +77,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);
// same as ggml_graph_compute but uses Metal // same as ggml_graph_compute but uses Metal
// creates gf->n_threads command buffers in parallel // creates gf->n_threads command buffers in parallel
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool concurrent); void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -610,15 +610,14 @@ void ggml_metal_graph_find_concurrency(
void ggml_metal_graph_compute( void ggml_metal_graph_compute(
struct ggml_metal_context * ctx, struct ggml_metal_context * ctx,
struct ggml_cgraph * gf, struct ggml_cgraph * gf) {
bool concurrent) {
@autoreleasepool { @autoreleasepool {
// if there is ctx->concur_list, dispatch concurrently // if there is ctx->concur_list, dispatch concurrently
// else fallback to serial dispatch // else fallback to serial dispatch
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
const bool has_concur = concurrent && ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR; const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes; const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial; edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;

View File

@ -3012,7 +3012,7 @@ static bool llama_eval_internal(
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
if (lctx.ctx_metal) { if (lctx.ctx_metal) {
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
ggml_metal_graph_compute(lctx.ctx_metal, gf, n_tokens > 1); ggml_metal_graph_compute(lctx.ctx_metal, gf);
} else { } else {
ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads);
} }