mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-03 17:51:09 +01:00
llama : minimize swaps when reordering logits
This reduces overhead when running hellaswag on thousands of sequences with very small 100k params Mamba models.
This commit is contained in:
parent
72eea49224
commit
18d1c14047
50
llama.cpp
50
llama.cpp
@ -19828,33 +19828,43 @@ void llama_synchronize(struct llama_context * ctx) {
|
|||||||
static void llama_reorder_outputs(struct llama_context * ctx) {
|
static void llama_reorder_outputs(struct llama_context * ctx) {
|
||||||
std::vector<size_t> & out_ids = ctx->sbatch.out_ids;
|
std::vector<size_t> & out_ids = ctx->sbatch.out_ids;
|
||||||
if (!out_ids.empty()) {
|
if (!out_ids.empty()) {
|
||||||
std::vector<float> logits_tmp;
|
|
||||||
std::vector<float> embd_tmp;
|
|
||||||
uint32_t n_vocab = ctx->model.hparams.n_vocab;
|
uint32_t n_vocab = ctx->model.hparams.n_vocab;
|
||||||
uint32_t n_embd = ctx->model.hparams.n_embd;
|
uint32_t n_embd = ctx->model.hparams.n_embd;
|
||||||
int32_t n_outputs = ctx->n_outputs;
|
int32_t n_outputs = ctx->n_outputs;
|
||||||
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
||||||
// insertion sort (from https://en.wikipedia.org/wiki/Insertion_sort, but using memmove)
|
{
|
||||||
for (int32_t i = 1; i < n_outputs; ++i) {
|
bool is_already_sorted = true;
|
||||||
int32_t j = i;
|
for (int32_t i = 0; i < n_outputs - 1; ++i) {
|
||||||
size_t out_id_tmp = out_ids[i];
|
if (out_ids[i] > out_ids[i + 1]) {
|
||||||
while (j > 0 && out_ids[j - 1] > out_id_tmp) { j -= 1; }
|
is_already_sorted = false;
|
||||||
if (i - j == 0) { continue; }
|
break;
|
||||||
memmove(out_ids.data() + j + 1, out_ids.data() + j, (i - j)*sizeof(out_ids[0]));
|
}
|
||||||
out_ids[j] = out_id_tmp;
|
}
|
||||||
|
if (is_already_sorted) {
|
||||||
|
out_ids.clear();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO: is there something more efficient which also minimizes swaps?
|
||||||
|
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
||||||
|
for (int32_t i = 0; i < n_outputs - 1; ++i) {
|
||||||
|
int32_t j_min = i;
|
||||||
|
for (int32_t j = i + 1; j < n_outputs; ++j) {
|
||||||
|
if (out_ids[j] < out_ids[j_min]) {
|
||||||
|
j_min = j;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (j_min == i) { continue; }
|
||||||
|
std::swap(out_ids[i], out_ids[j_min]);
|
||||||
if (ctx->logits_size > 0) {
|
if (ctx->logits_size > 0) {
|
||||||
// only allocate once something needs to be moved
|
for (uint32_t k = 0; k < n_vocab; k++) {
|
||||||
if (logits_tmp.empty()) { logits_tmp.resize(n_vocab); }
|
std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]);
|
||||||
memcpy(logits_tmp.data(), ctx->logits + i*n_vocab, n_vocab*sizeof(float));
|
}
|
||||||
memmove(ctx->logits + (j + 1)*n_vocab, ctx->logits + j*n_vocab, (i - j)*n_vocab*sizeof(float));
|
|
||||||
memcpy(ctx->logits + j*n_vocab, logits_tmp.data(), n_vocab*sizeof(float));
|
|
||||||
}
|
}
|
||||||
if (ctx->embd_size > 0) {
|
if (ctx->embd_size > 0) {
|
||||||
// only allocate once something needs to be moved
|
for (uint32_t k = 0; k < n_embd; k++) {
|
||||||
if (embd_tmp.empty()) { embd_tmp.resize(n_embd); }
|
std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]);
|
||||||
memcpy(embd_tmp.data(), ctx->embd + i*n_embd, n_embd*sizeof(float));
|
}
|
||||||
memmove(ctx->embd + (j + 1)*n_embd, ctx->embd + j*n_embd, (i - j)*n_embd*sizeof(float));
|
|
||||||
memcpy(ctx->embd + j*n_embd, embd_tmp.data(), n_embd*sizeof(float));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1);
|
std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1);
|
||||||
|
Loading…
Reference in New Issue
Block a user