From 18d1c140471da9443db9e0b67f61ccf540e113c0 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 1 Jun 2024 15:01:34 -0400 Subject: [PATCH] llama : minimize swaps when reordering logits This reduces overhead when running hellaswag on thousands of sequences with very small 100k params Mamba models. --- llama.cpp | 50 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/llama.cpp b/llama.cpp index 7c6afa7d1..d44dfe7b2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -19828,33 +19828,43 @@ void llama_synchronize(struct llama_context * ctx) { static void llama_reorder_outputs(struct llama_context * ctx) { std::vector & out_ids = ctx->sbatch.out_ids; if (!out_ids.empty()) { - std::vector logits_tmp; - std::vector embd_tmp; uint32_t n_vocab = ctx->model.hparams.n_vocab; uint32_t n_embd = ctx->model.hparams.n_embd; int32_t n_outputs = ctx->n_outputs; GGML_ASSERT((size_t) n_outputs == out_ids.size()); - // insertion sort (from https://en.wikipedia.org/wiki/Insertion_sort, but using memmove) - for (int32_t i = 1; i < n_outputs; ++i) { - int32_t j = i; - size_t out_id_tmp = out_ids[i]; - while (j > 0 && out_ids[j - 1] > out_id_tmp) { j -= 1; } - if (i - j == 0) { continue; } - memmove(out_ids.data() + j + 1, out_ids.data() + j, (i - j)*sizeof(out_ids[0])); - out_ids[j] = out_id_tmp; + { + bool is_already_sorted = true; + for (int32_t i = 0; i < n_outputs - 1; ++i) { + if (out_ids[i] > out_ids[i + 1]) { + is_already_sorted = false; + break; + } + } + if (is_already_sorted) { + out_ids.clear(); + return; + } + } + // TODO: is there something more efficient which also minimizes swaps? + // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) + for (int32_t i = 0; i < n_outputs - 1; ++i) { + int32_t j_min = i; + for (int32_t j = i + 1; j < n_outputs; ++j) { + if (out_ids[j] < out_ids[j_min]) { + j_min = j; + } + } + if (j_min == i) { continue; } + std::swap(out_ids[i], out_ids[j_min]); if (ctx->logits_size > 0) { - // only allocate once something needs to be moved - if (logits_tmp.empty()) { logits_tmp.resize(n_vocab); } - memcpy(logits_tmp.data(), ctx->logits + i*n_vocab, n_vocab*sizeof(float)); - memmove(ctx->logits + (j + 1)*n_vocab, ctx->logits + j*n_vocab, (i - j)*n_vocab*sizeof(float)); - memcpy(ctx->logits + j*n_vocab, logits_tmp.data(), n_vocab*sizeof(float)); + for (uint32_t k = 0; k < n_vocab; k++) { + std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]); + } } if (ctx->embd_size > 0) { - // only allocate once something needs to be moved - if (embd_tmp.empty()) { embd_tmp.resize(n_embd); } - memcpy(embd_tmp.data(), ctx->embd + i*n_embd, n_embd*sizeof(float)); - memmove(ctx->embd + (j + 1)*n_embd, ctx->embd + j*n_embd, (i - j)*n_embd*sizeof(float)); - memcpy(ctx->embd + j*n_embd, embd_tmp.data(), n_embd*sizeof(float)); + for (uint32_t k = 0; k < n_embd; k++) { + std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]); + } } } std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1);